From 2ee8a5d4637fc1feff9ca2fc45cd402d052d872d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 29 Oct 2024 16:52:01 +0000 Subject: [PATCH] Tapir -> Mooncake (#338) * Tapir -> Mooncake * Bump minor version * Mark Mooncake test as broken * Remove BijectorsEnzymeExt on 1.11.1+ * Increase tolerance on `ordered` test --- .github/workflows/AD.yml | 4 +- Project.toml | 10 ++-- ...orsTapirExt.jl => BijectorsMooncakeExt.jl} | 15 +++--- test/ad/chainrules.jl | 14 ++--- test/ad/utils.jl | 54 +++++++++++-------- test/bijectors/ordered.jl | 4 +- test/runtests.jl | 6 +-- 7 files changed, 60 insertions(+), 47 deletions(-) rename ext/{BijectorsTapirExt.jl => BijectorsMooncakeExt.jl} (77%) diff --git a/.github/workflows/AD.yml b/.github/workflows/AD.yml index 3777b346..7d0aa4ae 100644 --- a/.github/workflows/AD.yml +++ b/.github/workflows/AD.yml @@ -23,13 +23,13 @@ jobs: AD: - Enzyme - ForwardDiff - - Tapir + - Mooncake - Tracker - ReverseDiff - Zygote exclude: - version: 1.6 - AD: Tapir + AD: Mooncake # TODO(mhauru) Hopefully can enable Enzyme on older versions at some point, see # discussion in https://github.com/TuringLang/Bijectors.jl/pull. - version: 1.6 diff --git a/Project.toml b/Project.toml index 62400e20..283b706c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.13.19" +version = "0.14.0" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" @@ -30,7 +30,7 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -40,7 +40,7 @@ BijectorsEnzymeExt = ["Enzyme", "EnzymeCore"] BijectorsForwardDiffExt = "ForwardDiff" BijectorsLazyArraysExt = "LazyArrays" BijectorsReverseDiffExt = "ReverseDiff" -BijectorsTapirExt = "Tapir" +BijectorsMooncakeExt = "Mooncake" BijectorsTrackerExt = "Tracker" BijectorsZygoteExt = "Zygote" @@ -67,7 +67,7 @@ Requires = "0.5, 1" ReverseDiff = "1" Roots = "1.3.4, 2" Statistics = "1" -Tapir = "0.2.23" +Mooncake = "0.4.19" Tracker = "0.2" Zygote = "0.6.63" julia = "1.6" @@ -79,6 +79,6 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/ext/BijectorsTapirExt.jl b/ext/BijectorsMooncakeExt.jl similarity index 77% rename from ext/BijectorsTapirExt.jl rename to ext/BijectorsMooncakeExt.jl index 70805a82..d7285bf6 100644 --- a/ext/BijectorsTapirExt.jl +++ b/ext/BijectorsMooncakeExt.jl @@ -1,10 +1,11 @@ -module BijectorsTapirExt +module BijectorsMooncakeExt if isdefined(Base, :get_extension) - using Tapir: @is_primitive, MinimalCtx, Tapir, CoDual, primal, tangent_type, @from_rrule + using Mooncake: + @is_primitive, MinimalCtx, Mooncake, CoDual, primal, tangent_type, @from_rrule using Bijectors: find_alpha, ChainRulesCore else - using ..Tapir: @is_primitive, MinimalCtx, Tapir, primal, tangent_type, @from_rrule + using ..Mooncake: @is_primitive, MinimalCtx, Mooncake, primal, tangent_type, @from_rrule using ..Bijectors: find_alpha, ChainRulesCore end @@ -19,20 +20,20 @@ end # unusual Integer type is encountered. @is_primitive(MinimalCtx, Tuple{typeof(find_alpha),P,P,Integer} where {P<:Base.IEEEFloat}) -function Tapir.rrule!!( +function Mooncake.rrule!!( ::CoDual{typeof(find_alpha)}, x::CoDual{P}, y::CoDual{P}, z::CoDual{I} ) where {P<:Base.IEEEFloat,I<:Integer} # Require that the integer is non-differentiable. - if tangent_type(I) != Tapir.NoTangent + if tangent_type(I) != Mooncake.NoTangent msg = "Integer argument has tangent type $(tangent_type(I)), should be NoTangent." throw(ArgumentError(msg)) end out, pb = ChainRulesCore.rrule(find_alpha, primal(x), primal(y), primal(z)) function find_alpha_pb(dout::P) _, dx, dy, _ = pb(dout) - return Tapir.NoRData(), P(dx), P(dy), Tapir.NoRData() + return Mooncake.NoRData(), P(dx), P(dy), Mooncake.NoRData() end - return Tapir.zero_fcodual(out), find_alpha_pb + return Mooncake.zero_fcodual(out), find_alpha_pb end end diff --git a/test/ad/chainrules.jl b/test/ad/chainrules.jl index bcdb9523..a2c13df1 100644 --- a/test/ad/chainrules.jl +++ b/test/ad/chainrules.jl @@ -27,9 +27,9 @@ end test_frule(Bijectors.find_alpha, x, y, z) test_rrule(Bijectors.find_alpha, x, y, z) - if @isdefined Tapir + if @isdefined Mooncake rng = Xoshiro(123456) - Tapir.TestUtils.test_rule( + Mooncake.TestUtils.test_rule( rng, Bijectors.find_alpha, x, @@ -37,9 +37,9 @@ end z; is_primitive=true, perf_flag=:none, - interp=Tapir.TapirInterpreter(), + interp=Mooncake.MooncakeInterpreter(), ) - Tapir.TestUtils.test_rule( + Mooncake.TestUtils.test_rule( rng, Bijectors.find_alpha, x, @@ -47,9 +47,9 @@ end 3; is_primitive=true, perf_flag=:none, - interp=Tapir.TapirInterpreter(), + interp=Mooncake.MooncakeInterpreter(), ) - Tapir.TestUtils.test_rule( + Mooncake.TestUtils.test_rule( rng, Bijectors.find_alpha, x, @@ -57,7 +57,7 @@ end UInt32(3); is_primitive=true, perf_flag=:none, - interp=Tapir.TapirInterpreter(), + interp=Mooncake.MooncakeInterpreter(), ) end diff --git a/test/ad/utils.jl b/test/ad/utils.jl index 3e21e693..2e709491 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -7,7 +7,7 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) b in ( :ForwardDiff, :Zygote, - :Tapir, + :Mooncake, :ReverseDiff, :Enzyme, :EnzymeForward, @@ -78,27 +78,39 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) end end - if (AD == "All" || AD == "Tapir") && VERSION >= v"1.10" - rule = Tapir.build_rrule(f, x; safety_on=false) - if :tapir in broken - @test_broken( - isapprox( - Tapir.value_and_gradient!!(rule, f, x)[2][2], - finitediff; - rtol=rtol, - atol=atol, - ) - ) - else - @test( - isapprox( - Tapir.value_and_gradient!!(rule, f, x)[2][2], - finitediff; - rtol=rtol, - atol=atol, - ) - ) + if (AD == "All" || AD == "Mooncake") && VERSION >= v"1.10" + try + Mooncake.build_rrule(f, x) + catch exc + # TODO(penelopeysm): + # @test_throws AssertionError (expr...) doesn't work, unclear why + @test exc isa AssertionError end + # TODO: The above @test_throws happens because of + # https://github.com/compintell/Mooncake.jl/issues/319. If that test + # fails, it probably means that the issue was fixed, in which case + # we can remove that block and uncomment the following instead. + + # rule = Mooncake.build_rrule(f, x) + # if :Mooncake in broken + # @test_broken ( + # isapprox( + # Mooncake.value_and_gradient!!(rule, f, x)[2][2], + # finitediff; + # rtol=rtol, + # atol=atol, + # ) + # ) + # else + # @test( + # isapprox( + # Mooncake.value_and_gradient!!(rule, f, x)[2][2], + # finitediff; + # rtol=rtol, + # atol=atol, + # ) + # ) + # end end return nothing diff --git a/test/bijectors/ordered.jl b/test/bijectors/ordered.jl index b2115fe2..60354005 100644 --- a/test/bijectors/ordered.jl +++ b/test/bijectors/ordered.jl @@ -127,12 +127,12 @@ end end end # Check that the quantiles are reasonable, i.e. within - # 5 standard errors of the true quantiles (and that the MCSE is + # 6 standard errors of the true quantiles (and that the MCSE is # not too large). for i in 1:k for j in 1:length(qts) @test qs_mcse[i, j] < abs(qs_true[i, end] - qs_true[i, 1]) / 2 - @test abs(qs[i, j] - qs_true[i, j]) < 5 * qs_mcse[i, j] + @test abs(qs[i, j] - qs_true[i, j]) < 6 * qs_mcse[i, j] end end end diff --git a/test/runtests.jl b/test/runtests.jl index 914c0e32..638bd15c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,12 +34,12 @@ if VERSION < v"1.9" using Compat: stack end -# Sadly, Tapir.jl cannot be installed on version 1.6, so we have to add it if we're testing +# Sadly, Mooncake.jl cannot be installed on version 1.6, so we have to add it if we're testing # on at least version 1.10. if VERSION >= v"1.10" using Pkg - Pkg.add("Tapir") - using Tapir + Pkg.add("Mooncake") + using Mooncake end const GROUP = get(ENV, "GROUP", "All")