From 651899b37143725a001999365a57196bd2bc0fed Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 28 Nov 2024 16:40:36 +0530 Subject: [PATCH] feat: add lazy initialization to `remake` --- src/initialization.jl | 12 ++++++++++++ src/remake.jl | 17 ++++++++++++++++- test/downstream/modelingtoolkit_remake.jl | 9 +++++++++ 3 files changed, 37 insertions(+), 1 deletion(-) diff --git a/src/initialization.jl b/src/initialization.jl index 1179e606c..29a95ed7d 100644 --- a/src/initialization.jl +++ b/src/initialization.jl @@ -225,3 +225,15 @@ function get_initial_values(prob, valp, f, alg::OverrideInit, return u0, p, success end + +function is_trivial_initialization(initdata::OverrideInitData) + state_values(initdata.initializeprob) === nothing +end + +function is_trivial_initialization(f::AbstractSciMLFunction) + has_initialization_data(f) && is_trivial_initialization(f.initialization_data) +end + +function is_trivial_initialization(prob::AbstractSciMLProblem) + is_trivial_initialization(prob.f) +end diff --git a/src/remake.jl b/src/remake.jl index 1a3d1c34b..5979cda2e 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -114,6 +114,7 @@ function remake(prob::ODEProblem; f = missing, interpret_symbolicmap = true, build_initializeprob = true, use_defaults = false, + lazy_initialization = nothing, _kwargs...) if tspan === missing tspan = prob.tspan @@ -123,6 +124,8 @@ function remake(prob::ODEProblem; f = missing, iip = isinplace(prob) + initialization_data = prob.f.initialization_data + if f === missing if build_initializeprob initialization_data = remake_initialization_data_compat_wrapper( @@ -170,13 +173,25 @@ function remake(prob::ODEProblem; f = missing, _f = ODEFunction{isinplace(prob), specialization(prob.f)}(f) end - if kwargs === missing + prob = if kwargs === missing ODEProblem{isinplace(prob)}( _f, newu0, tspan, newp, prob.problem_type; prob.kwargs..., _kwargs...) else ODEProblem{isinplace(prob)}(_f, newu0, tspan, newp, prob.problem_type; kwargs...) end + + if lazy_initialization === nothing + lazy_initialization = !is_trivial_initialization(initialization_data) + end + if !lazy_initialization + u0, p, _ = get_initial_values( + prob, prob, prob.f, OverrideInit(), Val(isinplace(prob))) + @reset prob.u0 = u0 + @reset prob.p = p + end + + return prob end """ diff --git a/test/downstream/modelingtoolkit_remake.jl b/test/downstream/modelingtoolkit_remake.jl index 0395bb395..81a7c4868 100644 --- a/test/downstream/modelingtoolkit_remake.jl +++ b/test/downstream/modelingtoolkit_remake.jl @@ -336,3 +336,12 @@ end @test sccprob4.parameter_object !== sccprob4.probs[1].p @test sccprob4.parameter_object !== sccprob4.probs[2].p end + +@testset "Lazy initialization" begin + @variables x(t) [guess = 1.0] y(t) [guess = 1.0] + @parameters p=missing [guess = 1.0] + @mtkbuild sys = ODESystem([D(x) ~ x, x + y ~ p], t) + prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 1.0)) + prob2 = remake(prob; u0 = [x => 2.0]) + @test prob2.ps[p] ≈ 3.0 +end