Skip to content

Commit

Permalink
Tapir -> Mooncake (#338)
Browse files Browse the repository at this point in the history
* Tapir -> Mooncake

* Bump minor version

* Mark Mooncake test as broken

* Remove BijectorsEnzymeExt on 1.11.1+

* Increase tolerance on `ordered` test
  • Loading branch information
penelopeysm authored Oct 29, 2024
1 parent aed472a commit 2ee8a5d
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 47 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/AD.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ jobs:
AD:
- Enzyme
- ForwardDiff
- Tapir
- Mooncake
- Tracker
- ReverseDiff
- Zygote
exclude:
- version: 1.6
AD: Tapir
AD: Mooncake
# TODO(mhauru) Hopefully can enable Enzyme on older versions at some point, see
# discussion in https://github.com/TuringLang/Bijectors.jl/pull.
- version: 1.6
Expand Down
10 changes: 5 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.13.19"
version = "0.14.0"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down Expand Up @@ -30,7 +30,7 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand All @@ -40,7 +40,7 @@ BijectorsEnzymeExt = ["Enzyme", "EnzymeCore"]
BijectorsForwardDiffExt = "ForwardDiff"
BijectorsLazyArraysExt = "LazyArrays"
BijectorsReverseDiffExt = "ReverseDiff"
BijectorsTapirExt = "Tapir"
BijectorsMooncakeExt = "Mooncake"
BijectorsTrackerExt = "Tracker"
BijectorsZygoteExt = "Zygote"

Expand All @@ -67,7 +67,7 @@ Requires = "0.5, 1"
ReverseDiff = "1"
Roots = "1.3.4, 2"
Statistics = "1"
Tapir = "0.2.23"
Mooncake = "0.4.19"
Tracker = "0.2"
Zygote = "0.6.63"
julia = "1.6"
Expand All @@ -79,6 +79,6 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
15 changes: 8 additions & 7 deletions ext/BijectorsTapirExt.jl → ext/BijectorsMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
module BijectorsTapirExt
module BijectorsMooncakeExt

if isdefined(Base, :get_extension)
using Tapir: @is_primitive, MinimalCtx, Tapir, CoDual, primal, tangent_type, @from_rrule
using Mooncake:
@is_primitive, MinimalCtx, Mooncake, CoDual, primal, tangent_type, @from_rrule
using Bijectors: find_alpha, ChainRulesCore
else
using ..Tapir: @is_primitive, MinimalCtx, Tapir, primal, tangent_type, @from_rrule
using ..Mooncake: @is_primitive, MinimalCtx, Mooncake, primal, tangent_type, @from_rrule
using ..Bijectors: find_alpha, ChainRulesCore
end

Expand All @@ -19,20 +20,20 @@ end
# unusual Integer type is encountered.
@is_primitive(MinimalCtx, Tuple{typeof(find_alpha),P,P,Integer} where {P<:Base.IEEEFloat})

function Tapir.rrule!!(
function Mooncake.rrule!!(
::CoDual{typeof(find_alpha)}, x::CoDual{P}, y::CoDual{P}, z::CoDual{I}
) where {P<:Base.IEEEFloat,I<:Integer}
# Require that the integer is non-differentiable.
if tangent_type(I) != Tapir.NoTangent
if tangent_type(I) != Mooncake.NoTangent
msg = "Integer argument has tangent type $(tangent_type(I)), should be NoTangent."
throw(ArgumentError(msg))
end
out, pb = ChainRulesCore.rrule(find_alpha, primal(x), primal(y), primal(z))
function find_alpha_pb(dout::P)
_, dx, dy, _ = pb(dout)
return Tapir.NoRData(), P(dx), P(dy), Tapir.NoRData()
return Mooncake.NoRData(), P(dx), P(dy), Mooncake.NoRData()
end
return Tapir.zero_fcodual(out), find_alpha_pb
return Mooncake.zero_fcodual(out), find_alpha_pb
end

end
14 changes: 7 additions & 7 deletions test/ad/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,37 +27,37 @@ end
test_frule(Bijectors.find_alpha, x, y, z)
test_rrule(Bijectors.find_alpha, x, y, z)

if @isdefined Tapir
if @isdefined Mooncake
rng = Xoshiro(123456)
Tapir.TestUtils.test_rule(
Mooncake.TestUtils.test_rule(
rng,
Bijectors.find_alpha,
x,
y,
z;
is_primitive=true,
perf_flag=:none,
interp=Tapir.TapirInterpreter(),
interp=Mooncake.MooncakeInterpreter(),
)
Tapir.TestUtils.test_rule(
Mooncake.TestUtils.test_rule(
rng,
Bijectors.find_alpha,
x,
y,
3;
is_primitive=true,
perf_flag=:none,
interp=Tapir.TapirInterpreter(),
interp=Mooncake.MooncakeInterpreter(),
)
Tapir.TestUtils.test_rule(
Mooncake.TestUtils.test_rule(
rng,
Bijectors.find_alpha,
x,
y,
UInt32(3);
is_primitive=true,
perf_flag=:none,
interp=Tapir.TapirInterpreter(),
interp=Mooncake.MooncakeInterpreter(),
)
end

Expand Down
54 changes: 33 additions & 21 deletions test/ad/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6)
b in (
:ForwardDiff,
:Zygote,
:Tapir,
:Mooncake,
:ReverseDiff,
:Enzyme,
:EnzymeForward,
Expand Down Expand Up @@ -78,27 +78,39 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6)
end
end

if (AD == "All" || AD == "Tapir") && VERSION >= v"1.10"
rule = Tapir.build_rrule(f, x; safety_on=false)
if :tapir in broken
@test_broken(
isapprox(
Tapir.value_and_gradient!!(rule, f, x)[2][2],
finitediff;
rtol=rtol,
atol=atol,
)
)
else
@test(
isapprox(
Tapir.value_and_gradient!!(rule, f, x)[2][2],
finitediff;
rtol=rtol,
atol=atol,
)
)
if (AD == "All" || AD == "Mooncake") && VERSION >= v"1.10"
try
Mooncake.build_rrule(f, x)
catch exc
# TODO(penelopeysm):
# @test_throws AssertionError (expr...) doesn't work, unclear why
@test exc isa AssertionError
end
# TODO: The above @test_throws happens because of
# https://github.com/compintell/Mooncake.jl/issues/319. If that test
# fails, it probably means that the issue was fixed, in which case
# we can remove that block and uncomment the following instead.

# rule = Mooncake.build_rrule(f, x)
# if :Mooncake in broken
# @test_broken (
# isapprox(
# Mooncake.value_and_gradient!!(rule, f, x)[2][2],
# finitediff;
# rtol=rtol,
# atol=atol,
# )
# )
# else
# @test(
# isapprox(
# Mooncake.value_and_gradient!!(rule, f, x)[2][2],
# finitediff;
# rtol=rtol,
# atol=atol,
# )
# )
# end
end

return nothing
Expand Down
4 changes: 2 additions & 2 deletions test/bijectors/ordered.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,12 @@ end
end
end
# Check that the quantiles are reasonable, i.e. within
# 5 standard errors of the true quantiles (and that the MCSE is
# 6 standard errors of the true quantiles (and that the MCSE is
# not too large).
for i in 1:k
for j in 1:length(qts)
@test qs_mcse[i, j] < abs(qs_true[i, end] - qs_true[i, 1]) / 2
@test abs(qs[i, j] - qs_true[i, j]) < 5 * qs_mcse[i, j]
@test abs(qs[i, j] - qs_true[i, j]) < 6 * qs_mcse[i, j]
end
end
end
Expand Down
6 changes: 3 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ if VERSION < v"1.9"
using Compat: stack
end

# Sadly, Tapir.jl cannot be installed on version 1.6, so we have to add it if we're testing
# Sadly, Mooncake.jl cannot be installed on version 1.6, so we have to add it if we're testing
# on at least version 1.10.
if VERSION >= v"1.10"
using Pkg
Pkg.add("Tapir")
using Tapir
Pkg.add("Mooncake")
using Mooncake
end

const GROUP = get(ENV, "GROUP", "All")
Expand Down

0 comments on commit 2ee8a5d

Please sign in to comment.