diff --git a/README.md b/README.md index db64ceb..1990a40 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,9 @@ np = pyimport("jax.numpy") dl = pyimport("jax.dlpack") pyv = np.arange(10) -v = DLPack.wrap(pyv, o -> @pycall dl.to_dlpack(o)::PyObject) +v = from_dlpack(pyv) +# For older jax version use: +# v = DLPack.wrap(pyv, o -> @pycall dl.to_dlpack(o)::PyObject) (pyv[1] == 1).item() # This is false since the first element is 0 @@ -44,7 +46,7 @@ v[1] = 1 ``` If the python tensor has more than one dimension and the memory layout is -row-major the array returned by `DLPack.wrap` has its dimensions reversed. +row-major the array returned by `DLPack.from_dlpack` has its dimensions reversed. Let us illustrate this now by importing a `torch.Tensor` via the `PythonCall` package: @@ -55,7 +57,9 @@ using PythonCall torch = pyimport("torch") pyv = torch.arange(1, 5).reshape(2, 2) -v = DLPack.wrap(pyv, torch.to_dlpack) +v = from_dlpack(pyv) +# For older torch releases use: +# v = DLPack.wrap(pyv, torch.to_dlpack) Bool(v[2, 1] == 2 == pyv[0, 1]) # dimensions are reversed ``` @@ -84,7 +88,9 @@ using PyCall cupy = pyimport("cupy") pyv = cupy.arange(6).reshape(2, 3) -v = DLPack.wrap(pyv, o -> pycall(o.toDlpack, PyObject)) +v = from_dlpack(pyv) +# For older versions of cupy use: +# v = DLPack.wrap(pyv, o -> pycall(o.toDlpack, PyObject)) v .= 1 pyv.sum().item() == 6 # true diff --git a/ext/PyCallExt.jl b/ext/PyCallExt.jl index 33b544d..eadd35c 100644 --- a/ext/PyCallExt.jl +++ b/ext/PyCallExt.jl @@ -66,14 +66,30 @@ function DLPack.DLManagedTensor(po::PyCall.PyObject) return tensor end +# Docstring in src/DLPack.jl +function DLPack.from_dlpack(o::PyCall.PyObject) + tensor = DLPack.DLManagedTensor(PyCall.@pycall o.__dlpack__()::PyCall.PyObject) + return DLPack.unsafe_wrap(tensor, o) +end + +""" + from_dlpack(::Type{<: AbstractArray{T, N}}, ::Type{<: MemoryLayout}, o::PyObject) + +Type-inferrable alternative to `from_dlpack`. +""" +function DLPack.from_dlpack(::Type{A}, ::Type{M}, o::PyCall.PyObject) where { + T, N, A <: AbstractArray{T, N}, M +} + tensor = DLPack.DLManagedTensor(PyCall.@pycall o.__dlpack__()::PyCall.PyObject) + return DLPack.unsafe_wrap(A, M, tensor, o) +end + """ wrap(o::PyObject, to_dlpack) -Takes a tensor `o::PyObject` and a `to_dlpack` function that generates a -`DLManagedTensor` bundled in a PyCapsule, and returns a zero-copy -`array::AbstractArray` pointing to the same data in `o`. -For tensors with row-major ordering the resulting array will have all -dimensions reversed. +Similar to `from_dlpack`, but works for python arrays that do not implement a `__dlpack__` +method. `to_dlpack` must be a function that, when applied to `o`, generates a +`DLManagedTensor` bundled into a `PyCapsule`. """ function DLPack.wrap(o::PyCall.PyObject, to_dlpack::Union{PyCall.PyObject, Function}) return DLPack.unsafe_wrap(DLPack.DLManagedTensor(to_dlpack(o)), o) @@ -82,7 +98,7 @@ end """ wrap(::Type{<: AbstractArray{T, N}}, ::Type{<: MemoryLayout}, o::PyObject, to_dlpack) -Type-inferrable alternative to `wrap(o, to_dlpack)`. +Type-inferrable alternative to `wrap`. """ function DLPack.wrap(::Type{A}, ::Type{M}, o::PyCall.PyObject, to_dlpack) where { T, N, A <: AbstractArray{T, N}, M @@ -117,7 +133,10 @@ function DLPack.share(A::StridedArray, from_dlpack::PyCall.PyObject) return try from_dlpack(pycapsule) - catch + catch e + if !(e isa KeyError && any(e.key .== (:__dlpack__, :__dlpack_device__))) + rethrow() + end dl_array = DLArray() ctx = DLPack.dldevice(tensor) dl_array.capsule = pycapsule diff --git a/ext/PythonCallExt.jl b/ext/PythonCallExt.jl index 6a99ed5..3365cb0 100644 --- a/ext/PythonCallExt.jl +++ b/ext/PythonCallExt.jl @@ -55,14 +55,30 @@ function DLPack.DLManagedTensor(po::PythonCall.Py) return tensor end +# Docstring in src/DLPack.jl +function DLPack.from_dlpack(o::PythonCall.Py) + tensor = DLPack.DLManagedTensor(o.__dlpack__()) + return DLPack.unsafe_wrap(tensor, o) +end + +""" + from_dlpack(::Type{<: AbstractArray{T, N}}, ::Type{<: MemoryLayout}, o::Py) + +Type-inferrable alternative to `from_dlpack`. +""" +function DLPack.from_dlpack(::Type{A}, ::Type{M}, o::PythonCall.Py) where { + T, N, A <: AbstractArray{T, N}, M +} + tensor = DLPack.DLManagedTensor(o.__dlpack__()) + return DLPack.unsafe_wrap(A, M, tensor, o) +end + """ wrap(o::Py, to_dlpack) -Takes a tensor `o::Py` and a `to_dlpack` function that generates a -`DLManagedTensor` bundled in a PyCapsule, and returns a zero-copy -`array::AbstractArray` pointing to the same data in `o`. -For tensors with row-major ordering the resulting array will have all -dimensions reversed. +Similar to `from_dlpack`, but works for python arrays that do not implement a `__dlpack__` +method. `to_dlpack` must be a function that, when applied to `o`, generates a +`DLManagedTensor` bundled into a `PyCapsule`. """ function DLPack.wrap(o::PythonCall.Py, to_dlpack::Union{PythonCall.Py, Function}) return DLPack.unsafe_wrap(DLPack.DLManagedTensor(to_dlpack(o)), o) @@ -71,7 +87,7 @@ end """ wrap(::Type{<: AbstractArray{T, N}}, ::Type{<: MemoryLayout}, o::Py, to_dlpack) -Type-inferrable alternative to `wrap(o, to_dlpack)`. +Type-inferrable alternative to `wrap`. """ function DLPack.wrap(::Type{A}, ::Type{M}, o::PythonCall.Py, to_dlpack) where { T, N, A <: AbstractArray{T, N}, M @@ -104,7 +120,13 @@ function DLPack.share(A::StridedArray, from_dlpack::PythonCall.Py) return try from_dlpack(pycapsule) - catch + catch e + if !( + PythonCall.pyisinstance(e, PythonCall.pybuiltins.AttributeError) && + any(contains.(string(PythonCall.Py(e)), ("__dlpack__", "__dlpack_device__"))) + ) + rethrow() + end ctx = DLPack.dldevice(tensor) device = (Int(ctx.device_type), ctx.device_id) from_dlpack(DLArray(pycapsule, device)) diff --git a/src/DLPack.jl b/src/DLPack.jl index c21a16e..7a5f6d4 100644 --- a/src/DLPack.jl +++ b/src/DLPack.jl @@ -20,11 +20,6 @@ module DLPack using Requires -## Exports ## - -export RowMajor, ColMajor - - ## Types ## @enum DLDeviceType::Cint begin @@ -128,7 +123,12 @@ end abstract type MemoryLayout end struct ColMajor <: MemoryLayout end +# +export ColMajor + struct RowMajor <: MemoryLayout end +# +export RowMajor """ DLManager{T, N} @@ -184,6 +184,19 @@ const SHARES_POOL = Dict{Ptr{Cvoid}, Tuple{Capsule, Any}}() ## Wrapping and sharing ## +""" + from_dlpack(o) + +If `o` follows the DLPack specification, it returns a zero-copy `array::AbstractArray` +pointing to the same data in `o`. For arrays with row-major ordering the resulting array +will have all dimensions reversed. +""" +from_dlpack(o) = throw(ArgumentError("The input does not follow the DLPack specification")) +# +export from_dlpack + +# Similar to `from_dlpack`, but takes a second argument that generates a `DLManagedTensor` +# possibly bundled in another data structure. function wrap end """ diff --git a/test/pycall.jl b/test/pycall.jl index 25b6b56..0a48dc1 100644 --- a/test/pycall.jl +++ b/test/pycall.jl @@ -15,7 +15,8 @@ to_dlpack = o -> @pycall dlpack.to_dlpack(o)::PyObject v = jnp.asarray([1.0, 2.0, 3.0], dtype = jnp.float32) - jv = DLPack.wrap(v, to_dlpack) + follows_dlpack_spec = hasproperty(v, :__dlpack__) + jv = follows_dlpack_spec ? DLPack.from_dlpack(v) : DLPack.wrap(v, to_dlpack) dlv = DLPack.DLManagedTensor(to_dlpack(v)) opaque_tensor = dlv.dl_tensor @@ -24,7 +25,11 @@ if DLPack.device_type(opaque_tensor) == DLPack.kDLCPU jv[1] = 0 # mutate a jax tensor - @inferred DLPack.wrap(Vector{Float32}, ColMajor, v, to_dlpack) + if follows_dlpack_spec + @inferred DLPack.from_dlpack(Vector{Float32}, ColMajor, v) + else + @inferred DLPack.wrap(Vector{Float32}, ColMajor, v, to_dlpack) + end elseif DLPack.device_type(opaque_tensor) == DLPack.kDLCUDA jv[1:1] .= 0 # mutate a jax tensor end @@ -32,7 +37,7 @@ @test py"$jnp.all($v == $jnp.asarray([0.0, 2.0, 3.0])).item()" w = jnp.asarray([1 2; 3 4], dtype = jnp.int64) - jw = DLPack.wrap(w, to_dlpack) + jw = follows_dlpack_spec ? DLPack.from_dlpack(w) : DLPack.wrap(w, to_dlpack) dlw = DLPack.DLManagedTensor(to_dlpack(w)) opaque_tensor = dlw.dl_tensor @@ -41,7 +46,11 @@ if DLPack.device_type(opaque_tensor) == DLPack.kDLCPU @test jw[1, 2] == 3 # dimensions are reversed - @inferred DLPack.wrap(Matrix{Int64}, RowMajor, w, to_dlpack) + if follows_dlpack_spec + @inferred DLPack.from_dlpack(Matrix{Int64}, RowMajor, w) + else + @inferred DLPack.wrap(Matrix{Int64}, RowMajor, w, to_dlpack) + end elseif DLPack.device_type(opaque_tensor) == DLPack.kDLCUDA @test all(view(jw, 1, 2) .== 3) # dimensions are reversed end diff --git a/test/pythoncall.jl b/test/pythoncall.jl index 1c2a0d3..a910ef5 100644 --- a/test/pythoncall.jl +++ b/test/pythoncall.jl @@ -10,7 +10,8 @@ @testset "wrap" begin v = torch.ones((2, 4), dtype = torch.float64) - jv = DLPack.wrap(v, torch.to_dlpack) + follows_dlpack_spec = hasproperty(v, :__dlpack__) + jv = follows_dlpack_spec ? DLPack.from_dlpack(v) : DLPack.wrap(v, torch.to_dlpack) dlv = DLPack.DLManagedTensor(torch.to_dlpack(v)) opaque_tensor = dlv.dl_tensor