Skip to content

Commit

Permalink
Merge pull request #582 from AayushSabharwal/as/getu
Browse files Browse the repository at this point in the history
feat: add support for `getu`/`setu` in SII, add tests
  • Loading branch information
ChrisRackauckas authored Jan 4, 2024
2 parents f6de48a + ed57e16 commit c6d281f
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 27 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,14 @@ PyCall = "1.96"
PythonCall = "0.9.15"
RCall = "0.13.18"
RecipesBase = "1.0"
RecursiveArrayTools = "3.0"
RecursiveArrayTools = "3.3.4"
Reexport = "1"
RuntimeGeneratedFunctions = "0.5"
SciMLOperators = "0.3.7"
StaticArrays = "1.7"
StaticArraysCore = "1.4"
Statistics = "1.9"
SymbolicIndexingInterface = "0.3"
SymbolicIndexingInterface = "0.3.2"
Tables = "1.11"
TruncatedStacktraces = "1.4"
Zygote = "0.6.67"
Expand Down
15 changes: 11 additions & 4 deletions src/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -426,14 +426,20 @@ end

# SymbolicIndexingInterface
SymbolicIndexingInterface.symbolic_container(A::DEIntegrator) = A.f
SymbolicIndexingInterface.parameter_values(A::DEIntegrator) = A.p

function SymbolicIndexingInterface.is_observed(A::DEIntegrator, sym)
return !is_variable(A, sym) && !is_parameter(A, sym) && !is_independent_variable(A, sym) && symbolic_type(sym) == ScalarSymbolic()
end

function SymbolicIndexingInterface.observed(A::DEIntegrator, sym)
(u, p, t) -> getobserved(A)(sym, u, p, t)
return getobserved(A)(sym)
end
SymbolicIndexingInterface.parameter_values(A::DEIntegrator) = A.p
SymbolicIndexingInterface.state_values(A::DEIntegrator) = A.u
SymbolicIndexingInterface.current_time(A::DEIntegrator) = A.t
function SymbolicIndexingInterface.set_state!(A::DEIntegrator, val, idx)
# So any error checking happens to ensure we actually _can_ set state
set_u!(A, A.u)
A.u[idx] = val
u_modified!(A, true)
end

SymbolicIndexingInterface.is_time_dependent(::DEIntegrator) = true
Expand Down Expand Up @@ -506,6 +512,7 @@ function Base.setindex!(A::DEIntegrator, val, sym)
if symbolic_type(sym) == ScalarSymbolic()
if is_variable(A, sym)
A.u[variable_index(A, sym)] = val
u_modified!(A, true)
elseif is_parameter(A, sym)
error("Parameter indexing is deprecated. Use `setp(sys, $sym)(integrator, $val)` to set parameter value.")
else
Expand Down
8 changes: 8 additions & 0 deletions src/problems/problem_interface.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
SymbolicIndexingInterface.symbolic_container(prob::AbstractSciMLProblem) = prob.f
function SymbolicIndexingInterface.is_observed(A::AbstractSciMLProblem, sym)
return !is_variable(A, sym) && !is_parameter(A, sym) && !is_independent_variable(A, sym) && symbolic_type(sym) == ScalarSymbolic()
end
function SymbolicIndexingInterface.observed(A::AbstractSciMLProblem, sym)
return getobserved(A)(sym)
end
SymbolicIndexingInterface.parameter_values(prob::AbstractSciMLProblem) = prob.p
SymbolicIndexingInterface.state_values(prob::AbstractSciMLProblem) = prob.u0
SymbolicIndexingInterface.current_time(prob::AbstractSciMLProblem) = prob.tspan[1]

Base.@propagate_inbounds function Base.getindex(prob::AbstractSciMLProblem, ::SymbolicIndexingInterface.SolvedVariables)
return getindex(prob, variable_symbols(prob))
Expand Down
13 changes: 1 addition & 12 deletions src/solutions/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,6 @@ SymbolicIndexingInterface.is_independent_variable(::AbstractNoTimeSolution, sym)

SymbolicIndexingInterface.independent_variable_symbols(::AbstractNoTimeSolution) = []

function SymbolicIndexingInterface.is_observed(A::AbstractSolution, sym)
return !is_variable(A, sym) && !is_parameter(A, sym) && !is_independent_variable(A, sym) && symbolic_type(sym) == ScalarSymbolic()
end

function SymbolicIndexingInterface.observed(A::AbstractTimeseriesSolution, sym)
(u, p, t) -> getobserved(A)(sym, u, p, t)
end

function SymbolicIndexingInterface.observed(A::AbstractNoTimeSolution, sym)
(u, p) -> getobserved(A)(sym, u, p)
end

for soltype in [AbstractTimeseriesSolution, AbstractNoTimeSolution]
@eval function SymbolicIndexingInterface.observed(A::$(soltype), sym::Symbol)
has_sys(A.prob.f) || error("Cannot use observed without system")
Expand All @@ -63,6 +51,7 @@ SymbolicIndexingInterface.is_time_dependent(::AbstractNoTimeSolution) = false

# TODO make this nontrivial once dynamic state selection works
SymbolicIndexingInterface.constant_structure(::AbstractSolution) = true
SymbolicIndexingInterface.state_values(A::AbstractNoTimeSolution) = A.u

Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, ::Colon)
return A.u[:]
Expand Down
27 changes: 27 additions & 0 deletions test/downstream/integrator_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -303,3 +303,30 @@ prob = ODEProblem(sys, [], (0, 1.0))
integrator = init(prob, Tsit5())
@test integrator[x] isa Vector{Float64}
@test integrator[@nonamespace sys.x] isa Vector{Float64}

getx = getu(integrator, x)
gety = getu(integrator, :y)
get_arr = getu(integrator, [x, y])
get_tuple = getu(integrator, (x, y))
get_obs = getu(integrator, x[1] / p[1])
@test getx(integrator) == [1.0, 2.0, 3.0]
@test gety(integrator) == 1.0
@test get_arr(integrator) == [[1.0, 2.0, 3.0], 1.0]
@test get_tuple(integrator) == ([1.0, 2.0, 3.0], 1.0)
@test get_obs(integrator) == 1.0

setx! = setu(integrator, x)
sety! = setu(integrator, :y)
set_arr! = setu(integrator, [x, y])
set_tuple! = setu(integrator, (x, y))

setx!(integrator, [4.0, 5.0, 6.0])
@test getx(integrator) == [4.0, 5.0, 6.0]
sety!(integrator, 3.0)
@test gety(integrator) == 3.0
set_arr!(integrator, [1.0, 2.0])
@test get_arr(integrator) == [[1.0, 1.0, 1.0], 2.0]
set_arr!(integrator, [[1.0, 2.0, 3.0], 1.0])
@test get_arr(integrator) == [[1.0, 2.0, 3.0], 1.0]
set_tuple!(integrator, ([2.0, 4.0, 6.0], 2.0))
@test get_tuple(integrator) == ([2.0, 4.0, 6.0], 2.0)
58 changes: 58 additions & 0 deletions test/downstream/problem_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,35 @@ oprob[sys.y] = 10.0
oprob[:z] = 1.0
@test oprob[z] == oprob[sys.z] == oprob[:z] == 1.0

getx = getu(oprob, x)
gety = getu(oprob, :y)
get_arr = getu(oprob, [x, y])
get_tuple = getu(oprob, (y, z))
get_obs = getu(oprob, sys.x + sys.z + t + σ)
@test getx(oprob) == 10.0
@test gety(oprob) == 10.0
@test get_arr(oprob) == [10.0, 10.0]
@test get_tuple(oprob) == (10.0, 1.0)
@test get_obs(oprob) == 39.0

setx! = setu(oprob, x)
sety! = setu(oprob, :y)
set_arr! = setu(oprob, [x, y])
set_tuple! = setu(oprob, (y, z))

setx!(oprob, 11.0)
@test getx(oprob) == 11.0
sety!(oprob, 12.0)
@test gety(oprob) == 12.0
set_arr!(oprob, 10.0)
@test get_arr(oprob) == [10.0, 10.0]
set_arr!(oprob, [11.0, 12.0])
@test get_arr(oprob) == [11.0, 12.0]
set_tuple!(oprob, 13.0)
@test get_tuple(oprob) == (13.0, 13.0)
set_tuple!(oprob, [10.0, 10.0])
@test get_tuple(oprob) == (10.0, 10.0)

# SDEProblem.
noiseeqs = [0.1 * x,
0.1 * y,
Expand Down Expand Up @@ -97,3 +126,32 @@ sprob[noise_sys.y] = 10.0
@test sprob[y] == sprob[noise_sys.y] == sprob[:y] == 10.0
sprob[:z] = 1.0
@test sprob[z] == sprob[noise_sys.z] == sprob[:z] == 1.0

getx = getu(sprob, x)
gety = getu(sprob, :y)
get_arr = getu(sprob, [x, y])
get_tuple = getu(sprob, (y, z))
get_obs = getu(sprob, sys.x + sys.z + t + σ)
@test getx(sprob) == 10.0
@test gety(sprob) == 10.0
@test get_arr(sprob) == [10.0, 10.0]
@test get_tuple(sprob) == (10.0, 1.0)
@test get_obs(sprob) == 39.0

setx! = setu(sprob, x)
sety! = setu(sprob, :y)
set_arr! = setu(sprob, [x, y])
set_tuple! = setu(sprob, (y, z))

setx!(sprob, 11.0)
@test getx(sprob) == 11.0
sety!(sprob, 12.0)
@test gety(sprob) == 12.0
set_arr!(sprob, 10.0)
@test get_arr(sprob) == [10.0, 10.0]
set_arr!(sprob, [11.0, 12.0])
@test get_arr(sprob) == [11.0, 12.0]
set_tuple!(sprob, 13.0)
@test get_tuple(sprob) == (13.0, 13.0)
set_tuple!(sprob, [10.0, 10.0])
@test get_tuple(sprob) == (10.0, 10.0)
26 changes: 26 additions & 0 deletions test/downstream/symbol_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,23 @@ sol9 = sol(0.0:1.0:10.0, idxs = 2)
sol10 = sol(0.1, idxs = 2)
@test sol10 isa Real

@test is_timeseries(sol) == Timeseries()
getx = getu(sys_simplified, lorenz1.x)
get_arr = getu(sys_simplified, [lorenz1.x, lorenz2.x])
get_tuple = getu(sys_simplified, (lorenz1.x, lorenz2.x))
get_obs = getu(sys_simplified, lorenz1.x + lorenz2.x)
get_obs_arr = getu(sys_simplified, [lorenz1.x + lorenz2.x, lorenz1.y + lorenz2.y])
l1x_idx = variable_index(sol, lorenz1.x)
l2x_idx = variable_index(sol, lorenz2.x)
l1y_idx = variable_index(sol, lorenz1.y)
l2y_idx = variable_index(sol, lorenz2.y)

@test getx(sol) == sol[:, l1x_idx]
@test get_arr(sol) == sol[:, [l1x_idx, l2x_idx]]
@test get_tuple(sol) == tuple.(sol[:, l1x_idx], sol[:, l2x_idx])
@test get_obs(sol) == sol[:, l1x_idx] + sol[:, l2x_idx]
@test get_obs_arr(sol) == vcat.(sol[:, l1x_idx] + sol[:, l2x_idx], sol[:, l1y_idx] + sol[:, l2y_idx])

#=
using Plots
plot(sol,idxs=(lorenz2.x,lorenz2.z))
Expand All @@ -197,6 +214,15 @@ sol = solve(prob, Tsit5())
@test sol[x] isa Vector{<:Vector}
@test sol[@nonamespace sys.x] isa Vector{<:Vector}

getx = getu(sys, x)
get_mix_arr = getu(sys, [x, y])
get_mix_tuple = getu(sys, (x, y))
x_idx = variable_index.((sys,), [x[1], x[2], x[3]])
y_idx = variable_index(sys, y)
@test getx(sol) == sol[:, x_idx]
@test get_mix_arr(sol) == vcat.(sol[:, x_idx], sol[:, y_idx])
@test get_mix_tuple(sol) == tuple.(sol[:, x_idx], sol[:, y_idx])

# accessing parameters
@variables t x(t)
@parameters tau
Expand Down
23 changes: 14 additions & 9 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,7 @@ end
@time @safetestset "Ensemble with DifferentialEquations automatic algorithm selection" begin
include("downstream/ensemble_diffeq.jl")
end
@time @safetestset "Symbol and integer based indexing of interpolated solutions" begin
include("downstream/symbol_indexing.jl")
end
if VERSION >= v"1.8"
@time @safetestset "Symbol and integer based indexing of integrators" begin
include("downstream/integrator_indexing.jl")
end
@time @safetestset "Problem Indexing" begin
include("downstream/problem_interface.jl")
end
@time @safetestset "Solution Indexing" begin
include("downstream/solution_interface.jl")
end
Expand All @@ -112,6 +103,20 @@ end
end
end

if !is_APPVEYOR && (GROUP == "Downstream" || GROUP == "SymbolicIndexingInterface")
@time @safetestset "Symbol and integer based indexing of interpolated solutions" begin
include("downstream/symbol_indexing.jl")
end
if VERSION >= v"1.8"
@time @safetestset "Symbol and integer based indexing of integrators" begin
include("downstream/integrator_indexing.jl")
end
@time @safetestset "Problem Indexing" begin
include("downstream/problem_interface.jl")
end
end
end

if !is_APPVEYOR && GROUP == "Python"
activate_python_env()
@time @safetestset "PyCall" begin
Expand Down

0 comments on commit c6d281f

Please sign in to comment.