Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stackoverflow in custom bijector #292

Closed
hanyas opened this issue Oct 21, 2023 · 2 comments
Closed

Stackoverflow in custom bijector #292

hanyas opened this issue Oct 21, 2023 · 2 comments

Comments

@hanyas
Copy link

hanyas commented Oct 21, 2023

I am trying to define my own Tanh bijector

using Random
using Distributions
using LinearAlgebra

import Bijectors


struct Tanh <: Bijectors.Bijector end
(b::Tanh)(x::Real) = tanh(x)
(b::Tanh)(x) = map(b, x)

(ib::Bijectors.Inverse{<: Tanh})(y::Real) = atanh(y)
(ib::Bijectors.Inverse{<: Tanh})(y) = map(ib, y)

Bijectors.logabsdetjac(b::Tanh, x::Real) = 2.0 * (log(2.0) - x - softplus(-2.0 * x))
Bijectors.logabsdetjac(b::Tanh, x) = map(Bijectors.logabsdetjac, x)

dist = Distributions.MvNormal(zeros(1), I)
td = Bijectors.transformed(dist, Tanh())

y = rand(td)
Distributions.logpdf(td, y)

The forward transformation appears to be working, but I am struggling to understand the error I am receiving when computing the logpdf

ERROR: StackOverflowError:
Stacktrace:
     [1] with_logabsdet_jacobian(ib::Bijectors.Inverse{Tanh}, y::Vector{Float64})
       @ Bijectors ~/.julia/packages/Bijectors/QhObI/src/interface.jl:213
     [2] transform(t::Bijectors.Inverse{Tanh}, x::Vector{Float64})
       @ Bijectors ~/.julia/packages/Bijectors/QhObI/src/interface.jl:92
--- the last 2 lines are repeated 39990 more times ---
 [79983] with_logabsdet_jacobian(ib::Bijectors.Inverse{Tanh}, y::Vector{Float64})
       @ Bijectors ~/.julia/packages/Bijectors/QhObI/src/interface.jl:213
@hanyas
Copy link
Author

hanyas commented Oct 22, 2023

I've also tried to define the bijector by following a similar recipe to that of the Scale bijector but without success

struct Tanh <: Bijector end

with_logabsdet_jacobian(b::Tanh, x) = transform(b, x), logabsdetjac(b, x)

transform(b::Tanh, x) = tanh(x)
transform(b::Tanh, x::AbstractVecOrMat) = tanh.(x)
transform(ib::Inverse{<:Tanh}, y) = transform(atanh, y)
transform(ib::Inverse{<:Tanh}, y::AbstractVecOrMat) = transform(@. atanh, y)

logabsdetjac(b::Tanh, x::Real) = _logabsdetjac_tanh(b, x, Val(0))
function logabsdetjac(b::Tanh, x::AbstractArray{<:Real,N}) where {N}
    return _logabsdetjac_tanh(b, x, Val(N))
end

_logabsdetjac_tanh(b::Tanh, x::Real, ::Val{0}) = 2.0 * (log(2.0) - x - softplus(-2.0 * x))
_logabsdetjac_tanh(b::Tanh, x::AbstractVector, ::Val{1}) = 2.0 * (log(2.0) - x - softplus(-2.0 * x)) * length(x)

@torfjelde
Copy link
Member

torfjelde commented Oct 23, 2023

The default implementation of with_logabsdet_jacobian for a Bijector is (transform(b, x), logabsdetjac(b, x)), but since you haven't defined logabsdetjac(::Inverse{Tanh}, y), you also hit the default impl of this, which is -logabsdetjac(inverse(b), inverse(b)(y)).

You then get a stack overlflow error because transform(::Inverse{Tanh}, y) is also not defined (Scale does not have an Inverse{<:Scale} implementation because its inverse is just inverting the scale factor and returning a new Scale).

In fact, here you don't really need to mess around with the Bijector stuff at all, since tanh is already a function so you don't need a "new" representation of it + its inverse atanh is similarly already defined.

I'd implement the above as:

using ChangesOfVariables, InverseFunctions, StatsFuns

InverseFunctions.inverse(::typeof(tanh)) = atanh
InverseFunctions.inverse(::typeof(atanh), x) = tanh

function ChangesOfVariables.with_logabsdet_jacobian(::typeof(tanh), x::Real)
    y = tanh(x)
    return y, _logabsdetjac_tanh(x)
end
function ChangesOfVariables.with_logabsdet_jacobian(::typeof(atanh), y::Real)
    x = atanh(y)
    return x, -_logabsdetjac_tanh(x)
end

# Use the irrational representation `StatsFuns.logtwo` to defer type-promotion.
# Similarly, I've removed all explicit usages of `Float64`, e.g. converted `2.0` to `2`
# to allow type-promotion to do its thing rather than forcing usage of `Float64`.
_logabsdetjac_tanh(x::Real) = 2 * (StatsFuns.logtwo - x - softplus(-2 * x))

If you want a version that is supposed to act elementwise, then you can use Bijectors.elementwise(f):

julia> using Bijectors

julia> elementwise(tanh)(rand(10))
10-element Vector{Float64}:
 0.22076308094447367
 0.06828859488600718
 0.3496810171644955
 0.02413051400382789
 0.6228303792319176
 0.5772825278828461
 0.7370222452215927
 0.45865543543291265
 0.6128386429868988
 0.7094298145373448

julia> with_logabsdet_jacobian(elementwise(tanh), rand(10))
([0.5475308984676883, 0.7498770212815672, 0.11406375475912378, 0.04598020777639154, 0.41278517115619784, 0.3067650082385844, 0.6441810700388316, 0.7430095366528289, 0.7023124306195118, 0.2806093226497268], -3.5844094465162772)

@hanyas hanyas closed this as completed Oct 23, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants