Skip to content

Commit

Permalink
Merge pull request #809 from AayushSabharwal/as/symbolic-save-idxs
Browse files Browse the repository at this point in the history
feat: support symbolic indexing of a subset of the system
  • Loading branch information
ChrisRackauckas authored Oct 23, 2024
2 parents 3c1211a + c219343 commit 7cb1379
Show file tree
Hide file tree
Showing 9 changed files with 648 additions and 159 deletions.
1 change: 1 addition & 0 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,7 @@ include("problems/problem_interface.jl")
include("problems/optimization_problems.jl")

include("clock.jl")
include("solutions/save_idxs.jl")
include("solutions/basic_solutions.jl")
include("solutions/nonlinear_solutions.jl")
include("solutions/ode_solutions.jl")
Expand Down
17 changes: 10 additions & 7 deletions src/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -331,19 +331,22 @@ Otherwise the integrator is allowed to skip recalculating the interpolation.
# Arguments
- `continuous_modification`: determines whether the modification is due to a continuous change (continuous callback)
or a discrete callback. For a continuous change, this can include a change to time which requires a re-evaluation
of the interpolations.
- `callback_initializealg`: the initialization algorithm provided by the callback. For DAEs, this is the choice for the
initialization that is done post callback. The default value of `nothing` means that the initialization choice
used for the DAE should be performed post-callback.
- `continuous_modification`: determines whether the modification is due to a continuous change (continuous callback)
or a discrete callback. For a continuous change, this can include a change to time which requires a re-evaluation
of the interpolations.
- `callback_initializealg`: the initialization algorithm provided by the callback. For DAEs, this is the choice for the
initialization that is done post callback. The default value of `nothing` means that the initialization choice
used for the DAE should be performed post-callback.
"""
function reeval_internals_due_to_modification!(
integrator::DEIntegrator, continuous_modification;
callback_initializealg = nothing)
reeval_internals_due_to_modification!(integrator::DEIntegrator)
end
reeval_internals_due_to_modification!(integrator::DEIntegrator; callback_initializealg = nothing) = nothing
function reeval_internals_due_to_modification!(
integrator::DEIntegrator; callback_initializealg = nothing)
nothing
end

"""
set_t!(integrator::DEIntegrator, t)
Expand Down
4 changes: 2 additions & 2 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2693,8 +2693,8 @@ function SplitFunction{iip, specialize}(f1, f2;
f1.jac_prototype :
nothing,
W_prototype = __has_W_prototype(f1) ?
f1.W_prototype :
nothing,
f1.W_prototype :
nothing,
sparsity = __has_sparsity(f1) ? f1.sparsity :
jac_prototype,
Wfact = __has_Wfact(f1) ? f1.Wfact : nothing,
Expand Down
115 changes: 46 additions & 69 deletions src/solutions/dae_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/
exited due to an error. For more details, see
[the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes).
"""
struct DAESolution{T, N, uType, duType, uType2, DType, tType, P, A, ID, S, rateType} <:
struct DAESolution{T, N, uType, duType, uType2, DType, tType, P, A, ID, S, rateType, V} <:
AbstractDAESolution{T, N, uType}
u::uType
du::duType
Expand All @@ -42,6 +42,31 @@ struct DAESolution{T, N, uType, duType, uType2, DType, tType, P, A, ID, S, rateT
tslocation::Int
stats::S
retcode::ReturnCode.T
saved_subsystem::V
end

function DAESolution{T, N}(u, du, u_analytic, errors, t, k, prob, alg, interp, dense,
tslocation, stats, retcode, saved_subsystem) where {T, N}
return DAESolution{T, N, typeof(u), typeof(du), typeof(u_analytic), typeof(errors),
typeof(t), typeof(prob), typeof(alg), typeof(interp), typeof(stats), typeof(k),
typeof(saved_subsystem)}(
u, du, u_analytic, errors, t, k, prob, alg, interp, dense, tslocation, stats,
retcode, saved_subsystem
)
end

function ConstructionBase.constructorof(::Type{O}) where {T, N, O <: DAESolution{T, N}}
DAESolution{T, N}
end

function ConstructionBase.setproperties(sol::DAESolution, patch::NamedTuple)
u = get(patch, :u, sol.u)
N = u === nothing ? 2 : ndims(eltype(u)) + 1
T = eltype(eltype(u))
patch = merge(getproperties(sol), patch)
return DAESolution{T, N}(patch.u, patch.du, patch.u_analytic, patch.errors, patch.t,
patch.k, patch.prob, patch.alg, patch.interp, patch.dense, patch.tslocation,
patch.stats, patch.retcode, patch.saved_subsystem)
end

Base.@propagate_inbounds function Base.getproperty(x::AbstractDAESolution, s::Symbol)
Expand All @@ -65,13 +90,14 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing;
retcode = ReturnCode.Default,
destats = missing,
stats = nothing,
saved_subsystem = nothing,
kwargs...)
T = eltype(eltype(u))

if prob.u0 === nothing
N = 2
else
N = length((size(prob.u0)..., length(u)))
N = ndims(eltype(u)) + 1
end

if !ismissing(destats)
Expand All @@ -88,7 +114,8 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing;
errors = Dict{Symbol, real(eltype(prob.u0))}()

sol = DAESolution{T, N, typeof(u), typeof(du), typeof(u_analytic), typeof(errors),
typeof(t), typeof(prob), typeof(alg), typeof(interp), typeof(stats), typeof(k)}(
typeof(t), typeof(prob), typeof(alg), typeof(interp), typeof(stats), typeof(k),
typeof(saved_subsystem)}(
u,
du,
u_analytic,
Expand All @@ -101,7 +128,8 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing;
dense,
0,
stats,
retcode)
retcode,
saved_subsystem)

if calculate_error
calculate_solution_errors!(sol; timeseries_errors = timeseries_errors,
Expand All @@ -110,15 +138,17 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing;
sol
else
DAESolution{T, N, typeof(u), typeof(du), Nothing, Nothing, typeof(t),
typeof(prob), typeof(alg), typeof(interp), typeof(stats), typeof(k)}(
typeof(prob), typeof(alg), typeof(interp), typeof(stats), typeof(k),
typeof(saved_subsystem)}(
u, du,
nothing,
nothing, t, k,
prob, alg,
interp,
dense, 0,
stats,
retcode)
retcode,
saved_subsystem)
end
end

Expand Down Expand Up @@ -161,76 +191,23 @@ function calculate_solution_errors!(sol::AbstractDAESolution;
end

function build_solution(sol::AbstractDAESolution{T, N}, u_analytic, errors) where {T, N}
DAESolution{T, N, typeof(sol.u), typeof(sol.du), typeof(u_analytic), typeof(errors),
typeof(sol.t), typeof(sol.prob), typeof(sol.alg), typeof(sol.interp),
typeof(sol.stats), typeof(sol.k)}(sol.u,
sol.du,
u_analytic,
errors,
sol.t,
sol.k,
sol.prob,
sol.alg,
sol.interp,
sol.dense,
sol.tslocation,
sol.stats,
sol.retcode)
@reset sol.u_analytic = u_analytic
return @set sol.errors = errors
end

function solution_new_retcode(sol::AbstractDAESolution{T, N}, retcode) where {T, N}
DAESolution{T, N, typeof(sol.u), typeof(sol.du), typeof(sol.u_analytic),
typeof(sol.errors), typeof(sol.t), typeof(sol.prob), typeof(sol.alg),
typeof(sol.interp), typeof(sol.stats), typeof(sol.k)}(sol.u,
sol.du,
sol.u_analytic,
sol.errors,
sol.t,
sol.k,
sol.prob,
sol.alg,
sol.interp,
sol.dense,
sol.tslocation,
sol.stats,
retcode)
return @set sol.retcode = retcode
end

function solution_new_tslocation(sol::AbstractDAESolution{T, N}, tslocation) where {T, N}
DAESolution{T, N, typeof(sol.u), typeof(sol.du), typeof(sol.u_analytic),
typeof(sol.errors), typeof(sol.t), typeof(sol.prob), typeof(sol.alg),
typeof(sol.interp), typeof(sol.stats), typeof(k)}(sol.u,
sol.du,
sol.u_analytic,
sol.errors,
sol.t,
sol.k,
sol.prob,
sol.alg,
sol.interp,
sol.dense,
tslocation,
sol.stats,
sol.retcode)
return @set sol.tslocation = tslocation
end

function solution_slice(sol::AbstractDAESolution{T, N}, I) where {T, N}
DAESolution{T, N, typeof(sol.u), typeof(sol.du), typeof(sol.u_analytic),
typeof(sol.errors), typeof(sol.t), typeof(sol.prob), typeof(sol.alg),
typeof(sol.interp), typeof(sol.stats), typeof(sol.k)}(sol.u[I],
sol.du[I],
sol.u_analytic ===
nothing ?
nothing :
sol.u_analytic[I],
sol.errors,
sol.t[I],
sol.k[I],
sol.prob,
sol.alg,
sol.interp,
false,
sol.tslocation,
sol.stats,
sol.retcode)
@reset sol.u = sol.u[I]
@reset sol.du = sol.du[I]
@reset sol.u_analytic = sol.u_analytic === nothing ? nothing : sol.u_analytic[I]
@reset sol.t = sol.t[I]
@reset sol.k = sol.dense ? sol.k[I] : sol.k
return @set sol.dense = false
end
48 changes: 35 additions & 13 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,12 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/
successfully, whether it terminated early due to a user-defined callback, or whether it
exited due to an error. For more details, see
[the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes).
- `saved_subsystem`: a [`SavedSubsystem`](@ref) representing the subset of variables saved
in this solution, or `nothing` if all variables are saved. Here "variables" refers to
both continuous-time state variables and timeseries parameters.
"""
struct ODESolution{T, N, uType, uType2, DType, tType, rateType, discType, P, A, IType, S,
AC <: Union{Nothing, Vector{Int}}, R, O} <:
AC <: Union{Nothing, Vector{Int}}, R, O, V} <:
AbstractODESolution{T, N, uType}
u::uType
u_analytic::uType2
Expand All @@ -124,6 +127,7 @@ struct ODESolution{T, N, uType, uType2, DType, tType, rateType, discType, P, A,
retcode::ReturnCode.T
resid::R
original::O
saved_subsystem::V
end

function ConstructionBase.constructorof(::Type{O}) where {T, N, O <: ODESolution{T, N}}
Expand All @@ -137,7 +141,7 @@ function ConstructionBase.setproperties(sol::ODESolution, patch::NamedTuple)
patch = merge(getproperties(sol), patch)
return ODESolution{T, N}(patch.u, patch.u_analytic, patch.errors, patch.t, patch.k,
patch.discretes, patch.prob, patch.alg, patch.interp, patch.dense, patch.tslocation, patch.stats,
patch.alg_choice, patch.retcode, patch.resid, patch.original)
patch.alg_choice, patch.retcode, patch.resid, patch.original, patch.saved_subsystem)
end

Base.@propagate_inbounds function Base.getproperty(x::AbstractODESolution, s::Symbol)
Expand All @@ -154,12 +158,12 @@ end
function ODESolution{T, N}(
u, u_analytic, errors, t, k, discretes, prob, alg, interp, dense,
tslocation, stats, alg_choice, retcode, resid = nothing,
original = nothing) where {T, N}
original = nothing, saved_subsystem = nothing) where {T, N}
return ODESolution{T, N, typeof(u), typeof(u_analytic), typeof(errors), typeof(t),
typeof(k), typeof(discretes), typeof(prob), typeof(alg), typeof(interp),
typeof(stats), typeof(alg_choice), typeof(resid),
typeof(original)}(u, u_analytic, errors, t, k, discretes, prob, alg, interp,
dense, tslocation, stats, alg_choice, retcode, resid, original)
typeof(stats), typeof(alg_choice), typeof(resid), typeof(original),
typeof(saved_subsystem)}(u, u_analytic, errors, t, k, discretes, prob, alg, interp,
dense, tslocation, stats, alg_choice, retcode, resid, original, saved_subsystem)
end

error_if_observed_derivative(_, _, ::Type{Val{0}}) = nothing
Expand Down Expand Up @@ -409,15 +413,25 @@ const PeriodicDiffEqArray = DiffEqArray{T, N, A, B} where {T, N, A, B <: Abstrac
# public API, used by MTK
"""
get_saveable_values(sys, ps, timeseries_idx)
Return the values to be saved in parameter object `ps` for timeseries index `timeseries_idx`. Called by
`save_discretes!`. If this returns `nothing`, `save_discretes!` will not save anything.
"""
function get_saveable_values(sys, ps, timeseries_idx)
return get_saveable_values(symbolic_container(sys), ps, timeseries_idx)
end

"""
save_discretes!(integ::DEIntegrator, timeseries_idx)
Save the parameter timeseries with index `timeseries_idx`. Calls `get_saveable_values` to
get the values to save. If it returns `nothing`, then the save does not happen.
"""
function save_discretes!(integ::DEIntegrator, timeseries_idx)
save_discretes!(integ.sol, current_time(integ),
get_saveable_values(integ, parameter_values(integ), timeseries_idx),
timeseries_idx)
inner_sol = get_sol(integ)
vals = get_saveable_values(inner_sol, parameter_values(integ), timeseries_idx)
vals === nothing && return
save_discretes!(integ.sol, current_time(integ), vals, timeseries_idx)
end

save_discretes!(args...) = nothing
Expand Down Expand Up @@ -451,6 +465,7 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
interp = LinearInterpolation(t, u),
retcode = ReturnCode.Default, destats = missing, stats = nothing,
resid = nothing, original = nothing,
saved_subsystem = nothing,
kwargs...)
T = eltype(eltype(u))

Expand Down Expand Up @@ -482,7 +497,12 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},

ps = parameter_values(prob)
if has_sys(prob.f)
discretes = create_parameter_timeseries_collection(prob.f.sys, ps, prob.tspan)
sswf = if saved_subsystem === nothing
prob.f.sys
else
SavedSubsystemWithFallback(saved_subsystem, prob.f.sys)
end
discretes = create_parameter_timeseries_collection(sswf, ps, prob.tspan)
else
discretes = nothing
end
Expand All @@ -503,7 +523,8 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
alg_choice,
retcode,
resid,
original)
original,
saved_subsystem)
if calculate_error
calculate_solution_errors!(sol; timeseries_errors = timeseries_errors,
dense_errors = dense_errors)
Expand All @@ -524,7 +545,8 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
alg_choice,
retcode,
resid,
original)
original,
saved_subsystem)
end
end

Expand Down Expand Up @@ -593,7 +615,7 @@ function solution_slice(sol::ODESolution{T, N}, I) where {T, N}
@reset sol.u_analytic = sol.u_analytic === nothing ? nothing : sol.u_analytic[I]
@reset sol.t = sol.t[I]
@reset sol.k = sol.dense ? sol.k[I] : sol.k
return @set sol.alg = false
return @set sol.dense = false
end

mask_discretes(::Nothing, _, _...) = nothing
Expand Down
Loading

0 comments on commit 7cb1379

Please sign in to comment.