diff --git a/src/initialization.jl b/src/initialization.jl index aaf36bfcd..fb489e0dd 100644 --- a/src/initialization.jl +++ b/src/initialization.jl @@ -100,7 +100,8 @@ Check if the algebraic constraints are satisfied, and error if they aren't. Retu the `u0` and `p` as-is, and is always successful if it returns. Valid only for `ODEProblem` and `DAEProblem`. Requires a `DEIntegrator` as its second argument. """ -function get_initial_values(prob::ODEProblem, integrator, f, alg::CheckInit, +function get_initial_values( + prob::AbstractDEProblem, integrator::DEIntegrator, f, alg::CheckInit, isinplace::Union{Val{true}, Val{false}}; kwargs...) u0 = state_values(integrator) p = parameter_values(integrator) @@ -109,7 +110,7 @@ function get_initial_values(prob::ODEProblem, integrator, f, alg::CheckInit, algebraic_vars = [all(iszero, x) for x in eachcol(M)] algebraic_eqs = [all(iszero, x) for x in eachrow(M)] - (iszero(algebraic_vars) || iszero(algebraic_eqs)) && return + (iszero(algebraic_vars) || iszero(algebraic_eqs)) && return u0, p, true update_coefficients!(M, u0, p, t) tmp = _evaluate_f_ode(integrator, f, isinplace, u0, p, t) tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp)) @@ -135,7 +136,8 @@ function _evaluate_f_dae(integrator, f, isinplace::Val{false}, args...) return f(args...) end -function get_initial_values(prob::DAEProblem, integrator, f, alg::CheckInit, +function get_initial_values( + prob::AbstractDAEProblem, integrator::DEIntegrator, f, alg::CheckInit, isinplace::Union{Val{true}, Val{false}}; kwargs...) u0 = state_values(integrator) p = parameter_values(integrator) diff --git a/test/initialization.jl b/test/initialization.jl index 3074aa690..c99ce74b6 100644 --- a/test/initialization.jl +++ b/test/initialization.jl @@ -1,4 +1,4 @@ -using OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterface, Test +using StochasticDiffEq, OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterface, Test @testset "CheckInit" begin @testset "ODEProblem" begin @@ -57,6 +57,44 @@ using OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterface, Test prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f))) end end + + @testset "SDEProblem" begin + mm_A = [1 0 0; 0 1 0; 0 0 0] + function sdef!(du, u, p, t) + du[1] = u[1] + du[2] = u[2] + du[3] = u[1] + u[2] + u[3] - 1 + end + function sdef(u, p, t) + du = similar(u) + sdef!(du, u, p, t) + du + end + + function g!(du, u, p, t) + @. du = 0.1 + end + function g(u, p, t) + du = similar(u) + g!(du, u, p, t) + du + end + iipfn = SDEFunction{true}(sdef!, g!; mass_matrix = mm_A) + oopfn = SDEFunction{false}(sdef, g; mass_matrix = mm_A) + + @testset "Inplace = $(SciMLBase.isinplace(f))" for f in [oopfn, iipfn] + prob = SDEProblem(f, [1.0, 1.0, -1.0], (0.0, 1.0)) + integ = init(prob, ImplicitEM()) + u0, _, success = SciMLBase.get_initial_values( + prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f))) + @test success + @test u0 == prob.u0 + + integ.u[2] = 2.0 + @test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values( + prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f))) + end + end end @testset "OverrideInit" begin