Skip to content

Commit

Permalink
Merge pull request #512 from pepijndevos/pv/enstats
Browse files Browse the repository at this point in the history
add stats to ensemble solution
  • Loading branch information
ChrisRackauckas authored Oct 28, 2023
2 parents a29ecf1 + 35f5f06 commit ff42d8b
Show file tree
Hide file tree
Showing 9 changed files with 149 additions and 13 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand Down
24 changes: 23 additions & 1 deletion ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ module SciMLBaseZygoteExt

using Zygote: pullback
using ZygoteRules: @adjoint
using SciMLBase: ODESolution, issymbollike, sym_to_index, remake, getobserved
import ZygoteRules
using SciMLBase: EnsembleSolution, ODESolution, issymbollike, sym_to_index, remake, getobserved

# This method resolves the ambiguity with the pullback defined in
# RecursiveArrayToolsZygoteExt
Expand Down Expand Up @@ -55,4 +56,25 @@ end
VA[sym, j], ODESolution_getindex_pullback
end

ZygoteRules.@adjoint function EnsembleSolution(sim, time, converged, stats)
out = EnsembleSolution(sim, time, converged)
function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N}
arrarr = [[p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i]
for j in 1:size(p̄)[end - 1]] for i in 1:size(p̄)[end]]
(EnsembleSolution(arrarr, 0.0, true), nothing, nothing, nothing)
end
function EnsembleSolution_adjoint(p̄::AbstractArray{<:AbstractArray, 1})
(EnsembleSolution(p̄, 0.0, true), nothing, nothing, nothing)
end
function EnsembleSolution_adjoint(p̄::EnsembleSolution)
(p̄, nothing, nothing, nothing)
end
out, EnsembleSolution_adjoint
end

ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(sim::EnsembleSolution,
::Val{:u})
sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true),)
end

end
1 change: 1 addition & 0 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using LinearAlgebra
using Statistics
using Distributed
using Markdown
using Printf
import Preferences

import Logging, ArrayInterface
Expand Down
13 changes: 10 additions & 3 deletions src/ensemble/basic_ensemble_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ $(TYPEDEF)
"""
struct EnsembleSerial <: BasicEnsembleAlgorithm end

function merge_stats(us)
st = Iterators.filter(!isnothing, (hasproperty(x, :stats) ? x.stats : nothing for x in us))
isempty(st) && return nothing
reduce(merge, st)
end

function __solve(prob::AbstractEnsembleProblem,
alg::Union{AbstractDEAlgorithm, Nothing};
kwargs...)
Expand Down Expand Up @@ -64,7 +70,8 @@ function __solve(prob::AbstractEnsembleProblem,
elapsed_time = @elapsed u = solve_batch(prob, alg, ensemblealg, 1:trajectories,
pmap_batch_size; kwargs...)
_u = tighten_container_eltype(u)
return EnsembleSolution(_u, elapsed_time, true)
stats = merge_stats(_u)
return EnsembleSolution(_u, elapsed_time, true, stats)
end

converged::Bool = false
Expand All @@ -88,8 +95,8 @@ function __solve(prob::AbstractEnsembleProblem,
end
end
_u = tighten_container_eltype(u)

return EnsembleSolution(_u, elapsed_time, converged)
stats = merge_stats(_u)
return EnsembleSolution(_u, elapsed_time, converged, stats)
end

function batch_func(i, prob, alg; kwargs...)
Expand Down
20 changes: 11 additions & 9 deletions src/ensemble/ensemble_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,23 @@ struct EnsembleSolution{T, N, S} <: AbstractEnsembleSolution{T, N, S}
u::S
elapsedTime::Float64
converged::Bool
stats
end
function EnsembleSolution(sim, dims::NTuple{N}, elapsedTime, converged) where {N}
EnsembleSolution{eltype(eltype(sim)), N, typeof(sim)}(sim, elapsedTime, converged)
function EnsembleSolution(sim, dims::NTuple{N}, elapsedTime, converged, stats) where {N}
EnsembleSolution{eltype(eltype(sim)), N, typeof(sim)}(sim, elapsedTime, converged, stats)
end
function EnsembleSolution(sim, elapsedTime, converged)
EnsembleSolution(sim, (length(sim),), elapsedTime, converged)
function EnsembleSolution(sim, elapsedTime, converged, stats=nothing)
EnsembleSolution(sim, (length(sim),), elapsedTime, converged, stats)
end # Vector of some type which is not an array
function EnsembleSolution(sim::T, elapsedTime,
converged) where {T <: AbstractVector{T2}
converged, stats=nothing) where {T <: AbstractVector{T2}
} where {T2 <:
AbstractArray}
EnsembleSolution{eltype(eltype(sim)), ndims(sim[1]) + 1,
typeof(sim)}(sim,
EnsembleSolution{eltype(eltype(sim)), ndims(sim[1]) + 1, typeof(sim)}(
sim,
elapsedTime,
converged)
converged,
stats)
end

struct WeightedEnsembleSolution{T1 <: AbstractEnsembleSolution, T2 <: Number}
Expand All @@ -56,7 +58,7 @@ struct WeightedEnsembleSolution{T1 <: AbstractEnsembleSolution, T2 <: Number}
end

function Base.reverse(sim::EnsembleSolution)
EnsembleSolution(reverse(sim.u), sim.elapsedTime, sim.converged)
EnsembleSolution(reverse(sim.u), sim.elapsedTime, sim.converged, sim.stats)
end

"""
Expand Down
14 changes: 14 additions & 0 deletions src/solutions/nonlinear_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,20 @@ mutable struct NLStats
nsteps::Int
end

function Base.show(io::IO, ::MIME"text/plain", s::NLStats)
println(io, summary(s))
@printf io "%-50s %-d\n" "Number of function evaluations:" s.nf
@printf io "%-50s %-d\n" "Number of Jacobians created:" s.njacs
@printf io "%-50s %-d\n" "Number of factorizations:" s.nfactors
@printf io "%-50s %-d\n" "Number of linear solves:" s.nsolve
@printf io "%-50s %-d" "Number of nonlinear solver iterations:" s.nsteps
end

function Base.merge(s1::NLStats, s2::NLStats)
NLStats(s1.nf + s2.nf, s1.njacs + s2.njacs, s1.nfactors + s2.nfactors,
s1.nsolve + s2.nsolve, s1.nsteps + s2.nsteps)
end

"""
$(TYPEDEF)
Expand Down
73 changes: 73 additions & 0 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,79 @@
"""
$(TYPEDEF)
Statistics from the differential equation solver about the solution process.
## Fields
- nf: Number of function evaluations. If the differential equation is a split function,
such as a `SplitFunction` for implicit-explicit (IMEX) integration, then `nf` is the
number of function evaluations for the first function (the implicit function)
- nf2: If the differential equation is a split function, such as a `SplitFunction`
for implicit-explicit (IMEX) integration, then `nf2` is the number of function
evaluations for the second function, i.e. the function treated explicitly. Otherwise
it is zero.
- nw: The number of W=I-gamma*J (or W=I/gamma-J) matrices constructed during the solving
process.
- nsolve: The number of linear solves `W\b` required for the integration.
- njacs: Number of Jacobians calculated during the integration.
- nnonliniter: Total number of iterations for the nonlinear solvers.
- nnonlinconvfail: Number of nonlinear solver convergence failures.
- ncondition: Number of calls to the condition function for callbacks.
- naccept: Number of accepted steps.
- nreject: Number of rejected steps.
- maxeig: Maximum eigenvalue over the solution. This is only computed if the
method is an auto-switching algorithm.
"""
mutable struct DEStats
nf::Int
nf2::Int
nw::Int
nsolve::Int
njacs::Int
nnonliniter::Int
nnonlinconvfail::Int
ncondition::Int
naccept::Int
nreject::Int
maxeig::Float64
end

DEStats(x::Int = -1) = DEStats(x, x, x, x, x, x, x, x, x, x, 0.0)

function Base.show(io::IO, ::MIME"text/plain", s::DEStats)
println(io, summary(s))
@printf io "%-50s %-d\n" "Number of function 1 evaluations:" s.nf
@printf io "%-50s %-d\n" "Number of function 2 evaluations:" s.nf2
@printf io "%-50s %-d\n" "Number of W matrix evaluations:" s.nw
@printf io "%-50s %-d\n" "Number of linear solves:" s.nsolve
@printf io "%-50s %-d\n" "Number of Jacobians created:" s.njacs
@printf io "%-50s %-d\n" "Number of nonlinear solver iterations:" s.nnonliniter
@printf io "%-50s %-d\n" "Number of nonlinear solver convergence failures:" s.nnonlinconvfail
@printf io "%-50s %-d\n" "Number of rootfind condition calls:" s.ncondition
@printf io "%-50s %-d\n" "Number of accepted steps:" s.naccept
@printf io "%-50s %-d" "Number of rejected steps:" s.nreject
iszero(s.maxeig) || @printf io "\n%-50s %-d" "Maximum eigenvalue recorded:" s.maxeig
end

function Base.merge(a::DEStats, b::DEStats)
DEStats(
a.nf + b.nf,
a.nf2 + b.nf2,
a.nw + b.nw,
a.nsolve + b.nsolve,
a.njacs + b.njacs,
a.nnonliniter + b.nnonliniter,
a.nnonlinconvfail + b.nnonlinconvfail,
a.ncondition + b.ncondition,
a.naccept + b.naccept,
a.nreject + b.nreject,
max(a.maxeig, b.maxeig),
)
end

"""
$(TYPEDEF)
Representation of the solution to an ordinary differential equation defined by an ODEProblem.
## DESolution Interface
Expand Down
13 changes: 13 additions & 0 deletions test/downstream/ensemble_stats.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using OrdinaryDiffEq
using Test

f(u,p,t) = 1.01*u
u0=1/2
tspan = (0.0,1.0)
prob = ODEProblem(f,u0,tspan)
function prob_func(prob, i, repeat)
remake(prob, u0 = rand() * prob.u0)
end
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 10)
@test sim.stats.nf == mapreduce(x -> x.stats.nf, +, sim.u)
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ end
@time @safetestset "solving Ensembles with multiple problems" begin
include("downstream/ensemble_multi_prob.jl")
end
@time @safetestset "Ensemble solution statistics" begin
include("downstream/ensemble_stats.jl")
end
@time @safetestset "Symbol and integer based indexing of interpolated solutions" begin
include("downstream/symbol_indexing.jl")
end
Expand Down

0 comments on commit ff42d8b

Please sign in to comment.