From 0b96e24ea9ec331666452007d7292ad71640fe7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Wed, 17 Apr 2024 16:30:55 +0200 Subject: [PATCH] allow for explicit size specification --- ext/PreallocationToolsReverseDiffExt.jl | 3 +-- src/PreallocationTools.jl | 8 +++++--- test/general_lbc.jl | 8 ++++++++ 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/ext/PreallocationToolsReverseDiffExt.jl b/ext/PreallocationToolsReverseDiffExt.jl index 3ee68dc..4c7f323 100644 --- a/ext/PreallocationToolsReverseDiffExt.jl +++ b/ext/PreallocationToolsReverseDiffExt.jl @@ -4,8 +4,7 @@ using PreallocationTools isdefined(Base, :get_extension) ? (import ReverseDiff) : (import ..ReverseDiff) # PreallocationTools https://github.com/SciML/PreallocationTools.jl/issues/39 -function Base.getindex(b::PreallocationTools.LazyBufferCache, u::ReverseDiff.TrackedArray) - s = b.sizemap(size(u)) # required buffer size +function Base.getindex(b::PreallocationTools.LazyBufferCache, u::ReverseDiff.TrackedArray, s = b.sizemap(size(u))) T = ReverseDiff.TrackedArray buf = get!(b.bufs, (T, s)) do # declare type since b.bufs dictionary is untyped diff --git a/src/PreallocationTools.jl b/src/PreallocationTools.jl index 1226eaf..93b5d91 100644 --- a/src/PreallocationTools.jl +++ b/src/PreallocationTools.jl @@ -202,6 +202,9 @@ end A lazily allocated buffer object. Given an array `u`, `b[u]` returns an array of the same type and size `f(size(u))` (defaulting to the same size), which is allocated as needed and then cached within `b` for subsequent usage. + +Optionally, the size can be explicitly given at calltime using `b[u,s]`, which will +return a cache of size `s`. """ struct LazyBufferCache{F <: Function} bufs::Dict{Any, Any} # a dictionary mapping (type, size) pairs to buffers @@ -216,15 +219,14 @@ function similar_type(x::AbstractArray{T}, s::NTuple{N, Integer}) where {T, N} typeof(similar(x, ntuple(Returns(1), N))) end -function get_tmp(b::LazyBufferCache, u::T) where {T <: AbstractArray} - s = b.sizemap(size(u)) # required buffer size +function get_tmp(b::LazyBufferCache, u::T, s=b.sizemap(size(u))) where {T <: AbstractArray} get!(b.bufs, (T, s)) do similar(u, s) # buffer to allocate if it was not found in b.bufs end::similar_type(u, s) # declare type since b.bufs dictionary is untyped end # override the [] method -Base.getindex(b::LazyBufferCache, u::T) where {T <: AbstractArray} = get_tmp(b, u) +Base.getindex(b::LazyBufferCache, u::T, s=b.sizemap(size(u))) where {T <: AbstractArray} = get_tmp(b, u, s) # GeneralLazyBufferCache diff --git a/test/general_lbc.jl b/test/general_lbc.jl index db606d3..a35e5f3 100644 --- a/test/general_lbc.jl +++ b/test/general_lbc.jl @@ -42,12 +42,20 @@ y = view(x, 1:900) @test 0 == @allocated cache[y] @test cache[y] === get_tmp(cache, y) +@inferred cache[x, 1111] +@test 0 == @allocated cache[x, 1111] +@test size(cache[x, 1111]) == (1111,) + cache_17 = LazyBufferCache(Returns(17)) x = 1:10 @inferred cache_17[x] @test 0 == @allocated cache_17[x] @test size(cache_17[x]) == (17,) +@inferred cache_17[x, 1111] +@test 0 == @allocated cache_17[x, 1111] +@test size(cache_17[x, 1111]) == (1111,) + cache = GeneralLazyBufferCache(T -> Vector{T}(undef, 1000)) # GeneralLazyBufferCache is documented not to infer. # @inferred cache[Float64]