diff --git a/ext/PreallocationToolsReverseDiffExt.jl b/ext/PreallocationToolsReverseDiffExt.jl index 3ee68dc..54311d2 100644 --- a/ext/PreallocationToolsReverseDiffExt.jl +++ b/ext/PreallocationToolsReverseDiffExt.jl @@ -4,8 +4,8 @@ using PreallocationTools isdefined(Base, :get_extension) ? (import ReverseDiff) : (import ..ReverseDiff) # PreallocationTools https://github.com/SciML/PreallocationTools.jl/issues/39 -function Base.getindex(b::PreallocationTools.LazyBufferCache, u::ReverseDiff.TrackedArray) - s = b.sizemap(size(u)) # required buffer size +function Base.getindex(b::PreallocationTools.LazyBufferCache, + u::ReverseDiff.TrackedArray, s = b.sizemap(size(u))) T = ReverseDiff.TrackedArray buf = get!(b.bufs, (T, s)) do # declare type since b.bufs dictionary is untyped diff --git a/src/PreallocationTools.jl b/src/PreallocationTools.jl index 1226eaf..80e3c34 100644 --- a/src/PreallocationTools.jl +++ b/src/PreallocationTools.jl @@ -202,6 +202,9 @@ end A lazily allocated buffer object. Given an array `u`, `b[u]` returns an array of the same type and size `f(size(u))` (defaulting to the same size), which is allocated as needed and then cached within `b` for subsequent usage. + +Optionally, the size can be explicitly given at calltime using `b[u,s]`, which will +return a cache of size `s`. """ struct LazyBufferCache{F <: Function} bufs::Dict{Any, Any} # a dictionary mapping (type, size) pairs to buffers @@ -216,15 +219,18 @@ function similar_type(x::AbstractArray{T}, s::NTuple{N, Integer}) where {T, N} typeof(similar(x, ntuple(Returns(1), N))) end -function get_tmp(b::LazyBufferCache, u::T) where {T <: AbstractArray} - s = b.sizemap(size(u)) # required buffer size +function get_tmp( + b::LazyBufferCache, u::T, s = b.sizemap(size(u))) where {T <: AbstractArray} get!(b.bufs, (T, s)) do similar(u, s) # buffer to allocate if it was not found in b.bufs end::similar_type(u, s) # declare type since b.bufs dictionary is untyped end # override the [] method -Base.getindex(b::LazyBufferCache, u::T) where {T <: AbstractArray} = get_tmp(b, u) +function Base.getindex( + b::LazyBufferCache, u::T, s = b.sizemap(size(u))) where {T <: AbstractArray} + get_tmp(b, u, s) +end # GeneralLazyBufferCache diff --git a/test/general_lbc.jl b/test/general_lbc.jl index db606d3..a35e5f3 100644 --- a/test/general_lbc.jl +++ b/test/general_lbc.jl @@ -42,12 +42,20 @@ y = view(x, 1:900) @test 0 == @allocated cache[y] @test cache[y] === get_tmp(cache, y) +@inferred cache[x, 1111] +@test 0 == @allocated cache[x, 1111] +@test size(cache[x, 1111]) == (1111,) + cache_17 = LazyBufferCache(Returns(17)) x = 1:10 @inferred cache_17[x] @test 0 == @allocated cache_17[x] @test size(cache_17[x]) == (17,) +@inferred cache_17[x, 1111] +@test 0 == @allocated cache_17[x, 1111] +@test size(cache_17[x, 1111]) == (1111,) + cache = GeneralLazyBufferCache(T -> Vector{T}(undef, 1000)) # GeneralLazyBufferCache is documented not to infer. # @inferred cache[Float64]