Skip to content

Commit

Permalink
fix: CheckInit bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Nov 12, 2024
1 parent 9ab1a71 commit 705bf97
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 4 deletions.
8 changes: 5 additions & 3 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ Check if the algebraic constraints are satisfied, and error if they aren't. Retu
the `u0` and `p` as-is, and is always successful if it returns. Valid only for
`ODEProblem` and `DAEProblem`. Requires a `DEIntegrator` as its second argument.
"""
function get_initial_values(prob::ODEProblem, integrator, f, alg::CheckInit,
function get_initial_values(
prob::AbstractDEProblem, integrator::DEIntegrator, f, alg::CheckInit,
isinplace::Union{Val{true}, Val{false}}; kwargs...)
u0 = state_values(integrator)
p = parameter_values(integrator)
Expand All @@ -109,7 +110,7 @@ function get_initial_values(prob::ODEProblem, integrator, f, alg::CheckInit,

algebraic_vars = [all(iszero, x) for x in eachcol(M)]
algebraic_eqs = [all(iszero, x) for x in eachrow(M)]
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return u0, p, true
update_coefficients!(M, u0, p, t)
tmp = _evaluate_f_ode(integrator, f, isinplace, u0, p, t)
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))
Expand All @@ -135,7 +136,8 @@ function _evaluate_f_dae(integrator, f, isinplace::Val{false}, args...)
return f(args...)
end

function get_initial_values(prob::DAEProblem, integrator, f, alg::CheckInit,
function get_initial_values(
prob::AbstractDAEProblem, integrator::DEIntegrator, f, alg::CheckInit,
isinplace::Union{Val{true}, Val{false}}; kwargs...)
u0 = state_values(integrator)
p = parameter_values(integrator)
Expand Down
40 changes: 39 additions & 1 deletion test/initialization.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterface, Test
using StochasticDiffEq, OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterface, Test

@testset "CheckInit" begin
@testset "ODEProblem" begin
Expand Down Expand Up @@ -57,6 +57,44 @@ using OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterface, Test
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
end
end

@testset "SDEProblem" begin
mm_A = [1 0 0; 0 1 0; 0 0 0]
function sdef!(du, u, p, t)
du[1] = u[1]
du[2] = u[2]
du[3] = u[1] + u[2] + u[3] - 1
end
function sdef(u, p, t)
du = similar(u)
sdef!(du, u, p, t)
du
end

function g!(du, u, p, t)
@. du = 0.1
end
function g(u, p, t)
du = similar(u)
g!(du, u, p, t)
du
end
iipfn = SDEFunction{true}(sdef!, g!; mass_matrix = mm_A)
oopfn = SDEFunction{false}(sdef, g; mass_matrix = mm_A)

@testset "Inplace = $(SciMLBase.isinplace(f))" for f in [oopfn, iipfn]
prob = SDEProblem(f, [1.0, 1.0, -1.0], (0.0, 1.0))
integ = init(prob, ImplicitEM())
u0, _, success = SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
@test success
@test u0 == prob.u0

integ.u[2] = 2.0
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
end
end
end

@testset "OverrideInit" begin
Expand Down

0 comments on commit 705bf97

Please sign in to comment.