Skip to content

Commit

Permalink
Rework LinearizingSavingCallback
Browse files Browse the repository at this point in the history
The original implementation of `LinearizingSavingCallback` had two
modes; one with a single `SavedValues` that held all `u` values, and one
with multiple `SavedValues` structs, one for each `u` index.
Unfortunately, this second version (being the one that most users would
want) was inefficient, wasteful in memory, and a pain to deal with after
the fact due to each `u` index having a different time vector.

This PR deletes the first mode and improves the second mode to now have
a much more efficient and useful datastructure backing the independently
linearized states.  In particular, it now preallocates memory in chunks
while solving, stores the time vector only once (with a memory-efficient
bitmap storing time vector occupancy for each `u` variable) and provides
easy-to-use methods to consume the output both through iteration and
explicit sampling at timepoints.  In my experiments with a moderate-size
ODE with 22 states that takes ~2000 timesteps, adding this callback
increases solve time by ~10%, and allocations by ~100%.
  • Loading branch information
staticfloat committed Nov 15, 2023
1 parent 69c684c commit 351b2d7
Show file tree
Hide file tree
Showing 6 changed files with 540 additions and 180 deletions.
1 change: 1 addition & 0 deletions src/DiffEqCallbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ include("manifold.jl")
include("domain.jl")
include("stepsizelimiters.jl")
include("function_caller.jl")
include("independentlylinearizedutils.jl")
include("saving.jl")
include("integrating.jl")
include("integrating_sum.jl")
Expand Down
286 changes: 286 additions & 0 deletions src/independentlylinearizedutils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
using SciMLBase

export IndependentlyLinearizedSolution

"""
IndependentlyLinearizedSolutionChunks
When constructing an `IndependentlyLinearizedSolution` via the `IndependentlyLinearizingCallback`,
we use this indermediate structure to reduce allocations and collect the unknown number of timesteps
that the solve will generate.
"""
mutable struct IndependentlyLinearizedSolutionChunks{T, S}
t_chunks::Vector{Vector{T}}
u_chunks::Vector{Vector{Vector{S}}}
time_masks::Vector{BitMatrix}

chunk_size::Int

# Index of next write into the last chunk
u_offsets::Vector{Int}
t_offset::Int

function IndependentlyLinearizedSolutionChunks{T, S}(num_us::Int,

Check warning on line 23 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L23

Added line #L23 was not covered by tests
chunk_size::Int = 100) where {T, S}
return new([Vector{T}(undef, chunk_size)],

Check warning on line 25 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L25

Added line #L25 was not covered by tests
[[Vector{S}(undef, chunk_size)] for _ in 1:num_us],
[BitMatrix(undef, chunk_size, num_us)],
chunk_size,
[1 for _ in 1:num_us],
1,
)
end
end

function get_chunks(ilsc::IndependentlyLinearizedSolutionChunks{T, S}) where {T, S}

Check warning on line 35 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L35

Added line #L35 was not covered by tests
# Check if we need to allocate new `t` chunk
if ilsc.t_offset > ilsc.chunk_size
push!(ilsc.t_chunks, Vector{T}(undef, ilsc.chunk_size))
push!(ilsc.time_masks, BitMatrix(undef, ilsc.chunk_size, length(ilsc.u_offsets)))
ilsc.t_offset = 1

Check warning on line 40 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L37-L40

Added lines #L37 - L40 were not covered by tests
end

# Check if we need to allocate any new `u` chunks (but only for those with `u_mask`)
for (u_idx, u_chunks) in enumerate(ilsc.u_chunks)
if ilsc.u_offsets[u_idx] > ilsc.chunk_size
push!(u_chunks, Vector{S}(undef, ilsc.chunk_size))
ilsc.u_offsets[u_idx] = 1

Check warning on line 47 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L44-L47

Added lines #L44 - L47 were not covered by tests
end
end

Check warning on line 49 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L49

Added line #L49 was not covered by tests

# return the last chunk for each
return ilsc.t_chunks[end],

Check warning on line 52 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L52

Added line #L52 was not covered by tests
ilsc.time_masks[end],
[u_chunks[end] for u_chunks in ilsc.u_chunks]
end

"""
store!(ilsc::IndependentlyLinearizedSolutionChunks, t, u, u_mask)
Store a new `u` vector into our `ilsc`, but only the values identified by the
given `u_mask`.
"""
function store!(ilsc::IndependentlyLinearizedSolutionChunks{T, S},

Check warning on line 63 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L63

Added line #L63 was not covered by tests
t::T,
u::Vector{S},
u_mask::BitVector) where {T, S}
ts, time_mask, us = get_chunks(ilsc)

Check warning on line 67 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L67

Added line #L67 was not covered by tests

# Store into the chunks, gated by `u_mask`
for u_idx in 1:length(us)
if u_mask[u_idx]
us[u_idx][ilsc.u_offsets[u_idx]] = u[u_idx]
ilsc.u_offsets[u_idx] += 1

Check warning on line 73 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L70-L73

Added lines #L70 - L73 were not covered by tests
end
end
ts[ilsc.t_offset] = t
time_mask[ilsc.t_offset, :] .= u_mask
ilsc.t_offset += 1

Check warning on line 78 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L75-L78

Added lines #L75 - L78 were not covered by tests
end



"""
IndependentlyLinearizedSolution
Efficient datastructure that holds a set of independently linearized solutions
(obtained via the `LinearizingSavingCallback`) with related, but slightly
different time vectors. Stores a single time vector with a packed `BitMatrix`
denoting which `u` vectors are sampled at which timepoints. Provides an
efficient `iterate()` method that can be used to reconstruct coherent views
of the state variables at all timepoints, as well as an efficient `sample!()`
method that can sample at arbitrary timesteps.
"""
mutable struct IndependentlyLinearizedSolution{T, S}
# All timepoints, shared by all `us`
ts::Vector{T}

# Ragged matrix of `us`
us::Vector{Vector{S}}

# Bitmatrix denoting which time indices are used for which us.
time_mask::BitMatrix

# Temporary object used during construction, will be set to `nothing` at the end.
ilsc::Union{Nothing,IndependentlyLinearizedSolutionChunks{T,S}}
end
# Helper function to create an ILS wrapped around an in-progress ILSC
function IndependentlyLinearizedSolution(ilsc::IndependentlyLinearizedSolutionChunks{T,S}) where {T,S}
ils = IndependentlyLinearizedSolution(

Check warning on line 109 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L108-L109

Added lines #L108 - L109 were not covered by tests
T[],
Vector{S}[],
BitMatrix(undef, 0,0),
ilsc,
)
return ils

Check warning on line 115 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L115

Added line #L115 was not covered by tests
end
# Automatically create an ILS wrapped around an ILSC from a `prob`
function IndependentlyLinearizedSolution(prob::SciMLBase.AbstractDEProblem)
return IndependentlyLinearizedSolution(

Check warning on line 119 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L118-L119

Added lines #L118 - L119 were not covered by tests
IndependentlyLinearizedSolutionChunks{eltype(prob.tspan),eltype(prob.u0)}(length(prob.u0))
)
end

Base.size(ils::IndependentlyLinearizedSolution) = size(ils.time_mask)
Base.length(ils::IndependentlyLinearizedSolution) = length(ils.ts)

Check warning on line 125 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L124-L125

Added lines #L124 - L125 were not covered by tests

function finish!(ils::IndependentlyLinearizedSolution)
function trim_chunk(chunks::Vector, offset)
chunks = [chunk for chunk in chunks]
if eltype(chunks) <: Vector
chunks[end] = chunks[end][1:(offset - 1)]
elseif eltype(chunks) <: BitMatrix
chunks[end] = chunks[end][1:(offset - 1), :]

Check warning on line 133 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L127-L133

Added lines #L127 - L133 were not covered by tests
else
error(eltype(chunks))

Check warning on line 135 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L135

Added line #L135 was not covered by tests
end
if isempty(chunks[end])
pop!(chunks)

Check warning on line 138 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L137-L138

Added lines #L137 - L138 were not covered by tests
end
return chunks

Check warning on line 140 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L140

Added line #L140 was not covered by tests
end

ilsc = ils.ilsc::IndependentlyLinearizedSolutionChunks
ts = vcat(trim_chunk(ilsc.t_chunks, ilsc.t_offset)...)
time_mask = vcat(trim_chunk(ilsc.time_masks, ilsc.t_offset)...)
us = [vcat(trim_chunk(ilsc.u_chunks[u_idx], ilsc.u_offsets[u_idx])...)

Check warning on line 146 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L143-L146

Added lines #L143 - L146 were not covered by tests
for u_idx in 1:length(ilsc.u_chunks)]

# Sanity-check lengths
if length(ts) != size(time_mask, 1)
throw(ArgumentError("`length(ts)` must equal `size(time_mask, 1)`!"))

Check warning on line 151 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L150-L151

Added lines #L150 - L151 were not covered by tests
end

# All time masks must start and end with `1`:
if !all(@view time_mask[1, :]) || !all(@view time_mask[end, :])
throw(ArgumentError("Time mask must start and end with 1s!"))

Check warning on line 156 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L155-L156

Added lines #L155 - L156 were not covered by tests
end

# Length of time mask enable vectors must equal the lengths of our `us`:
time_mask_lens = vec(sum(time_mask; dims = 1))
if !all(time_mask_lens .== length.(us))
throw(ArgumentError("Time mask must indicate same length as `us` ($(time_mask_lens) != $(length.(us)))"))

Check warning on line 162 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L160-L162

Added lines #L160 - L162 were not covered by tests
end

# Update our struct, release the `ilsc`
ils.ilsc = nothing
ils.ts = ts
ils.us = us
ils.time_mask = time_mask
return ils

Check warning on line 170 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L166-L170

Added lines #L166 - L170 were not covered by tests
end

struct ILSStateCursor
# The index into `us` that identifies this state
u_idx::Int

idx_u₀::Int
# idx_u₁ is by definition idx_u₀ + 1, so we don't store it

# Time index of u₀
idx_t₀::Int
# Time index of u₁
idx_t₁::Int
end
# Helper to construct a state cursor off of an ILS, at a particular time index
function ILSStateCursor(ils::IndependentlyLinearizedSolution, u_idx::Int, t_idx::Int = 1)
cursor = ILSStateCursor(u_idx,

Check warning on line 187 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L186-L187

Added lines #L186 - L187 were not covered by tests
1,
1,
findfirst(@view ils.time_mask[2:end, u_idx]) + 1)
return seek_forward(ils, cursor, ils.ts[t_idx])

Check warning on line 191 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L190-L191

Added lines #L190 - L191 were not covered by tests
end
function interpolate(ils::IndependentlyLinearizedSolution{T},

Check warning on line 193 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L193

Added line #L193 was not covered by tests
cursor::ILSStateCursor,
t::T) where {T}
u₀ = ils.us[cursor.u_idx][cursor.idx_u₀]
u₁ = ils.us[cursor.u_idx][cursor.idx_u₀ + 1]
t₀ = ils.ts[cursor.idx_t₀]
t₁ = ils.ts[cursor.idx_t₁]
return (u₁ - u₀) / (t₁ - t₀) * (t - t₀) + u₀

Check warning on line 200 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L196-L200

Added lines #L196 - L200 were not covered by tests
end

"""
seek_forward(ils::IndependentlyLinearizedSolution, cursor::ILSStateCursor, t_target)
Seek the given `cursor` forward until it contains `t_target`. Does not seek backward, use `seek()`
for the more general formulation, this form is optimized for the inner loop of `iterate()`.
"""
function seek_forward(ils::IndependentlyLinearizedSolution{T},

Check warning on line 209 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L209

Added line #L209 was not covered by tests
cursor::ILSStateCursor,
t_target::T) where {T}
# We do not test `t_start` because we don't support seeking backward here
while ils.ts[cursor.idx_t₁] < t_target
next_t = findfirst(@view ils.time_mask[(cursor.idx_t₁ + 1):end, cursor.u_idx])
cursor = ILSStateCursor(cursor.u_idx,

Check warning on line 215 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L213-L215

Added lines #L213 - L215 were not covered by tests
cursor.idx_u₀ + 1,
cursor.idx_t₁,
next_t + cursor.idx_t₁)
end
return cursor

Check warning on line 220 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L219-L220

Added lines #L219 - L220 were not covered by tests
end

function Base.seek(ils::IndependentlyLinearizedSolution{T},

Check warning on line 223 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L223

Added line #L223 was not covered by tests
cursor::ILSStateCursor,
t_target::T) where {T}
# If we need to rewind, just start from the beginning
if t_target < ils.ts[cursor.idx_t₀]
cursor = ILSStateCursor(ils, cursor.u_idx)

Check warning on line 228 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L227-L228

Added lines #L227 - L228 were not covered by tests
end
return seek_forward(ils, cursor, t_target)

Check warning on line 230 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L230

Added line #L230 was not covered by tests
end

function Base.seek(ils::IndependentlyLinearizedSolution, t_idx = 1)
return [ILSStateCursor(ils, u_idx, t_idx) for u_idx in 1:length(ils.us)]

Check warning on line 234 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L233-L234

Added lines #L233 - L234 were not covered by tests
end
function default_state(ils::IndependentlyLinearizedSolution{T, S}) where {T, S}
t_idx = 1

Check warning on line 237 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L236-L237

Added lines #L236 - L237 were not covered by tests
# Nice little hack so we don't have to allocate `u` over and over again
u = S[S(0) for _ in ils.us]
cursors = seek(ils, t_idx)
return (t_idx, u, cursors)

Check warning on line 241 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L239-L241

Added lines #L239 - L241 were not covered by tests
end
function Base.iterate(ils::IndependentlyLinearizedSolution{T, S},

Check warning on line 243 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L243

Added line #L243 was not covered by tests
(t_idx, u, cursors) = default_state(ils)) where {T, S}
if t_idx > length(ils.ts)
return nothing

Check warning on line 246 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L245-L246

Added lines #L245 - L246 were not covered by tests
end

# We iteratively inch `offsets` forward, efficiently reconstructing a full set of `u`'s
t = ils.ts[t_idx]
for u_idx in 1:length(u)
cursors[u_idx] = seek_forward(ils, cursors[u_idx], t)
u[u_idx] = interpolate(ils, cursors[u_idx], t)
end

Check warning on line 254 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L250-L254

Added lines #L250 - L254 were not covered by tests

return (t, u), (t_idx + 1, u, cursors)

Check warning on line 256 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L256

Added line #L256 was not covered by tests
end

"""
sample!(out::Matrix{S}, ils::IndependentlyLinearizedSolution, ts::Vector{T})
Batch-sample `ils` at the given timepoints, storing into `out`.
"""
function sample!(out::Matrix{S},

Check warning on line 264 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L264

Added line #L264 was not covered by tests
ils::IndependentlyLinearizedSolution{T, S},
ts::AbstractVector{T}) where {T, S}
sampled_size = (length(ts), length(ils.us))
if size(out) != sampled_size
throw(ArgumentError("Output size ($(size(out))) != sampled size ($(sampled_size))"))

Check warning on line 269 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L267-L269

Added lines #L267 - L269 were not covered by tests
end

# We don't make use of `iterate` here because we're sampling at arbitrary timepoints
cursors = seek(ils)
for (t_idx, t) in enumerate(ts)
for u_idx in 1:length(ils.us)
cursors[u_idx] = seek_forward(ils, cursors[u_idx], t)
out[t_idx, u_idx] = interpolate(ils, cursors[u_idx], t)
end
end
return out

Check warning on line 280 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L273-L280

Added lines #L273 - L280 were not covered by tests
end
function sample(ils::IndependentlyLinearizedSolution{T, S},

Check warning on line 282 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L282

Added line #L282 was not covered by tests
ts::AbstractVector{T}) where {T, S}
out = Matrix{S}(undef, length(ts), length(ils.us))
return sample!(out, ils, ts)

Check warning on line 285 in src/independentlylinearizedutils.jl

View check run for this annotation

Codecov / codecov/patch

src/independentlylinearizedutils.jl#L284-L285

Added lines #L284 - L285 were not covered by tests
end
Loading

0 comments on commit 351b2d7

Please sign in to comment.