Skip to content

Commit

Permalink
Merge branch 'master' into torfjelde/product-bijector
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde authored May 7, 2024
2 parents e8fc8cf + d364639 commit 9887270
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.13.9"
version = "0.13.11"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
2 changes: 1 addition & 1 deletion ext/BijectorsDistributionsADExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ function Bijectors.bijector(d::FillMatrixOfUnivariate{Continuous})
end
Bijectors.bijector(d::MatrixOfUnivariate{Discrete}) = identity
function Bijectors.bijector(d::MatrixOfUnivariate{Continuous})
return TruncatedBijectors.Bijector(_minmax(d.dists)...)
return Bijectors.TruncatedBijector(Bijectors._minmax(d.dists)...)
end
Bijectors.bijector(d::VectorOfMultivariate{Discrete}) = identity
for T in (:VectorOfMultivariate, :FillVectorOfMultivariate)
Expand Down
6 changes: 3 additions & 3 deletions src/bijectors/logit.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
######################
# Logit and Logistic #
######################
struct Logit{T} <: Bijector
a::T
b::T
struct Logit{T1,T2} <: Bijector
a::T1
b::T2
end

Functors.@functor Logit
Expand Down
7 changes: 7 additions & 0 deletions test/distributionsad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
@testset "DistributionsAD" begin
# https://github.com/TuringLang/Bijectors.jl/issues/298
@testset "#298" begin
dists = arraydist(fill(InverseGamma(), 2, 2))
@test bijector(dists) isa Bijectors.TruncatedBijector
end
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ if GROUP == "All" || GROUP == "Interface"
include("bijectors/pd.jl")
include("bijectors/reshape.jl")
include("bijectors/corr.jl")

include("distributionsad.jl")
end

if GROUP == "All" || GROUP == "AD"
Expand Down

0 comments on commit 9887270

Please sign in to comment.