Skip to content

Commit

Permalink
Add and export from_dlpack
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloferz committed Jan 27, 2024
1 parent f2386a3 commit 29f708c
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 28 deletions.
14 changes: 10 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ np = pyimport("jax.numpy")
dl = pyimport("jax.dlpack")

pyv = np.arange(10)
v = DLPack.wrap(pyv, o -> @pycall dl.to_dlpack(o)::PyObject)
v = from_dlpack(pyv)
# For older jax version use:
# v = DLPack.wrap(pyv, o -> @pycall dl.to_dlpack(o)::PyObject)

(pyv[1] == 1).item() # This is false since the first element is 0

Expand All @@ -44,7 +46,7 @@ v[1] = 1
```

If the python tensor has more than one dimension and the memory layout is
row-major the array returned by `DLPack.wrap` has its dimensions reversed.
row-major the array returned by `DLPack.from_dlpack` has its dimensions reversed.
Let us illustrate this now by importing a `torch.Tensor` via the
`PythonCall` package:

Expand All @@ -55,7 +57,9 @@ using PythonCall
torch = pyimport("torch")

pyv = torch.arange(1, 5).reshape(2, 2)
v = DLPack.wrap(pyv, torch.to_dlpack)
v = from_dlpack(pyv)
# For older torch releases use:
# v = DLPack.wrap(pyv, torch.to_dlpack)

Bool(v[2, 1] == 2 == pyv[0, 1]) # dimensions are reversed
```
Expand Down Expand Up @@ -84,7 +88,9 @@ using PyCall
cupy = pyimport("cupy")

pyv = cupy.arange(6).reshape(2, 3)
v = DLPack.wrap(pyv, o -> pycall(o.toDlpack, PyObject))
v = from_dlpack(pyv)
# For older versions of cupy use:
# v = DLPack.wrap(pyv, o -> pycall(o.toDlpack, PyObject))

v .= 1
pyv.sum().item() == 6 # true
Expand Down
33 changes: 26 additions & 7 deletions ext/PyCallExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,30 @@ function DLPack.DLManagedTensor(po::PyCall.PyObject)
return tensor
end

# Docstring in src/DLPack.jl
function DLPack.from_dlpack(o::PyCall.PyObject)
tensor = DLPack.DLManagedTensor(PyCall.@pycall o.__dlpack__()::PyCall.PyObject)
return DLPack.unsafe_wrap(tensor, o)
end

"""
from_dlpack(::Type{<: AbstractArray{T, N}}, ::Type{<: MemoryLayout}, o::PyObject)
Type-inferrable alternative to `from_dlpack`.
"""
function DLPack.from_dlpack(::Type{A}, ::Type{M}, o::PyCall.PyObject) where {
T, N, A <: AbstractArray{T, N}, M
}
tensor = DLPack.DLManagedTensor(PyCall.@pycall o.__dlpack__()::PyCall.PyObject)
return DLPack.unsafe_wrap(A, M, tensor, o)
end

"""
wrap(o::PyObject, to_dlpack)
Takes a tensor `o::PyObject` and a `to_dlpack` function that generates a
`DLManagedTensor` bundled in a PyCapsule, and returns a zero-copy
`array::AbstractArray` pointing to the same data in `o`.
For tensors with row-major ordering the resulting array will have all
dimensions reversed.
Similar to `from_dlpack`, but works for python arrays that do not implement a `__dlpack__`
method. `to_dlpack` must be a function that, when applied to `o`, generates a
`DLManagedTensor` bundled into a `PyCapsule`.
"""
function DLPack.wrap(o::PyCall.PyObject, to_dlpack::Union{PyCall.PyObject, Function})
return DLPack.unsafe_wrap(DLPack.DLManagedTensor(to_dlpack(o)), o)
Expand All @@ -82,7 +98,7 @@ end
"""
wrap(::Type{<: AbstractArray{T, N}}, ::Type{<: MemoryLayout}, o::PyObject, to_dlpack)
Type-inferrable alternative to `wrap(o, to_dlpack)`.
Type-inferrable alternative to `wrap`.
"""
function DLPack.wrap(::Type{A}, ::Type{M}, o::PyCall.PyObject, to_dlpack) where {
T, N, A <: AbstractArray{T, N}, M
Expand Down Expand Up @@ -117,7 +133,10 @@ function DLPack.share(A::StridedArray, from_dlpack::PyCall.PyObject)

return try
from_dlpack(pycapsule)
catch
catch e
if !(e isa KeyError && any(e.key .== (:__dlpack__, :__dlpack_device__)))
rethrow()
end
dl_array = DLArray()
ctx = DLPack.dldevice(tensor)
dl_array.capsule = pycapsule
Expand Down
36 changes: 29 additions & 7 deletions ext/PythonCallExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,30 @@ function DLPack.DLManagedTensor(po::PythonCall.Py)
return tensor
end

# Docstring in src/DLPack.jl
function DLPack.from_dlpack(o::PythonCall.Py)
tensor = DLPack.DLManagedTensor(o.__dlpack__())
return DLPack.unsafe_wrap(tensor, o)
end

"""
from_dlpack(::Type{<: AbstractArray{T, N}}, ::Type{<: MemoryLayout}, o::Py)
Type-inferrable alternative to `from_dlpack`.
"""
function DLPack.from_dlpack(::Type{A}, ::Type{M}, o::PythonCall.Py) where {
T, N, A <: AbstractArray{T, N}, M
}
tensor = DLPack.DLManagedTensor(o.__dlpack__())
return DLPack.unsafe_wrap(A, M, tensor, o)
end

"""
wrap(o::Py, to_dlpack)
Takes a tensor `o::Py` and a `to_dlpack` function that generates a
`DLManagedTensor` bundled in a PyCapsule, and returns a zero-copy
`array::AbstractArray` pointing to the same data in `o`.
For tensors with row-major ordering the resulting array will have all
dimensions reversed.
Similar to `from_dlpack`, but works for python arrays that do not implement a `__dlpack__`
method. `to_dlpack` must be a function that, when applied to `o`, generates a
`DLManagedTensor` bundled into a `PyCapsule`.
"""
function DLPack.wrap(o::PythonCall.Py, to_dlpack::Union{PythonCall.Py, Function})
return DLPack.unsafe_wrap(DLPack.DLManagedTensor(to_dlpack(o)), o)
Expand All @@ -71,7 +87,7 @@ end
"""
wrap(::Type{<: AbstractArray{T, N}}, ::Type{<: MemoryLayout}, o::Py, to_dlpack)
Type-inferrable alternative to `wrap(o, to_dlpack)`.
Type-inferrable alternative to `wrap`.
"""
function DLPack.wrap(::Type{A}, ::Type{M}, o::PythonCall.Py, to_dlpack) where {
T, N, A <: AbstractArray{T, N}, M
Expand Down Expand Up @@ -104,7 +120,13 @@ function DLPack.share(A::StridedArray, from_dlpack::PythonCall.Py)

return try
from_dlpack(pycapsule)
catch
catch e
if !(
PythonCall.pyisinstance(e, PythonCall.pybuiltins.AttributeError) &&
any(contains.(string(PythonCall.Py(e)), ("__dlpack__", "__dlpack_device__")))
)
rethrow()
end
ctx = DLPack.dldevice(tensor)
device = (Int(ctx.device_type), ctx.device_id)
from_dlpack(DLArray(pycapsule, device))
Expand Down
23 changes: 18 additions & 5 deletions src/DLPack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@ module DLPack
using Requires


## Exports ##

export RowMajor, ColMajor


## Types ##

@enum DLDeviceType::Cint begin
Expand Down Expand Up @@ -128,7 +123,12 @@ end
abstract type MemoryLayout end

struct ColMajor <: MemoryLayout end
#
export ColMajor

struct RowMajor <: MemoryLayout end
#
export RowMajor

"""
DLManager{T, N}
Expand Down Expand Up @@ -184,6 +184,19 @@ const SHARES_POOL = Dict{Ptr{Cvoid}, Tuple{Capsule, Any}}()

## Wrapping and sharing ##

"""
from_dlpack(o)
If `o` follows the DLPack specification, it returns a zero-copy `array::AbstractArray`
pointing to the same data in `o`. For arrays with row-major ordering the resulting array
will have all dimensions reversed.
"""
from_dlpack(o) = throw(ArgumentError("The input does not follow the DLPack specification"))
#
export from_dlpack

# Similar to `from_dlpack`, but takes a second argument that generates a `DLManagedTensor`
# possibly bundled in another data structure.
function wrap end

"""
Expand Down
17 changes: 13 additions & 4 deletions test/pycall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
to_dlpack = o -> @pycall dlpack.to_dlpack(o)::PyObject

v = jnp.asarray([1.0, 2.0, 3.0], dtype = jnp.float32)
jv = DLPack.wrap(v, to_dlpack)
follows_dlpack_spec = hasproperty(v, :__dlpack__)
jv = follows_dlpack_spec ? DLPack.from_dlpack(v) : DLPack.wrap(v, to_dlpack)
dlv = DLPack.DLManagedTensor(to_dlpack(v))
opaque_tensor = dlv.dl_tensor

Expand All @@ -24,15 +25,19 @@

if DLPack.device_type(opaque_tensor) == DLPack.kDLCPU
jv[1] = 0 # mutate a jax tensor
@inferred DLPack.wrap(Vector{Float32}, ColMajor, v, to_dlpack)
if follows_dlpack_spec
@inferred DLPack.from_dlpack(Vector{Float32}, ColMajor, v)
else
@inferred DLPack.wrap(Vector{Float32}, ColMajor, v, to_dlpack)
end
elseif DLPack.device_type(opaque_tensor) == DLPack.kDLCUDA
jv[1:1] .= 0 # mutate a jax tensor
end

@test py"$jnp.all($v == $jnp.asarray([0.0, 2.0, 3.0])).item()"

w = jnp.asarray([1 2; 3 4], dtype = jnp.int64)
jw = DLPack.wrap(w, to_dlpack)
jw = follows_dlpack_spec ? DLPack.from_dlpack(w) : DLPack.wrap(w, to_dlpack)
dlw = DLPack.DLManagedTensor(to_dlpack(w))
opaque_tensor = dlw.dl_tensor

Expand All @@ -41,7 +46,11 @@

if DLPack.device_type(opaque_tensor) == DLPack.kDLCPU
@test jw[1, 2] == 3 # dimensions are reversed
@inferred DLPack.wrap(Matrix{Int64}, RowMajor, w, to_dlpack)
if follows_dlpack_spec
@inferred DLPack.from_dlpack(Matrix{Int64}, RowMajor, w)
else
@inferred DLPack.wrap(Matrix{Int64}, RowMajor, w, to_dlpack)
end
elseif DLPack.device_type(opaque_tensor) == DLPack.kDLCUDA
@test all(view(jw, 1, 2) .== 3) # dimensions are reversed
end
Expand Down
3 changes: 2 additions & 1 deletion test/pythoncall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

@testset "wrap" begin
v = torch.ones((2, 4), dtype = torch.float64)
jv = DLPack.wrap(v, torch.to_dlpack)
follows_dlpack_spec = hasproperty(v, :__dlpack__)
jv = follows_dlpack_spec ? DLPack.from_dlpack(v) : DLPack.wrap(v, torch.to_dlpack)
dlv = DLPack.DLManagedTensor(torch.to_dlpack(v))
opaque_tensor = dlv.dl_tensor

Expand Down

0 comments on commit 29f708c

Please sign in to comment.