Skip to content

Commit

Permalink
Refactor SVD code
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Oct 31, 2023
1 parent 34b229b commit 766f945
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 64 deletions.
4 changes: 1 addition & 3 deletions NDTensors/src/arraystorage/arraystorage/tensor/qr.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
function qr(A::ArrayStorageTensor; positive=false)
if positive
error("Not implemented")
end
positive && error("Not implemented")
Q, R = qr(storage(A))
Q = convert(typeof(R), Q)
i, j = inds(A)
Expand Down
141 changes: 83 additions & 58 deletions NDTensors/src/arraystorage/arraystorage/tensor/svd.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,47 @@
# TODO: Rewrite this function to be more modern:
# 1. Output `Spectrum` as a keyword argument that gets overwritten.
# 2. Dispatch on `alg`.
# 3. Make this into two layers, one that handles indices and one that works with `Matrix`.
backup_svd_alg(::Algorithm"divide_and_conquer") = Algorithm"qr_iteration"()
backup_svd_alg(::Algorithm"qr_iteration") = Algorithm"recursive"()

function svd(alg::Algorithm"divide_and_conquer", a::ArrayStorage)
USV = svd_catch_error(a; alg=LinearAlgebra.DivideAndConquer())
if isnothing(USV)
return svd(backup_svd_alg(alg), a)
end
return USV
end

function svd(alg::Algorithm"qr_iteration", a::ArrayStorage)
USV = svd_catch_error(a; alg=LinearAlgebra.QRIteration())
if isnothing(USV)
return svd(backup_svd_alg(alg), a)
end
return USV
end

function svd(alg::Algorithm"recursive", a::ArrayStorage)
return svd_recursive(a)
end

function svd(::Algorithm"QRAlgorithm", a::ArrayStorage)
return error("Not implemented yet")
end

function svd(::Algorithm"JacobiAlgorithm", a::ArrayStorage)
return error("Not implemented yet")
end

function svd(alg::Algorithm, a::ArrayStorage)
return error(
"svd algorithm $alg is not currently supported. Please see the documentation for currently supported algorithms.",
)
end

"""
svd(T::ArrayStorageTensor{<:Number,2}; kwargs...)
tsvd(a::ArrayStorage{<:Number,2}; kwargs...)
svd of an order-2 DenseTensor
"""
function svd(
T::ArrayStorageTensor;
function tsvd(
a::ArrayStorage;
mindim=nothing,
maxdim=nothing,
cutoff=nothing,
Expand All @@ -18,56 +51,22 @@ function svd(
# Only used by BlockSparse svd
min_blockdim=nothing,
)
truncate = !isnothing(maxdim) || !isnothing(cutoff)
mindim = replace_nothing(mindim, default_mindim(T))
maxdim = replace_nothing(maxdim, default_maxdim(T))
cutoff = replace_nothing(cutoff, default_cutoff(T))
use_absolute_cutoff = replace_nothing(use_absolute_cutoff, default_use_absolute_cutoff(T))
use_relative_cutoff = replace_nothing(use_relative_cutoff, default_use_relative_cutoff(T))
alg = replace_nothing(alg, default_svd_alg(T))

# TODO: Dispatch on `Algorithm(alg)`.
if alg == "divide_and_conquer"
MUSV = svd_catch_error(matrix(T); alg=LinearAlgebra.DivideAndConquer())
if isnothing(MUSV)
# If "divide_and_conquer" fails, try "qr_iteration"
alg = "qr_iteration"
MUSV = svd_catch_error(matrix(T); alg=LinearAlgebra.QRIteration())
if isnothing(MUSV)
# If "qr_iteration" fails, try "recursive"
alg = "recursive"
MUSV = svd_recursive(matrix(T))
end
end
elseif alg == "qr_iteration"
MUSV = svd_catch_error(matrix(T); alg=LinearAlgebra.QRIteration())
if isnothing(MUSV)
# If "qr_iteration" fails, try "recursive"
alg = "recursive"
MUSV = svd_recursive(matrix(T))
end
elseif alg == "recursive"
MUSV = svd_recursive(matrix(T))
elseif alg == "QRAlgorithm" || alg == "JacobiAlgorithm"
MUSV = svd_catch_error(matrix(T); alg=alg)
else
error(
"svd algorithm $alg is not currently supported. Please see the documentation for currently supported algorithms.",
)
end
if isnothing(MUSV)
if any(isnan, T)
alg = replace_nothing(alg, default_svd_alg(a))
USV = svd(Algorithm(alg), a)
if isnothing(USV)
if any(isnan, a)
println("SVD failed, the matrix you were trying to SVD contains NaNs.")
else
println(lapack_svd_error_message(alg))
end
return nothing
end
MU, MS, MV = MUSV
conj!(MV)

P = MS .^ 2
if truncate
U, S, V = USV
conj!(V)

P = S .^ 2
if !isnothing(maxdim) || !isnothing(cutoff)
P, truncerr, _ = truncate!!(
P; mindim, maxdim, cutoff, use_absolute_cutoff, use_relative_cutoff
)
Expand All @@ -76,25 +75,51 @@ function svd(
end
spec = Spectrum(P, truncerr)
dS = length(P)
if dS < length(MS)
MU = MU[:, 1:dS]
if dS < length(S)
U = U[:, 1:dS]
# Fails on some GPU backends like Metal.
# resize!(MS, dS)
MS = MS[1:dS]
MV = MV[:, 1:dS]
S = S[1:dS]
V = V[:, 1:dS]
end
return U, DiagonalMatrix(S), V, spec
end

# TODO: Rewrite this function to be more modern:
# 1. Output `Spectrum` as a keyword argument that gets overwritten.
# 2. Dispatch on `alg`.
# 3. Make this into two layers, one that handles indices and one that works with `Matrix`.
"""
svd(T::ArrayStorageTensor{<:Number,2}; kwargs...)
svd of an order-2 DenseTensor
"""
function svd(
T::ArrayStorageTensor;
mindim=nothing,
maxdim=nothing,
cutoff=nothing,
alg=nothing,
use_absolute_cutoff=nothing,
use_relative_cutoff=nothing,
# Only used by BlockSparse svd
min_blockdim=nothing,
)
U, S, V, spec = tsvd(
storage(T); mindim, maxdim, cutoff, alg, use_absolute_cutoff, use_relative_cutoff
)
# Make the new indices to go onto U and V
# TODO: Put in a separate function, such as
# `rewrap_inds` or something like that.
dS = length(S[DiagIndices()])
indstype = typeof(inds(T))
u = eltype(indstype)(dS)
v = eltype(indstype)(dS)
Uinds = indstype((ind(T, 1), u))
Sinds = indstype((u, v))
Vinds = indstype((ind(T, 2), v))
U = tensor(MU, Uinds)
S = tensor(DiagonalMatrix(MS), Sinds)
V = tensor(MV, Vinds)
return U, S, V, spec
TU = tensor(U, Uinds)
TS = tensor(S, Sinds)
TV = tensor(V, Vinds)
return TU, TS, TV, spec
end
6 changes: 3 additions & 3 deletions src/mps/abstractmps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1592,7 +1592,7 @@ bond indices is performed. Afterward, tensors
Either modify in-place with `orthogonalize!` or
out-of-place with `orthogonalize`.
"""
function orthogonalize!(M::AbstractMPS, j::Int; kwargs...)
function orthogonalize!(M::AbstractMPS, j::Int)
@debug_check begin
if !(1 <= j <= length(M))
error("Input j=$j to `orthogonalize!` out of range (valid range = 1:$(length(M)))")
Expand All @@ -1608,7 +1608,7 @@ function orthogonalize!(M::AbstractMPS, j::Int; kwargs...)
else
ltags = TagSet("Link,l=$b")
end
L, R = factorize(M[b], linds; tags=ltags, kwargs...)
L, R = factorize(M[b], linds; tags=ltags)
M[b] = L
M[b + 1] *= R
setleftlim!(M, b)
Expand All @@ -1629,7 +1629,7 @@ function orthogonalize!(M::AbstractMPS, j::Int; kwargs...)
else
ltags = TagSet("Link,l=$b")
end
L, R = factorize(M[b + 1], rinds; tags=ltags, kwargs...)
L, R = factorize(M[b + 1], rinds; tags=ltags)
M[b + 1] = L
M[b] *= R

Expand Down

0 comments on commit 766f945

Please sign in to comment.