Skip to content

Commit

Permalink
[NDTensorsCUDAExt] Fix QR-based SVD for some rectangular matrices (#1229
Browse files Browse the repository at this point in the history
)
  • Loading branch information
kmp5VT authored Nov 2, 2023
1 parent 1a31294 commit 5a22695
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 8 deletions.
1 change: 1 addition & 0 deletions NDTensors/ext/NDTensorsCUDAExt/NDTensorsCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ else
end

include("imports.jl")
include("default_kwargs.jl")
include("set_types.jl")
include("iscu.jl")
include("adapt.jl")
Expand Down
1 change: 1 addition & 0 deletions NDTensors/ext/NDTensorsCUDAExt/default_kwargs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
NDTensors.default_svd_alg(::Type{<:CuArray}, a) = "qr_algorithm"
38 changes: 34 additions & 4 deletions NDTensors/ext/NDTensorsCUDAExt/linearalgebra.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,43 @@
function NDTensors.svd_catch_error(A::CuMatrix; alg="JacobiAlgorithm")
if alg == "JacobiAlgorithm"
function NDTensors.svd_catch_error(A::CuMatrix; alg::String="jacobi_algorithm")
if alg == "jacobi_algorithm"
alg = CUDA.CUSOLVER.JacobiAlgorithm()
else
elseif alg == "qr_algorithm"
alg = CUDA.CUSOLVER.QRAlgorithm()
else
error(
"svd algorithm $alg is not currently supported. Please see the documentation for currently supported algorithms.",
)
end
return NDTensors.svd_catch_error(A, alg)
end

function NDTensors.svd_catch_error(A::CuMatrix, ::CUDA.CUSOLVER.JacobiAlgorithm)
USV = try
svd(expose(A); alg=alg)
svd(A; alg=CUDA.CUSOLVER.JacobiAlgorithm())
catch
return nothing
end
return USV
end

function NDTensors.svd_catch_error(A::CuMatrix, ::CUDA.CUSOLVER.QRAlgorithm)
s = size(A)
if s[1] < s[2]
At = copy(Adjoint(A))

USV = try
svd(At; alg=CUDA.CUSOLVER.QRAlgorithm())
catch
return nothing
end
MV, MS, MU = USV
USV = (MU, MS, MV)
else
USV = try
svd(A; alg=CUDA.CUSOLVER.QRAlgorithm())
catch
return nothing
end
end
return USV
end
3 changes: 2 additions & 1 deletion NDTensors/src/default_kwargs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ replace_nothing(value, replacement) = value
default_maxdim(a) = minimum(size(a))
default_mindim(a) = true
default_cutoff(a) = zero(eltype(a))
default_svd_alg(a) = "divide_and_conquer"
default_svd_alg(a) = default_svd_alg(unwrap_type(a), a)
default_svd_alg(::Type{<:AbstractArray}, a) = "divide_and_conquer"
default_use_absolute_cutoff(a) = false
default_use_relative_cutoff(a) = true
6 changes: 3 additions & 3 deletions NDTensors/src/linearalgebra/linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ function lapack_svd_error_message(alg)
" ``A^{\\dagger} A`` is used to compute `U` and then a `qr` of\n" *
" ``A^{\\dagger} U`` is used to compute `V`. This is performed\n" *
" recursively to compute small singular values.\n" *
" - `\"QRAlgorithm\"` is a CUDA.jl implemented SVD algorithm using QR.\n" *
" - `\"JacobiAlgorithm\"` is a CUDA.jl implemented SVD algorithm.\n\n" *
" - `\"qr_algorithm\"` is a CUDA.jl implemented SVD algorithm using QR.\n" *
" - `\"jacobi_algorithm\"` is a CUDA.jl implemented SVD algorithm.\n\n" *
"Returning `nothing`. For an output `F = svd(A, ...)` you can check if\n" *
"`isnothing(F)` in your code and try a different algorithm.\n\n" *
"To suppress this message in the future, you can wrap the `svd` call in the\n" *
Expand Down Expand Up @@ -127,7 +127,7 @@ function svd(
end
elseif alg == "recursive"
MUSV = svd_recursive(matrix(T))
elseif alg == "QRAlgorithm" || alg == "JacobiAlgorithm"
elseif alg == "qr_algorithm" || alg == "jacobi_algorithm"
MUSV = svd_catch_error(matrix(T); alg)
else
error(
Expand Down

0 comments on commit 5a22695

Please sign in to comment.