Skip to content

Commit

Permalink
ProductOf should use insupport
Browse files Browse the repository at this point in the history
  • Loading branch information
bvdmitri committed Jun 13, 2024
1 parent 353cbb1 commit 79fbb78
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 4 deletions.
9 changes: 7 additions & 2 deletions src/prod.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,14 @@ function BayesBase.support(product::ProductOf)
return fuse_supports(support(getleft(product)), support(getright(product)))
end

function BayesBase.insupport(product::ProductOf, x)
return insupport(getleft(product), x) && insupport(getright(product), x)
end

BayesBase.pdf(product::ProductOf, x) = exp(logpdf(product, x))

function BayesBase.logpdf(product::ProductOf, x)
@assert x support(product) "The `$(x)` does not belong to the support of the product `$(product)`"
@assert insupport(product, x) lazy"The `$(x)` does not belong to the support of the product `$(product)`"
return logpdf(getleft(product), x) + logpdf(getright(product), x)
end

Expand Down Expand Up @@ -332,6 +336,7 @@ function Base.push!(product::LinearizedProductOf{F}, item::F) where {F}
end

BayesBase.support(dist::LinearizedProductOf) = support(first(dist.vector))
BayesBase.insupport(dist::LinearizedProductOf, x) = insupport(first(dist.vector), x)

Base.length(product::LinearizedProductOf) = product.length
Base.eltype(product::LinearizedProductOf) = eltype(first(product.vector))
Expand All @@ -352,7 +357,7 @@ function Base.show(io::IO, product::LinearizedProductOf{F}) where {F}
end

function BayesBase.logpdf(product::LinearizedProductOf, x)
@assert x support(product) "The `$(x)` does not belong to the support of the product `$(product)`"
@assert insupport(product, x) "The `$(x)` does not belong to the support of the product `$(product)`"
return mapreduce(
(d) -> logpdf(d, x),
+,
Expand Down
4 changes: 2 additions & 2 deletions test/densities/function_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ end
d5 = ContinuousUnivariateLogPdf(FullSpace(), (x) -> 2.0 * -x^2)
d6 = ContinuousUnivariateLogPdf(HalfLine(), (x) -> 2.0 * -x^2)

@test_throws AssertionError logpdf(prod(GenericProd(), d5, d6), 1.0) # domains are different
@test_throws AssertionError logpdf(prod(GenericProd(), d5, d6), -1.0) # domains are different
@test logpdf(prod(GenericProd(), d5, d6), 1.0) -4.0
@test_throws AssertionError logpdf(prod(GenericProd(), d5, d6), -1.0) # supports are different
end

@testitem "ContinuousUnivariateLogPdf: vectorised-prod" begin
Expand Down
36 changes: 36 additions & 0 deletions test/prod_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,4 +292,40 @@ end
@test_throws ErrorException resolve_prod_strategy(
PreserveTypeRightProd(), PreserveTypeLeftProd()
)
end

@testitem "`ProductOf` should support distributions that do not have explicitly defined `support`" begin

struct SomeComplexDistribution end

BayesBase.support(::SomeComplexDistribution) = error("not defined")
BayesBase.insupport(::SomeComplexDistribution, x) = x > 0
BayesBase.logpdf(::SomeComplexDistribution, x) = x

prod = ProductOf(SomeComplexDistribution(), SomeComplexDistribution())

@test_throws "not defined" support(prod)
@test insupport(prod, 1)
@test !insupport(prod, -1)
@test logpdf(prod, 1) === 2
@test logpdf(prod, 2) === 4

end

@testitem "`LinearizedProductOf` should support distributions that do not have explicitly defined `support`" begin

struct SomeComplexDistribution end

BayesBase.support(::SomeComplexDistribution) = error("not defined")
BayesBase.insupport(::SomeComplexDistribution, x) = x > 0
BayesBase.logpdf(::SomeComplexDistribution, x) = x

prod = LinearizedProductOf([SomeComplexDistribution(), SomeComplexDistribution()], 2)

@test_throws "not defined" support(prod)
@test insupport(prod, 1)
@test !insupport(prod, -1)
@test logpdf(prod, 1) === 2
@test logpdf(prod, 2) === 4

end

0 comments on commit 79fbb78

Please sign in to comment.