diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index e65ad71cd..caaa83a03 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -154,6 +154,16 @@ function SymbolicIndexingInterface.is_parameter_timeseries(::Type{S}) where { Timeseries() end +function get_interpolated_discretes(sol::AbstractODESolution, t, deriv, continuity) + is_parameter_timeseries(sol) == Timeseries() || return nothing + + discs::ParameterTimeseriesCollection = RecursiveArrayTools.get_discretes(sol) + interp_discs = map(discs) do partition + ConstantInterpolation(partition.t, partition.u)(t, nothing, deriv, nothing, continuity) + end + return ParameterTimeseriesCollection(interp_discs, parameter_values(discs)) +end + function (sol::AbstractODESolution)(t, ::Type{deriv} = Val{0}; idxs = nothing, continuity = :left) where {deriv} sol(t, deriv, idxs, continuity) @@ -170,14 +180,7 @@ end function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs::Nothing, continuity) where {deriv} - if is_parameter_timeseries(sol) == Timeseries() - discs = RecursiveArrayTools.get_discretes(sol) - interp_discs = ConstantInterpolation(discs.t, discs.u) - discretes = interp_discs(t, nothing, deriv, parameter_values(sol), continuity) - else - discretes = nothing - end - + discretes = get_interpolated_discretes(sol, t, deriv, continuity) augment(sol.interp(t, idxs, deriv, sol.prob.p, continuity), sol; discretes) end @@ -214,14 +217,19 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs, continuity) where {deriv} symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`") ps = parameter_values(sol) + # NOTE: This is basically SII.parameter_values_at_time but that isn't public API + # and once we move interpolation to SII, there's no reason for it to be if is_parameter_timeseries(sol) == Timeseries() - discs = RecursiveArrayTools.get_discretes(sol) - interp_discs = ConstantInterpolation(discs.t, discs.u) - discretes = interp_discs(t, nothing, deriv, parameter_values(sol), continuity) - ps = SciMLStructures.replace(SciMLStructures.Discrete(), ps, discretes) + discs::ParameterTimeseriesCollection = RecursiveArrayTools.get_discretes(sol) + ps = parameter_values(discs) + for ts_idx in eachindex(discs) + partition = discs[ts_idx] + interp_val = ConstantInterpolation(partition.t, partition.u)(t, nothing, deriv, nothing, continuity) + ps = with_updated_parameter_timeseries_values(ps, ts_idx => interp_val) + end end - interp_sol = augment(sol.interp([t], nothing, deriv, ps, continuity), sol) - return getu(interp_sol, idxs)(interp_sol, 1) + state = ProblemState(; u = sol.interp(t, nothing, deriv, ps, continuity), p = ps, t) + return getu(sol, idxs)(state) end function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVector, @@ -229,27 +237,26 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVect all(!isequal(NotSymbolic()), symbolic_type.(idxs)) || error("Incorrect specification of `idxs`") ps = parameter_values(sol) + # NOTE: This is basically SII.parameter_values_at_time but that isn't public API + # and once we move interpolation to SII, there's no reason for it to be if is_parameter_timeseries(sol) == Timeseries() - discs = RecursiveArrayTools.get_discretes(sol) - interp_discs = ConstantInterpolation(discs.t, discs.u) - discretes = interp_discs(t, nothing, deriv, parameter_values(sol), continuity) - ps = SciMLStructures.replace(SciMLStructures.Discrete(), ps, discretes) + discs::ParameterTimeseriesCollection = RecursiveArrayTools.get_discretes(sol) + ps = parameter_values(discs) + for ts_idx in eachindex(discs) + partition = discs[ts_idx] + interp_val = ConstantInterpolation(partition.t, partition.u)(t, nothing, deriv, nothing, continuity) + ps = with_updated_parameter_timeseries_values(ps, ts_idx => interp_val) + end end - interp_sol = augment(sol.interp([t], nothing, deriv, ps, continuity), sol) - first(getu(interp_sol, idxs)(interp_sol)) + state = ProblemState(; u = sol.interp(t, nothing, deriv, ps, continuity), p = ps, t) + return getu(sol, idxs)(state) end function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs, continuity) where {deriv} symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`") p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing - if is_parameter_timeseries(sol) == Timeseries() - discs = RecursiveArrayTools.get_discretes(sol) - interp_discs = ConstantInterpolation(discs.t, discs.u) - discretes = interp_discs(t, nothing, deriv, parameter_values(sol), continuity) - else - discretes = nothing - end + discretes = get_interpolated_discretes(sol, t, deriv, continuity) interp_sol = augment(sol.interp(t, nothing, deriv, p, continuity), sol; discretes) return DiffEqArray(getu(interp_sol, idxs)(interp_sol), t, p, sol; discretes) end @@ -259,18 +266,57 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, all(!isequal(NotSymbolic()), symbolic_type.(idxs)) || error("Incorrect specification of `idxs`") p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing - if is_parameter_timeseries(sol) == Timeseries() - discs = RecursiveArrayTools.get_discretes(sol) - interp_discs = ConstantInterpolation(discs.t, discs.u) - discretes = interp_discs(t, nothing, deriv, parameter_values(sol), continuity) - else - discretes = nothing - end + discretes = get_interpolated_discretes(sol, t, deriv, continuity) interp_sol = augment(sol.interp(t, nothing, deriv, p, continuity), sol; discretes) return DiffEqArray( getu(interp_sol, idxs)(interp_sol), t, p, sol; discretes) end +# public API, used by MTK +""" + create_parameter_timeseries_collection(sys, ps) + +Create a `SymbolicIndexingInterface.ParameterTimeseriesCollection` for the given system +`sys` and parameter object `ps`. Return `nothing` if there are no timeseries parameters. +Defaults to `nothing`. +""" +function create_parameter_timeseries_collection(sys, ps, tspan) + return nothing +end + +const PeriodicDiffEqArray = DiffEqArray{T, N, A, B} where {T, N, A, B <: AbstractRange} + +# public API, used by MTK +""" + get_saveable_values(ps, timeseries_idx) +""" +function get_saveable_values end + +function save_discretes!(integ::DEIntegrator, timeseries_idx) + save_discretes!(integ.sol, current_time(integ), get_saveable_values(parameter_values(integ), timeseries_idx), timeseries_idx) +end + +save_discretes!(args...) = nothing + +# public API, used by MTK +function save_discretes!(sol::AbstractODESolution, t, vals, timeseries_idx) + RecursiveArrayTools.has_discretes(sol) || return + disc = RecursiveArrayTools.get_discretes(sol) + _save_discretes_internal!(disc[timeseries_idx], t, vals) +end + +function _save_discretes_internal!(A::AbstractDiffEqArray, t, vals) + push!(A.t, t) + push!(A.u, vals) +end + +function _save_discretes_internal!(A::PeriodicDiffEqArray, t, vals) + if all(!isapprox(t), A.t) + error("Tried to save periodic discrete value with timeseries $(A.t) at time $t") + end + push!(A.u, vals) +end + function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, alg, t, u; timeseries_errors = length(u) > 2, dense = false, dense_errors = dense,