-
Notifications
You must be signed in to change notification settings - Fork 125
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[NDTensors]
Unwrap
module for dispatching on unwrapped types (#1220)
- Loading branch information
Showing
59 changed files
with
339 additions
and
276 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,24 @@ | ||
function Base.getindex(::Type{<:CuArray}, T::DenseTensor{<:Number}) | ||
return CUDA.@allowscalar data(T)[] | ||
function Base.getindex(E::Exposed{<:CuArray}) | ||
return CUDA.@allowscalar unexpose(E)[] | ||
end | ||
|
||
function Base.setindex!(::Type{<:CuArray}, T::DenseTensor{<:Number}, x::Number) | ||
CUDA.@allowscalar data(T)[] = x | ||
return T | ||
function setindex!(E::Exposed{<:CuArray}, x::Number) | ||
CUDA.@allowscalar unexpose(E)[] = x | ||
return unexpose(E) | ||
end | ||
|
||
function Base.getindex(E::Exposed{<:CuArray,<:Adjoint}, I...) | ||
Ap = parent(E) | ||
return expose(Ap)[I...] | ||
end | ||
|
||
function Base.copy(E::Exposed{<:CuArray,<:Base.ReshapedArray}) | ||
Ap = parent(E) | ||
return copy(expose(Ap)) | ||
end | ||
|
||
Base.any(f, E::Exposed{<:CuArray,<:NDTensors.Tensor}) = any(f, data(unexpose(E))) | ||
|
||
function Base.print_array(io::IO, E::Exposed{<:CuArray}) | ||
return Base.print_array(io, expose(NDTensors.cpu(E))) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
import NDTensors: mtl, set_ndims, set_eltype, set_eltype_if_unspecified | ||
import NDTensors.SetParameters: nparameters, get_parameter, set_parameter, default_parameter | ||
|
||
using NDTensors.Unwrap: Exposed, unwrap_type, unexpose, expose | ||
using Metal: DefaultStorageMode | ||
using NDTensors: adapt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
function Base.getindex(::Type{<:MtlArray}, T::DenseTensor{<:Number}) | ||
return Metal.@allowscalar data(T)[] | ||
function Base.getindex(E::Exposed{<:MtlArray}) | ||
return Metal.@allowscalar unexpose(E)[] | ||
end | ||
|
||
function Base.setindex!(::Type{<:MtlArray}, T::DenseTensor{<:Number}, x::Number) | ||
Metal.@allowscalar data(T)[] = x | ||
return T | ||
function Base.setindex!(E::Exposed{<:MtlArray}, x::Number) | ||
Metal.@allowscalar unexpose(E)[] = x | ||
return unexpose(E) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,30 @@ | ||
function NDTensors.qr(leaf_parenttype::Type{<:MtlArray}, A::AbstractMatrix) | ||
Q, R = NDTensors.qr(NDTensors.cpu(A)) | ||
return adapt(leaf_parenttype, Matrix(Q)), adapt(leaf_parenttype, R) | ||
function LinearAlgebra.qr(A::Exposed{<:MtlMatrix}) | ||
Q, R = qr(expose(NDTensors.cpu(A))) | ||
return adapt(unwrap_type(A), Matrix(Q)), adapt(unwrap_type(A), R) | ||
end | ||
|
||
function NDTensors.eigen(leaf_parenttype::Type{<:MtlArray}, A::AbstractMatrix) | ||
D, U = NDTensors.eigen(NDTensors.cpu(A)) | ||
return adapt(set_ndims(leaf_parenttype, ndims(D)), D), adapt(leaf_parenttype, U) | ||
function NDTensors.Unwrap.qr_positive(A::Exposed{<:MtlMatrix}) | ||
Q, R = qr_positive(expose(NDTensors.cpu(A))) | ||
return adapt(unwrap_type(A), Matrix(Q)), adapt(unwrap_type(A), R) | ||
end | ||
|
||
function NDTensors.svd(leaf_parenttype::Type{<:MtlArray}, A::AbstractMatrix) | ||
U, S, V = NDTensors.svd(NDTensors.cpu(A)) | ||
return adapt(leaf_parenttype, U), | ||
adapt(set_ndims(leaf_parenttype, ndims(S)), S), | ||
adapt(leaf_parenttype, V) | ||
function NDTensors.Unwrap.ql(A::Exposed{<:MtlMatrix}) | ||
Q, L = ql(expose(NDTensors.cpu(A))) | ||
return adapt(unwrap_type(A), Matrix(Q)), adapt(unwrap_type(A), L) | ||
end | ||
function NDTensors.Unwrap.ql_positive(A::Exposed{<:MtlMatrix}) | ||
Q, L = ql_positive(expose(NDTensors.cpu(A))) | ||
return adapt(unwrap_type(A), Matrix(Q)), adapt(unwrap_type(A), L) | ||
end | ||
|
||
function LinearAlgebra.eigen(A::Exposed{<:MtlMatrix}) | ||
D, U = eigen(expose(NDTensors.cpu(A))) | ||
return adapt(set_ndims(unwrap_type(A), ndims(D)), D), adapt(unwrap_type(A), U) | ||
end | ||
|
||
function LinearAlgebra.svd(A::Exposed{<:MtlMatrix}; kwargs...) | ||
U, S, V = svd(expose(NDTensors.cpu(A)); kwargs...) | ||
return adapt(unwrap_type(A), U), | ||
adapt(set_ndims(unwrap_type(A), ndims(S)), S), | ||
adapt(unwrap_type(A), V) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,12 @@ | ||
# This was calling generic matrix multiplication. | ||
# TODO: Raise an issue with `Metal.jl`. | ||
function NDTensors.mul!!( | ||
::Type{<:MtlArray}, | ||
CM::Transpose, | ||
::Type{<:MtlArray}, | ||
AM::AbstractMatrix, | ||
::Type{<:MtlArray}, | ||
BM::AbstractMatrix, | ||
function LinearAlgebra.mul!( | ||
CM::Exposed{<:MtlArray,<:Transpose}, | ||
AM::Exposed{<:MtlArray}, | ||
BM::Exposed{<:MtlArray}, | ||
α, | ||
β, | ||
) | ||
mul!(transpose(CM), transpose(BM), transpose(AM), α, β) | ||
return CM | ||
return unexpose(CM) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,7 @@ | ||
function NDTensors.permutedims!( | ||
::Type{<:MtlArray}, | ||
Adest::Base.ReshapedArray{<:Any,<:Any,<:SubArray}, | ||
::Type{<:MtlArray}, | ||
A, | ||
perm, | ||
function permutedims!( | ||
Edest::Exposed{<:MtlArray,<:Base.ReshapedArray}, Esrc::Exposed{<:MtlArray}, perm | ||
) | ||
Aperm = permutedims(A, perm) | ||
Adest_parent = parent(Adest) | ||
copyto!(Adest_parent, Aperm) | ||
return Adest | ||
Aperm = permutedims(Esrc, perm) | ||
copyto!(expose(parent(Edest)), expose(Aperm)) | ||
return unexpose(Edest) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Unwrap | ||
|
||
A module to unwrap complex array types to assist in the generic programming of array-type based functions. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
Replace all `leaf_parenttype` calls by wrapping the arrays in this `expose` type | ||
|
||
Fix the issue Ryan found in MPS | ||
Make a GPUArrays extension that has generic GPU algorithms |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
module Unwrap | ||
using SimpleTraits | ||
using LinearAlgebra | ||
using Base: ReshapedArray | ||
using Strided.StridedViews | ||
|
||
include("expose.jl") | ||
include("iswrappedarray.jl") | ||
|
||
include("import.jl") | ||
## TODO Create functions which take the `Expose` type and launch functions | ||
## using that type | ||
## Exposed based functions | ||
include("functions/abstractarray.jl") | ||
include("functions/copyto.jl") | ||
include("functions/linearalgebra.jl") | ||
include("functions/mul.jl") | ||
include("functions/permutedims.jl") | ||
|
||
export IsWrappedArray, | ||
is_wrapped_array, parenttype, unwrap_type, expose, Exposed, unexpose, cpu | ||
|
||
## TODO write exposed based functions in the NDTensors Extensions when necessary | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
struct Exposed{Unwrapped,Object} | ||
object::Object | ||
end | ||
|
||
expose(object) = Exposed{unwrap_type(object),typeof(object)}(object) | ||
|
||
unexpose(E::Exposed) = E.object |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
parent(E::Exposed) = parent(unexpose(E)) | ||
|
||
transpose(E::Exposed) = transpose(unexpose(E)) | ||
|
||
cpu(E::Exposed) = cpu(unexpose(E)) | ||
|
||
getindex(E::Exposed) = unexpose(E)[] | ||
|
||
function setindex!(E::Exposed, x::Number) | ||
unexpose(E)[] = x | ||
return unexpose(E) | ||
end | ||
|
||
getindex(E::Exposed, I...) = unexpose(E)[I...] | ||
|
||
function copy(E::Exposed) | ||
return copy(unexpose(E)) | ||
end | ||
|
||
any(f, E::Exposed) = any(f, unexpose(E)) | ||
|
||
print_array(io::IO, E::Exposed) = print_array(io, unexpose(E)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
function copyto!(R::Exposed, T::Exposed) | ||
copyto!(unexpose(R), unexpose(T)) | ||
return unexpose(R) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
function qr(E::Exposed) | ||
return qr(unexpose(E)) | ||
end | ||
## These functions do not exist in `LinearAlgebra` but were defined | ||
## in NDTensors. Because Unwrap is imported before NDTensors, | ||
## one cannot import a these functions from NDTensors so instead | ||
## I define them here and extend them in NDTensors | ||
## I have done the same thing for the function cpu | ||
## Unwrap.qr_positive | ||
function qr_positive(E::Exposed) | ||
return qr_positive(unexpose(E)) | ||
end | ||
|
||
## Unwrap.ql | ||
function ql(E::Exposed) | ||
return ql(unexpose(E)) | ||
end | ||
## Unwrap.ql_positive | ||
function ql_positive(E::Exposed) | ||
return ql_positive(unexpose(E)) | ||
end | ||
|
||
function eigen(E::Exposed) | ||
return eigen(unexpose(E)) | ||
end | ||
|
||
function svd(E::Exposed; kwargs...) | ||
return svd(unexpose(E); kwargs...) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
function mul!(CM::Exposed, AM::Exposed, BM::Exposed, α, β) | ||
mul!(unexpose(CM), unexpose(AM), unexpose(BM), α, β) | ||
return unexpose(CM) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
function permutedims(E::Exposed, perm) | ||
return permutedims(unexpose(E), perm) | ||
end | ||
|
||
function permutedims!(Edest::Exposed, Esrc::Exposed, perm) | ||
permutedims!(unexpose(Edest), unexpose(Esrc), perm) | ||
return unexpose(Edest) | ||
end | ||
|
||
function permutedims!(Edest::Exposed, Esrc::Exposed, perm, f) | ||
unexpose(Edest) .= f.(unexpose(Edest), permutedims(Esrc, perm)) | ||
return unexpose(Edest) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import Base: | ||
permutedims, | ||
permutedims!, | ||
copy, | ||
copyto!, | ||
parent, | ||
print_array, | ||
transpose, | ||
getindex, | ||
setindex!, | ||
any | ||
|
||
import LinearAlgebra: mul!, qr, eigen, svd |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.