From 79fbb7809f8a68e0f698e47906040ea2b34b5213 Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Thu, 13 Jun 2024 10:02:52 +0200 Subject: [PATCH] ProductOf should use `insupport` --- src/prod.jl | 9 ++++++-- test/densities/function_tests.jl | 4 ++-- test/prod_tests.jl | 36 ++++++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 4 deletions(-) diff --git a/src/prod.jl b/src/prod.jl index 169fd63..b13e10b 100644 --- a/src/prod.jl +++ b/src/prod.jl @@ -182,10 +182,14 @@ function BayesBase.support(product::ProductOf) return fuse_supports(support(getleft(product)), support(getright(product))) end +function BayesBase.insupport(product::ProductOf, x) + return insupport(getleft(product), x) && insupport(getright(product), x) +end + BayesBase.pdf(product::ProductOf, x) = exp(logpdf(product, x)) function BayesBase.logpdf(product::ProductOf, x) - @assert x ∈ support(product) "The `$(x)` does not belong to the support of the product `$(product)`" + @assert insupport(product, x) lazy"The `$(x)` does not belong to the support of the product `$(product)`" return logpdf(getleft(product), x) + logpdf(getright(product), x) end @@ -332,6 +336,7 @@ function Base.push!(product::LinearizedProductOf{F}, item::F) where {F} end BayesBase.support(dist::LinearizedProductOf) = support(first(dist.vector)) +BayesBase.insupport(dist::LinearizedProductOf, x) = insupport(first(dist.vector), x) Base.length(product::LinearizedProductOf) = product.length Base.eltype(product::LinearizedProductOf) = eltype(first(product.vector)) @@ -352,7 +357,7 @@ function Base.show(io::IO, product::LinearizedProductOf{F}) where {F} end function BayesBase.logpdf(product::LinearizedProductOf, x) - @assert x ∈ support(product) "The `$(x)` does not belong to the support of the product `$(product)`" + @assert insupport(product, x) "The `$(x)` does not belong to the support of the product `$(product)`" return mapreduce( (d) -> logpdf(d, x), +, diff --git a/test/densities/function_tests.jl b/test/densities/function_tests.jl index a780232..f711379 100644 --- a/test/densities/function_tests.jl +++ b/test/densities/function_tests.jl @@ -191,8 +191,8 @@ end d5 = ContinuousUnivariateLogPdf(FullSpace(), (x) -> 2.0 * -x^2) d6 = ContinuousUnivariateLogPdf(HalfLine(), (x) -> 2.0 * -x^2) - @test_throws AssertionError logpdf(prod(GenericProd(), d5, d6), 1.0) # domains are different - @test_throws AssertionError logpdf(prod(GenericProd(), d5, d6), -1.0) # domains are different + @test logpdf(prod(GenericProd(), d5, d6), 1.0) ≈ -4.0 + @test_throws AssertionError logpdf(prod(GenericProd(), d5, d6), -1.0) # supports are different end @testitem "ContinuousUnivariateLogPdf: vectorised-prod" begin diff --git a/test/prod_tests.jl b/test/prod_tests.jl index 90f72eb..db372b8 100644 --- a/test/prod_tests.jl +++ b/test/prod_tests.jl @@ -292,4 +292,40 @@ end @test_throws ErrorException resolve_prod_strategy( PreserveTypeRightProd(), PreserveTypeLeftProd() ) +end + +@testitem "`ProductOf` should support distributions that do not have explicitly defined `support`" begin + + struct SomeComplexDistribution end + + BayesBase.support(::SomeComplexDistribution) = error("not defined") + BayesBase.insupport(::SomeComplexDistribution, x) = x > 0 + BayesBase.logpdf(::SomeComplexDistribution, x) = x + + prod = ProductOf(SomeComplexDistribution(), SomeComplexDistribution()) + + @test_throws "not defined" support(prod) + @test insupport(prod, 1) + @test !insupport(prod, -1) + @test logpdf(prod, 1) === 2 + @test logpdf(prod, 2) === 4 + +end + +@testitem "`LinearizedProductOf` should support distributions that do not have explicitly defined `support`" begin + + struct SomeComplexDistribution end + + BayesBase.support(::SomeComplexDistribution) = error("not defined") + BayesBase.insupport(::SomeComplexDistribution, x) = x > 0 + BayesBase.logpdf(::SomeComplexDistribution, x) = x + + prod = LinearizedProductOf([SomeComplexDistribution(), SomeComplexDistribution()], 2) + + @test_throws "not defined" support(prod) + @test insupport(prod, 1) + @test !insupport(prod, -1) + @test logpdf(prod, 1) === 2 + @test logpdf(prod, 2) === 4 + end \ No newline at end of file