From 2bb044ae7dc631879dcd96c7cfb4f6f7878a40aa Mon Sep 17 00:00:00 2001 From: Elliot Saba Date: Sat, 3 Feb 2024 18:19:08 -0800 Subject: [PATCH] Improve caching and dispatch of `LinearizingSavingCallback` This adds a new type, `LinearizingSavingCallbackCache` and some sub-types to allow for efficient re-use of memory as the callback executes over the course of a solve, as well as re-use of that memory in future solves when operating on a large ensemble simulation. The top-level `LinearizingSavingCallbackCache` creates thread-safe cache pool objects that are then used to acquire thread-unsafe cache pool objects to be used within a single solve. Those thread-unsafe cache pool objects can then be released and acquired anew by the next solve. The thread-unsafe pool objects allow for acquisition of pieces of memory such as temporary `u` vectors (the recusrive nature of the `LinearizingSavingCallback` means that we must allocate unknown numbers of temporary `u` vectors) and chunks of `u` blocks that are then compacted into a single large matrix in the finalize method of the callback. All these pieces of memory are stored within that set of thread-unsafe caches, and these are released back to the top-level thread-safe cache pool, for the next solve to acquire and make use of those pieces of memory in the cache pool. Using these techniques, the solve time of a large ensemble simulation with low per-simulation computation has reduced dramatically. The simulation solves a butterworth 3rd-order filter circuit over a certain timespan, swept across different simulus frequencies and circuit parameters. The parameter sweep results in a 13500-element ensemble simulation, that when run with 8 threads on a M1 Pro takes: ``` 48.364827 seconds (625.86 M allocations: 19.472 GiB, 41.81% gc time, 0.17% compilation time) ``` Now, after these caching optimizations, we solve the same ensemble in: ``` 13.208123 seconds (166.76 M allocations: 7.621 GiB, 22.21% gc time, 0.61% compilation time) ``` As a side note, the size requirements of the raw linearized solution data itself is `1.04 GB`. In general, we expect to allocate somewhere between 2-3x the final output data to account for temporaries and inefficient sharing, so while there is still some more work to be done, this gets us significantly closer to minimal overhead. This also adds a package extension on `Sundials`, as `IDA` requires that state vectors are `NVector` types, rather than `Vector{S}` types in order to not allocate. --- Project.toml | 3 + src/independentlylinearizedutils.jl | 253 +++++++++++++++++++++++---- src/saving.jl | 238 +++++++++++++++++-------- test/independentlylinearizedtests.jl | 56 +++++- test/saving_tests.jl | 37 +++- 5 files changed, 472 insertions(+), 115 deletions(-) diff --git a/Project.toml b/Project.toml index 05576122..3f6174a9 100644 --- a/Project.toml +++ b/Project.toml @@ -21,6 +21,9 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" +[extensions] +DiffEqCallbacksSundialsExt = "Sundials" + [compat] Aqua = "0.8" DataInterpolations = "4" diff --git a/src/independentlylinearizedutils.jl b/src/independentlylinearizedutils.jl index 2bd0b7f5..163d9017 100644 --- a/src/independentlylinearizedutils.jl +++ b/src/independentlylinearizedutils.jl @@ -2,6 +2,113 @@ using SciMLBase export IndependentlyLinearizedSolution + +""" + CachePool(T, alloc; thread_safe = true) + +Simple memory-reusing cache that allows us to grow a cache and keep +re-using those pieces of memory (in our case, typically `u` vectors) +until the solve is finished. By default, this datastructure is made +to be thread-safe by locking on every acquire and release, but it +can be made thread-unsafe (and correspondingly faster) by passing +`thread_safe = false` to the constructor. + +While manual usage with `acquire!()` and `release!()` is possible, +most users will want to use `@with_cache`, which provides lexically- +scoped `acquire!()` and `release!()` usage automatically. Example: + +```julia +us = CachePool(Vector{S}, () -> Vector{S}(undef, num_us); thread_safe=false) +@with_cache us u_prev begin + @with_cache us u_next begin + # perform tasks with these two `u` vectors + end +end +``` + +!!! warning "Escaping values" + You must not use an acquired value after you have released it; + the memory may be immediately re-used by some other consumer of + your cache pool. Do not allow the acquired value to escape + outside of the `@with_cache` block, or past a `release!()`. +""" +mutable struct CachePool{T, THREAD_SAFE} + const pool::Vector{T} + const alloc::Function + lock::ReentrantLock + num_alloced::Int + num_acquired::Int + + function CachePool(T, alloc::F; thread_safe::Bool = true) where {F} + return new{T,Val{thread_safe}}(T[], alloc, ReentrantLock(), 0, 0) + end +end +const ThreadSafeCachePool{T} = CachePool{T,Val{true}} +const ThreadUnsafeCachePool{T} = CachePool{T,Val{false}} + +""" + acquire!(cache::CachePool) + +Returns a cached element of the cache pool, calling `cache.alloc()` if none +are available. +""" +Base.@inline function acquire!(cache::CachePool{T}, _dummy = nothing) where {T} + cache.num_acquired += 1 + if isempty(cache.pool) + cache.num_alloced += 1 + return cache.alloc()::T + end + return pop!(cache.pool) +end + +""" + release!(cache::CachePool, val) + +Returns the value `val` to the cache pool. +""" +Base.@inline function release!(cache::CachePool, val, _dummy = nothing) + push!(cache.pool, val) + cache.num_acquired -= 1 +end + +function is_fully_released(cache::CachePool, _dummy = nothing) + return cache.num_acquired == 0 +end + +# Thread-safe versions just sub out to the other methods, using `_dummy` to force correct dispatch +acquire!(cache::ThreadSafeCachePool) = @lock cache.lock acquire!(cache, nothing) +release!(cache::ThreadSafeCachePool, val) = @lock cache.lock release!(cache, val, nothing) +is_fully_released(cache::ThreadSafeCachePool) = @lock cache.lock is_fully_released(cache, nothing) + +macro with_cache(cache, name, body) + return quote + $(esc(name)) = acquire!($(esc(cache))) + try + $(esc(body)) + finally + release!($(esc(cache)), $(esc(name))) + end + end +end + + +struct IndependentlyLinearizedSolutionChunksCache{T,S} + t_chunks::ThreadUnsafeCachePool{Vector{T}} + u_chunks::ThreadUnsafeCachePool{Matrix{S}} + time_masks::ThreadUnsafeCachePool{BitMatrix} + + function IndependentlyLinearizedSolutionChunksCache{T,S}(num_us::Int, num_derivatives::Int, chunk_size::Int) where {T,S} + t_chunks_alloc = () -> Vector{T}(undef, chunk_size) + u_chunks_alloc = () -> Matrix{S}(undef, num_derivatives+1, chunk_size) + time_masks_alloc = () -> BitMatrix(undef, num_us, chunk_size) + return new( + CachePool(Vector{T}, t_chunks_alloc; thread_safe=false), + CachePool(Matrix{S}, u_chunks_alloc; thread_safe=false), + CachePool(BitMatrix, time_masks_alloc; thread_safe=false), + ) + end +end + """ IndependentlyLinearizedSolutionChunks @@ -9,23 +116,30 @@ When constructing an `IndependentlyLinearizedSolution` via the `IndependentlyLin we use this indermediate structure to reduce allocations and collect the unknown number of timesteps that the solve will generate. """ -mutable struct IndependentlyLinearizedSolutionChunks{T, S} +mutable struct IndependentlyLinearizedSolutionChunks{T, S, N} t_chunks::Vector{Vector{T}} u_chunks::Vector{Vector{Matrix{S}}} time_masks::Vector{BitMatrix} + # Temporary array that gets used by `get_chunks` + last_chunks::Vector{Matrix{S}} + # Index of next write into the last chunk u_offsets::Vector{Int} t_offset::Int + cache::IndependentlyLinearizedSolutionChunksCache + function IndependentlyLinearizedSolutionChunks{T, S}(num_us::Int, num_derivatives::Int = 0, - chunk_size::Int = 100) where {T, S} - return new([Vector{T}(undef, chunk_size)], - [[Matrix{S}(undef, num_derivatives+1, chunk_size)] for _ in 1:num_us], - [BitMatrix(undef, num_us, chunk_size)], - [1 for _ in 1:num_us], - 1, - ) + chunk_size::Int = 512, + cache::IndependentlyLinearizedSolutionChunksCache = IndependentlyLinearizedSolutionChunksCache{T,S}(num_us, num_derivatives, chunk_size)) where {T, S} + t_chunks = [acquire!(cache.t_chunks)] + u_chunks = [[acquire!(cache.u_chunks)] for _ in 1:num_us] + time_masks = [acquire!(cache.time_masks)] + last_chunks = [u_chunks[u_idx][1] for u_idx in 1:num_us] + u_offsets = [1 for _ in 1:num_us] + t_offset = 1 + return new{T,S,num_derivatives}(t_chunks, u_chunks, time_masks, last_chunks, u_offsets, t_offset, cache) end end @@ -44,14 +158,8 @@ function num_us(ilsc::IndependentlyLinearizedSolutionChunks) end return length(ilsc.u_chunks) end +num_derivatives(ilsc::IndependentlyLinearizedSolutionChunks{T,S,N}) where {T,S,N} = N -function num_derivatives(ilsc::IndependentlyLinearizedSolutionChunks) - # If we've been finalized, just return `0` (which means only the primal) - if isempty(ilsc.t_chunks) - return 0 - end - return size(first(first(ilsc.u_chunks)), 1) - 1 -end function Base.isempty(ilsc::IndependentlyLinearizedSolutionChunks) return length(ilsc.t_chunks) == 1 && ilsc.t_offset == 1 @@ -61,24 +169,25 @@ function get_chunks(ilsc::IndependentlyLinearizedSolutionChunks{T, S}) where {T, # Check if we need to allocate new `t` chunk chunksize = chunk_size(ilsc) if ilsc.t_offset > chunksize - push!(ilsc.t_chunks, Vector{T}(undef, chunksize)) - push!(ilsc.time_masks, BitMatrix(undef, length(ilsc.u_offsets), chunksize)) + push!(ilsc.t_chunks, acquire!(ilsc.cache.t_chunks)) + push!(ilsc.time_masks, acquire!(ilsc.cache.time_masks)) ilsc.t_offset = 1 end # Check if we need to allocate any new `u` chunks (but only for those with `u_mask`) for (u_idx, u_chunks) in enumerate(ilsc.u_chunks) if ilsc.u_offsets[u_idx] > chunksize - push!(u_chunks, Matrix{S}(undef, num_derivatives(ilsc)+1, chunksize)) + push!(u_chunks, acquire!(ilsc.cache.u_chunks)) ilsc.u_offsets[u_idx] = 1 end + ilsc.last_chunks[u_idx] = u_chunks[end] end # return the last chunk for each return ( ilsc.t_chunks[end], ilsc.time_masks[end], - [u_chunks[end] for u_chunks in ilsc.u_chunks], + ilsc.last_chunks, ) end @@ -135,16 +244,18 @@ function store!(ilsc::IndependentlyLinearizedSolutionChunks{T, S}, ts, time_mask, us = get_chunks(ilsc) # Store into the chunks, gated by `u_mask` - for u_idx in 1:size(u, 2) + @inbounds for u_idx in 1:size(u, 2) if u_mask[u_idx] for deriv_idx in 1:size(u, 1) us[u_idx][deriv_idx, ilsc.u_offsets[u_idx]] = u[deriv_idx, u_idx] end ilsc.u_offsets[u_idx] += 1 end + + # Update our `time_mask` while we're at it + time_mask[u_idx, ilsc.t_offset] = u_mask[u_idx] end ts[ilsc.t_offset] = t - time_mask[:, ilsc.t_offset] .= u_mask ilsc.t_offset += 1 end @@ -161,7 +272,7 @@ efficient `iterate()` method that can be used to reconstruct coherent views of the state variables at all timepoints, as well as an efficient `sample!()` method that can sample at arbitrary timesteps. """ -mutable struct IndependentlyLinearizedSolution{T, S} +mutable struct IndependentlyLinearizedSolution{T, S, N} # All timepoints, shared by all `us` ts::Vector{T} @@ -173,33 +284,42 @@ mutable struct IndependentlyLinearizedSolution{T, S} time_mask::BitMatrix # Temporary object used during construction, will be set to `nothing` at the end. - ilsc::Union{Nothing,IndependentlyLinearizedSolutionChunks{T,S}} + ilsc::Union{Nothing,IndependentlyLinearizedSolutionChunks{T,S,N}} + ilsc_cache_pool::Union{Nothing,ThreadSafeCachePool{IndependentlyLinearizedSolutionChunksCache{T,S}}} end # Helper function to create an ILS wrapped around an in-progress ILSC -function IndependentlyLinearizedSolution(ilsc::IndependentlyLinearizedSolutionChunks{T,S}) where {T,S} - ils = IndependentlyLinearizedSolution( +function IndependentlyLinearizedSolution(ilsc::IndependentlyLinearizedSolutionChunks{T,S,N}, cache_pool = nothing) where {T,S,N} + return IndependentlyLinearizedSolution{T,S,N}( T[], Matrix{S}[], BitMatrix(undef, 0,0), ilsc, + cache_pool, ) - return ils end # Automatically create an ILS wrapped around an ILSC from a `prob` -function IndependentlyLinearizedSolution(prob::SciMLBase.AbstractDEProblem, num_derivatives = 0) +function IndependentlyLinearizedSolution(prob::SciMLBase.AbstractDEProblem, num_derivatives = 0; + cache_pool = nothing, + chunk_size::Int = 512) T = eltype(prob.tspan) + S = eltype(prob.u0) U = isnothing(prob.u0) ? Float64 : eltype(prob.u0) - N = isnothing(prob.u0) ? 0 : length(prob.u0) - chunks = IndependentlyLinearizedSolutionChunks{T,U}(N, num_derivatives) - return IndependentlyLinearizedSolution(chunks) + num_us = isnothing(prob.u0) ? 0 : length(prob.u0) + if cache_pool === nothing + cache = IndependentlyLinearizedSolutionChunksCache{T,S}(num_us, num_derivatives, chunk_size) + else + cache = acquire!(cache_pool) + end + chunks = IndependentlyLinearizedSolutionChunks{T,U}(num_us, num_derivatives, chunk_size, cache) + return IndependentlyLinearizedSolution(chunks, nothing) end -num_derivatives(ils::IndependentlyLinearizedSolution) = !isempty(ils.us) ? size(first(ils.us), 1) : 0 +num_derivatives(::IndependentlyLinearizedSolution{T,S,N}) where {T,S,N} = N num_us(ils::IndependentlyLinearizedSolution) = length(ils.us) Base.size(ils::IndependentlyLinearizedSolution) = size(ils.time_mask) Base.length(ils::IndependentlyLinearizedSolution) = length(ils.ts) -function finish!(ils::IndependentlyLinearizedSolution) +function finish!(ils::IndependentlyLinearizedSolution{T,S}) where {T,S} function trim_chunk(chunks::Vector, offset) chunks = [chunk for chunk in chunks] if eltype(chunks) <: AbstractVector @@ -216,10 +336,52 @@ function finish!(ils::IndependentlyLinearizedSolution) end ilsc = ils.ilsc::IndependentlyLinearizedSolutionChunks - ts = vcat(trim_chunk(ilsc.t_chunks, ilsc.t_offset)...) - time_mask = hcat(trim_chunk(ilsc.time_masks, ilsc.t_offset)...) - us = [hcat(trim_chunk(ilsc.u_chunks[u_idx], ilsc.u_offsets[u_idx])...) - for u_idx in 1:length(ilsc.u_chunks)] + + chunk_len(chunk) = size(chunk, ndims(chunk)) + function chunks_len(chunks::Vector, offset) + len = 0 + for chunk_idx in 1:length(chunks)-1 + len += chunk_len(chunks[chunk_idx]) + end + return len + offset - 1 + end + + function copy_chunk!(out::Vector, in::Vector, out_offset::Int, len=chunk_len(in)) + for idx in 1:len + out[idx+out_offset] = in[idx] + end + end + function copy_chunk!(out::AbstractMatrix, in::AbstractMatrix, out_offset::Int, len=chunk_len(in)) + for zdx in 1:size(in, 1) + for idx in 1:len + out[zdx, idx+out_offset] = in[zdx, idx] + end + end + end + + function collapse_chunks!(out, chunks, offset::Int) + write_offset = 0 + for chunk_idx in 1:(length(chunks)-1) + chunk = chunks[chunk_idx] + copy_chunk!(out, chunk, write_offset) + write_offset += chunk_len(chunk) + end + copy_chunk!(out, chunks[end], write_offset, offset-1) + end + + # Collapse t_chunks + ts = Vector{T}(undef, chunks_len(ilsc.t_chunks, ilsc.t_offset)) + collapse_chunks!(ts, ilsc.t_chunks, ilsc.t_offset) + + # Collapse u_chunks + us = Vector{Matrix{S}}(undef, length(ilsc.u_chunks)) + for u_idx in 1:length(ilsc.u_chunks) + us[u_idx] = Matrix{S}(undef, size(ilsc.u_chunks[u_idx][1],1), chunks_len(ilsc.u_chunks[u_idx], ilsc.u_offsets[u_idx])) + collapse_chunks!(us[u_idx], ilsc.u_chunks[u_idx], ilsc.u_offsets[u_idx]) + end + + time_mask = BitMatrix(undef, size(ilsc.time_masks[1], 1), chunks_len(ilsc.time_masks, ilsc.t_offset)) + collapse_chunks!(time_mask, ilsc.time_masks, ilsc.t_offset) # Sanity-check lengths if length(ts) != size(time_mask, 2) @@ -238,7 +400,24 @@ function finish!(ils::IndependentlyLinearizedSolution) throw(ArgumentError("Time mask must indicate same length as `us` ($(time_mask_lens) != $(us_lens))")) end - # Update our struct, release the `ilsc` + # Update our struct, release the `ilsc` and its caches + for t_chunk in ilsc.t_chunks + release!(ilsc.cache.t_chunks, t_chunk) + end + @assert is_fully_released(ilsc.cache.t_chunks) + for u_idx in 1:length(ilsc.u_chunks) + for u_chunk in ilsc.u_chunks[u_idx] + release!(ilsc.cache.u_chunks, u_chunk) + end + end + @assert is_fully_released(ilsc.cache.u_chunks) + for time_mask in ilsc.time_masks + release!(ilsc.cache.time_masks, time_mask) + end + @assert is_fully_released(ilsc.cache.time_masks) + if ils.ilsc_cache_pool !== nothing + release!(ils.ilsc_cache_pool, ilsc.cache) + end ils.ilsc = nothing ils.ts = ts ils.us = us diff --git a/src/saving.jl b/src/saving.jl index e1058358..b6064226 100644 --- a/src/saving.jl +++ b/src/saving.jl @@ -1,3 +1,5 @@ +import SciMLBase: AbstractSciMLAlgorithm + """ SavedValues{tType<:Real, savevalType} @@ -182,14 +184,19 @@ as_array(t::AbstractArray) = t function is_linear_enough!(caches, is_linear, t₀, t₁, u₀, u₁, integ, abstol, reltol) (; y_linear, y_interp, slopes) = caches tspread = t₁ - t₀ - slopes .= (u₁ .- u₀) ./ tspread + num_us = length(u₀) + @inbounds for u_idx in 1:num_us + slopes[u_idx] = (u₁[u_idx] - u₀[u_idx])/tspread + end t_quartile(t_idx) = t₀ + tspread * t_idx/4.0 # Calculate interpolated and linear samplings in our three quartiles - for t_idx in 1:3 + @inbounds for t_idx in 1:3 t = t_quartile(t_idx) # Linear interpolation - y_linear[:, t_idx] .= u₀ .+ (t - t₀) .* slopes + @inbounds for u_idx in 1:num_us + y_linear[u_idx, t_idx] = u₀[u_idx] .+ (t - t₀) .* slopes[u_idx] + end # Solver interpolation # We would like to use `integ(@view(y_interp[:, t_idx]))` here, @@ -197,14 +204,16 @@ function is_linear_enough!(caches, is_linear, t₀, t₁, u₀, u₁, integ, abs # memory that the view would have given us, so we instead use a # temporary array then copy it into `y_interp`, which loses very # little time and still prevents allocations from `integ(t)`. - with_cache(caches.us) do u_interp - integ(u_interp, t) - y_interp[:, t_idx] .= u_interp + @with_cache caches.us u_interp begin + integ(u_interp, t, Val{0}; idxs=nothing) + for u_idx in 1:num_us + y_interp[u_idx, t_idx] = u_interp[u_idx] + end end end # Return `is_linear` for each state - for u_idx in 1:length(u₀) + @inbounds for u_idx in 1:num_us is_linear[u_idx] = true for t_idx in 1:3 is_linear[u_idx] &= isapprox(y_linear[u_idx, t_idx], @@ -216,8 +225,8 @@ function is_linear_enough!(caches, is_linear, t₀, t₁, u₀, u₁, integ, abs # Find worst time index so that we split our period there t_max_idx = 1 e_max = 0.0 - for t_idx in 1:3 - for u_idx in 1:length(u₀) + @inbounds for t_idx in 1:3 + for u_idx in 1:num_us e = abs(y_linear[u_idx, t_idx] - y_interp[u_idx, t_idx]) if e > e_max t_max_idx = t_idx @@ -235,7 +244,7 @@ function linearize_period(t₀, t₁, u₀, u₁, integ, ilsc, caches, u_mask, d throw(ArgumentError("Linearization failed, fell below linearization subdivision threshold")) end - with_cache(caches.u_masks) do is_linear + @with_cache caches.u_masks is_linear begin tᵦ = is_linear_enough!(caches, is_linear, t₀, t₁, @@ -252,8 +261,8 @@ function linearize_period(t₀, t₁, u₀, u₁, integ, ilsc, caches, u_mask, d if any(is_nonlinear) # If it's not linear, split this period into two and recurse, altering our `u_mask`: - with_cache(caches.us) do uᵦ - integ(uᵦ, tᵦ) + @with_cache caches.us uᵦ begin + integ(uᵦ, tᵦ, Val{0}; idxs=nothing) linearize_period( t₀, tᵦ, u₀, uᵦ, @@ -283,50 +292,144 @@ function linearize_period(t₀, t₁, u₀, u₁, integ, ilsc, caches, u_mask, d end else # If everyone is linear, store this period, according to our `u_mask`! - store_u_block!(ilsc, integ, caches, t₁, u₁, u_mask) + store_u_block!(ilsc, Val(num_derivatives(ilsc)), integ, caches, t₁, u₁, u_mask) end end end -function store_u_block!(ilsc, integ, caches, t₁, u₁, u_mask) - with_cache(caches.us) do u - caches.u_block[1, :] .= u₁ - for deriv_idx in 1:num_derivatives(ilsc) - integ(u, t₁, Val{deriv_idx}) - caches.u_block[deriv_idx+1, :] .= u +function store_u_block!(ilsc, ::Val{num_derivatives}, integ, caches, t₁, u₁, u_mask) where {num_derivatives} + @with_cache caches.us u begin + for u_idx in 1:length(u) + caches.u_block[1, u_idx] = u₁[u_idx] + end + for deriv_idx in 1:num_derivatives + integ(u, t₁, Val{deriv_idx}; idxs=nothing) + for u_idx in 1:length(u) + caches.u_block[deriv_idx+1, u_idx] = u[u_idx] + end end store!(ilsc, t₁, caches.u_block, u_mask) end end -""" - CachePool +struct LinearizingSavingCallbackCacheType{S,U} + y_linear::Matrix{S} + y_interp::Matrix{S} + slopes::Vector{S} + # U is not necessarily a `Vector{S}` because it can be an `NVector` thanks to Sundials. + us::ThreadUnsafeCachePool{U} + u_block::Matrix{S} + u_masks::ThreadUnsafeCachePool{BitVector} + + function LinearizingSavingCallbackCacheType{S,U}(num_us::Int, num_derivatives::Int, U_alloc::Function) where {S,U} + y_linear = Matrix{S}(undef, (num_us, 3)) + y_interp = Matrix{S}(undef, (num_us, 3)) + slopes = Vector{S}(undef, num_us) + u_block = Matrix{S}(undef, (num_derivatives+1, num_us)) + F_umasks = () -> BitVector(undef, num_us) + u_masks = CachePool(BitVector, F_umasks; thread_safe=false) + + # Workaround for Sundials allocations; conversion from `Vector{S}` + # to `NVector()` allocates, so we require the caller to pass in a + # `U` and a `U_alloc` so that the default of `U = Vector{S}` and + # U_alloc = () -> Vector{S}(undef, num_us)` can be overridden. + # This is automatically done by `DiffEqCallbacksSundialsExt`, via + # the `solver_state_type()` and `solver_state_alloc()` hooks below. + us = CachePool(U, U_alloc; thread_safe=false) + return new{S,U}( + y_linear, + y_interp, + slopes, + us, + u_block, + u_masks, + ) + end +end + +# This exists purely so that different solver types can wrap/alter the +# type of the state vectors cached by the `LinearizingSavingCallbackCache`. +# U is typically something like `Vector{Float64}`. +solver_state_type(solver::AbstractSciMLAlgorithm, U::DataType) = U +solver_state_alloc(solver::AbstractSciMLAlgorithm, U::DataType, num_us::Int) = () -> U(undef, num_us) -Simple memory-reusing cache that allows us to grow a cache and keep -re-using those pieces of memory (in our case, typically `u` vectors) -until the solve is finished. Note that this datastructure is _not_ -thread-safe! """ -mutable struct CachePool{T,F} - pool::Vector{T} - alloc::F - write_idx::Int + LinearizingSavingCallbackCache(prob, solver; num_derivatives=0) - function CachePool(T, alloc::F) where {F} - return new{T,F}(T[], alloc, 0) - end +Top-level cache for the `LinearizingSavingCallback`. Typically used +to vastly reduce the number of allocations when performing an ensemble +solve, where allocations from one solution can be used by the next. + +Users must pass in `solver` to allow for solver-specific allocation +strategies. As an example, `IDA` requires allocation of `NVector` +objects rather than `Vector{S}` objects, and to automatically +determine this, the `LinearizingSavingCallbackCache` takes in the +solver as well. See the `DiffEqCallbacksSundialsExt` extension +for the details on how this type adjustment is made. + +This top-level cache creates two thread-safe cache pools that are then +used by each solve to allocate thread-unsafe cache pools. Those per- +solution cache pools are then re-used across solutions as the ensemble +finishes one trajectory and moves to another. + +Example usage: + +```julia +# Linearize the primal, and the first derivative +num_derivatives = 1 + +# Create a cache, to be used across all ensemble simulations +cache = LinearizingSavingCallbackCache(prob, solver; num_derivatives) + +# Store the results in this array of independently linearized solutions +ilss = Vector{IndependentlyLinearizedSolution}(undef, num_trajectories) + +# Create `prob_func` piece ot remake `prob` to have the correct callback, +# hooking up the necessary caching pieces. +function linearizer_adding_remake(prob,i,_) + ilss[i] = IndependentlyLinearizedSolution(prob, num_derivatives; cache_pool=cache.ils_cache) + lsc = LinearizingSavingCallback(ilss[i]; cache_pool=cache.lsc_cache) + return remake(prob; callback=lsc) end -function with_cache(f::Function, cache::CachePool{T}) where {T} - cache.write_idx += 1 - if length(cache.pool) < cache.write_idx - push!(cache.pool, cache.alloc()) - end - try - f(cache.pool[cache.write_idx]) - finally - cache.write_idx -= 1 - end +ensembleprob = EnsembleProblem(prob; prob_func=linearizer_adding_remake) +solve(ensembleprob, solver, EnsembleThreads(); ...) +``` +""" +function LinearizingSavingCallbackCache(prob, solver; num_derivatives=0, chunk_size=512) + T = eltype(prob.tspan) + S = eltype(prob.u0) + U = solver_state_type(solver, typeof(prob.u0)) + num_us = length(prob.u0) + U_alloc = solver_state_alloc(solver, typeof(prob.u0), num_us) + return LinearizingSavingCallbackCache(T, S, U, U_alloc, num_us; num_derivatives, chunk_size) +end + +function LinearizingSavingCallbackCache(T, S, U, U_alloc, num_us; num_derivatives=0, chunk_size=512) + return (; + # This cache is used by the LinearizingSavingCallback, it creates `LinearizingSavingCallbackCacheType` + # objects, which is quite a mouthful, but contains all the temporary values needed for a single + # solve's linearization. Notably, it contains within itself cachepools for `u` vectors and whatnot, + # and it is _not_ thread-safe, because we assume that a single solve is single-threaded, so we use a + # single thread-safe cache pool (the `lsc_cache`) to spawn off a collection of these smaller, thread- + # unsafe (but faster to acquire/release) cache pools. + lsc_cache = CachePool( + LinearizingSavingCallbackCacheType{S,U}, + () -> LinearizingSavingCallbackCacheType{S,U}(num_us, num_derivatives, U_alloc); + thread_safe = true, + ), + # This cache is used by the `IndependentlyLinearizedSolutionChunks` to do things like allocate `u`, + # `t` and `time_mask` chunks. + ils_cache = CachePool( + IndependentlyLinearizedSolutionChunksCache{T,S}, + () -> IndependentlyLinearizedSolutionChunksCache{T,S}( + num_us, + num_derivatives, + chunk_size, + ), + thread_safe = true, + ), + ) end """ @@ -366,11 +469,16 @@ function LinearizingSavingCallback(ils::IndependentlyLinearizedSolution{T,S}; interpolate_mask = BitVector(true for _ in 1:length(ils.ilsc.u_chunks)), abstol::Union{S,Nothing} = nothing, reltol::Union{S,Nothing} = nothing, - ) where {T, S} + cache_pool::CachePool{C} = LinearizingSavingCallbackCache(T, S, Vector{S}, length(ils.ilsc.u_chunks); num_derivatives=num_derivatives(ils.ilsc)).lsc_cache, + ) where {T, S, C} ilsc = ils.ilsc full_mask = BitVector(true for _ in 1:length(ilsc.u_chunks)) - # caches will be allocated in `initialize()` - caches = nothing + num_derivatives_val = Val(num_derivatives(ilsc)) + + # `caches` is initialized in `initialize`, but we need to constrain + # its type here so that the closures in `DiscreteCallback` are stable + local caches::C + #caches = acquire!(cache_pool) return DiscreteCallback( # We will process every timestep (u, t, integ) -> begin @@ -380,17 +488,17 @@ function LinearizingSavingCallback(ils::IndependentlyLinearizedSolution{T,S}; integ -> begin t₀ = integ.tprev t₁ = integ.t - with_cache(caches.us) do u₀ - with_cache(caches.us) do u₁ + @with_cache caches.us u₀ begin + @with_cache caches.us u₁ begin # Get `u₀` and `u₁` from the integrator - integ(u₀, t₀) - integ(u₁, t₁) + integ(u₀, t₀, Val{0}; idxs=nothing) + integ(u₁, t₁, Val{0}; idxs=nothing) # Store first timepoints. Usually we'd do this in `initialize` # but `integ(u, t, deriv)` doesn't work that early, and so we # must wait until we've taken at least a single step. if isempty(ilsc) - store_u_block!(ilsc, integ, caches, t₀, u₀, full_mask) + store_u_block!(ilsc, num_derivatives_val, integ, caches, t₀, u₀, full_mask) end dtmin = eps(t₁ - t₀)*1000.0 @@ -406,38 +514,20 @@ function LinearizingSavingCallback(ils::IndependentlyLinearizedSolution{T,S}; end u_modified!(integ, false) end, - # In our `initialize`, we create some caches so we allocate less + # In our `initialize`, we create some caches so we allocate less. initialize = (c, u, t, integ) -> begin - u = as_array(u) - num_us = length(ilsc.u_chunks) - - # Workaround for Sundials allocations; `NVector()` allocates, - # so we first use `typeof(integ.u_nvec)` to pull out the `NVector` type, - # then teach our `CachePool` to create pre-wrapped `NVector`s rather - # than just `Vector{S}`'s. - if hasfield(typeof(integ), :u_nvec) - NVector = typeof(integ.u_nvec) - us = CachePool(NVector, () -> NVector(Vector{S}(undef, num_us))) - else - us = CachePool(Vector{S}, () -> Vector{S}(undef, num_us)) - end - caches = (; - y_linear = Matrix{S}(undef, (num_us, 3)), - y_interp = Matrix{S}(undef, (num_us, 3)), - slopes = Vector{S}(undef, num_us), - us = us, - u_block = Matrix{S}(undef, (num_derivatives(ilsc)+1, num_us)), - u_masks = CachePool(BitVector, () -> BitVector(undef, num_us)) - ) + caches = acquire!(cache_pool) u_modified!(integ, false) end, # We need to finalize the ils and free our caches finalize = (c, u, t, integ) -> begin finish!(ils) - caches = nothing + if cache_pool !== nothing + release!(cache_pool, caches) + end end, # Don't add tstops to the left and right. save_positions = (false, false)) end -export SavingCallback, SavedValues, LinearizingSavingCallback +export SavingCallback, SavedValues, LinearizingSavingCallback, LinearizingSavingCallbackCache diff --git a/test/independentlylinearizedtests.jl b/test/independentlylinearizedtests.jl index afc3c9c7..38332eb0 100644 --- a/test/independentlylinearizedtests.jl +++ b/test/independentlylinearizedtests.jl @@ -1,8 +1,60 @@ using Test, DiffEqCallbacks using DiffEqCallbacks: sample, store!, IndependentlyLinearizedSolutionChunks, finish! +using DiffEqCallbacks: CachePool, acquire!, release!, @with_cache +@testset "CachePool" begin + num_us = 10 + for thread_safe in (false, true) + pool = CachePool(Vector{Float64}, () -> Vector{Float64}(undef, num_us); thread_safe) + + @with_cache pool u₀ begin + @with_cache pool u₁ begin + @test length(u₀) == num_us + @test length(u₁) == num_us + + # We get two separate arrays + u₀[1] = 0.0 + u₁[1] = 1.0 + @test u₀[1] == 0.0 + @test u₁[1] == 1.0 + end + + # Test that u₁ gets re-used + @with_cache pool u₁ begin + @test length(u₁) == num_us + @test u₁[1] == 1.0 + end + + previously_alloced = pool.num_alloced + # Test that asking for a vector while we have one free does not allocate + function foo(pool) + return @with_cache pool u begin + u[1] + end + end + @test foo(pool) == 1.0 + @test pool.num_alloced == previously_alloced + end + end + + # Test the threadsafe versions actually lock + pool = CachePool(Vector{Float64}, () -> Vector{Float64}(undef, num_us); thread_safe=true) + # elapsed acquire once to warm everything up + @elapsed acquire!(pool) + @test @elapsed(acquire!(pool)) < 0.001 + + t_locker = Threads.@spawn begin + @lock pool.lock begin + sleep(0.01) + end + end + yield() + @test @elapsed(acquire!(pool)) > 0.001 +end + + @testset "IndependentlyLinearizedSolution" begin - ils = IndependentlyLinearizedSolution{Float64,Float64}( + ils = IndependentlyLinearizedSolution{Float64,Float64,0}( # t [0.0, 0.5, 0.75, 1.0], # us (primal only, no derivatives) @@ -19,6 +71,8 @@ using DiffEqCallbacks: sample, store!, IndependentlyLinearizedSolutionChunks, fi ), # ilsc nothing, + # ilsc_cache_pool + nothing, ) # Test `iterate()` diff --git a/test/saving_tests.jl b/test/saving_tests.jl index 6854afa5..9a9ca139 100644 --- a/test/saving_tests.jl +++ b/test/saving_tests.jl @@ -158,7 +158,8 @@ cb = SavingCallback((u, t, integrator) -> integrator.EEst * integrator.dt, saved # Test that our `LinearizingSavingCallback` gives back something that when interpolated, # respects our `abstol`/`reltol` versus the actual solution: using DataInterpolations -import DiffEqCallbacks: as_array, finish! +import DiffEqCallbacks: as_array, finish!, + IndependentlyLinearizedSolution, LinearizingSavingCallbackCache as_array(T::Type{<:AbstractArray}) = T as_array(T::Type{<:Number}) = Vector{T} @@ -217,9 +218,39 @@ if VERSION >= v"1.9" # stack test_linearization(prob_ode_rigidbody, solver(); max_deriv) test_linearization(prob_ode_nonlinchem, solver(); max_deriv) test_linearization(prob_ode_lorenz, solver(); max_deriv) + + # We do not support 2d states yet. + #test_linearization(prob_ode_2Dlinear, solver(); max_deriv) end end - # We do not support 2d states yet. - #test_linearization(prob_ode_2Dlinear, Tsit5()) + + # Benchmark that shows how to use a cache to speed up linearization of large ensemble solutions: + #= + num_trajectories = 10000 + num_derivatives = 1 + cache = LinearizingSavingCallbackCache(prob_ode_lorenz, Tsit5(); num_derivatives) + function p_disturbing_remake(prob, i, repeat; kwargs...) + return remake(prob; + p=prob.p .+ 0.1*randn(size(prob.p)), + kwargs... + ) + end + ilss = Vector{IndependentlyLinearizedSolution}(undef, num_trajectories) + function linearizer_adding_remake(prob,i,repeat) + ilss[i] = IndependentlyLinearizedSolution(prob, num_derivatives; cache_pool=cache.ils_cache) + lsc = LinearizingSavingCallback(ilss[i]; cache_pool=cache.lsc_cache) + return p_disturbing_remake(prob, i, repeat; callback=lsc) + end + ensembleprob = EnsembleProblem( + prob_ode_lorenz; + prob_func=linearizer_adding_remake, + safetycopy=false, + ) + @time sol = solve(ensembleprob, + Tsit5(), + EnsembleThreads(); + trajectories=num_trajectories, + ) + =# end