From 61332585a17c87669d6557f55d0d8b6e90f82b01 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 25 Jun 2024 14:40:14 +0530 Subject: [PATCH] refactor: update implementation of discrete save interface --- src/systems/abstractsystem.jl | 60 +++++++++++++++++++---------- src/systems/index_cache.jl | 18 ++++----- src/systems/parameter_buffer.jl | 5 ++- test/mtkparameters.jl | 52 ++++--------------------- test/parameter_dependencies.jl | 9 +++-- test/symbolic_indexing_interface.jl | 29 +++++++------- 6 files changed, 79 insertions(+), 94 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 3d1b6d46fe..c3066255ef 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -429,7 +429,8 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym) sym = unwrap(sym) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing return sym isa ParameterIndex || is_parameter(ic, sym) || - iscall(sym) && operation(sym) === getindex && + iscall(sym) && + operation(sym) === getindex && is_parameter(ic, first(arguments(sym))) end if unwrap(sym) isa Int @@ -508,34 +509,19 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Sym end function SymbolicIndexingInterface.is_timeseries_parameter(sys::AbstractSystem, sym) + is_time_dependent(sys) || return false has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing || return false is_timeseries_parameter(ic, sym) end function SymbolicIndexingInterface.timeseries_parameter_index(sys::AbstractSystem, sym) + is_time_dependent(sys) || return nothing has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing || return nothing timeseries_parameter_index(ic, sym) end function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing - allvars = vars(sym; op = Symbolics.Operator) - ts_idxs = Set{Int}() - for var in allvars - var = unwrap(var) - # FIXME: Shouldn't have to shift systems - if istree(var) && (op = operation(var)) isa Shift && op.steps == 1 - var = only(arguments(var)) - end - ts_idx = check_index_map(ic.discrete_idx, unwrap(var)) - ts_idx === nothing && continue - push!(ts_idxs, ts_idx[1]) - end - if length(ts_idxs) == 1 - ts_idx = only(ts_idxs) - else - ts_idx = nothing - end rawobs = build_explicit_observed_function( sys, sym; param_only = true, return_inplace = true) if rawobs isa Tuple @@ -562,10 +548,44 @@ function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym) end end else - ts_idx = nothing obsfn = build_explicit_observed_function(sys, sym; param_only = true) end - return ParameterObservedFunction(ts_idx, obsfn) + return obsfn +end + +function _all_ts_idxs!(ts_idxs, ::NotSymbolic, sys, sym) + if is_variable(sys, sym) + push!(ts_idxs, ContinuousTimeseries()) + elseif is_timeseries_parameter(sys, sym) + push!(ts_idxs, timeseries_parameter_index(sys, sym).timeseries_idx) + end +end +# Need this to avoid ambiguity with the array case +for traitT in [ + ScalarSymbolic, + ArraySymbolic +] + @eval function _all_ts_idxs!(ts_idxs, ::$traitT, sys, sym) + allsyms = vars(sym; op = Symbolics.Operator) + foreach(allsyms) do s + _all_ts_idxs!(ts_idxs, sys, s) + end + end +end +function _all_ts_idxs!(ts_idxs, ::NotSymbolic, sys, sym::AbstractArray) + foreach(sym) do s + _all_ts_idxs!(ts_idxs, sys, s) + end +end +_all_ts_idxs!(ts_idxs, sys, sym) = _all_ts_idxs!(ts_idxs, NotSymbolic(), sys, sym) + +function SymbolicIndexingInterface.get_all_timeseries_indexes(sys::AbstractSystem, sym) + if !is_time_dependent(sys) + return Set() + end + ts_idxs = Set() + _all_ts_idxs!(ts_idxs, sys, sym) + return ts_idxs end function SymbolicIndexingInterface.parameter_symbols(sys::AbstractSystem) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index b1063f214e..13fb7adef2 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -113,7 +113,7 @@ function IndexCache(sys::AbstractSystem) error("Discrete subsystem $i input $inp is not a parameter") disc_clocks[inp] = i disc_clocks[default_toterm(inp)] = i - if hasname(inp) && (!istree(inp) || operation(inp) !== getindex) + if hasname(inp) && (!iscall(inp) || operation(inp) !== getindex) disc_clocks[getname(inp)] = i disc_clocks[default_toterm(inp)] = i end @@ -126,7 +126,7 @@ function IndexCache(sys::AbstractSystem) error("Discrete subsystem $i unknown $sym is not a parameter") disc_clocks[sym] = i disc_clocks[default_toterm(sym)] = i - if hasname(sym) && (!istree(sym) || operation(sym) !== getindex) + if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) disc_clocks[getname(sym)] = i disc_clocks[getname(default_toterm(sym))] = i end @@ -138,13 +138,13 @@ function IndexCache(sys::AbstractSystem) # FIXME: This shouldn't be necessary eq.rhs === -0.0 && continue sym = eq.lhs - if istree(sym) && operation(sym) == Shift(t, 1) + if iscall(sym) && operation(sym) == Shift(t, 1) sym = only(arguments(sym)) end disc_clocks[sym] = i disc_clocks[sym] = i disc_clocks[default_toterm(sym)] = i - if hasname(sym) && (!istree(sym) || operation(sym) !== getindex) + if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) disc_clocks[getname(sym)] = i disc_clocks[getname(default_toterm(sym))] = i end @@ -153,7 +153,7 @@ function IndexCache(sys::AbstractSystem) for par in inputs[continuous_id] is_parameter(sys, par) || error("Discrete subsystem input is not a parameter") - istree(par) && operation(par) isa Hold || + iscall(par) && operation(par) isa Hold || error("Continuous subsystem input is not a Hold") if haskey(disc_clocks, par) sym = par @@ -176,7 +176,7 @@ function IndexCache(sys::AbstractSystem) disc_clocks[affect.lhs] = user_affect_clock disc_clocks[default_toterm(affect.lhs)] = user_affect_clock if hasname(affect.lhs) && - (!istree(affect.lhs) || operation(affect.lhs) !== getindex) + (!iscall(affect.lhs) || operation(affect.lhs) !== getindex) disc_clocks[getname(affect.lhs)] = user_affect_clock disc_clocks[getname(default_toterm(affect.lhs))] = user_affect_clock end @@ -190,7 +190,7 @@ function IndexCache(sys::AbstractSystem) disc = unwrap(disc) disc_clocks[disc] = user_affect_clock disc_clocks[default_toterm(disc)] = user_affect_clock - if hasname(disc) && (!istree(disc) || operation(disc) !== getindex) + if hasname(disc) && (!iscall(disc) || operation(disc) !== getindex) disc_clocks[getname(disc)] = user_affect_clock disc_clocks[getname(default_toterm(disc))] = user_affect_clock end @@ -245,7 +245,7 @@ function IndexCache(sys::AbstractSystem) for (j, sym) in enumerate(buffer[btype]) disc_idxs[sym] = (clockidx, i, j) disc_idxs[default_toterm(sym)] = (clockidx, i, j) - if hasname(sym) && (!istree(sym) || operation(sym) !== getindex) + if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) disc_idxs[getname(sym)] = (clockidx, i, j) disc_idxs[getname(default_toterm(sym))] = (clockidx, i, j) end @@ -256,7 +256,7 @@ function IndexCache(sys::AbstractSystem) haskey(disc_idxs, sym) && continue disc_idxs[sym] = (clockid, 0, 0) disc_idxs[default_toterm(sym)] = (clockid, 0, 0) - if hasname(sym) && (!istree(sym) || operation(sym) !== getindex) + if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) disc_idxs[getname(sym)] = (clockid, 0, 0) disc_idxs[getname(default_toterm(sym))] = (clockid, 0, 0) end diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index aec2287e62..ad6e5de368 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -362,7 +362,7 @@ function SymbolicIndexingInterface.set_parameter!( if validate_size && size(val) !== size(p.discrete[i][j][k]) throw(InvalidParameterSizeException(size(p.discrete[i][j][k]), size(val))) end - p.discrete[i][j][k][l...] = val + p.discrete[i][j][k] = val else p.discrete[i][j][k][l...] = val end @@ -562,7 +562,8 @@ end Base.size(::NestedGetIndex) = () function SymbolicIndexingInterface.with_updated_parameter_timeseries_values( - ps::MTKParameters, args::Pair{A, B}...) where {A, B <: NestedGetIndex} + ::AbstractSystem, ps::MTKParameters, args::Pair{A, B}...) where { + A, B <: NestedGetIndex} for (i, val) in args ps.discrete[i] = val.x end diff --git a/test/mtkparameters.jl b/test/mtkparameters.jl index 095e6172d8..d2c6ee48b1 100644 --- a/test/mtkparameters.jl +++ b/test/mtkparameters.jl @@ -2,6 +2,7 @@ using ModelingToolkit using ModelingToolkit: t_nounits as t, D_nounits as D, MTKParameters using SymbolicIndexingInterface using SciMLStructures: SciMLStructures, canonicalize, Tunable, Discrete, Constants +using StaticArrays: SizedVector using OrdinaryDiffEq using ForwardDiff using JET @@ -292,29 +293,10 @@ end end # Parameter timeseries -# dt = 0.1 -# dt2 = 0.2 -# @variables x(t)=0 y(t)=0 u(t)=0 yd1(t)=0 ud1(t)=0 yd2(t)=0 ud2(t)=0 -# @parameters kp=1 r=1 - -# eqs = [ -# # controller (time discrete part `dt=0.1`) -# yd1 ~ Sample(t, dt)(y) -# ud1 ~ kp * (r - yd1) -# # controller (time discrete part `dt=0.2`) -# yd2 ~ Sample(t, dt2)(y) -# ud2 ~ kp * (r - yd2) - -# # plant (time continuous part) -# u ~ Hold(ud1) + Hold(ud2) -# D(x) ~ -x + u -# y ~ x] - -# @mtkbuild cl = ODESystem(eqs, t) -ps = MTKParameters(([1.0, 1.0],), SizedArray{2}([([0.0, 0.0],), ([0.0, 0.0],)]), (), (), (), nothing, nothing) -# ps = MTKParameters(cl, [kp => 1.0]) +ps = MTKParameters(([1.0, 1.0],), SizedVector{2}([([0.0, 0.0],), ([0.0, 0.0],)]), + (), (), (), nothing, nothing) with_updated_parameter_timeseries_values( - ps, 1 => ModelingToolkit.NestedGetIndex(([5.0, 10.0],))) + sys, ps, 1 => ModelingToolkit.NestedGetIndex(([5.0, 10.0],))) @test ps.discrete[1][1] == [5.0, 10.0] with_updated_parameter_timeseries_values( ps, 1 => ModelingToolkit.NestedGetIndex(([3.0, 30.0],)), @@ -324,27 +306,9 @@ with_updated_parameter_timeseries_values( @test SciMLBase.get_saveable_values(ps, 1).x == ps.discrete[1] # With multiple types and clocks -# @variables x(t) xd1(t) xd2(t) flag(t)::Bool yd1(t) yd2(t) yc1(t) yc2(t) -# dt = 0.1 -# k1 = ShiftIndex(t, dt) -# ssc = ModelingToolkit.SolverStepClock(t) -# k2 = ShiftIndex(ssc) - -# eqs = [ -# flag ~ ~flag(k1 - 1), -# xd1 ~ Sample(t, dt)(x), -# yd1 ~ ifelse(flag, xd1, yd1(k1 - 1)), xd2 ~ Sample(ssc)(x), -# yd2 ~ yd2(k2 - 1) + xd2, yc1 ~ Hold(yd1), -# yc2 ~ Hold(yd2), -# D(x) ~ yc1 + yc2 -# ] -# @mtkbuild sys = ODESystem(eqs, t) -# ps = MTKParameters(sys, -# [flag => true, yd1 => ifelse(flag, Sample(t, dt)(x), 1.0), -# yd2 => 2.0 + Sample(ssc)(x), Sample(t, dt)(x) => x, -# Sample(ssc)(x) => x, Hold(yd1) => yd1, Hold(yd2) => yd2], -# [x => 3.0]) -ps = MTKParameters((), SizedVector{2}([([1.0, 2.0, 3.0], falses(1)), ([4.0, 5.0, 6.0], falses(0))]), (), (), (), nothing, nothing) +ps = MTKParameters( + (), SizedVector{2}([([1.0, 2.0, 3.0], falses(1)), ([4.0, 5.0, 6.0], falses(0))]), + (), (), (), nothing, nothing) @test SciMLBase.get_saveable_values(ps, 1).x isa Tuple{Vector{Float64}, BitVector} # tsidx1 = timeseries_parameter_index(sys, flag).timeseries_idx # tsidx2 = 3 - tsidx1 @@ -355,6 +319,6 @@ tsidx2 = 2 @test length(ps.discrete[tsidx2][1]) == 3 @test length(ps.discrete[tsidx2][2]) == 0 with_updated_parameter_timeseries_values( - ps, tsidx1 => ModelingToolkit.NestedGetIndex(([10.0, 11.0, 12.0], [false]))) + sys, ps, tsidx1 => ModelingToolkit.NestedGetIndex(([10.0, 11.0, 12.0], [false]))) @test ps.discrete[tsidx1][1] == [10.0, 11.0, 12.0] @test ps.discrete[tsidx1][2][] == false diff --git a/test/parameter_dependencies.jl b/test/parameter_dependencies.jl index 815c63cb59..242be8f1d7 100644 --- a/test/parameter_dependencies.jl +++ b/test/parameter_dependencies.jl @@ -173,18 +173,21 @@ end @test_skip begin Tf = 1.0 prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf), - [kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0]) + [kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; + yd(k - 2) => 2.0]) @test_nowarn solve(prob, Tsit5()) @mtkbuild sys = ODESystem(eqs, t; parameter_dependencies = [kq => 2kp], discrete_events = [[0.5] => [kp ~ 2.0]]) prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf), - [kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0]) + [kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; + yd(k - 2) => 2.0]) @test prob.ps[kp] == 1.0 @test prob.ps[kq] == 2.0 @test_nowarn solve(prob, Tsit5()) prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf), - [kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0]) + [kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; + yd(k - 2) => 2.0]) integ = init(prob, Tsit5()) @test integ.ps[kp] == 1.0 @test integ.ps[kq] == 2.0 diff --git a/test/symbolic_indexing_interface.jl b/test/symbolic_indexing_interface.jl index 70963e3371..10d24fd6f2 100644 --- a/test/symbolic_indexing_interface.jl +++ b/test/symbolic_indexing_interface.jl @@ -38,12 +38,12 @@ using SciMLStructures: Tunable odesys = complete(odesys) @test default_values(odesys)[xy] == 3.0 pobs = parameter_observed(odesys, a + b) - @test pobs.timeseries_idx === nothing - @test pobs.observed_fn( + @test isempty(get_all_timeseries_indexes(odesys, a + b)) + @test pobs( ModelingToolkit.MTKParameters(odesys, [a => 1.0, b => 2.0]), 0.0) ≈ 3.0 pobs = parameter_observed(odesys, [a + b, a - b]) - @test pobs.timeseries_idx === nothing - @test pobs.observed_fn( + @test isempty(get_all_timeseries_indexes(odesys, [a + b, a - b])) + @test pobs( ModelingToolkit.MTKParameters(odesys, [a => 1.0, b => 2.0]), 0.0) ≈ [3.0, -1.0] end @@ -102,11 +102,11 @@ end @test !is_time_dependent(ns) ps = ModelingToolkit.MTKParameters(ns, [σ => 1.0, ρ => 2.0, β => 3.0]) pobs = parameter_observed(ns, σ + ρ) - @test pobs.timeseries_idx === nothing - @test pobs.observed_fn(ps) == 3.0 + @test isempty(get_all_timeseries_indexes(ns, σ + ρ)) + @test pobs(ps) == 3.0 pobs = parameter_observed(ns, [σ + ρ, ρ + β]) - @test pobs.timeseries_idx === nothing - @test pobs.observed_fn(ps) == [3.0, 5.0] + @test isempty(get_all_timeseries_indexes(ns, [σ + ρ, ρ + β])) + @test pobs(ps) == [3.0, 5.0] end @testset "PDESystem" begin @@ -147,6 +147,11 @@ end domains = [t ∈ (0.0, 1.0), x ∈ (0.0, 1.0)] + analytic = [u(t, x) ~ -h * x * (x - 1) * sin(x) * exp(-2 * h * t)] + analytic_function = (ps, t, x) -> -ps[1] * x * (x - 1) * sin(x) * exp(-2 * ps[1] * t) + + @named pdesys = PDESystem(eq, bcs, domains, [t, x], [u], [h], analytic = analytic) + @test isequal(pdesys.ps, [h]) @test isequal(parameter_symbols(pdesys), [h]) @test isequal(parameters(pdesys), [h]) @@ -179,12 +184,4 @@ get_dep = @test_nowarn getu(prob, 2p1) @test getu(prob, z)(prob) == getu(prob, :z)(prob) @test getu(prob, p1)(prob) == getu(prob, :p1)(prob) @test getu(prob, p2)(prob) == getu(prob, :p2)(prob) - analytic = [u(t, x) ~ -h * x * (x - 1) * sin(x) * exp(-2 * h * t)] - analytic_function = (ps, t, x) -> -ps[1] * x * (x - 1) * sin(x) * exp(-2 * ps[1] * t) - - @named pdesys = PDESystem(eq, bcs, domains, [t, x], [u], [h], analytic = analytic) - - @test isequal(pdesys.ps, [h]) - @test isequal(parameter_symbols(pdesys), [h]) - @test isequal(parameters(pdesys), [h]) end