From 4e2886284939f2bc48047f1bf311c91558343f44 Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Mon, 10 Jun 2024 11:00:50 +0200 Subject: [PATCH 1/5] Add InplaceLogpdf wrapper --- docs/src/index.md | 1 + src/statsfuns.jl | 69 +++++++++++++++++++++++++++++++++++++++++ test/statsfuns_tests.jl | 61 ++++++++++++++++++++++++++++++++++++ 3 files changed, 131 insertions(+) diff --git a/docs/src/index.md b/docs/src/index.md index 7552b0c..a7c658b 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -113,4 +113,5 @@ BayesBase.distribution_typewrapper BayesBase.CountingReal BayesBase.Infinity BayesBase.MinusInfinity +BayesBase.InplaceLogpdf ``` diff --git a/src/statsfuns.jl b/src/statsfuns.jl index 5ecaea0..f5e00b4 100644 --- a/src/statsfuns.jl +++ b/src/statsfuns.jl @@ -326,3 +326,72 @@ Base.promote_rule(::Type{CountingReal}, ::Type{T}) where {T<:Real} = CountingRea function Base.:(==)(left::CountingReal{T}, right::CountingReal{T}) where {T} return (value(left) == value(right)) && (infinities(left) == infinities(right)) end + +""" + InplaceLogpdf(logpdf!) + +Wraps a `logpdf!` function in a type that can later on be used for dispatch. +The sole purpose of this wrapper type is to allow for in-place logpdf operation on a batch of samples. +Accepts a function `logpdf!` that takes two arguments: `out` and `sample` and writes the logpdf of the sample to the `out` array. +A regular `logpdf` function can be converted to `logpdf!` by using `convert(InplaceLogpdf, logpdf)`. + +```jldoctest +julia> using Distributions, BayesBase + +julia> d = Beta(2, 3); + +julia> inplace = convert(BayesBase.InplaceLogpdf, (sample) -> logpdf(d, sample)); + +julia> out = zeros(9); + +julia> inplace(out, 0.1:0.1:0.9) +9-element Vector{Float64}: + -0.028399474521697776 + 0.42918163472548043 + 0.5675839575845996 + 0.5469646703818638 + 0.4054651081081646 + 0.14149956227369964 + -0.2797139028026039 + -0.9571127263944104 + -2.2256240518579173 +``` + +```jldoctest +julia> using Distributions, BayesBase + +julia> d = Beta(2, 3); + +julia> inplace = BayesBase.InplaceLogpdf((out, sample) -> logpdf!(out, d, sample)); + +julia> out = zeros(9); + +julia> inplace(out, 0.1:0.1:0.9) +9-element Vector{Float64}: + -0.028399474521697776 + 0.42918163472548043 + 0.5675839575845996 + 0.5469646703818638 + 0.4054651081081646 + 0.14149956227369964 + -0.2797139028026039 + -0.9571127263944104 + -2.2256240518579173 +``` +""" +struct InplaceLogpdf{F} + logpdf!::F +end + +function (inplace::InplaceLogpdf)(out, x) + inplace.logpdf!(out, x) + return out +end + +function Base.convert(::Type{InplaceLogpdf}, something) + return InplaceLogpdf((out, x) -> map!(something, out, x)) +end + +function Base.convert(::Type{InplaceLogpdf}, inplace::InplaceLogpdf) + return inplace +end \ No newline at end of file diff --git a/test/statsfuns_tests.jl b/test/statsfuns_tests.jl index 4211b96..94f2fc0 100644 --- a/test/statsfuns_tests.jl +++ b/test/statsfuns_tests.jl @@ -89,4 +89,65 @@ end @test float(convert(CountingReal{Float64}, r)) ≈ zero(Float64) end +end + +@testitem "InplaceLogpdf" begin + import BayesBase: InplaceLogpdf + using Distributions, LinearAlgebra, StableRNGs + + @testset "Vector based samples" begin + distribution = Beta(10, 10) + fn = (x) -> logpdf(distribution, x) + inplacefn = convert(InplaceLogpdf, fn) + + @test fn !== inplacefn + + rng = StableRNG(42) + samples = rand(rng, distribution, 100) + evaluated = map(fn, samples) + + + container = similar(evaluated) + inplacefn(container, samples) + + @test evaluated == container + end + + @testset "Matrix based samples" begin + distribution = MvNormal(ones(2), ones(2)) + fn = (x) -> logpdf(distribution, x) + inplacefn = convert(InplaceLogpdf, fn) + + @test inplacefn !== fn + + rng = StableRNG(42) + samples = rand(rng, distribution, 100) + evaluated = map(fn, eachcol(samples)) + + + container = similar(evaluated) + inplacefn(container, eachcol(samples)) + + @test evaluated == container + end + + @testset "Do not convert already inplace version" begin + distribution = MvNormal(ones(2), ones(2)) + fn = InplaceLogpdf((out, x) -> logpdf!(out, distribution, x)) + inplacefn = convert(InplaceLogpdf, fn) + + @test inplacefn === fn + + rng = StableRNG(42) + samples = rand(rng, distribution, 100) + evaluated = zeros(100) + fn(evaluated, eachcol(samples)) + + container = similar(evaluated) + inplacefn(container, eachcol(samples)) + + @test evaluated == container + end + + end \ No newline at end of file From b7fab60db316039078a42fa63f8a783f695c1cc5 Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Mon, 10 Jun 2024 11:01:21 +0200 Subject: [PATCH 2/5] formatting --- test/statsfuns_tests.jl | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/test/statsfuns_tests.jl b/test/statsfuns_tests.jl index 94f2fc0..1042aa5 100644 --- a/test/statsfuns_tests.jl +++ b/test/statsfuns_tests.jl @@ -26,10 +26,10 @@ end end end -@testitem "dtanh" begin +@testitem "dtanh" begin for T in (Float32, Float64, BigFloat) foreach(rand(T, 10)) do number - @test dtanh(number) ≈ 1 - tanh(number) ^ 2 + @test dtanh(number) ≈ 1 - tanh(number)^2 end end end @@ -87,7 +87,6 @@ end @test float(convert(CountingReal, r)) ≈ zero(T) @test float(convert(CountingReal{Float64}, r)) ≈ zero(Float64) - end end @@ -106,7 +105,6 @@ end samples = rand(rng, distribution, 100) evaluated = map(fn, samples) - container = similar(evaluated) inplacefn(container, samples) @@ -124,7 +122,6 @@ end samples = rand(rng, distribution, 100) evaluated = map(fn, eachcol(samples)) - container = similar(evaluated) inplacefn(container, eachcol(samples)) @@ -148,6 +145,4 @@ end @test evaluated == container end - - end \ No newline at end of file From 5b07e7d57a8a16ee820f71ed4c1633fd9a5bc565 Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Mon, 10 Jun 2024 11:30:48 +0200 Subject: [PATCH 3/5] re-export logpdf! --- src/BayesBase.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/BayesBase.jl b/src/BayesBase.jl index 6df188d..19ff7ea 100644 --- a/src/BayesBase.jl +++ b/src/BayesBase.jl @@ -30,6 +30,7 @@ import Distributions: pdf!, cdf, logpdf, + logpdf!, logdetcov, VariateForm, ValueSupport, @@ -58,6 +59,7 @@ export failprob, pdf!, cdf, logpdf, + logpdf!, logdetcov, VariateForm, ValueSupport, From d01000d1f692ae3aed0068a302c71fde82bd27d5 Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Mon, 10 Jun 2024 11:38:48 +0200 Subject: [PATCH 4/5] add non-allocating check for simple logpdf! --- test/statsfuns_tests.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/statsfuns_tests.jl b/test/statsfuns_tests.jl index 1042aa5..85316ee 100644 --- a/test/statsfuns_tests.jl +++ b/test/statsfuns_tests.jl @@ -145,4 +145,13 @@ end @test evaluated == container end + + @testset "Shouldn't allocate anything for simple `logpdf!`" begin + fn = InplaceLogpdf((out, x) -> out .= log.(x)) + samples = 1:10 + out = zeros(10) + fn(out, samples) + @test out == log.(samples) + @test @allocated(fn(out, samples)) === 0 + end end \ No newline at end of file From ffc40849e1df83dec19d5f1091e96a1b380ffc33 Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Mon, 10 Jun 2024 11:40:30 +0200 Subject: [PATCH 5/5] add Distributions to the docs project --- docs/Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/Project.toml b/docs/Project.toml index 2c96d67..7300733 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,3 +1,4 @@ [deps] BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"