But the real utility of TransformedDistribution
becomes more apparent when using transformed(dist, b)
for any bijector b
. To get the transformed distribution corresponding to the Beta(2, 2)
, we called transformed(dist)
before. This is simply an alias for transformed(dist, bijector(dist))
. Remember bijector(dist)
returns the constrained-to-constrained bijector for that particular Distribution
. But we can of course construct a TransformedDistribution
using different bijectors with the same dist
. This is particularly useful in something called Automatic Differentiation Variational Inference (ADVI).[2] An important part of ADVI is to approximate a constrained distribution, e.g. Beta
, as follows:
- Sample
x
from a Normal
with parameters μ
and σ
, i.e. x ~ Normal(μ, σ)
. - Transform
x
to y
s.t. y ∈ support(Beta)
, with the transform being a differentiable bijection with a differentiable inverse (a "bijector")
This then defines a probability density with same support as Beta
! Of course, it's unlikely that it will be the same density, but it's an approximation. Creating such a distribution becomes trivial with Bijector
and TransformedDistribution
:
julia> using StableRNGs: StableRNG
julia> rng = StableRNG(42);
julia> dist = Beta(2, 2)
Beta{Float64}(α=2.0, β=2.0)
julia> b = bijector(dist) # (0, 1) → ℝ
Bijectors.Logit{Float64}(0.0, 1.0)
julia> b⁻¹ = inverse(b) # ℝ → (0, 1)
Inverse{Bijectors.Logit{Float64}}(Bijectors.Logit{Float64}(0.0, 1.0))
julia> td = transformed(Normal(), b⁻¹) # x ∼ 𝓝(0, 1) then b(x) ∈ (0, 1)
UnivariateTransformed{Normal{Float64}, Inverse{Bijectors.Logit{Float64}}}(
+dist: Normal{Float64}(μ=0.0, σ=1.0)
+transform: Inverse{Bijectors.Logit{Float64}}(Bijectors.Logit{Float64}(0.0, 1.0))
+)
julia> x = rand(rng, td) # ∈ (0, 1)
0.3384404850130036
It's worth noting that support(Beta)
is the closed interval [0, 1]
, while the constrained-to-unconstrained bijection, Logit
in this case, is only well-defined as a map (0, 1) → ℝ
for the open interval (0, 1)
. This is of course not an implementation detail. ℝ
is itself open, thus no continuous bijection exists from a closed interval to ℝ
. But since the boundaries of a closed interval has what's known as measure zero, this doesn't end up affecting the resulting density with support on the entire real line. In practice, this means that
julia> td = transformed(Beta())
UnivariateTransformed{Beta{Float64}, Bijectors.Logit{Float64}}(
+dist: Beta{Float64}(α=1.0, β=1.0)
+transform: Bijectors.Logit{Float64}(0.0, 1.0)
+)
julia> inverse(td.transform)(rand(rng, td))
0.8130302707446476
will never result in 0
or 1
though any sample arbitrarily close to either 0
or 1
is possible. Disclaimer: numerical accuracy is limited, so you might still see 0
and 1
if you're lucky.
We can also do multivariate ADVI using the Stacked
bijector. Stacked
gives us a way to combine univariate and/or multivariate bijectors into a singe multivariate bijector. Say you have a vector x
of length 2 and you want to transform the first entry using Exp
and the second entry using Log
. Stacked
gives you an easy and efficient way of representing such a bijector.
julia> using Bijectors: SimplexBijector
+
+ # Original distributions
julia> dists = (Beta(), InverseGamma(), Dirichlet(2, 3));
+
+ # Construct the corresponding ranges
julia> ranges = [];
julia> idx = 1;
julia> for i in 1:length(dists)
+ d = dists[i]
+ push!(ranges, idx:(idx + length(d) - 1))
+
+ global idx
+ idx += length(d)
+ end;
julia> ranges
+
+ # Base distribution; mean-field normal
3-element Vector{Any}:
+ 1:1
+ 2:2
+ 3:4
julia> num_params = ranges[end][end]
4
julia> d = MvNormal(zeros(num_params), ones(num_params));
+
+ # Construct the transform
julia> bs = bijector.(dists); # constrained-to-unconstrained bijectors for dists
julia> ibs = inverse.(bs); # invert, so we get unconstrained-to-constrained
julia> sb = Stacked(ibs, ranges) # => Stacked <: Bijector
+
+ # Mean-field normal with unconstrained-to-constrained stacked bijector
Stacked(Any[Inverse{Bijectors.Logit{Float64}}(Bijectors.Logit{Float64}(0.0, 1.0)), Base.Fix1{typeof(broadcast), typeof(exp)}(broadcast, exp), Inverse{Bijectors.SimplexBijector}(Bijectors.SimplexBijector())], Any[1:1, 2:2, 3:4], Any[1:1, 2:2, 3:5])
julia> td = transformed(d, sb);
julia> y = rand(td)
5-element Vector{Float64}:
+ 0.5238835820805957
+ 0.1876901852149366
+ 0.35543087024186415
+ 0.270091955711644
+ 0.37447717404649183
julia> 0.0 ≤ y[1] ≤ 1.0
true
julia> 0.0 < y[2]
true
julia> sum(y[3:4]) ≈ 1.0
false
A very interesting application is that of normalizing flows.[1] Usually this is done by sampling from a multivariate normal distribution, and then transforming this to a target distribution using invertible neural networks. Currently there are two such transforms available in Bijectors.jl: PlanarLayer
and RadialLayer
. Let's create a flow with a single PlanarLayer
:
julia> d = MvNormal(zeros(2), ones(2));
julia> b = PlanarLayer(2)
PlanarLayer(w = [-0.8664015375048035, 0.5750920096114442], u = [-0.36607589140268626, -0.9396265935545343], b = [1.0400481176727907])
julia> flow = transformed(d, b)
MultivariateTransformed{DiagNormal, PlanarLayer{Vector{Float64}, Vector{Float64}}}(
+dist: DiagNormal(
+dim: 2
+μ: [0.0, 0.0]
+Σ: [1.0 0.0; 0.0 1.0]
+)
+
+transform: PlanarLayer(w = [-0.8664015375048035, 0.5750920096114442], u = [-0.36607589140268626, -0.9396265935545343], b = [1.0400481176727907])
+)
julia> flow isa MultivariateDistribution
true
That's it. Now we can sample from it using rand
and compute the logpdf
, like any other Distribution
.
julia> y = rand(rng, flow)
2-element Vector{Float64}:
+ -0.8750644748706963
+ -0.5455058397340147
julia> logpdf(flow, y) # uses inverse of `b`
-2.124945483960251
Similarily to the multivariate ADVI example, we could use Stacked
to get a bounded flow:
julia> d = MvNormal(zeros(2), ones(2));
julia> ibs = inverse.(bijector.((InverseGamma(2, 3), Beta())));
julia> sb = stack(ibs...) # == Stacked(ibs) == Stacked(ibs, [i:i for i = 1:length(ibs)]
ERROR: MethodError: no method matching length(::Inverse{Bijectors.Logit{Float64}})
+
+Closest candidates are:
+ length(!Matched::Core.Compiler.InstructionStream)
+ @ Base show.jl:2777
+ length(!Matched::DataStructures.IntSet)
+ @ DataStructures ~/.julia/packages/DataStructures/aD5vv/src/int_set.jl:191
+ length(!Matched::Documenter.Utilities.Markdown2.MD)
+ @ Documenter ~/.julia/packages/Documenter/bYYzK/src/Utilities/Markdown2.jl:83
+ ...
julia> b = sb ∘ PlanarLayer(2)
ERROR: UndefVarError: `sb` not defined
julia> td = transformed(d, b);
julia> y = rand(rng, td)
2-element Vector{Float64}:
+ 1.2579058965585515
+ 0.7486778156915774
julia> 0 < y[1]
true
julia> 0 ≤ y[2] ≤ 1
true
Want to fit the flow?
julia> using Zygote
+
+ # Construct the flow.
julia> b = PlanarLayer(2)
+
+ # Convenient for extracting parameters and reconstructing the flow.
PlanarLayer(w = [0.7482590259107407, -0.8636690301719852], u = [0.647695263666225, 0.5409668660245471], b = [-0.3641185212846577])
julia> using Functors
julia> θs, reconstruct = Functors.functor(b);
+
+ # Make the objective a `struct` to avoid capturing global variables.
julia> struct NLLObjective{R,D,T}
+ reconstruct::R
+ basedist::D
+ data::T
+ end
julia> function (obj::NLLObjective)(θs...)
+ transformed_dist = transformed(obj.basedist, obj.reconstruct(θs))
+ return -sum(Base.Fix1(logpdf, transformed_dist), eachcol(obj.data))
+ end
+
+ # Some random data to estimate the density of.
julia> xs = randn(2, 1000);
+
+ # Construct the objective.
julia> f = NLLObjective(reconstruct, MvNormal(2, 1), xs);
+
+ # Initial loss.
julia> @info "Initial loss: $(f(θs...))"
+
+ # Train using gradient descent.
[ Info: Initial loss: 3039.727539234296
julia> ε = 1e-3;
julia> for i in 1:100
+ ∇s = Zygote.gradient(f, θs...)
+ θs = map(θs, ∇s) do θ, ∇
+ θ - ε .* ∇
+ end
+ end
+
+ # Final loss
julia> @info "Finall loss: $(f(θs...))"
+
+ # Very simple check to see if we learned something useful.
[ Info: Finall loss: 2803.4819425544724
julia> samples = rand(transformed(f.basedist, f.reconstruct(θs)), 1000);
julia> mean(eachcol(samples)) # ≈ [0, 0]
2-element Vector{Float64}:
+ -0.022087993766749765
+ 0.004507573408651547
julia> cov(samples; dims=2) # ≈ I
2×2 Matrix{Float64}:
+ 1.01336 0.0180501
+ 0.0180501 0.961199
We can easily create more complex flows by simply doing PlanarLayer(10) ∘ PlanarLayer(10) ∘ RadialLayer(10)
and so on.