Skip to content

Commit

Permalink
allow for explicit size specification
Browse files Browse the repository at this point in the history
  • Loading branch information
hexaeder committed Apr 17, 2024
1 parent 1f8db64 commit 0b96e24
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
3 changes: 1 addition & 2 deletions ext/PreallocationToolsReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ using PreallocationTools
isdefined(Base, :get_extension) ? (import ReverseDiff) : (import ..ReverseDiff)

# PreallocationTools https://github.com/SciML/PreallocationTools.jl/issues/39
function Base.getindex(b::PreallocationTools.LazyBufferCache, u::ReverseDiff.TrackedArray)
s = b.sizemap(size(u)) # required buffer size
function Base.getindex(b::PreallocationTools.LazyBufferCache, u::ReverseDiff.TrackedArray, s = b.sizemap(size(u)))
T = ReverseDiff.TrackedArray
buf = get!(b.bufs, (T, s)) do
# declare type since b.bufs dictionary is untyped
Expand Down
8 changes: 5 additions & 3 deletions src/PreallocationTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ end
A lazily allocated buffer object. Given an array `u`, `b[u]` returns an array of the
same type and size `f(size(u))` (defaulting to the same size), which is allocated as
needed and then cached within `b` for subsequent usage.
Optionally, the size can be explicitly given at calltime using `b[u,s]`, which will
return a cache of size `s`.
"""
struct LazyBufferCache{F <: Function}
bufs::Dict{Any, Any} # a dictionary mapping (type, size) pairs to buffers
Expand All @@ -216,15 +219,14 @@ function similar_type(x::AbstractArray{T}, s::NTuple{N, Integer}) where {T, N}
typeof(similar(x, ntuple(Returns(1), N)))
end

function get_tmp(b::LazyBufferCache, u::T) where {T <: AbstractArray}
s = b.sizemap(size(u)) # required buffer size
function get_tmp(b::LazyBufferCache, u::T, s=b.sizemap(size(u))) where {T <: AbstractArray}
get!(b.bufs, (T, s)) do
similar(u, s) # buffer to allocate if it was not found in b.bufs
end::similar_type(u, s) # declare type since b.bufs dictionary is untyped
end

# override the [] method
Base.getindex(b::LazyBufferCache, u::T) where {T <: AbstractArray} = get_tmp(b, u)
Base.getindex(b::LazyBufferCache, u::T, s=b.sizemap(size(u))) where {T <: AbstractArray} = get_tmp(b, u, s)

# GeneralLazyBufferCache

Expand Down
8 changes: 8 additions & 0 deletions test/general_lbc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,20 @@ y = view(x, 1:900)
@test 0 == @allocated cache[y]
@test cache[y] === get_tmp(cache, y)

@inferred cache[x, 1111]
@test 0 == @allocated cache[x, 1111]
@test size(cache[x, 1111]) == (1111,)

cache_17 = LazyBufferCache(Returns(17))
x = 1:10
@inferred cache_17[x]
@test 0 == @allocated cache_17[x]
@test size(cache_17[x]) == (17,)

@inferred cache_17[x, 1111]
@test 0 == @allocated cache_17[x, 1111]
@test size(cache_17[x, 1111]) == (1111,)

cache = GeneralLazyBufferCache(T -> Vector{T}(undef, 1000))
# GeneralLazyBufferCache is documented not to infer.
# @inferred cache[Float64]
Expand Down

0 comments on commit 0b96e24

Please sign in to comment.