diff --git a/lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl b/lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl index 994ecea696..5fde821eb6 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl +++ b/lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl @@ -38,12 +38,16 @@ function initialize!(nlsolver::NLSolver{<:NonlinearSolveAlg, false}, integrator.stats.nnonliniter += cache.cache.stats.nsteps integrator.stats.njacs += cache.cache.stats.njacs end - if f isa DAEFunction - nlp_params = (tmp, α, tstep, invγdt, p, dt, uprev, f) + new_prob = __has_nlprob_data(f) + update_nlprob!(cache.prob, u0=z, p=(;dt, γ, inner_tmp, outer_tmp, t, p)) else - nlp_params = (tmp, γ, α, tstep, invγdt, method, p, dt, f) + if f isa DAEFunction + nlp_params = (tmp, α, tstep, invγdt, p, dt, uprev, f) + else + nlp_params = (tmp, γ, α, tstep, invγdt, method, p, dt, f) + end + remake(cache.prob, p = nlp_params, u0 = z) end - new_prob = remake(cache.prob, p = nlp_params, u0 = z) cache.cache = init(new_prob, alg.alg) nothing end @@ -63,12 +67,17 @@ function initialize!(nlsolver::NLSolver{<:NonlinearSolveAlg, true}, integrator.stats.nnonliniter += cache.cache.stats.nsteps integrator.stats.njacs += cache.cache.stats.njacs end - if f isa DAEFunction - nlp_params = (tmp, ztmp, ustep, γ, α, tstep, k, invγdt, p, dt, f) + + new_prob = __has_nlprob_data(f) + update_nlprob!(cache.prob, u0=z, p=(;dt, γ, inner_tmp, outer_tmp, t, p)) else - nlp_params = (tmp, ustep, γ, α, tstep, k, invγdt, method, p, dt, f) + if f isa DAEFunction + nlp_params = (tmp, ztmp, ustep, γ, α, tstep, k, invγdt, p, dt, f) + else + nlp_params = (tmp, ustep, γ, α, tstep, k, invγdt, method, p, dt, f) + end + remake(cache.prob, p = nlp_params, u0 = z) end - new_prob = remake(cache.prob, p = nlp_params, u0 = z) cache.cache = init(new_prob, alg.alg) nothing end @@ -76,13 +85,17 @@ end ## compute_step! @muladd function compute_step!(nlsolver::NLSolver{<:NonlinearSolveAlg, false}, integrator) - @unpack uprev, t, p, dt, opts = integrator + @unpack uprev, t, p, dt, opts, f = integrator @unpack z, tmp, ztmp, γ, α, cache, method = nlsolver @unpack tstep, invγdt = cache nlcache = nlsolver.cache.cache step!(nlcache) - nlsolver.ztmp = nlcache.u + if __has_nlprob_data(f) + ztmp = nlprobmap(nlcache) + else + ztmp = nlcache.u + end ustep = compute_ustep(tmp, γ, z, method) atmp = calculate_residuals(nlcache.fu, uprev, ustep, opts.abstol, opts.reltol, @@ -98,13 +111,17 @@ end end @muladd function compute_step!(nlsolver::NLSolver{<:NonlinearSolveAlg, true}, integrator) - @unpack uprev, t, p, dt, opts = integrator + @unpack uprev, t, p, dt, opts, f = integrator @unpack z, tmp, ztmp, γ, α, cache, method = nlsolver @unpack tstep, invγdt, atmp, ustep = cache nlcache = nlsolver.cache.cache step!(nlcache) - @.. broadcast=false ztmp=nlcache.u + if __has_nlprob_data(f) + @.. ztmp = nlprobmap(nlcache) + else + @.. ztmp=nlcache.u + end ustep = compute_ustep!(ustep, tmp, γ, z, method) calculate_residuals!(atmp, nlcache.fu, uprev, ustep, opts.abstol, opts.reltol, diff --git a/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl b/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl index 3bed6a453a..958bc62d68 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl +++ b/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl @@ -206,21 +206,25 @@ function build_nlsolver( if nlalg isa NonlinearSolveAlg α = tTypeNoUnits(α) dt = tTypeNoUnits(dt) - if isdae - nlf = (ztmp, z, p) -> begin - tmp, ustep, γ, α, tstep, k, invγdt, _p, dt, f = p - _compute_rhs!(tmp, ztmp, ustep, γ, α, tstep, k, invγdt, _p, dt, f, z)[1] - end - nlp_params = (tmp, ustep, γ, α, tstep, k, invγdt, p, dt, f) + prob = if __has_nlprob_data(f) + f.nlprob_data.nlprob else - nlf = (ztmp, z, p) -> begin - tmp, ustep, γ, α, tstep, k, invγdt, method, _p, dt, f = p - _compute_rhs!( - tmp, ztmp, ustep, γ, α, tstep, k, invγdt, method, _p, dt, f, z)[1] + if isdae + nlf = (ztmp, z, p) -> begin + tmp, ustep, γ, α, tstep, k, invγdt, _p, dt, f = p + _compute_rhs!(tmp, ztmp, ustep, γ, α, tstep, k, invγdt, _p, dt, f, z)[1] + end + nlp_params = (tmp, ustep, γ, α, tstep, k, invγdt, p, dt, f) + else + nlf = (ztmp, z, p) -> begin + tmp, ustep, γ, α, tstep, k, invγdt, method, _p, dt, f = p + _compute_rhs!( + tmp, ztmp, ustep, γ, α, tstep, k, invγdt, method, _p, dt, f, z)[1] + end + nlp_params = (tmp, ustep, γ, α, tstep, k, invγdt, DIRK, p, dt, f) end - nlp_params = (tmp, ustep, γ, α, tstep, k, invγdt, DIRK, p, dt, f) + NonlinearProblem(NonlinearFunction(nlf), ztmp, nlp_params) end - prob = NonlinearProblem(NonlinearFunction(nlf), ztmp, nlp_params) cache = init(prob, nlalg.alg) nlcache = NonlinearSolveCache(ustep, tstep, k, atmp, invγdt, prob, cache) else @@ -291,20 +295,24 @@ function build_nlsolver( if nlalg isa NonlinearSolveAlg α = tTypeNoUnits(α) dt = tTypeNoUnits(dt) - if isdae - nlf = (z, p) -> begin - tmp, α, tstep, invγdt, _p, dt, uprev, f = p - _compute_rhs(tmp, α, tstep, invγdt, p, dt, uprev, f, z)[1] - end - nlp_params = (tmp, α, tstep, invγdt, _p, dt, uprev, f) + prob = if __has_nlprob_data(f) + f.nlprob_data.nlprob else - nlf = (z, p) -> begin - tmp, γ, α, tstep, invγdt, method, _p, dt, f = p - _compute_rhs(tmp, γ, α, tstep, invγdt, method, _p, dt, f, z)[1] + if isdae + nlf = (z, p) -> begin + tmp, α, tstep, invγdt, _p, dt, uprev, f = p + _compute_rhs(tmp, α, tstep, invγdt, p, dt, uprev, f, z)[1] + end + nlp_params = (tmp, α, tstep, invγdt, _p, dt, uprev, f) + else + nlf = (z, p) -> begin + tmp, γ, α, tstep, invγdt, method, _p, dt, f = p + _compute_rhs(tmp, γ, α, tstep, invγdt, method, _p, dt, f, z)[1] + end + nlp_params = (tmp, γ, α, tstep, invγdt, DIRK, p, dt, f) end - nlp_params = (tmp, γ, α, tstep, invγdt, DIRK, p, dt, f) + NonlinearProblem(NonlinearFunction(nlf), copy(ztmp), nlp_params) end - prob = NonlinearProblem(NonlinearFunction(nlf), copy(ztmp), nlp_params) cache = init(prob, nlalg.alg) nlcache = NonlinearSolveCache( nothing, tstep, nothing, nothing, invγdt, prob, cache)