Skip to content

Commit

Permalink
[NDTensors] Unwrap module for dispatching on unwrapped types (#1220)
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT authored Nov 1, 2023
1 parent 70967cd commit 66f5d39
Show file tree
Hide file tree
Showing 59 changed files with 339 additions and 276 deletions.
1 change: 1 addition & 0 deletions NDTensors/ext/NDTensorsCUDAExt/NDTensorsCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module NDTensorsCUDAExt

using NDTensors
using NDTensors.SetParameters
using NDTensors.Unwrap
using Adapt
using Functors
using LinearAlgebra
Expand Down
26 changes: 21 additions & 5 deletions NDTensors/ext/NDTensorsCUDAExt/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
function Base.getindex(::Type{<:CuArray}, T::DenseTensor{<:Number})
return CUDA.@allowscalar data(T)[]
function Base.getindex(E::Exposed{<:CuArray})
return CUDA.@allowscalar unexpose(E)[]
end

function Base.setindex!(::Type{<:CuArray}, T::DenseTensor{<:Number}, x::Number)
CUDA.@allowscalar data(T)[] = x
return T
function setindex!(E::Exposed{<:CuArray}, x::Number)
CUDA.@allowscalar unexpose(E)[] = x
return unexpose(E)
end

function Base.getindex(E::Exposed{<:CuArray,<:Adjoint}, I...)
Ap = parent(E)
return expose(Ap)[I...]
end

function Base.copy(E::Exposed{<:CuArray,<:Base.ReshapedArray})
Ap = parent(E)
return copy(expose(Ap))
end

Base.any(f, E::Exposed{<:CuArray,<:NDTensors.Tensor}) = any(f, data(unexpose(E)))

function Base.print_array(io::IO, E::Exposed{<:CuArray})
return Base.print_array(io, expose(NDTensors.cpu(E)))
end
2 changes: 1 addition & 1 deletion NDTensors/ext/NDTensorsCUDAExt/linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ function NDTensors.svd_catch_error(A::CuMatrix; alg="JacobiAlgorithm")
alg = CUDA.CUSOLVER.QRAlgorithm()
end
USV = try
svd(A; alg=alg)
svd(expose(A); alg=alg)
catch
return nothing
end
Expand Down
1 change: 1 addition & 0 deletions NDTensors/ext/NDTensorsMetalExt/imports.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import NDTensors: mtl, set_ndims, set_eltype, set_eltype_if_unspecified
import NDTensors.SetParameters: nparameters, get_parameter, set_parameter, default_parameter

using NDTensors.Unwrap: Exposed, unwrap_type, unexpose, expose
using Metal: DefaultStorageMode
using NDTensors: adapt
10 changes: 5 additions & 5 deletions NDTensors/ext/NDTensorsMetalExt/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
function Base.getindex(::Type{<:MtlArray}, T::DenseTensor{<:Number})
return Metal.@allowscalar data(T)[]
function Base.getindex(E::Exposed{<:MtlArray})
return Metal.@allowscalar unexpose(E)[]
end

function Base.setindex!(::Type{<:MtlArray}, T::DenseTensor{<:Number}, x::Number)
Metal.@allowscalar data(T)[] = x
return T
function Base.setindex!(E::Exposed{<:MtlArray}, x::Number)
Metal.@allowscalar unexpose(E)[] = x
return unexpose(E)
end
36 changes: 25 additions & 11 deletions NDTensors/ext/NDTensorsMetalExt/linearalgebra.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,30 @@
function NDTensors.qr(leaf_parenttype::Type{<:MtlArray}, A::AbstractMatrix)
Q, R = NDTensors.qr(NDTensors.cpu(A))
return adapt(leaf_parenttype, Matrix(Q)), adapt(leaf_parenttype, R)
function LinearAlgebra.qr(A::Exposed{<:MtlMatrix})
Q, R = qr(expose(NDTensors.cpu(A)))
return adapt(unwrap_type(A), Matrix(Q)), adapt(unwrap_type(A), R)
end

function NDTensors.eigen(leaf_parenttype::Type{<:MtlArray}, A::AbstractMatrix)
D, U = NDTensors.eigen(NDTensors.cpu(A))
return adapt(set_ndims(leaf_parenttype, ndims(D)), D), adapt(leaf_parenttype, U)
function NDTensors.Unwrap.qr_positive(A::Exposed{<:MtlMatrix})
Q, R = qr_positive(expose(NDTensors.cpu(A)))
return adapt(unwrap_type(A), Matrix(Q)), adapt(unwrap_type(A), R)
end

function NDTensors.svd(leaf_parenttype::Type{<:MtlArray}, A::AbstractMatrix)
U, S, V = NDTensors.svd(NDTensors.cpu(A))
return adapt(leaf_parenttype, U),
adapt(set_ndims(leaf_parenttype, ndims(S)), S),
adapt(leaf_parenttype, V)
function NDTensors.Unwrap.ql(A::Exposed{<:MtlMatrix})
Q, L = ql(expose(NDTensors.cpu(A)))
return adapt(unwrap_type(A), Matrix(Q)), adapt(unwrap_type(A), L)
end
function NDTensors.Unwrap.ql_positive(A::Exposed{<:MtlMatrix})
Q, L = ql_positive(expose(NDTensors.cpu(A)))
return adapt(unwrap_type(A), Matrix(Q)), adapt(unwrap_type(A), L)
end

function LinearAlgebra.eigen(A::Exposed{<:MtlMatrix})
D, U = eigen(expose(NDTensors.cpu(A)))
return adapt(set_ndims(unwrap_type(A), ndims(D)), D), adapt(unwrap_type(A), U)
end

function LinearAlgebra.svd(A::Exposed{<:MtlMatrix}; kwargs...)
U, S, V = svd(expose(NDTensors.cpu(A)); kwargs...)
return adapt(unwrap_type(A), U),
adapt(set_ndims(unwrap_type(A), ndims(S)), S),
adapt(unwrap_type(A), V)
end
13 changes: 5 additions & 8 deletions NDTensors/ext/NDTensorsMetalExt/mul.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
# This was calling generic matrix multiplication.
# TODO: Raise an issue with `Metal.jl`.
function NDTensors.mul!!(
::Type{<:MtlArray},
CM::Transpose,
::Type{<:MtlArray},
AM::AbstractMatrix,
::Type{<:MtlArray},
BM::AbstractMatrix,
function LinearAlgebra.mul!(
CM::Exposed{<:MtlArray,<:Transpose},
AM::Exposed{<:MtlArray},
BM::Exposed{<:MtlArray},
α,
β,
)
mul!(transpose(CM), transpose(BM), transpose(AM), α, β)
return CM
return unexpose(CM)
end
15 changes: 5 additions & 10 deletions NDTensors/ext/NDTensorsMetalExt/permutedims.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
function NDTensors.permutedims!(
::Type{<:MtlArray},
Adest::Base.ReshapedArray{<:Any,<:Any,<:SubArray},
::Type{<:MtlArray},
A,
perm,
function permutedims!(
Edest::Exposed{<:MtlArray,<:Base.ReshapedArray}, Esrc::Exposed{<:MtlArray}, perm
)
Aperm = permutedims(A, perm)
Adest_parent = parent(Adest)
copyto!(Adest_parent, Aperm)
return Adest
Aperm = permutedims(Esrc, perm)
copyto!(expose(parent(Edest)), expose(Aperm))
return unexpose(Edest)
end
9 changes: 3 additions & 6 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ include("SortedSets/src/SortedSets.jl")
using .SortedSets
include("TagSets/src/TagSets.jl")
using .TagSets
include("Unwrap/src/Unwrap.jl")
using .Unwrap

using Base: @propagate_inbounds, ReshapedArray, DimOrInd, OneTo

Expand All @@ -51,16 +53,13 @@ include("algorithm.jl")
include("aliasstyle.jl")
include("abstractarray/set_types.jl")
include("abstractarray/to_shape.jl")
include("abstractarray/iswrappedarray.jl")
include("abstractarray/iscu.jl")
include("abstractarray/similar.jl")
include("abstractarray/ndims.jl")
include("abstractarray/copyto.jl")
include("abstractarray/mul.jl")
include("abstractarray/append.jl")
include("abstractarray/permutedims.jl")
include("abstractarray/fill.jl")
include("abstractarray/mul.jl")
include("abstractarray/linearalgebra.jl")
include("array/set_types.jl")
include("array/permutedims.jl")
include("array/mul.jl")
Expand All @@ -75,8 +74,6 @@ include("tensor/tensor.jl")
include("dims.jl")
include("tensor/set_types.jl")
include("tensor/similar.jl")
include("tensor/permutedims.jl")
include("tensor/linearalgebra.jl")
include("adapt.jl")
include("tensoralgebra/generic_tensor_operations.jl")
include("tensoralgebra/contraction_logic.jl")
Expand Down
3 changes: 3 additions & 0 deletions NDTensors/src/Unwrap/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Unwrap

A module to unwrap complex array types to assist in the generic programming of array-type based functions.
4 changes: 4 additions & 0 deletions NDTensors/src/Unwrap/TODO.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Replace all `leaf_parenttype` calls by wrapping the arrays in this `expose` type

Fix the issue Ryan found in MPS
Make a GPUArrays extension that has generic GPU algorithms
25 changes: 25 additions & 0 deletions NDTensors/src/Unwrap/src/Unwrap.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
module Unwrap
using SimpleTraits
using LinearAlgebra
using Base: ReshapedArray
using Strided.StridedViews

include("expose.jl")
include("iswrappedarray.jl")

include("import.jl")
## TODO Create functions which take the `Expose` type and launch functions
## using that type
## Exposed based functions
include("functions/abstractarray.jl")
include("functions/copyto.jl")
include("functions/linearalgebra.jl")
include("functions/mul.jl")
include("functions/permutedims.jl")

export IsWrappedArray,
is_wrapped_array, parenttype, unwrap_type, expose, Exposed, unexpose, cpu

## TODO write exposed based functions in the NDTensors Extensions when necessary

end
7 changes: 7 additions & 0 deletions NDTensors/src/Unwrap/src/expose.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
struct Exposed{Unwrapped,Object}
object::Object
end

expose(object) = Exposed{unwrap_type(object),typeof(object)}(object)

unexpose(E::Exposed) = E.object
22 changes: 22 additions & 0 deletions NDTensors/src/Unwrap/src/functions/abstractarray.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
parent(E::Exposed) = parent(unexpose(E))

transpose(E::Exposed) = transpose(unexpose(E))

cpu(E::Exposed) = cpu(unexpose(E))

getindex(E::Exposed) = unexpose(E)[]

function setindex!(E::Exposed, x::Number)
unexpose(E)[] = x
return unexpose(E)
end

getindex(E::Exposed, I...) = unexpose(E)[I...]

function copy(E::Exposed)
return copy(unexpose(E))
end

any(f, E::Exposed) = any(f, unexpose(E))

print_array(io::IO, E::Exposed) = print_array(io, unexpose(E))
4 changes: 4 additions & 0 deletions NDTensors/src/Unwrap/src/functions/copyto.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
function copyto!(R::Exposed, T::Exposed)
copyto!(unexpose(R), unexpose(T))
return unexpose(R)
end
29 changes: 29 additions & 0 deletions NDTensors/src/Unwrap/src/functions/linearalgebra.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
function qr(E::Exposed)
return qr(unexpose(E))
end
## These functions do not exist in `LinearAlgebra` but were defined
## in NDTensors. Because Unwrap is imported before NDTensors,
## one cannot import a these functions from NDTensors so instead
## I define them here and extend them in NDTensors
## I have done the same thing for the function cpu
## Unwrap.qr_positive
function qr_positive(E::Exposed)
return qr_positive(unexpose(E))
end

## Unwrap.ql
function ql(E::Exposed)
return ql(unexpose(E))
end
## Unwrap.ql_positive
function ql_positive(E::Exposed)
return ql_positive(unexpose(E))
end

function eigen(E::Exposed)
return eigen(unexpose(E))
end

function svd(E::Exposed; kwargs...)
return svd(unexpose(E); kwargs...)
end
4 changes: 4 additions & 0 deletions NDTensors/src/Unwrap/src/functions/mul.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
function mul!(CM::Exposed, AM::Exposed, BM::Exposed, α, β)
mul!(unexpose(CM), unexpose(AM), unexpose(BM), α, β)
return unexpose(CM)
end
13 changes: 13 additions & 0 deletions NDTensors/src/Unwrap/src/functions/permutedims.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
function permutedims(E::Exposed, perm)
return permutedims(unexpose(E), perm)
end

function permutedims!(Edest::Exposed, Esrc::Exposed, perm)
permutedims!(unexpose(Edest), unexpose(Esrc), perm)
return unexpose(Edest)
end

function permutedims!(Edest::Exposed, Esrc::Exposed, perm, f)
unexpose(Edest) .= f.(unexpose(Edest), permutedims(Esrc, perm))
return unexpose(Edest)
end
13 changes: 13 additions & 0 deletions NDTensors/src/Unwrap/src/import.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import Base:
permutedims,
permutedims!,
copy,
copyto!,
parent,
print_array,
transpose,
getindex,
setindex!,
any

import LinearAlgebra: mul!, qr, eigen, svd
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,20 @@ parenttype(::Type{<:StridedView{<:Any,<:Any,P}}) where {P} = P
# `SimpleTraits.jl` traits dispatch.
parenttype(array::AbstractArray) = parenttype(typeof(array))

@traitfn function leaf_parenttype(
## These functions will be used in place of unwrap_type but will be
## call indirectly through the expose function.
@traitfn function unwrap_type(
arraytype::Type{ArrayT}
) where {ArrayT; IsWrappedArray{ArrayT}}
return leaf_parenttype(parenttype(arraytype))
return unwrap_type(parenttype(arraytype))
end

@traitfn function leaf_parenttype(
@traitfn function unwrap_type(
arraytype::Type{ArrayT}
) where {ArrayT; !IsWrappedArray{ArrayT}}
return arraytype
end

# For working with instances.
leaf_parenttype(array::AbstractArray) = leaf_parenttype(typeof(array))
unwrap_type(array::AbstractArray) = unwrap_type(typeof(array))
unwrap_type(E::Exposed) = unwrap_type(unexpose(E))
2 changes: 1 addition & 1 deletion NDTensors/src/abstractarray/append.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Used to circumvent issues with some GPU backends like Metal
# not supporting `resize!`.
function append!!(collection, collections...)
return append!!(leaf_parenttype(collection), collection, collections...)
return append!!(unwrap_type(collection), collection, collections...)
end

function append!!(::Type, collection, collections...)
Expand Down
13 changes: 0 additions & 13 deletions NDTensors/src/abstractarray/copyto.jl

This file was deleted.

4 changes: 2 additions & 2 deletions NDTensors/src/abstractarray/fill.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ function generic_randn(
arraytype::Type{<:AbstractArray}, dim::Integer=0; rng=Random.default_rng()
)
arraytype_specified = set_unspecified_parameters(
leaf_parenttype(arraytype), DefaultParameters()
unwrap_type(arraytype), DefaultParameters()
)
data = similar(arraytype_specified, dim)
return randn!(rng, data)
end

function generic_zeros(arraytype::Type{<:AbstractArray}, dims...)
arraytype_specified = set_unspecified_parameters(
leaf_parenttype(arraytype), DefaultParameters()
unwrap_type(arraytype), DefaultParameters()
)
ElT = eltype(arraytype_specified)
return fill!(similar(arraytype_specified, dims...), zero(ElT))
Expand Down
2 changes: 1 addition & 1 deletion NDTensors/src/abstractarray/iscu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
# For `isgpu`, will require a `NDTensorsGPUArrayCoreExt`.
iscu(A::AbstractArray) = iscu(typeof(A))
function iscu(A::Type{<:AbstractArray})
return (leaf_parenttype(A) == A ? false : iscu(leaf_parenttype(A)))
return (unwrap_type(A) == A ? false : iscu(unwrap_type(A)))
end
Loading

0 comments on commit 66f5d39

Please sign in to comment.