Skip to content

Commit

Permalink
Reset the indices in the PresetTimeCallback
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Mar 11, 2024
1 parent 2a46de7 commit 7ecf6a4
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/preset_time.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ function PresetTimeCallback(tstops, user_affect!;
# Initialization: first call to `f` should be *before* any time steps have been taken:
initialize_preset = function (c, u, t, integrator)
initialize(c, u, t, integrator)
if tstops isa AbstractVector
search_start, search_end = firstindex(tstops), lastindex(tstops)

Check warning on line 60 in src/preset_time.jl

View check run for this annotation

Codecov / codecov/patch

src/preset_time.jl#L59-L60

Added lines #L59 - L60 were not covered by tests
else
search_start, search_end = 0, 0

Check warning on line 62 in src/preset_time.jl

View check run for this annotation

Codecov / codecov/patch

src/preset_time.jl#L62

Added line #L62 was not covered by tests
end

if filter_tstops
tdir = integrator.tdir
_tstops = tstops[@.((tdir * tstops >
Expand Down
30 changes: 30 additions & 0 deletions test/preset_time.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,33 @@ cb = PresetTimeCallback([-0.2], integrator -> begin
end, filter_tstops = false)
sol = solve(prob, Tsit5(), callback = cb)
@test !notcalled

# Test indexes reset
# https://github.com/SciML/DifferentialEquations.jl/issues/1022

function mod(du, u, p, t)
du[1] = -p[1]*u[1]
end

p = [1.0]
u0 = [10.0]
tspan = (0.0,72.0)

times1 = 0.0:24.0:tspan[2]
times2 = 24.0:24.0:tspan[2]
affect!(integrator) = integrator.u[1] += 10.0
cb1 = PresetTimeCallback(times1, affect!)
cb2 = PresetTimeCallback(times2, affect!)

prob1 = ODEProblem(mod, u0, tspan, p, callback=cb1)
prob2 = ODEProblem(mod, u0, tspan, p)

sol1 = solve(prob1)
sol2 = solve(prob2, callback=cb1)

@test sol1(0.0) == [10.0]
@test sol1(24.0 + eps(24.0)) [10.0]
@test sol1(48.0 + eps(48.0)) [10.0]
@test sol2(0.0) == [10.0]
@test sol2(24.0 + eps(24.0)) [10.0]
@test sol2(48.0 + eps(48.0)) [10.0]

0 comments on commit 7ecf6a4

Please sign in to comment.