From 35247a5d46c0bcccd6d1cfebee838c901bce64c4 Mon Sep 17 00:00:00 2001 From: Antoine Levitt Date: Wed, 5 Jun 2024 08:59:51 +0200 Subject: [PATCH] Tweak bookkeeping in LOBPCG --- src/eigen/lobpcg_hyper_impl.jl | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/src/eigen/lobpcg_hyper_impl.jl b/src/eigen/lobpcg_hyper_impl.jl index 0728c079f2..a0dc8a4c30 100644 --- a/src/eigen/lobpcg_hyper_impl.jl +++ b/src/eigen/lobpcg_hyper_impl.jl @@ -298,21 +298,17 @@ end end -function final_retval(X, AX, BX, resid_history, niter, n_matvec) - λ = @views [real((X[:, n]'*AX[:, n]) / (X[:, n]'BX[:, n])) for n=1:size(X, 2)] - λ_device = oftype(X[:, 1], λ) # Offload to GPU if needed - residuals = AX .- BX .* λ_device' +function final_retval(X, AX, BX, λ, resid_history, niter, n_matvec) if !issorted(λ) p = sortperm(λ) λ = λ[p] - residuals = residuals[:, p] X = X[:, p] AX = AX[:, p] BX = BX[:, p] resid_history = resid_history[p, :] end - (; λ=λ_device, X, AX, BX, - residual_norms=norm.(eachcol(residuals)), + (; λ=λ, X, AX, BX, + residual_norms=resid_history[:, niter+1], residual_history=resid_history[:, 1:niter+1], n_matvec) end @@ -368,7 +364,7 @@ end nlocked = 0 niter = 0 # the first iteration is fake λs = @views [real((X[:, n]'*AX[:, n]) / (X[:, n]'BX[:, n])) for n=1:M] - λs = oftype(X[:, 1], λs) # Offload to GPU if needed + λs = oftype(real(X[:, 1]), λs) # Offload to GPU if needed new_X = X new_AX = AX new_BX = BX @@ -376,6 +372,7 @@ end full_X = X full_AX = AX full_BX = BX + full_λs = λs while true if niter > 0 # first iteration is just to compute the residuals (no X update) @@ -393,7 +390,8 @@ end AY = LazyHcat(AX, AR) BY = LazyHcat(BX, BR) # data shared with (X, R) in non-general case end - cX, λs = rayleigh_ritz(Y, AY, M-nlocked) + cX, λs_RR = rayleigh_ritz(Y, AY, M-nlocked) + λs .= λs_RR # Update X. By contrast to some other implementations, we # wait on updating P because we have to know which vectors @@ -446,7 +444,7 @@ end if nlocked >= n_conv_check # Converged! X .= new_X # Update the part of X which is still active AX .= new_AX - return final_retval(full_X, full_AX, full_BX, resid_history, niter, n_matvec) + return final_retval(full_X, full_AX, full_BX, full_λs, resid_history, niter, n_matvec) end newly_locked = nlocked - prev_nlocked active = newly_locked+1:size(X,2) # newly active vectors @@ -531,9 +529,9 @@ end B_ortho!(R, BR) end - niter < maxiter || break + niter >= maxiter && break niter = niter + 1 end - final_retval(full_X, full_AX, full_BX, resid_history, maxiter, n_matvec) + final_retval(full_X, full_AX, full_BX, full_λs, resid_history, maxiter, n_matvec) end