Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bring support for the Python specification for DLPack #37

Merged
merged 7 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,33 @@ version = "0.1.2"
[deps]
Requires = "ae029012-a4dd-5104-9daa-d747884805df"

[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"

[extensions]
DLPackCUDA = "CUDA"
DLPackPyCall = "PyCall"
DLPackPythonCall = "PythonCall"
CUDAExt = "CUDA"
PyCallExt = "PyCall"
PythonCallExt = "PythonCall"

[compat]
CUDA = "1.3 - 5"
Aqua = "0.8"
CUDA = "≥ 1.3"
PyCall = "1.92"
PythonCall = "0.6.1 - 0"
PythonCall = "≥ 0.7"
Requires = "1.3"
TestItemRunner = "0.2.2"
julia = "1.3"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"

[targets]
test = ["Aqua", "CondaPkg", "CUDA", "PyCall", "PythonCall", "Test", "TestItemRunner"]
14 changes: 10 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ np = pyimport("jax.numpy")
dl = pyimport("jax.dlpack")

pyv = np.arange(10)
v = DLPack.wrap(pyv, o -> @pycall dl.to_dlpack(o)::PyObject)
v = from_dlpack(pyv)
# For older jax version use:
# v = DLPack.wrap(pyv, o -> @pycall dl.to_dlpack(o)::PyObject)

(pyv[1] == 1).item() # This is false since the first element is 0

Expand All @@ -44,7 +46,7 @@ v[1] = 1
```

If the python tensor has more than one dimension and the memory layout is
row-major the array returned by `DLPack.wrap` has its dimensions reversed.
row-major the array returned by `DLPack.from_dlpack` has its dimensions reversed.
Let us illustrate this now by importing a `torch.Tensor` via the
`PythonCall` package:

Expand All @@ -55,7 +57,9 @@ using PythonCall
torch = pyimport("torch")

pyv = torch.arange(1, 5).reshape(2, 2)
v = DLPack.wrap(pyv, torch.to_dlpack)
v = from_dlpack(pyv)
# For older torch releases use:
# v = DLPack.wrap(pyv, torch.to_dlpack)

Bool(v[2, 1] == 2 == pyv[0, 1]) # dimensions are reversed
```
Expand Down Expand Up @@ -84,7 +88,9 @@ using PyCall
cupy = pyimport("cupy")

pyv = cupy.arange(6).reshape(2, 3)
v = DLPack.wrap(pyv, o -> pycall(o.toDlpack, PyObject))
v = from_dlpack(pyv)
# For older versions of cupy use:
# v = DLPack.wrap(pyv, o -> pycall(o.toDlpack, PyObject))

v .= 1
pyv.sum().item() == 6 # true
Expand Down
4 changes: 2 additions & 2 deletions ext/DLPackCUDA.jl → ext/CUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: MIT
# See LICENSE.md at https://github.com/pabloferz/DLPack.jl

module DLPackCUDA
module CUDAExt


## Dependencies ##
Expand Down Expand Up @@ -50,4 +50,4 @@ function Base.unsafe_wrap(::Type{<: CUDA.CuArray}, manager::DLPack.DLManager{T})
end


end # module DLPackCUDA
end # module CUDAExt
92 changes: 73 additions & 19 deletions ext/DLPackPyCall.jl → ext/PyCallExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: MIT
# See LICENSE.md at https://github.com/pabloferz/DLPack.jl

module DLPackPyCall
module PyCallExt


## Dependencies ##
Expand All @@ -16,6 +16,9 @@ end

## Extensions ##

const DLArray = PyCall.PyNULL()


"""
DLManagedTensor(po::PyObject)

Expand Down Expand Up @@ -63,14 +66,30 @@ function DLPack.DLManagedTensor(po::PyCall.PyObject)
return tensor
end

# Docstring in src/DLPack.jl
function DLPack.from_dlpack(o::PyCall.PyObject)
tensor = DLPack.DLManagedTensor(PyCall.@pycall o.__dlpack__()::PyCall.PyObject)
return DLPack.unsafe_wrap(tensor, o)
end

"""
from_dlpack(::Type{<: AbstractArray{T, N}}, ::Type{<: MemoryLayout}, o::PyObject)

Type-inferrable alternative to `from_dlpack`.
"""
function DLPack.from_dlpack(::Type{A}, ::Type{M}, o::PyCall.PyObject) where {
T, N, A <: AbstractArray{T, N}, M
}
tensor = DLPack.DLManagedTensor(PyCall.@pycall o.__dlpack__()::PyCall.PyObject)
return DLPack.unsafe_wrap(A, M, tensor, o)
end

"""
wrap(o::PyObject, to_dlpack)

Takes a tensor `o::PyObject` and a `to_dlpack` function that generates a
`DLManagedTensor` bundled in a PyCapsule, and returns a zero-copy
`array::AbstractArray` pointing to the same data in `o`.
For tensors with row-major ordering the resulting array will have all
dimensions reversed.
Similar to `from_dlpack`, but works for python arrays that do not implement a `__dlpack__`
method. `to_dlpack` must be a function that, when applied to `o`, generates a
`DLManagedTensor` bundled into a `PyCapsule`.
"""
function DLPack.wrap(o::PyCall.PyObject, to_dlpack::Union{PyCall.PyObject, Function})
return DLPack.unsafe_wrap(DLPack.DLManagedTensor(to_dlpack(o)), o)
Expand All @@ -79,7 +98,7 @@ end
"""
wrap(::Type{<: AbstractArray{T, N}}, ::Type{<: MemoryLayout}, o::PyObject, to_dlpack)

Type-inferrable alternative to `wrap(o, to_dlpack)`.
Type-inferrable alternative to `wrap`.
"""
function DLPack.wrap(::Type{A}, ::Type{M}, o::PyCall.PyObject, to_dlpack) where {
T, N, A <: AbstractArray{T, N}, M
Expand All @@ -96,16 +115,6 @@ The resulting tensor will have all dimensions reversed with respect
to the Julia array.
"""
function DLPack.share(A::StridedArray, from_dlpack::PyCall.PyObject)
return DLPack.share(A, PyCall.PyObject, from_dlpack)
end

"""
share(A::StridedArray, ::Type{PyObject}, from_dlpack)

Similar to `share(A, from_dlpack::PyObject)`. Use when there is a need to
disambiguate the return type.
"""
function DLPack.share(A::StridedArray, ::Type{PyCall.PyObject}, from_dlpack)
capsule = DLPack.share(A)
tensor = capsule.tensor
tensor_ptr = pointer_from_objref(tensor)
Expand All @@ -122,8 +131,53 @@ function DLPack.share(A::StridedArray, ::Type{PyCall.PyObject}, from_dlpack)
tensor_ptr, DLPack.PYCAPSULE_NAME, C_NULL
))

return from_dlpack(pycapsule)
return try
from_dlpack(pycapsule)
catch e
if !(e isa KeyError && any(e.key .== (:__dlpack__, :__dlpack_device__)))
rethrow()
end
dl_array = DLArray()
ctx = DLPack.dldevice(tensor)
dl_array.capsule = pycapsule
dl_array.device = (Int(ctx.device_type), ctx.device_id)
from_dlpack(dl_array)
end
end


## Deprecations ##

# NOTE: replace by the following when our julia lower bound get to ≥ v"1.9".
# @deprecate(
# DLPack.share(A::StridedArray, ::Type{PyCall.PyObject}, from_dlpack),
# DLPack.share(A, PyCall.pyfunction(from_dlpack, PyCall.PyObject)),
# false
# )
function DLPack.share(A::StridedArray, ::Type{PyCall.PyObject}, from_dlpack)
Base.depwarn("""
`DLPack.share`(A, ::Type{PyObject}), from_dlpack) is deprecated, use
`DLPack.share`(A, from_dlpack) instead. If `from_dlpack` is a julia function,
use `pyfunction` to wrap it.
""",
:share
)
DLPack.share(A, PyCall.pyfunction(from_dlpack, PyCall.PyObject))
end


## Extension initialization ##

function __init__()
copy!(DLArray,
PyCall.@pydef_object mutable struct DLArray
capsule = nothing
device = nothing
__dlpack__(self; stream = nothing) = self."capsule"
__dlpack_device__(self) = self."device"
end
)
end


end # module DLPackPyCall
end # module PyCallExt
107 changes: 89 additions & 18 deletions ext/DLPackPythonCall.jl → ext/PythonCallExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: MIT
# See LICENSE.md at https://github.com/pabloferz/DLPack.jl

module DLPackPythonCall
module PythonCallExt


## Dependencies ##
Expand All @@ -17,6 +17,7 @@ end
## Extensions ##

const CPython = PythonCall.C
const DLArray = PythonCall.pynew()


"""
Expand Down Expand Up @@ -54,14 +55,30 @@ function DLPack.DLManagedTensor(po::PythonCall.Py)
return tensor
end

# Docstring in src/DLPack.jl
function DLPack.from_dlpack(o::PythonCall.Py)
tensor = DLPack.DLManagedTensor(o.__dlpack__())
return DLPack.unsafe_wrap(tensor, o)
end

"""
from_dlpack(::Type{<: AbstractArray{T, N}}, ::Type{<: MemoryLayout}, o::Py)

Type-inferrable alternative to `from_dlpack`.
"""
function DLPack.from_dlpack(::Type{A}, ::Type{M}, o::PythonCall.Py) where {
T, N, A <: AbstractArray{T, N}, M
}
tensor = DLPack.DLManagedTensor(o.__dlpack__())
return DLPack.unsafe_wrap(A, M, tensor, o)
end

"""
wrap(o::Py, to_dlpack)

Takes a tensor `o::Py` and a `to_dlpack` function that generates a
`DLManagedTensor` bundled in a PyCapsule, and returns a zero-copy
`array::AbstractArray` pointing to the same data in `o`.
For tensors with row-major ordering the resulting array will have all
dimensions reversed.
Similar to `from_dlpack`, but works for python arrays that do not implement a `__dlpack__`
method. `to_dlpack` must be a function that, when applied to `o`, generates a
`DLManagedTensor` bundled into a `PyCapsule`.
"""
function DLPack.wrap(o::PythonCall.Py, to_dlpack::Union{PythonCall.Py, Function})
return DLPack.unsafe_wrap(DLPack.DLManagedTensor(to_dlpack(o)), o)
Expand All @@ -70,7 +87,7 @@ end
"""
wrap(::Type{<: AbstractArray{T, N}}, ::Type{<: MemoryLayout}, o::Py, to_dlpack)

Type-inferrable alternative to `wrap(o, to_dlpack)`.
Type-inferrable alternative to `wrap`.
"""
function DLPack.wrap(::Type{A}, ::Type{M}, o::PythonCall.Py, to_dlpack) where {
T, N, A <: AbstractArray{T, N}, M
Expand All @@ -86,15 +103,7 @@ following the DLPack protocol. Returns a Python tensor that shares the data with
The resulting tensor will have all dimensions reversed with respect
to the Julia array.
"""
DLPack.share(A::StridedArray, from_dlpack::PythonCall.Py) = DLPack.share(A, PythonCall.Py, from_dlpack)

"""
share(A::StridedArray, ::Type{Py}, from_dlpack)

Similar to `share(A, from_dlpack::Py)`. Use when there is a need to
disambiguate the return type.
"""
function DLPack.share(A::StridedArray, ::Type{PythonCall.Py}, from_dlpack)
function DLPack.share(A::StridedArray, from_dlpack::PythonCall.Py)
capsule = DLPack.share(A)
tensor = capsule.tensor
tensor_ptr = pointer_from_objref(tensor)
Expand All @@ -109,8 +118,70 @@ function DLPack.share(A::StridedArray, ::Type{PythonCall.Py}, from_dlpack)
CPython.PyCapsule_New(tensor_ptr, DLPack.PYCAPSULE_NAME, C_NULL)
)

return from_dlpack(pycapsule)
return try
from_dlpack(pycapsule)
catch e
if !(
PythonCall.pyisinstance(e, PythonCall.pybuiltins.AttributeError) &&
any(contains.(string(PythonCall.Py(e)), ("__dlpack__", "__dlpack_device__")))
)
rethrow()
end
ctx = DLPack.dldevice(tensor)
device = (Int(ctx.device_type), ctx.device_id)
from_dlpack(DLArray(pycapsule, device))
end
end


## Deprecations ##

# NOTE: replace by the following when our julia lower bound get to ≥ v"1.9".
# @deprecate(
# DLPack.share(A::StridedArray, ::Type{PythonCall.Py}, from_dlpack),
# DLPack.share(A, PythonCall.pyfunc(from_dlpack)),
# #= export_old =# false
# )
function DLPack.share(A::StridedArray, ::Type{PythonCall.Py}, from_dlpack)
Base.depwarn("""
`DLPack.share`(A, ::Type{Py}), from_dlpack) is deprecated, use
`DLPack.share`(A, from_dlpack) instead. If `from_dlpack` is a julia function,
use `pyfunc` to wrap it.
""",
:share
)
DLPack.share(A, PythonCall.pyfunc(from_dlpack))
end


## Extension initialization ##

function __init__()
PythonCall.pycopy!(DLArray,
PythonCall.pytype("DLArray", (), [
"__module__" => "__main__",

PythonCall.pyfunc(
name = "__init__",
(self, capsule, device) -> begin
self.capsule = capsule
self.device = device
nothing
end,
),

PythonCall.pyfunc(
name = "__dlpack__",
(self; stream = nothing) -> self.capsule,
),

PythonCall.pyfunc(
name = "__dlpack_device__",
(self) -> self.device,
)
])
)
end


end # module DLPackPythonCall
end # module PythonCallExt
Loading
Loading