Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into torfjelde/ordered-f…
Browse files Browse the repository at this point in the history
…or-monotonic
  • Loading branch information
torfjelde committed Dec 10, 2023
2 parents 3d404f1 + 2402be2 commit 505b92e
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 4 deletions.
9 changes: 5 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extras]
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
BijectorsDistributionsADExt = "DistributionsAD"
BijectorsForwardDiffExt = "ForwardDiff"
Expand All @@ -45,7 +42,6 @@ BijectorsZygoteExt = "Zygote"
ArgCheck = "1, 2"
ChainRules = "1"
ChainRulesCore = "0.10.11, 1"
Zygote = "0.6.63"
ChangesOfVariables = "0.1"
Compat = "3, 4"
Distributions = "0.25.33"
Expand All @@ -57,4 +53,9 @@ MappedArrays = "0.2.2, 0.3, 0.4"
Reexport = "0.2, 1"
Requires = "0.5, 1"
Roots = "1.3.4, 2"
Statistics = "1"
Zygote = "0.6.63"
julia = "1.6"

[extras]
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1 change: 1 addition & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ with_logabsdet_jacobian(f::Columnwise, x::AbstractMatrix) = (f(x), logabsdetjac(
Returns the output size of `f` given the input size `sz`.
"""
output_size(f, sz) = sz
output_size(f::ComposedFunction, sz) = output_size(f.outer, output_size(f.inner, sz))

"""
output_length(f, len::Int)
Expand Down
40 changes: 40 additions & 0 deletions test/bijectors/stacked.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,44 @@ end
@test y == [exp(1.0), 2.0]
@test binv(y) == [1.0, 2.0, 0.0]
end

@testset "composition" begin
# Composition with one dimension reduction.
b = Stacked((elementwise(exp), ProjectionBijector() identity), [1:1, 2:3])
binv = inverse(b)
x = [1.0, 2.0, 3.0]
y = b(x)
x_ = binv(y)

# Are the values of correct size?
@test size(y) == (2,)
@test size(x_) == (3,)
# Can we determine the sizes correctly?
@test Bijectors.output_size(b, size(x)) == (2,)
@test Bijectors.output_size(binv, size(y)) == (3,)

# Are values correct?
@test y == [exp(1.0), 2.0]
@test binv(y) == [1.0, 2.0, 0.0]

# Composition with two dimension reductions.
b = Stacked(
(elementwise(exp), ProjectionBijector() ProjectionBijector()), [1:1, 2:4]
)
binv = inverse(b)
x = [1.0, 2.0, 3.0, 4.0]
y = b(x)
x_ = binv(y)

# Are the values of correct size?
@test size(y) == (2,)
@test size(x_) == (4,)
# Can we determine the sizes correctly?
@test Bijectors.output_size(b, size(x)) == (2,)
@test Bijectors.output_size(binv, size(y)) == (4,)

# Are values correct?
@test y == [exp(1.0), 2.0]
@test binv(y) == [1.0, 2.0, 0.0, 0.0]
end
end

0 comments on commit 505b92e

Please sign in to comment.