Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for new symbol indexing methods in SII #571

Merged
merged 6 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
using Zygote: @adjoint, pullback
import Zygote: literal_getproperty
using SciMLBase
using SciMLBase: ODESolution, sym_to_index, remake,
using SciMLBase: ODESolution, remake,
getobserved, build_solution, EnsembleSolution,
NonlinearSolution, AbstractTimeseriesSolution
using SymbolicIndexingInterface: symbolic_type, NotSymbolic
using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index
using RecursiveArrayTools

# This method resolves the ambiguity with the pullback defined in
Expand All @@ -34,7 +34,7 @@

@adjoint function getindex(VA::ODESolution, sym, j::Int)
function ODESolution_getindex_pullback(Δ)
i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym

Check warning on line 37 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L37

Added line #L37 was not covered by tests
du, dprob = if i === nothing
getter = getobserved(VA)
grz = pullback(getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ)
Expand Down Expand Up @@ -96,7 +96,7 @@

@adjoint function getindex(VA::ODESolution, sym)
function ODESolution_getindex_pullback(Δ)
i = symbolic_type(sym) != NotSymbolic() ? sym_to_index(sym, VA) : sym
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym

Check warning on line 99 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L99

Added line #L99 was not covered by tests
if i === nothing
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
else
Expand Down
2 changes: 1 addition & 1 deletion src/ensemble/ensemble_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@
end
end

Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Union{ScalarSymbolic,ArraySymbolic}, s, ::Colon)
Base.@propagate_inbounds function RecursiveArrayTools._getindex(x::AbstractEnsembleSolution, ::Union{ScalarSymbolic,ArraySymbolic}, s, ::Colon)

Check warning on line 212 in src/ensemble/ensemble_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_solutions.jl#L212

Added line #L212 was not covered by tests
return [xi[s] for xi in x.u]
end

Expand Down
26 changes: 16 additions & 10 deletions src/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -450,18 +450,17 @@
end
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::NotSymbolic, I::Union{Int, AbstractArray{Int},
Base.@propagate_inbounds function _getindex(A::DEIntegrator, ::NotSymbolic, I::Union{Int, AbstractArray{Int},

Check warning on line 453 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L453

Added line #L453 was not covered by tests
CartesianIndex, Colon, BitArray,
AbstractArray{Bool}}...)
A.u[I...]
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ScalarSymbolic, sym)
Base.@propagate_inbounds function _getindex(A::DEIntegrator, ::ScalarSymbolic, sym)

Check warning on line 459 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L459

Added line #L459 was not covered by tests
if is_variable(A, sym)
return A[variable_index(A, sym)]
elseif is_parameter(A, sym)
Base.depwarn("Indexing with parameters is deprecated. Use `getp(sys, $sym)(integrator)` for parameter indexing.", :parameter_getindex)
return getp(A, sym)(A)
error("Indexing with parameters is deprecated. Use `getp(sys, $sym)(integrator)` for parameter indexing")

Check warning on line 463 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L463

Added line #L463 was not covered by tests
elseif is_independent_variable(A, sym)
return A.t
elseif is_observed(A, sym)
Expand All @@ -471,11 +470,11 @@
end
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ArraySymbolic, sym)
Base.@propagate_inbounds function _getindex(A::DEIntegrator, ::ArraySymbolic, sym)

Check warning on line 473 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L473

Added line #L473 was not covered by tests
return A[collect(sym)]
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray})
Base.@propagate_inbounds function _getindex(A::DEIntegrator, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray})

Check warning on line 477 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L477

Added line #L477 was not covered by tests
return getindex.((A,), sym)
end

Expand All @@ -484,12 +483,20 @@
elsymtype = symbolic_type(eltype(sym))

if symtype != NotSymbolic()
return getindex(A, symtype, sym)
return _getindex(A, symtype, sym)

Check warning on line 486 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L486

Added line #L486 was not covered by tests
else
return getindex(A, elsymtype, sym)
return _getindex(A, elsymtype, sym)

Check warning on line 488 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L488

Added line #L488 was not covered by tests
end
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::SymbolicIndexingInterface.SolvedVariables)
return getindex(A, variable_symbols(A))

Check warning on line 493 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L492-L493

Added lines #L492 - L493 were not covered by tests
end

Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, ::SymbolicIndexingInterface.AllVariables)
return getindex(A, all_variable_symbols(A))

Check warning on line 497 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L496-L497

Added lines #L496 - L497 were not covered by tests
end

function observed(A::DEIntegrator, sym)
getobserved(A)(sym, A.u, A.p, A.t)
end
Expand All @@ -500,8 +507,7 @@
if is_variable(A, sym)
A.u[variable_index(A, sym)] = val
elseif is_parameter(A, sym)
Base.depwarn("Parameter indexing is deprecated. Use `setp(sys, $sym)(integrator, $val)` to set parameter value.", :parameter_setindex)
setp(A, sym)(A, val)
error("Parameter indexing is deprecated. Use `setp(sys, $sym)(integrator, $val)` to set parameter value.")

Check warning on line 510 in src/integrator_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/integrator_interface.jl#L510

Added line #L510 was not covered by tests
else
error("Invalid indexing of integrator: $sym is not a state or parameter, it may be an observed variable.")
end
Expand Down
14 changes: 10 additions & 4 deletions src/problems/problem_interface.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
SymbolicIndexingInterface.symbolic_container(prob::AbstractSciMLProblem) = prob.f
SymbolicIndexingInterface.parameter_values(prob::AbstractSciMLProblem) = prob.p

Base.@propagate_inbounds function Base.getindex(prob::AbstractSciMLProblem, ::SymbolicIndexingInterface.SolvedVariables)
return getindex(prob, variable_symbols(prob))

Check warning on line 5 in src/problems/problem_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/problem_interface.jl#L4-L5

Added lines #L4 - L5 were not covered by tests
end

Base.@propagate_inbounds function Base.getindex(prob::AbstractSciMLProblem, ::SymbolicIndexingInterface.AllVariables)
return getindex(prob, all_variable_symbols(prob))

Check warning on line 9 in src/problems/problem_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/problem_interface.jl#L8-L9

Added lines #L8 - L9 were not covered by tests
end

Base.@propagate_inbounds function Base.getindex(prob::AbstractSciMLProblem, sym)
if symbolic_type(sym) == ScalarSymbolic()
if is_variable(prob.f, sym)
return prob.u0[variable_index(prob.f, sym)]
elseif is_parameter(prob.f, sym)
Base.depwarn("Indexing with parameters is deprecated. Use `getp(prob, $sym)(prob)` for parameter indexing.", :parameter_getindex)
return getp(prob, sym)(prob)
error("Indexing with parameters is deprecated. Use `getp(prob, $sym)(prob)` for parameter indexing.")

Check warning on line 17 in src/problems/problem_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/problem_interface.jl#L17

Added line #L17 was not covered by tests
elseif is_independent_variable(prob.f, sym)
return getindepsym(prob)
elseif is_observed(prob.f, sym)
Expand Down Expand Up @@ -37,8 +44,7 @@
if is_variable(prob.f, sym)
prob.u0[variable_index(prob.f, sym)] = val
elseif is_parameter(prob.f, sym)
Base.depwarn("Indexing with parameters is deprecated. Use `setp(prob, $sym)(prob, $val)` to set parameter value.", :parameter_setindex)
setp(prob, sym)(prob, val)
error("Indexing with parameters is deprecated. Use `setp(prob, $sym)(prob, $val)` to set parameter value.")

Check warning on line 47 in src/problems/problem_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/problem_interface.jl#L47

Added line #L47 was not covered by tests
else
error("Invalid indexing of problem: $sym is not a state or parameter, it may be an observed variable.")
end
Expand Down
18 changes: 13 additions & 5 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,22 +179,30 @@
function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs,
continuity) where {deriv}
symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`")
augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol)[idxs][1]
if is_parameter(sol, idxs)
return getp(sol, idxs)(sol)

Check warning on line 183 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L182-L183

Added lines #L182 - L183 were not covered by tests
else
return augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol)[idxs][1]

Check warning on line 185 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L185

Added line #L185 was not covered by tests
end
end

function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVector,
continuity) where {deriv}
all(!isequal(NotSymbolic()), symbolic_type.(idxs)) || error("Incorrect specification of `idxs`")
interp_sol = augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol)
[first(interp_sol[idx]) for idx in idxs]
[is_parameter(sol, idx) ? getp(sol, idx)(sol) : first(interp_sol[idx]) for idx in idxs]

Check warning on line 193 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L193

Added line #L193 was not covered by tests
end

function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs,
continuity) where {deriv}
symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`")
interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol)
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
return DiffEqArray(interp_sol[idxs], t, p, sol)
if is_parameter(sol, idxs)
return getp(sol, idxs)(sol)

Check warning on line 200 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L199-L200

Added lines #L199 - L200 were not covered by tests
else
interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol)
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
return DiffEqArray(interp_sol[idxs], t, p, sol)

Check warning on line 204 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L202-L204

Added lines #L202 - L204 were not covered by tests
end
end

function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
Expand Down
11 changes: 9 additions & 2 deletions src/solutions/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@
if is_variable(A, sym)
return A[variable_index(A, sym)]
elseif is_parameter(A, sym)
Base.depwarn("Indexing with parameters is deprecated. Use `getp(sys, $sym)(sol)` for parameter indexing.", :parameter_getindex)
return getp(A, sym)(A)
error("Indexing with parameters is deprecated. Use `getp(sys, $sym)(sol)` for parameter indexing.")

Check warning on line 76 in src/solutions/solution_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/solution_interface.jl#L76

Added line #L76 was not covered by tests
elseif is_observed(A, sym)
return SymbolicIndexingInterface.observed(A, sym)(A.u, A.prob.p)
else
Expand All @@ -88,6 +87,14 @@
end
end

Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, ::SymbolicIndexingInterface.SolvedVariables)
return getindex(A, variable_symbols(A))

Check warning on line 91 in src/solutions/solution_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/solution_interface.jl#L90-L91

Added lines #L90 - L91 were not covered by tests
end

Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, ::SymbolicIndexingInterface.AllVariables)
return getindex(A, all_variable_symbols(A))

Check warning on line 95 in src/solutions/solution_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/solution_interface.jl#L94-L95

Added lines #L94 - L95 were not covered by tests
end

function observed(A::AbstractTimeseriesSolution, sym, i::Int)
getobserved(A)(sym, A[i], A.prob.p, A.t[i])
end
Expand Down
4 changes: 2 additions & 2 deletions test/downstream/ensemble_nondes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleDistr
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective

prob = OptimizationProblem(optf, x0, lb = [-0.5, -0.5], ub = [0.5, 0.5])
ensembleprob = Optimization.EnsembleProblem(prob, 5, prob_func = (prob, i, repeat) -> remake(prob, u0 = rand(-0.5:0.001:0.5, 2)))
ensembleprob = Optimization.EnsembleProblem(prob, prob_func = (prob, i, repeat) -> remake(prob, u0 = rand(-0.5:0.001:0.5, 2)))

sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleThreads(), trajectories = 5, maxiters = 5)
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective
Expand All @@ -35,4 +35,4 @@ ensembleprob = EnsembleProblem(prob, [u0, u0 .+ rand(2), u0 .+ rand(2), u0 .+ ra

sol = solve(ensembleprob, EnsembleThreads(), trajectories = 4, maxiters = 100)

sol = solve(ensembleprob, EnsembleDistributed(), trajectories = 4, maxiters = 100)
sol = solve(ensembleprob, EnsembleDistributed(), trajectories = 4, maxiters = 100)
15 changes: 8 additions & 7 deletions test/downstream/integrator_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,18 @@ tspan = (0.0, 1000000.0)
oprob = ODEProblem(population_model, u0, tspan, p)
integrator = init(oprob, Rodas4())

@test_deprecated integrator[a]
@test_deprecated integrator[population_model.a]
@test_deprecated integrator[:a]
@test_throws Exception integrator[a]
@test_throws Exception integrator[population_model.a]
@test_throws Exception integrator[:a]
@test getp(oprob, a)(integrator) == getp(oprob, population_model.a)(integrator) == getp(oprob, :a)(integrator) == 2.0
@test getp(oprob, b)(integrator) == getp(oprob, population_model.b)(integrator) == getp(oprob, :b)(integrator) == 1.0
@test getp(oprob, c)(integrator) == getp(oprob, population_model.c)(integrator) == getp(oprob, :c)(integrator) == 1.0
@test getp(oprob, d)(integrator) == getp(oprob, population_model.d)(integrator) == getp(oprob, :d)(integrator) == 1.0

@test integrator[s1] == integrator[population_model.s1] == integrator[:s1] == 2.0
@test integrator[s2] == integrator[population_model.s2] == integrator[:s2] == 1.0

@test integrator[solvedvariables] == integrator.u
@test integrator[allvariables] == integrator.u
step!(integrator, 100.0, true)

@test getp(population_model, a)(integrator) == getp(population_model, population_model.a)(integrator) == getp(population_model, :a)(integrator) == 2.0
Expand Down Expand Up @@ -299,6 +300,6 @@ eqs = [collect(D.(x) .~ x)
D(y) ~ norm(x) * y - x[1]]
@named sys = ODESystem(eqs, t, [sts...;], [ps...;])
prob = ODEProblem(sys, [], (0, 1.0))
@test_broken local integrator = init(prob, Tsit5())
@test_broken integrator[x] isa Vector{<:Vector}
@test_broken integrator[@nonamespace sys.x] isa Vector{<:Vector}
integrator = init(prob, Tsit5())
@test integrator[x] isa Vector{Float64}
@test integrator[@nonamespace sys.x] isa Vector{Float64}
58 changes: 39 additions & 19 deletions test/downstream/problem_interface.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ModelingToolkit, OrdinaryDiffEq, Test
using SymbolicIndexingInterface

@parameters σ ρ β
@variables t x(t) y(t) z(t)
Expand Down Expand Up @@ -26,20 +27,37 @@ tspan = (0.0, 100.0)
# ODEProblem.
oprob = ODEProblem(sys, u0, tspan, p, jac = true)

@test oprob[σ] == oprob[sys.σ] == oprob[:σ] == 28.0
@test oprob[ρ] == oprob[sys.ρ] == oprob[:ρ] == 10.0
@test oprob[β] == oprob[sys.β] == oprob[:β] == 8 / 3
@test_throws Exception oprob[σ]
@test_throws Exception oprob[sys.σ]
@test_throws Exception oprob[:σ]
getσ1 = getp(sys, σ)
getσ2 = getp(sys, sys.σ)
getσ3 = getp(sys, :σ)
@test getσ1(oprob) == getσ2(oprob) == getσ3(oprob) == 28.0
getρ1 = getp(sys, ρ)
getρ2 = getp(sys, sys.ρ)
getρ3 = getp(sys, :ρ)
@test getρ1(oprob) == getρ2(oprob) == getρ3(oprob) == 10.0
getβ1 = getp(sys, β)
getβ2 = getp(sys, sys.β)
getβ3 = getp(sys, :β)
@test getβ1(oprob) == getβ2(oprob) == getβ3(oprob) == 8 / 3

@test oprob[x] == oprob[sys.x] == oprob[:x] == 1.0
@test oprob[y] == oprob[sys.y] == oprob[:y] == 0.0
@test oprob[z] == oprob[sys.z] == oprob[:z] == 0.0

oprob[σ] = 10.0
@test oprob[σ] == oprob[sys.σ] == oprob[:σ] == 10.0
oprob[sys.ρ] = 20.0
@test oprob[ρ] == oprob[sys.ρ] == oprob[:ρ] == 20.0
oprob[σ] = 30.0
@test oprob[σ] == oprob[sys.σ] == oprob[:σ] == 30.0
@test oprob[solvedvariables] == oprob[variable_symbols(sys)]
@test oprob[allvariables] == oprob[all_variable_symbols(sys)]

setσ = setp(sys, σ)
setσ(oprob, 10.0)
@test getσ1(oprob) == getσ2(oprob) == getσ3(oprob) == 10.0
setρ = setp(sys, sys.ρ)
setρ(oprob, 20.0)
@test getρ1(oprob) == getρ2(oprob) == getρ3(oprob) == 20.0
setβ = setp(sys, :β)
setβ(oprob, 30.0)
@test getβ1(oprob) == getβ2(oprob) == getβ3(oprob) == 30.0

oprob[x] = 10.0
@test oprob[x] == oprob[sys.x] == oprob[:x] == 10.0
Expand All @@ -56,20 +74,22 @@ noiseeqs = [0.1 * x,
sprob = SDEProblem(noise_sys, u0, (0.0, 100.0), p)
u0

@test sprob[σ] == sprob[noise_sys.σ] == sprob[:σ] == 28.0
@test sprob[ρ] == sprob[noise_sys.ρ] == sprob[:ρ] == 10.0
@test sprob[β] == sprob[noise_sys.β] == sprob[:β] == 8 / 3
@test getσ1(sprob) == getσ2(sprob) == getσ3(sprob) == 28.0
@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == 10.0
@test getβ1(sprob) == getβ2(sprob) == getβ3(sprob) == 8 / 3

@test sprob[x] == sprob[noise_sys.x] == sprob[:x] == 1.0
@test sprob[y] == sprob[noise_sys.y] == sprob[:y] == 0.0
@test sprob[z] == sprob[noise_sys.z] == sprob[:z] == 0.0

sprob[σ] = 10.0
@test sprob[σ] == sprob[noise_sys.σ] == sprob[:σ] == 10.0
sprob[noise_sys.ρ] = 20.0
@test sprob[ρ] == sprob[noise_sys.ρ] == sprob[:ρ] == 20.0
sprob[σ] = 30.0
@test sprob[σ] == sprob[noise_sys.σ] == sprob[:σ] == 30.0
setσ(sprob, 10.0)
@test getσ1(sprob) == getσ2(sprob) == getσ3(sprob) == 10.0
setρ(sprob, 20.0)
@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == 20.0
setp(noise_sys, noise_sys.ρ)(sprob, 25.0)
@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == 25.0
setβ(sprob, 30.0)
@test getβ1(sprob) == getβ2(sprob) == getβ3(sprob) == 30.0

sprob[x] = 10.0
@test sprob[x] == sprob[noise_sys.x] == sprob[:x] == 10.0
Expand Down
12 changes: 6 additions & 6 deletions test/downstream/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ sol = solve(oprob, Rodas4())
@test sol[s1] == sol[population_model.s1] == sol[:s1]
@test sol[s2] == sol[population_model.s2] == sol[:s2]
@test sol[s1][end] ≈ 1.0
@test_deprecated sol[a]
@test_deprecated sol[population_model.a]
@test_deprecated sol[:a]
@test_throws Exception sol[a]
@test_throws Exception sol[population_model.a]
@test_throws Exception sol[:a]

# Tests on SDEProblem
noiseeqs = [0.1 * s1,
Expand All @@ -34,9 +34,9 @@ sol = solve(sprob, ImplicitEM())

@test sol[s1] == sol[noisy_population_model.s1] == sol[:s1]
@test sol[s2] == sol[noisy_population_model.s2] == sol[:s2]
@test_deprecated sol[a]
@test_deprecated sol[noisy_population_model.a]
@test_deprecated sol[:a]
@test_throws Exception sol[a]
@test_throws Exception sol[noisy_population_model.a]
@test_throws Exception sol[:a]
### Tests on layered model (some things should not work). ###

@parameters t σ ρ β
Expand Down
Loading
Loading