From 04b79dd46eca8cea2f988348c47bd5e720a2b9a4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 4 Sep 2023 23:35:57 +0100 Subject: [PATCH] Fix for #287 (#288) * added has_constant_bijector and made bijector of product distributions return the identity whenever possible * no need to limit ourselves to identity for constant bijectors * no need to limit ourselves to identity for Product * bump patch version * Update test/bijectors/ordered.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fixed tests * attempt at fix for ordered MvTDist test * dispatch on GenericMvTDist instead of TDist * Update test/bijectors/ordered.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added some tests for MvTDist * make elementwise acting on identity return identity * fixed bug in error --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- Project.toml | 2 +- src/interface.jl | 3 ++- src/transformed_distribution.jl | 35 ++++++++++++++++++++++++++++++++- test/bijectors/ordered.jl | 33 +++++++++++++++++-------------- test/interface.jl | 2 ++ 5 files changed, 57 insertions(+), 18 deletions(-) diff --git a/Project.toml b/Project.toml index 38cb6f42..18d60250 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.13.6" +version = "0.13.7" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/src/interface.jl b/src/interface.jl index 099df1bb..8d401e5a 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -14,6 +14,7 @@ In the case where `f::ComposedFunction`, the result is `Base.Fix1(broadcast, f)`. """ elementwise(f) = Base.Fix1(broadcast, f) +elementwise(f::typeof(identity)) = identity # TODO: This is makes dispatching quite a bit easier, but uncertain if this is really # the way to go. function elementwise(f::ComposedFunction) @@ -91,7 +92,7 @@ function transform(t::Transform, x) res = with_logabsdet_jacobian(t, x) if res isa ChangesOfVariables.NoLogAbsDetJacobian error( - "`transform` not implemented for $(typeof(b)); implement `transform` and/or `with_logabsdet_jacobian`.", + "`transform` not implemented for $(typeof(f)); implement `transform` and/or `with_logabsdet_jacobian`.", ) end diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index 9c14fbf5..0126db05 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -47,13 +47,43 @@ function bijector(td::TransformedDistribution) b = bijector(td.dist) return b === identity ? inverse(td.transform) : b ∘ inverse(td.transform) end + +""" + has_constant_bijector(dist_type::Type) + +Returns `true` if the distribution type `dist_type` has a constant bijector, +i.e. the return-value of [`bijector`](@ref) does not depend on runtime information. +""" +has_constant_bijector(d::Type) = false +has_constant_bijector(d::Type{<:Normal}) = true +has_constant_bijector(d::Type{<:Distributions.AbstractMvNormal}) = true +has_constant_bijector(d::Type{<:Distributions.AbstractMvLogNormal}) = true +has_constant_bijector(d::Type{<:TDist}) = true +has_constant_bijector(d::Type{<:Distributions.GenericMvTDist}) = true +has_constant_bijector(d::Type{<:PositiveDistribution}) = true +has_constant_bijector(d::Type{<:SimplexDistribution}) = true +has_constant_bijector(d::Type{<:KSOneSided}) = true +function has_constant_bijector(::Type{<:Product{Continuous,D}}) where {D} + return has_constant_bijector(D) +end + +# Container distributions. bijector(d::DiscreteUnivariateDistribution) = identity bijector(d::DiscreteMultivariateDistribution) = identity bijector(d::ContinuousUnivariateDistribution) = TruncatedBijector(minimum(d), maximum(d)) bijector(d::Product{Discrete}) = identity function bijector(d::Product{Continuous}) - return TruncatedBijector(_minmax(d.v)...) + D = eltype(d.v) + return if has_constant_bijector(D) + elementwise(bijector(d.v[1])) + else + # FIXME: This is not great. Should use something like + # `Stacked(map(bijector, d.v))` instead. + # TODO: Specialize. F.ex. for FillArrays.jl we can do much better. + TruncatedBijector(_minmax(d.v)...) + end end + @generated function _minmax(d::AbstractArray{T}) where {T} try min, max = minimum(T), maximum(T) @@ -63,9 +93,12 @@ end end end +# Specialized implementations. bijector(d::Normal) = identity bijector(d::Distributions.AbstractMvNormal) = identity bijector(d::Distributions.AbstractMvLogNormal) = elementwise(log) +bijector(d::TDist) = identity +bijector(d::Distributions.GenericMvTDist) = identity bijector(d::PositiveDistribution) = elementwise(log) bijector(d::SimplexDistribution) = SimplexBijector() bijector(d::KSOneSided) = Logit(zero(eltype(d)), one(eltype(d))) diff --git a/test/bijectors/ordered.jl b/test/bijectors/ordered.jl index 0f9802e7..9b65bb2e 100644 --- a/test/bijectors/ordered.jl +++ b/test/bijectors/ordered.jl @@ -17,20 +17,23 @@ using LinearAlgebra end @testset "ordered" begin - d = MvNormal(1:5, Diagonal(6:10)) - d_ordered = ordered(d) - @test d_ordered isa Bijectors.TransformedDistribution - @test d_ordered.dist === d - @test d_ordered.transform isa OrderedBijector - y = randn(5) - x = inverse(bijector(d_ordered))(y) - @test issorted(x) + @testset "$d" for d in [ + MvNormal(1:5, Diagonal(6:10)), + MvTDist(1, collect(1.0:5), Matrix(I(5))), + product_distribution(fill(Normal(), 5)), + product_distribution(fill(TDist(1), 5)), + ] + d_ordered = ordered(d) + @test d_ordered isa Bijectors.TransformedDistribution + @test d_ordered.dist === d + @test d_ordered.transform isa OrderedBijector + y = randn(5) + x = inverse(bijector(d_ordered))(y) + @test issorted(x) + end - d = Product(fill(Normal(), 5)) - # currently errors because `bijector(Product(fill(Normal(), 5)))` is not an `Identity` - @test_broken ordered(d) isa Bijectors.TransformedDistribution - - # non-Identity bijector is not supported - d = Dirichlet(ones(5)) - @test_throws ArgumentError ordered(d) + @testset "non-identity bijector is not supported" begin + d = Dirichlet(ones(5)) + @test_throws ArgumentError ordered(d) + end end diff --git a/test/interface.jl b/test/interface.jl index 207adec1..3272e0f1 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -136,6 +136,7 @@ end MvLogNormal(MvNormal(randn(10), Diagonal(exp.(randn(10))))), Dirichlet([1000 * one(Float64), eps(Float64)]), Dirichlet([eps(Float64), 1000 * one(Float64)]), + MvTDist(1, randn(10), Matrix(Diagonal(exp.(randn(10))))), transformed(MvNormal(randn(10), Diagonal(exp.(randn(10))))), transformed(MvLogNormal(MvNormal(randn(10), Diagonal(exp.(randn(10)))))), transformed(reshape(product_distribution(fill(InverseGamma(2, 3), 6)), 2, 3)), @@ -200,6 +201,7 @@ end TuringInverseWishart(v, S), LKJ(3, 1.0), reshape(MvNormal(zeros(6), I), 2, 3), + product_distribution(fill(InverseGamma(2, 3), 6)), ] for dist in matrix_dists