Skip to content

Commit

Permalink
A few improvements of TerminateSteadyState
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Nov 28, 2024
1 parent c08609c commit b77a021
Showing 1 changed file with 20 additions and 12 deletions.
32 changes: 20 additions & 12 deletions src/terminatesteadystate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
# Terminate when all derivatives fall below a threshold or
# when derivatives are smaller than a fraction of state
function allDerivPass(integrator, abstol, reltol, min_t)
# Early exit
if min_t !== nothing && integrator.t < min_t
return false
end

if DiffEqBase.isinplace(integrator.sol.prob)
testval = first(get_tmp_cache(integrator))
DiffEqBase.get_du!(testval, integrator)
Expand All @@ -16,21 +21,25 @@ function allDerivPass(integrator, abstol, reltol, min_t)
end

if integrator.u isa Array
any(abs(d) > abstol && abs(d) > reltol * abs(u)
any(abs(d) > abstol || abs(d) > reltol * abs(u)
for (d, abstol, reltol, u) in zip(testval, Iterators.cycle(abstol),
Iterators.cycle(reltol), integrator.u)) &&
(return false)
else
any((abs.(testval) .> abstol) .& (abs.(testval) .> reltol .* abs.(integrator.u))) &&
any((abs.(testval) .> abstol) .|| (abs.(testval) .> reltol .* abs.(integrator.u))) &&
(return false)
end

if min_t === nothing
return true
else
return integrator.t >= min_t
end
return true
end

struct WrappedTest{T,A,R,M}
test::T
abstol::A
reltol::R
min_t::M
end
(f::WrappedTest)(u, t, integrator) = f.test(integrator, f.abstol, f.reltol, f.min_t)

# Allow user-defined tolerances and test functions but use sensible defaults
# test function must take integrator, time, followed by absolute
Expand Down Expand Up @@ -62,15 +71,14 @@ the [Steady State Solvers](https://docs.sciml.ai/DiffEqDocs/stable/solvers/stead
- `min_t` specifies an optional minimum `t` before the steady state calculations are allowed
to terminate.
"""
function TerminateSteadyState(abstol = 1e-8, reltol = 1e-6, test = allDerivPass;
min_t = nothing, wrap_test::Val{WT} = Val(true)) where {WT}
function TerminateSteadyState(abstol = 1e-8, reltol = 1e-6, test::T = allDerivPass;
min_t = nothing, wrap_test::Val{WT} = Val(true)) where {T,WT}
condition = if WT
(u, t, integrator) -> test(integrator, abstol, reltol, min_t)
WrappedTest(test, abstol, reltol, min_t)
else
test
end
affect! = (integrator) -> terminate!(integrator)
DiscreteCallback(condition, affect!; save_positions = (true, false))
DiscreteCallback(condition, terminate!; save_positions = (true, false))
end

export TerminateSteadyState

0 comments on commit b77a021

Please sign in to comment.