From 08392094bb28f19b02f039685956b4d05f8a69a3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 19 Apr 2024 08:49:20 +0100 Subject: [PATCH] Fixed incorrect reference in DistributionsAD extension (#299) * fixed incorrect reference in DistributionsAD extension * bump patch version * added tests for #298 * bump patch version --- Project.toml | 2 +- ext/BijectorsDistributionsADExt.jl | 2 +- test/distributionsad.jl | 7 +++++++ test/runtests.jl | 2 ++ 4 files changed, 11 insertions(+), 2 deletions(-) create mode 100644 test/distributionsad.jl diff --git a/Project.toml b/Project.toml index 2a7ccaf3..5c3db28c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.13.9" +version = "0.13.10" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/ext/BijectorsDistributionsADExt.jl b/ext/BijectorsDistributionsADExt.jl index 99534b0c..69dce30a 100644 --- a/ext/BijectorsDistributionsADExt.jl +++ b/ext/BijectorsDistributionsADExt.jl @@ -51,7 +51,7 @@ function Bijectors.bijector(d::FillMatrixOfUnivariate{Continuous}) end Bijectors.bijector(d::MatrixOfUnivariate{Discrete}) = identity function Bijectors.bijector(d::MatrixOfUnivariate{Continuous}) - return TruncatedBijectors.Bijector(_minmax(d.dists)...) + return Bijectors.TruncatedBijector(Bijectors._minmax(d.dists)...) end Bijectors.bijector(d::VectorOfMultivariate{Discrete}) = identity for T in (:VectorOfMultivariate, :FillVectorOfMultivariate) diff --git a/test/distributionsad.jl b/test/distributionsad.jl new file mode 100644 index 00000000..d30fb0dc --- /dev/null +++ b/test/distributionsad.jl @@ -0,0 +1,7 @@ +@testset "DistributionsAD" begin + # https://github.com/TuringLang/Bijectors.jl/issues/298 + @testset "#298" begin + dists = arraydist(fill(InverseGamma(), 2, 2)) + @test bijector(dists) isa Bijectors.TruncatedBijector + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 2cb3b7df..5c4f2df6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -48,6 +48,8 @@ if GROUP == "All" || GROUP == "Interface" include("bijectors/pd.jl") include("bijectors/reshape.jl") include("bijectors/corr.jl") + + include("distributionsad.jl") end if GROUP == "All" || GROUP == "AD"