From 766f945f94e08c8b80d8eb227b7150a0638d8bf1 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 31 Oct 2023 17:23:21 -0400 Subject: [PATCH] Refactor SVD code --- .../arraystorage/arraystorage/tensor/qr.jl | 4 +- .../arraystorage/arraystorage/tensor/svd.jl | 141 +++++++++++------- src/mps/abstractmps.jl | 6 +- 3 files changed, 87 insertions(+), 64 deletions(-) diff --git a/NDTensors/src/arraystorage/arraystorage/tensor/qr.jl b/NDTensors/src/arraystorage/arraystorage/tensor/qr.jl index c99d026f0f..7ee216341f 100644 --- a/NDTensors/src/arraystorage/arraystorage/tensor/qr.jl +++ b/NDTensors/src/arraystorage/arraystorage/tensor/qr.jl @@ -1,7 +1,5 @@ function qr(A::ArrayStorageTensor; positive=false) - if positive - error("Not implemented") - end + positive && error("Not implemented") Q, R = qr(storage(A)) Q = convert(typeof(R), Q) i, j = inds(A) diff --git a/NDTensors/src/arraystorage/arraystorage/tensor/svd.jl b/NDTensors/src/arraystorage/arraystorage/tensor/svd.jl index 46e9ea8e42..1d93c2f1d1 100644 --- a/NDTensors/src/arraystorage/arraystorage/tensor/svd.jl +++ b/NDTensors/src/arraystorage/arraystorage/tensor/svd.jl @@ -1,14 +1,47 @@ -# TODO: Rewrite this function to be more modern: -# 1. Output `Spectrum` as a keyword argument that gets overwritten. -# 2. Dispatch on `alg`. -# 3. Make this into two layers, one that handles indices and one that works with `Matrix`. +backup_svd_alg(::Algorithm"divide_and_conquer") = Algorithm"qr_iteration"() +backup_svd_alg(::Algorithm"qr_iteration") = Algorithm"recursive"() + +function svd(alg::Algorithm"divide_and_conquer", a::ArrayStorage) + USV = svd_catch_error(a; alg=LinearAlgebra.DivideAndConquer()) + if isnothing(USV) + return svd(backup_svd_alg(alg), a) + end + return USV +end + +function svd(alg::Algorithm"qr_iteration", a::ArrayStorage) + USV = svd_catch_error(a; alg=LinearAlgebra.QRIteration()) + if isnothing(USV) + return svd(backup_svd_alg(alg), a) + end + return USV +end + +function svd(alg::Algorithm"recursive", a::ArrayStorage) + return svd_recursive(a) +end + +function svd(::Algorithm"QRAlgorithm", a::ArrayStorage) + return error("Not implemented yet") +end + +function svd(::Algorithm"JacobiAlgorithm", a::ArrayStorage) + return error("Not implemented yet") +end + +function svd(alg::Algorithm, a::ArrayStorage) + return error( + "svd algorithm $alg is not currently supported. Please see the documentation for currently supported algorithms.", + ) +end + """ - svd(T::ArrayStorageTensor{<:Number,2}; kwargs...) + tsvd(a::ArrayStorage{<:Number,2}; kwargs...) svd of an order-2 DenseTensor """ -function svd( - T::ArrayStorageTensor; +function tsvd( + a::ArrayStorage; mindim=nothing, maxdim=nothing, cutoff=nothing, @@ -18,56 +51,22 @@ function svd( # Only used by BlockSparse svd min_blockdim=nothing, ) - truncate = !isnothing(maxdim) || !isnothing(cutoff) - mindim = replace_nothing(mindim, default_mindim(T)) - maxdim = replace_nothing(maxdim, default_maxdim(T)) - cutoff = replace_nothing(cutoff, default_cutoff(T)) - use_absolute_cutoff = replace_nothing(use_absolute_cutoff, default_use_absolute_cutoff(T)) - use_relative_cutoff = replace_nothing(use_relative_cutoff, default_use_relative_cutoff(T)) - alg = replace_nothing(alg, default_svd_alg(T)) - - # TODO: Dispatch on `Algorithm(alg)`. - if alg == "divide_and_conquer" - MUSV = svd_catch_error(matrix(T); alg=LinearAlgebra.DivideAndConquer()) - if isnothing(MUSV) - # If "divide_and_conquer" fails, try "qr_iteration" - alg = "qr_iteration" - MUSV = svd_catch_error(matrix(T); alg=LinearAlgebra.QRIteration()) - if isnothing(MUSV) - # If "qr_iteration" fails, try "recursive" - alg = "recursive" - MUSV = svd_recursive(matrix(T)) - end - end - elseif alg == "qr_iteration" - MUSV = svd_catch_error(matrix(T); alg=LinearAlgebra.QRIteration()) - if isnothing(MUSV) - # If "qr_iteration" fails, try "recursive" - alg = "recursive" - MUSV = svd_recursive(matrix(T)) - end - elseif alg == "recursive" - MUSV = svd_recursive(matrix(T)) - elseif alg == "QRAlgorithm" || alg == "JacobiAlgorithm" - MUSV = svd_catch_error(matrix(T); alg=alg) - else - error( - "svd algorithm $alg is not currently supported. Please see the documentation for currently supported algorithms.", - ) - end - if isnothing(MUSV) - if any(isnan, T) + alg = replace_nothing(alg, default_svd_alg(a)) + USV = svd(Algorithm(alg), a) + if isnothing(USV) + if any(isnan, a) println("SVD failed, the matrix you were trying to SVD contains NaNs.") else println(lapack_svd_error_message(alg)) end return nothing end - MU, MS, MV = MUSV - conj!(MV) - P = MS .^ 2 - if truncate + U, S, V = USV + conj!(V) + + P = S .^ 2 + if !isnothing(maxdim) || !isnothing(cutoff) P, truncerr, _ = truncate!!( P; mindim, maxdim, cutoff, use_absolute_cutoff, use_relative_cutoff ) @@ -76,25 +75,51 @@ function svd( end spec = Spectrum(P, truncerr) dS = length(P) - if dS < length(MS) - MU = MU[:, 1:dS] + if dS < length(S) + U = U[:, 1:dS] # Fails on some GPU backends like Metal. # resize!(MS, dS) - MS = MS[1:dS] - MV = MV[:, 1:dS] + S = S[1:dS] + V = V[:, 1:dS] end + return U, DiagonalMatrix(S), V, spec +end +# TODO: Rewrite this function to be more modern: +# 1. Output `Spectrum` as a keyword argument that gets overwritten. +# 2. Dispatch on `alg`. +# 3. Make this into two layers, one that handles indices and one that works with `Matrix`. +""" + svd(T::ArrayStorageTensor{<:Number,2}; kwargs...) + +svd of an order-2 DenseTensor +""" +function svd( + T::ArrayStorageTensor; + mindim=nothing, + maxdim=nothing, + cutoff=nothing, + alg=nothing, + use_absolute_cutoff=nothing, + use_relative_cutoff=nothing, + # Only used by BlockSparse svd + min_blockdim=nothing, +) + U, S, V, spec = tsvd( + storage(T); mindim, maxdim, cutoff, alg, use_absolute_cutoff, use_relative_cutoff + ) # Make the new indices to go onto U and V # TODO: Put in a separate function, such as # `rewrap_inds` or something like that. + dS = length(S[DiagIndices()]) indstype = typeof(inds(T)) u = eltype(indstype)(dS) v = eltype(indstype)(dS) Uinds = indstype((ind(T, 1), u)) Sinds = indstype((u, v)) Vinds = indstype((ind(T, 2), v)) - U = tensor(MU, Uinds) - S = tensor(DiagonalMatrix(MS), Sinds) - V = tensor(MV, Vinds) - return U, S, V, spec + TU = tensor(U, Uinds) + TS = tensor(S, Sinds) + TV = tensor(V, Vinds) + return TU, TS, TV, spec end diff --git a/src/mps/abstractmps.jl b/src/mps/abstractmps.jl index a6309e7d2f..57677138a0 100644 --- a/src/mps/abstractmps.jl +++ b/src/mps/abstractmps.jl @@ -1592,7 +1592,7 @@ bond indices is performed. Afterward, tensors Either modify in-place with `orthogonalize!` or out-of-place with `orthogonalize`. """ -function orthogonalize!(M::AbstractMPS, j::Int; kwargs...) +function orthogonalize!(M::AbstractMPS, j::Int) @debug_check begin if !(1 <= j <= length(M)) error("Input j=$j to `orthogonalize!` out of range (valid range = 1:$(length(M)))") @@ -1608,7 +1608,7 @@ function orthogonalize!(M::AbstractMPS, j::Int; kwargs...) else ltags = TagSet("Link,l=$b") end - L, R = factorize(M[b], linds; tags=ltags, kwargs...) + L, R = factorize(M[b], linds; tags=ltags) M[b] = L M[b + 1] *= R setleftlim!(M, b) @@ -1629,7 +1629,7 @@ function orthogonalize!(M::AbstractMPS, j::Int; kwargs...) else ltags = TagSet("Link,l=$b") end - L, R = factorize(M[b + 1], rinds; tags=ltags, kwargs...) + L, R = factorize(M[b + 1], rinds; tags=ltags) M[b + 1] = L M[b] *= R