diff --git a/src/systems/clock_inference.jl b/src/systems/clock_inference.jl index d1358c10b6..1c59d3259b 100644 --- a/src/systems/clock_inference.jl +++ b/src/systems/clock_inference.jl @@ -186,10 +186,13 @@ function split_system(ci::ClockInference{S}) where {S} end tss[id] = ts_i end - tss[continuous_id], tss[1] = tss[1], tss[continuous_id] - inputs[continuous_id], inputs[1] = inputs[1], inputs[continuous_id] - id_to_clock[continuous_id], id_to_clock[1] = id_to_clock[1], id_to_clock[continuous_id] - continuous_id = 1 + if continuous_id !== 0 + tss[continuous_id], tss[end] = tss[end], tss[continuous_id] + inputs[continuous_id], inputs[end] = inputs[end], inputs[continuous_id] + id_to_clock[continuous_id], id_to_clock[end] = id_to_clock[end], + id_to_clock[continuous_id] + continuous_id = lastindex(tss) + end return tss, inputs, continuous_id, id_to_clock end @@ -267,7 +270,7 @@ function generate_discrete_affect( let_block) |> toexpr cont_to_disc_idxs = [parameter_index(osys, sym) for sym in input] disc_range = [parameter_index(osys, sym) for sym in unknowns(sys)] - save_expr = :($(SciMLBase.save_discretes!)(integrator, $(i - 1))) + save_expr = :($(SciMLBase.save_discretes!)(integrator, $i)) empty_disc = isempty(disc_range) # @show disc_to_cont_idxs @@ -290,13 +293,13 @@ function generate_discrete_affect( # d2c comes last # @show t # @show "incoming", p - result = c2d_obs(integrator.u, p..., t) + result = c2d_obs(u, p..., t) for (val, i) in zip(result, $cont_to_disc_idxs) $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) end $(if !empty_disc quote - disc(disc_unknowns, integrator.u, p..., t) + disc(disc_unknowns, u, p..., t) for (val, i) in zip(disc_unknowns, $disc_range) $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) end @@ -316,6 +319,7 @@ function generate_discrete_affect( $(SciMLStructures.Discrete()), p) repack(discretes) end) + push!(affect_funs, affect!) end if eval_expression diff --git a/test/clock.jl b/test/clock.jl index d1e7fdd56d..91f96ce201 100644 --- a/test/clock.jl +++ b/test/clock.jl @@ -64,10 +64,11 @@ By inference: ci, varmap = infer_clocks(sys) eqmap = ci.eq_domain -tss, inputs = ModelingToolkit.split_system(deepcopy(ci)) -sss, = ModelingToolkit._structural_simplify!(deepcopy(tss[1]), (inputs[1], ())) +tss, inputs, continuous_id = ModelingToolkit.split_system(deepcopy(ci)) +sss, = ModelingToolkit._structural_simplify!( + deepcopy(tss[continuous_id]), (inputs[continuous_id], ())) @test equations(sss) == [D(x) ~ u - x] -sss, = ModelingToolkit._structural_simplify!(deepcopy(tss[2]), (inputs[2], ())) +sss, = ModelingToolkit._structural_simplify!(deepcopy(tss[1]), (inputs[1], ())) @test isempty(equations(sss)) d = Clock(t, dt) k = ShiftIndex(d)