Skip to content

Commit

Permalink
fixup! refactor: store parameters from different clock partitions sep…
Browse files Browse the repository at this point in the history
…arately
  • Loading branch information
AayushSabharwal committed May 27, 2024
1 parent 9a263de commit c79851f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 18 deletions.
34 changes: 17 additions & 17 deletions src/systems/index_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,32 +91,32 @@ function IndexCache(sys::AbstractSystem)

for (i, (inps, disc_sys)) in enumerate(zip(inputs, syss))
i == continuous_id && continue
disc_buffers[i - 1] = Dict{Any, Set{BasicSymbolic}}()
disc_buffers[i] = Dict{Any, Set{BasicSymbolic}}()

for inp in inps
inp = unwrap(inp)
is_parameter(sys, inp) ||
error("Discrete subsystem $i input $inp is not a parameter")
disc_clocks[inp] = i - 1
disc_clocks[default_toterm(inp)] = i - 1
disc_clocks[inp] = i
disc_clocks[default_toterm(inp)] = i
if hasname(inp) && (!istree(inp) || operation(inp) !== getindex)
disc_clocks[getname(inp)] = i - 1
disc_clocks[default_toterm(inp)] = i - 1
disc_clocks[getname(inp)] = i
disc_clocks[default_toterm(inp)] = i
end
insert_by_type!(disc_buffers[i - 1], inp)
insert_by_type!(disc_buffers[i], inp)
end

for sym in unknowns(disc_sys)
sym = unwrap(sym)
is_parameter(sys, sym) ||
error("Discrete subsystem $i unknown $sym is not a parameter")
disc_clocks[sym] = i - 1
disc_clocks[default_toterm(sym)] = i - 1
disc_clocks[sym] = i
disc_clocks[default_toterm(sym)] = i
if hasname(sym) && (!istree(sym) || operation(sym) !== getindex)
disc_clocks[getname(sym)] = i - 1
disc_clocks[getname(default_toterm(sym))] = i - 1
disc_clocks[getname(sym)] = i
disc_clocks[getname(default_toterm(sym))] = i
end
insert_by_type!(disc_buffers[i - 1], sym)
insert_by_type!(disc_buffers[i], sym)
end
t = get_iv(sys)
for eq in observed(disc_sys)
Expand All @@ -127,12 +127,12 @@ function IndexCache(sys::AbstractSystem)
if istree(sym) && operation(sym) == Shift(t, 1)
sym = only(arguments(sym))
end
disc_clocks[sym] = i - 1
disc_clocks[sym] = i - 1
disc_clocks[default_toterm(sym)] = i - 1
disc_clocks[sym] = i
disc_clocks[sym] = i
disc_clocks[default_toterm(sym)] = i
if hasname(sym) && (!istree(sym) || operation(sym) !== getindex)
disc_clocks[getname(sym)] = i - 1
disc_clocks[getname(default_toterm(sym))] = i - 1
disc_clocks[getname(sym)] = i
disc_clocks[getname(default_toterm(sym))] = i
end
end
end
Expand All @@ -154,7 +154,7 @@ function IndexCache(sys::AbstractSystem)
end

affs = vcat(affects(continuous_events(sys)), affects(discrete_events(sys)))
user_affect_clock = maximum(values(disc_clocks); init = 1)
user_affect_clock = maximum(values(disc_clocks); init = 0) + 1
for affect in affs
if affect isa Equation
is_parameter(sys, affect.lhs) || continue
Expand Down
4 changes: 3 additions & 1 deletion src/systems/parameter_buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -449,13 +449,15 @@ function SciMLBase.create_parameter_timeseries_collection(
buffers = []

for (i, partition) in enumerate(ps.discrete)
clock = id_to_clock[i + 1]
clock = id_to_clock[i]
if clock isa Clock
ts = tspan[1]:(clock.dt):tspan[2]
push!(buffers, DiffEqArray(NestedGetIndex{typeof(partition)}[], ts, (1, 1)))
elseif clock isa SolverStepClock
push!(buffers,
DiffEqArray(NestedGetIndex{typeof(partition)}[], eltype(tspan)[], (1, 1)))
elseif clock isa Continuous
continue
else
error("Unhandled clock $clock")
end
Expand Down

0 comments on commit c79851f

Please sign in to comment.