Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: only solve parameter initialization for NonlinearSystem #3345

Merged
merged 2 commits into from
Jan 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ RecursiveArrayTools = "3.26"
Reexport = "0.2, 1"
RuntimeGeneratedFunctions = "0.5.9"
SCCNonlinearSolve = "1.0.0"
SciMLBase = "2.71"
SciMLBase = "2.71.1"
SciMLStructures = "1.0"
Serialization = "1"
Setfield = "0.7, 0.8, 1"
Expand Down
2 changes: 1 addition & 1 deletion src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1335,7 +1335,7 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem,

# TODO: throw on uninitialized arrays
filter!(x -> !(x isa Symbolics.Arr), uninit)
if !isempty(uninit)
if is_time_dependent(sys) && !isempty(uninit)
allow_incomplete || throw(IncompleteInitializationError(uninit))
# for incomplete initialization, we will add the missing variables as parameters.
# they will be updated by `update_initializeprob!` and `initializeprobmap` will
Expand Down
106 changes: 58 additions & 48 deletions src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,51 +42,53 @@ function generate_initializesystem(sys::AbstractSystem;
diffmap = Dict()
end

if has_schedule(sys) && (schedule = get_schedule(sys); !isnothing(schedule))
# 2) process dummy derivatives and u0map into initialization system
# prepare map for dummy derivative substitution
for x in filter(x -> !isnothing(x[1]), schedule.dummy_sub)
# set dummy derivatives to default_dd_guess unless specified
push!(defs, x[1] => get(guesses, x[1], default_dd_guess))
end
function process_u0map_with_dummysubs(y, x)
y = get(schedule.dummy_sub, y, y)
y = fixpoint_sub(y, diffmap)
if y ∈ vars_set
# variables specified in u0 overrides defaults
push!(defs, y => x)
elseif y isa Symbolics.Arr
# TODO: don't scalarize arrays
merge!(defs, Dict(scalarize(y .=> x)))
elseif y isa Symbolics.BasicSymbolic
# y is a derivative expression expanded; add it to the initialization equations
push!(eqs_ics, y ~ x)
else
error("Initialization expression $y is currently not supported. If its a higher order derivative expression, then only the dummy derivative expressions are supported.")
if is_time_dependent(sys)
if has_schedule(sys) && (schedule = get_schedule(sys); !isnothing(schedule))
# 2) process dummy derivatives and u0map into initialization system
# prepare map for dummy derivative substitution
for x in filter(x -> !isnothing(x[1]), schedule.dummy_sub)
# set dummy derivatives to default_dd_guess unless specified
push!(defs, x[1] => get(guesses, x[1], default_dd_guess))
end
end
for (y, x) in u0map
if Symbolics.isarraysymbolic(y)
process_u0map_with_dummysubs.(collect(y), collect(x))
else
process_u0map_with_dummysubs(y, x)
function process_u0map_with_dummysubs(y, x)
y = get(schedule.dummy_sub, y, y)
y = fixpoint_sub(y, diffmap)
if y ∈ vars_set
# variables specified in u0 overrides defaults
push!(defs, y => x)
elseif y isa Symbolics.Arr
# TODO: don't scalarize arrays
merge!(defs, Dict(scalarize(y .=> x)))
elseif y isa Symbolics.BasicSymbolic
# y is a derivative expression expanded; add it to the initialization equations
push!(eqs_ics, y ~ x)
else
error("Initialization expression $y is currently not supported. If its a higher order derivative expression, then only the dummy derivative expressions are supported.")
end
end
for (y, x) in u0map
if Symbolics.isarraysymbolic(y)
process_u0map_with_dummysubs.(collect(y), collect(x))
else
process_u0map_with_dummysubs(y, x)
end
end
else
# 2) System doesn't have a schedule, so dummy derivatives don't exist/aren't handled (SDESystem)
for (k, v) in u0map
defs[k] = v
end
end
else
# 2) System doesn't have a schedule, so dummy derivatives don't exist/aren't handled (SDESystem)
for (k, v) in u0map
defs[k] = v
end
end

# 3) process other variables
for var in vars
if var ∈ keys(defs)
push!(eqs_ics, var ~ defs[var])
elseif var ∈ keys(guesses)
push!(defs, var => guesses[var])
elseif check_defguess
error("Invalid setup: variable $(var) has no default value or initial guess")
# 3) process other variables
for var in vars
if var ∈ keys(defs)
push!(eqs_ics, var ~ defs[var])
elseif var ∈ keys(guesses)
push!(defs, var => guesses[var])
elseif check_defguess
error("Invalid setup: variable $(var) has no default value or initial guess")
end
end
end

Expand Down Expand Up @@ -180,16 +182,24 @@ function generate_initializesystem(sys::AbstractSystem;
pars = Vector{SymbolicParam}(filter(p -> !haskey(paramsubs, p), parameters(sys)))
is_time_dependent(sys) && push!(pars, get_iv(sys))

# 8) use observed equations for guesses of observed variables if not provided
for eq in trueobs
haskey(defs, eq.lhs) && continue
any(x -> isequal(default_toterm(x), eq.lhs), keys(defs)) && continue
if is_time_dependent(sys)
# 8) use observed equations for guesses of observed variables if not provided
for eq in trueobs
haskey(defs, eq.lhs) && continue
any(x -> isequal(default_toterm(x), eq.lhs), keys(defs)) && continue

defs[eq.lhs] = eq.rhs
defs[eq.lhs] = eq.rhs
end
append!(eqs_ics, trueobs)
end

eqs_ics = Symbolics.substitute.(eqs_ics, (paramsubs,))
if is_time_dependent(sys)
vars = [vars; collect(values(paramsubs))]
else
vars = collect(values(paramsubs))
end

eqs_ics = Symbolics.substitute.([eqs_ics; trueobs], (paramsubs,))
vars = [vars; collect(values(paramsubs))]
for k in keys(defs)
defs[k] = substitute(defs[k], paramsubs)
end
Expand Down
38 changes: 27 additions & 11 deletions src/systems/problem_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -546,30 +546,46 @@ function maybe_build_initialization_problem(
initializeprob = ModelingToolkit.InitializationProblem(
sys, t, u0map, pmap; guesses, kwargs...)

all_init_syms = Set(all_symbols(initializeprob))
solved_unknowns = filter(var -> var in all_init_syms, unknowns(sys))
initializeprobmap = getu(initializeprob, solved_unknowns)
if is_time_dependent(sys)
all_init_syms = Set(all_symbols(initializeprob))
solved_unknowns = filter(var -> var in all_init_syms, unknowns(sys))
initializeprobmap = getu(initializeprob, solved_unknowns)
else
initializeprobmap = nothing
end

punknowns = [p
for p in all_variable_symbols(initializeprob)
if is_parameter(sys, p)]
getpunknowns = getu(initializeprob, punknowns)
setpunknowns = setp(sys, punknowns)
initializeprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns)
if isempty(punknowns)
initializeprobpmap = nothing
else
getpunknowns = getu(initializeprob, punknowns)
setpunknowns = setp(sys, punknowns)
initializeprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns)
end

reqd_syms = parameter_symbols(initializeprob)
update_initializeprob! = UpdateInitializeprob(
getu(sys, reqd_syms), setu(initializeprob, reqd_syms))
# we still want the `initialization_data` because it helps with `remake`
if initializeprobmap === nothing && initializeprobpmap === nothing
update_initializeprob! = nothing
else
update_initializeprob! = UpdateInitializeprob(
getu(sys, reqd_syms), setu(initializeprob, reqd_syms))
end

for p in punknowns
p = unwrap(p)
stype = symtype(p)
op[p] = get_temporary_value(p)
end

for v in missing_unknowns
op[v] = zero_var(v)
if is_time_dependent(sys)
for v in missing_unknowns
op[v] = zero_var(v)
end
empty!(missing_unknowns)
end
empty!(missing_unknowns)
return (;
initialization_data = SciMLBase.OverrideInitData(
initializeprob, update_initializeprob!, initializeprobmap,
Expand Down
Loading
Loading