Skip to content

Commit

Permalink
WIP: use nl_prob
Browse files Browse the repository at this point in the history
  • Loading branch information
oscardssmith committed Dec 18, 2024
1 parent 4b29bab commit fd08798
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 35 deletions.
41 changes: 29 additions & 12 deletions lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,16 @@ function initialize!(nlsolver::NLSolver{<:NonlinearSolveAlg, false},
integrator.stats.nnonliniter += cache.cache.stats.nsteps
integrator.stats.njacs += cache.cache.stats.njacs
end
if f isa DAEFunction
nlp_params = (tmp, α, tstep, invγdt, p, dt, uprev, f)
new_prob = __has_nlprob_data(f)
update_nlprob!(cache.prob, u0=z, p=(;dt, γ, inner_tmp, outer_tmp, t, p))
else
nlp_params = (tmp, γ, α, tstep, invγdt, method, p, dt, f)
if f isa DAEFunction
nlp_params = (tmp, α, tstep, invγdt, p, dt, uprev, f)
else
nlp_params = (tmp, γ, α, tstep, invγdt, method, p, dt, f)
end
remake(cache.prob, p = nlp_params, u0 = z)
end
new_prob = remake(cache.prob, p = nlp_params, u0 = z)
cache.cache = init(new_prob, alg.alg)
nothing
end
Expand All @@ -63,26 +67,35 @@ function initialize!(nlsolver::NLSolver{<:NonlinearSolveAlg, true},
integrator.stats.nnonliniter += cache.cache.stats.nsteps
integrator.stats.njacs += cache.cache.stats.njacs
end
if f isa DAEFunction
nlp_params = (tmp, ztmp, ustep, γ, α, tstep, k, invγdt, p, dt, f)

new_prob = __has_nlprob_data(f)
update_nlprob!(cache.prob, u0=z, p=(;dt, γ, inner_tmp, outer_tmp, t, p))
else
nlp_params = (tmp, ustep, γ, α, tstep, k, invγdt, method, p, dt, f)
if f isa DAEFunction
nlp_params = (tmp, ztmp, ustep, γ, α, tstep, k, invγdt, p, dt, f)
else
nlp_params = (tmp, ustep, γ, α, tstep, k, invγdt, method, p, dt, f)
end
remake(cache.prob, p = nlp_params, u0 = z)
end
new_prob = remake(cache.prob, p = nlp_params, u0 = z)
cache.cache = init(new_prob, alg.alg)
nothing
end

## compute_step!

@muladd function compute_step!(nlsolver::NLSolver{<:NonlinearSolveAlg, false}, integrator)
@unpack uprev, t, p, dt, opts = integrator
@unpack uprev, t, p, dt, opts, f = integrator
@unpack z, tmp, ztmp, γ, α, cache, method = nlsolver
@unpack tstep, invγdt = cache

nlcache = nlsolver.cache.cache
step!(nlcache)
nlsolver.ztmp = nlcache.u
if __has_nlprob_data(f)
ztmp = nlprobmap(nlcache)
else
ztmp = nlcache.u
end

ustep = compute_ustep(tmp, γ, z, method)
atmp = calculate_residuals(nlcache.fu, uprev, ustep, opts.abstol, opts.reltol,
Expand All @@ -98,13 +111,17 @@ end
end

@muladd function compute_step!(nlsolver::NLSolver{<:NonlinearSolveAlg, true}, integrator)
@unpack uprev, t, p, dt, opts = integrator
@unpack uprev, t, p, dt, opts, f = integrator
@unpack z, tmp, ztmp, γ, α, cache, method = nlsolver
@unpack tstep, invγdt, atmp, ustep = cache

nlcache = nlsolver.cache.cache
step!(nlcache)
@.. broadcast=false ztmp=nlcache.u
if __has_nlprob_data(f)
@.. ztmp = nlprobmap(nlcache)
else
@.. ztmp=nlcache.u
end

ustep = compute_ustep!(ustep, tmp, γ, z, method)
calculate_residuals!(atmp, nlcache.fu, uprev, ustep, opts.abstol, opts.reltol,
Expand Down
54 changes: 31 additions & 23 deletions lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,21 +206,25 @@ function build_nlsolver(
if nlalg isa NonlinearSolveAlg
α = tTypeNoUnits(α)
dt = tTypeNoUnits(dt)
if isdae
nlf = (ztmp, z, p) -> begin
tmp, ustep, γ, α, tstep, k, invγdt, _p, dt, f = p
_compute_rhs!(tmp, ztmp, ustep, γ, α, tstep, k, invγdt, _p, dt, f, z)[1]
end
nlp_params = (tmp, ustep, γ, α, tstep, k, invγdt, p, dt, f)
prob = if __has_nlprob_data(f)
f.nlprob_data.nlprob
else
nlf = (ztmp, z, p) -> begin
tmp, ustep, γ, α, tstep, k, invγdt, method, _p, dt, f = p
_compute_rhs!(
tmp, ztmp, ustep, γ, α, tstep, k, invγdt, method, _p, dt, f, z)[1]
if isdae
nlf = (ztmp, z, p) -> begin
tmp, ustep, γ, α, tstep, k, invγdt, _p, dt, f = p
_compute_rhs!(tmp, ztmp, ustep, γ, α, tstep, k, invγdt, _p, dt, f, z)[1]
end
nlp_params = (tmp, ustep, γ, α, tstep, k, invγdt, p, dt, f)
else
nlf = (ztmp, z, p) -> begin
tmp, ustep, γ, α, tstep, k, invγdt, method, _p, dt, f = p
_compute_rhs!(
tmp, ztmp, ustep, γ, α, tstep, k, invγdt, method, _p, dt, f, z)[1]
end
nlp_params = (tmp, ustep, γ, α, tstep, k, invγdt, DIRK, p, dt, f)
end
nlp_params = (tmp, ustep, γ, α, tstep, k, invγdt, DIRK, p, dt, f)
NonlinearProblem(NonlinearFunction(nlf), ztmp, nlp_params)
end
prob = NonlinearProblem(NonlinearFunction(nlf), ztmp, nlp_params)
cache = init(prob, nlalg.alg)
nlcache = NonlinearSolveCache(ustep, tstep, k, atmp, invγdt, prob, cache)
else
Expand Down Expand Up @@ -291,20 +295,24 @@ function build_nlsolver(
if nlalg isa NonlinearSolveAlg
α = tTypeNoUnits(α)
dt = tTypeNoUnits(dt)
if isdae
nlf = (z, p) -> begin
tmp, α, tstep, invγdt, _p, dt, uprev, f = p
_compute_rhs(tmp, α, tstep, invγdt, p, dt, uprev, f, z)[1]
end
nlp_params = (tmp, α, tstep, invγdt, _p, dt, uprev, f)
prob = if __has_nlprob_data(f)
f.nlprob_data.nlprob
else
nlf = (z, p) -> begin
tmp, γ, α, tstep, invγdt, method, _p, dt, f = p
_compute_rhs(tmp, γ, α, tstep, invγdt, method, _p, dt, f, z)[1]
if isdae
nlf = (z, p) -> begin
tmp, α, tstep, invγdt, _p, dt, uprev, f = p
_compute_rhs(tmp, α, tstep, invγdt, p, dt, uprev, f, z)[1]
end
nlp_params = (tmp, α, tstep, invγdt, _p, dt, uprev, f)
else
nlf = (z, p) -> begin
tmp, γ, α, tstep, invγdt, method, _p, dt, f = p
_compute_rhs(tmp, γ, α, tstep, invγdt, method, _p, dt, f, z)[1]
end
nlp_params = (tmp, γ, α, tstep, invγdt, DIRK, p, dt, f)
end
nlp_params = (tmp, γ, α, tstep, invγdt, DIRK, p, dt, f)
NonlinearProblem(NonlinearFunction(nlf), copy(ztmp), nlp_params)
end
prob = NonlinearProblem(NonlinearFunction(nlf), copy(ztmp), nlp_params)
cache = init(prob, nlalg.alg)
nlcache = NonlinearSolveCache(
nothing, tstep, nothing, nothing, invγdt, prob, cache)
Expand Down

0 comments on commit fd08798

Please sign in to comment.