From c8b17eac17a3b4c71cdfd14f88d6d47aaa5cb36e Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 6 Dec 2024 16:38:47 +0000 Subject: [PATCH 01/20] Remove selector stuff from ESS --- src/mcmc/Inference.jl | 2 +- src/mcmc/ess.jl | 42 ++++++++++++++---------------------------- test/mcmc/ess.jl | 13 ++++--------- 3 files changed, 19 insertions(+), 38 deletions(-) diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 5905b1686..37f03cb02 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -592,7 +592,7 @@ include("emcee.jl") # Typing tools # ################ -for alg in (:SMC, :PG, :MH, :IS, :ESS, :Emcee) +for alg in (:SMC, :PG, :MH, :IS, :Emcee) @eval DynamicPPL.getspace(::$alg{space}) where {space} = space end for alg in (:HMC, :HMCDA, :NUTS, :SGLD, :SGHMC) diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index aa1a9fe38..1b7128605 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -20,12 +20,11 @@ Mean │ 1 │ m │ 0.824853 │ ``` """ -struct ESS{space} <: InferenceAlgorithm end +struct ESS <: InferenceAlgorithm end -ESS() = ESS{()}() -ESS(space::Symbol) = ESS{(space,)}() - -drop_space(alg::ESS) = ESS() +# TODO(mhauru) Remove the below once DPPL no longer relies on it. +drop_space(alg::ESS) = alg +DynamicPPL.getspace(::Sampler{<:ESS}) = () # always accept in the first step function DynamicPPL.initialstep( @@ -35,7 +34,7 @@ function DynamicPPL.initialstep( vns = _getvns(vi, spl) length(vns) == 1 || error("[ESS] does only support one variable ($(length(vns)) variables specified)") - for vn in vns[1] + for vn in only(vns) dist = getdist(vi, vn) EllipticalSliceSampling.isgaussian(typeof(dist)) || error("[ESS] only supports Gaussian prior distributions") @@ -48,7 +47,7 @@ function AbstractMCMC.step( rng::AbstractRNG, model::Model, spl::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs... ) # obtain previous sample - f = vi[spl] + f = vi[:] # define previous sampler state # (do not use cache to avoid in-place sampling from prior) @@ -129,13 +128,11 @@ function (ℓ::ESSLogLikelihood)(f::AbstractVector) end function DynamicPPL.tilde_assume( - rng::Random.AbstractRNG, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, vn, vi + rng::Random.AbstractRNG, ::DefaultContext, ::Sampler{<:ESS}, right, vn, vi ) - return if inspace(vn, sampler) - DynamicPPL.tilde_assume(rng, LikelihoodContext(), SampleFromPrior(), right, vn, vi) - else - DynamicPPL.tilde_assume(rng, ctx, SampleFromPrior(), right, vn, vi) - end + return DynamicPPL.tilde_assume( + rng, LikelihoodContext(), SampleFromPrior(), right, vn, vi + ) end function DynamicPPL.tilde_observe( @@ -145,22 +142,11 @@ function DynamicPPL.tilde_observe( end function DynamicPPL.dot_tilde_assume( - rng::Random.AbstractRNG, - ctx::DefaultContext, - sampler::Sampler{<:ESS}, - right, - left, - vns, - vi, + rng::Random.AbstractRNG, ::DefaultContext, ::Sampler{<:ESS}, right, left, vns, vi ) - # TODO: Or should we do `all(Base.Fix2(inspace, sampler), vns)`? - return if inspace(first(vns), sampler) - DynamicPPL.dot_tilde_assume( - rng, LikelihoodContext(), SampleFromPrior(), right, left, vns, vi - ) - else - DynamicPPL.dot_tilde_assume(rng, ctx, SampleFromPrior(), right, left, vns, vi) - end + return DynamicPPL.dot_tilde_assume( + rng, LikelihoodContext(), SampleFromPrior(), right, left, vns, vi + ) end function DynamicPPL.dot_tilde_observe( diff --git a/test/mcmc/ess.jl b/test/mcmc/ess.jl index 6db469b76..5eaa0a592 100644 --- a/test/mcmc/ess.jl +++ b/test/mcmc/ess.jl @@ -30,18 +30,13 @@ using Turing N = 10 s1 = ESS() - s2 = ESS(:m) - for s in (s1, s2) - @test DynamicPPL.alg_str(Sampler(s, demo_default)) == "ESS" - end + @test DynamicPPL.alg_str(Sampler(s1, demo_default)) == "ESS" c1 = sample(demo_default, s1, N) - c2 = sample(demo_default, s2, N) - c3 = sample(demodot_default, s1, N) - c4 = sample(demodot_default, s2, N) + c2 = sample(demodot_default, s1, N) - s3 = Gibbs(; m=ESS(), s=MH()) - c5 = sample(gdemo_default, s3, N) + s2 = Gibbs(; m=ESS(), s=MH()) + c3 = sample(gdemo_default, s2, N) end @testset "ESS inference" begin From c62d35d38083d4d33fa52cf56200c9d252c93e96 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 6 Dec 2024 16:54:40 +0000 Subject: [PATCH 02/20] Remove selector stuff from MH --- src/mcmc/Inference.jl | 2 +- src/mcmc/mh.jl | 77 ++++++++++++++++++++----------------------- test/mcmc/mh.jl | 18 +++++----- 3 files changed, 47 insertions(+), 50 deletions(-) diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 37f03cb02..f5e01bf0c 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -592,7 +592,7 @@ include("emcee.jl") # Typing tools # ################ -for alg in (:SMC, :PG, :MH, :IS, :Emcee) +for alg in (:SMC, :PG, :IS, :Emcee) @eval DynamicPPL.getspace(::$alg{space}) where {space} = space end for alg in (:HMC, :HMCDA, :NUTS, :SGLD, :SGHMC) diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index edd46a457..65ef24184 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -2,10 +2,6 @@ ### Sampler states ### -struct MH{space,P} <: InferenceAlgorithm - proposals::P -end - proposal(p::AdvancedMH.Proposal) = p proposal(f::Function) = AdvancedMH.StaticProposal(f) proposal(d::Distribution) = AdvancedMH.StaticProposal(d) @@ -13,15 +9,15 @@ proposal(cov::AbstractMatrix) = AdvancedMH.RandomWalkProposal(MvNormal(cov)) proposal(x) = error("proposals of type ", typeof(x), " are not supported") """ - MH(space...) + MH(proposals...) Construct a Metropolis-Hastings algorithm. -The arguments `space` can be +The arguments `proposals` can be - Blank (i.e. `MH()`), in which case `MH` defaults to using the prior for each parameter as the proposal distribution. - An iterable of pairs or tuples mapping a `Symbol` to a `AdvancedMH.Proposal`, `Distribution`, or `Function` - that generates returns a conditional proposal distribution. + that returns a conditional proposal distribution. - A covariance matrix to use as for mean-zero multivariate normal proposals. # Examples @@ -108,44 +104,43 @@ mean(chain) ``` """ -function MH(space...) - syms = Symbol[] - - prop_syms = Symbol[] - props = AMH.Proposal[] - - for s in space - if s isa Symbol - # If it's just a symbol, proceed as normal. - push!(syms, s) - elseif s isa Pair || s isa Tuple - # Check to see whether it's a pair that specifies a kernel - # or a specific proposal distribution. - push!(prop_syms, s[1]) - push!(props, proposal(s[2])) - elseif length(space) == 1 - # If we hit this block, check to see if it's - # a run-of-the-mill proposal or covariance - # matrix. - prop = proposal(s) - - # Return early, we got a covariance matrix. - return MH{(),typeof(prop)}(prop) - else - # Try to convert it to a proposal anyways, - # throw an error if not acceptable. - prop = proposal(s) - push!(props, prop) +struct MH{P} <: InferenceAlgorithm + proposals::P + + function MH(proposals...) + prop_syms = Symbol[] + props = AMH.Proposal[] + + for s in proposals + if s isa Pair || s isa Tuple + # Check to see whether it's a pair that specifies a kernel + # or a specific proposal distribution. + push!(prop_syms, s[1]) + push!(props, proposal(s[2])) + elseif length(proposals) == 1 + # If we hit this block, check to see if it's + # a run-of-the-mill proposal or covariance + # matrix. + prop = proposal(s) + + # Return early, we got a covariance matrix. + return new{typeof(prop)}(prop) + else + # Try to convert it to a proposal anyways, + # throw an error if not acceptable. + prop = proposal(s) + push!(props, prop) + end end - end - proposals = NamedTuple{tuple(prop_syms...)}(tuple(props...)) - syms = vcat(syms, prop_syms) + proposals = NamedTuple{tuple(prop_syms...)}(tuple(props...)) - return MH{tuple(syms...),typeof(proposals)}(proposals) + return new{typeof(proposals)}(proposals) + end end -drop_space(alg::MH{space,P}) where {space,P} = MH{(),P}(alg.proposals) +drop_space(alg::MH) = alg +DynamicPPL.getspace(::MH) = () # Some of the proposals require working in unconstrained space. transform_maybe(proposal::AMH.Proposal) = proposal @@ -351,7 +346,7 @@ function propose!!( ) # If this is the case, we can just draw directly from the proposal # matrix. - vals = vi[spl] + vals = vi[:] # Create a sampler and the previous transition. mh_sampler = AMH.MetropolisHastings(spl.alg.proposals) diff --git a/test/mcmc/mh.jl b/test/mcmc/mh.jl index 3823c2986..5cbb1d309 100644 --- a/test/mcmc/mh.jl +++ b/test/mcmc/mh.jl @@ -24,26 +24,28 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) @testset "mh constructor" begin N = 10 s1 = MH((:s, InverseGamma(2, 3)), (:m, GKernel(3.0))) - s2 = MH(:s, :m) + s2 = MH(:s => InverseGamma(2, 3), :m => GKernel(3.0)) s3 = MH() - for s in (s1, s2, s3) + s4 = MH([1.0 0.1; 0.1 1.0]) + for s in (s1, s2, s3, s4) @test DynamicPPL.alg_str(Sampler(s, gdemo_default)) == "MH" end c1 = sample(gdemo_default, s1, N) c2 = sample(gdemo_default, s2, N) c3 = sample(gdemo_default, s3, N) - - s4 = Gibbs(; m=MH(), s=MH()) c4 = sample(gdemo_default, s4, N) - # s5 = externalsampler(MH(gdemo_default, proposal_type=AdvancedMH.RandomWalkProposal)) - # c5 = sample(gdemo_default, s5, N) + s5 = Gibbs(; m=MH(), s=MH()) + c5 = sample(gdemo_default, s5, N) + + # s6 = externalsampler(MH(gdemo_default, proposal_type=AdvancedMH.RandomWalkProposal)) + # c6 = sample(gdemo_default, s6, N) # NOTE: Broken because MH doesn't really follow the `logdensity` interface, but calls # it with `NamedTuple` instead of `AbstractVector`. - # s6 = externalsampler(MH(gdemo_default, proposal_type=AdvancedMH.StaticProposal)) - # c6 = sample(gdemo_default, s6, N) + # s7 = externalsampler(MH(gdemo_default, proposal_type=AdvancedMH.StaticProposal)) + # c7 = sample(gdemo_default, s7, N) end @testset "mh inference" begin From 2ac1d9f497bafc8970b90e489dd510368f3b60f7 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 8 Jan 2025 15:56:10 +0000 Subject: [PATCH 03/20] Remove selector stuff from HMC --- src/mcmc/Inference.jl | 3 --- src/mcmc/hmc.jl | 54 +++++++++++++------------------------ src/mcmc/mh.jl | 1 + test/dynamicppl/compiler.jl | 4 +-- test/mcmc/Inference.jl | 2 +- test/mcmc/gibbs.jl | 8 +----- test/mcmc/hmc.jl | 8 ------ test/mcmc/sghmc.jl | 14 ++-------- 8 files changed, 25 insertions(+), 69 deletions(-) diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index f5e01bf0c..615330fcf 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -595,9 +595,6 @@ include("emcee.jl") for alg in (:SMC, :PG, :IS, :Emcee) @eval DynamicPPL.getspace(::$alg{space}) where {space} = space end -for alg in (:HMC, :HMCDA, :NUTS, :SGLD, :SGHMC) - @eval DynamicPPL.getspace(::$alg{<:Any,space}) where {space} = space -end function DynamicPPL.get_matching_type( spl::Sampler{<:Union{PG,SMC}}, vi, ::Type{TV} diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index 80de196c6..b27999344 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -53,7 +53,7 @@ sample(gdemo([1.5, 2]), HMC(0.1, 10), 1000) sample(gdemo([1.5, 2]), HMC(0.01, 10), 1000) ``` """ -struct HMC{AD,space,metricT<:AHMC.AbstractMetric} <: StaticHamiltonian +struct HMC{AD,metricT<:AHMC.AbstractMetric} <: StaticHamiltonian ϵ::Float64 # leapfrog step size n_leapfrog::Int # leapfrog step number adtype::AD @@ -62,25 +62,23 @@ end function HMC( ϵ::Float64, n_leapfrog::Int, - ::Type{metricT}, - space::Tuple; + ::Type{metricT}; adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE, ) where {metricT<:AHMC.AbstractMetric} - return HMC{typeof(adtype),space,metricT}(ϵ, n_leapfrog, adtype) + return HMC{typeof(adtype),metricT}(ϵ, n_leapfrog, adtype) end function HMC( ϵ::Float64, - n_leapfrog::Int, - space::Symbol...; + n_leapfrog::Int; metricT=AHMC.UnitEuclideanMetric, adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE, ) - return HMC(ϵ, n_leapfrog, metricT, space; adtype=adtype) + return HMC(ϵ, n_leapfrog, metricT; adtype=adtype) end -function drop_space(alg::HMC{AD,space,metricT}) where {AD,space,metricT} - return HMC{AD,(),metricT}(alg.ϵ, alg.n_leapfrog, alg.adtype) -end +# TODO(mhauru) Remove the below once DPPL no longer relies on it. +drop_space(alg::Hamiltonian) = alg +DynamicPPL.getspace(::Hamiltonian) = () DynamicPPL.initialsampler(::Sampler{<:Hamiltonian}) = SampleFromUniform() @@ -336,7 +334,7 @@ Hoffman, Matthew D., and Andrew Gelman. "The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo." Journal of Machine Learning Research 15, no. 1 (2014): 1593-1623. """ -struct HMCDA{AD,space,metricT<:AHMC.AbstractMetric} <: AdaptiveHamiltonian +struct HMCDA{AD,metricT<:AHMC.AbstractMetric} <: AdaptiveHamiltonian n_adapts::Int # number of samples with adaption for ϵ δ::Float64 # target accept rate λ::Float64 # target leapfrog length @@ -349,11 +347,10 @@ function HMCDA( δ::Float64, λ::Float64, ϵ::Float64, - ::Type{metricT}, - space::Tuple; + ::Type{metricT}; adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE, ) where {metricT<:AHMC.AbstractMetric} - return HMCDA{typeof(adtype),space,metricT}(n_adapts, δ, λ, ϵ, adtype) + return HMCDA{typeof(adtype),metricT}(n_adapts, δ, λ, ϵ, adtype) end function HMCDA( @@ -373,17 +370,12 @@ end function HMCDA( n_adapts::Int, δ::Float64, - λ::Float64, - space::Symbol...; + λ::Float64; init_ϵ::Float64=0.0, metricT=AHMC.UnitEuclideanMetric, adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE, ) - return HMCDA(n_adapts, δ, λ, init_ϵ, metricT, space; adtype=adtype) -end - -function drop_space(alg::HMCDA{AD,space,metricT}) where {AD,space,metricT} - return HMCDA{AD,(),metricT}(alg.n_adapts, alg.δ, alg.λ, alg.ϵ, alg.adtype) + return HMCDA(n_adapts, δ, λ, init_ϵ, metricT; adtype=adtype) end """ @@ -409,7 +401,7 @@ Arguments: If not specified, `ForwardDiff` is used, with its `chunksize` automatically determined. """ -struct NUTS{AD,space,metricT<:AHMC.AbstractMetric} <: AdaptiveHamiltonian +struct NUTS{AD,metricT<:AHMC.AbstractMetric} <: AdaptiveHamiltonian n_adapts::Int # number of samples with adaption for ϵ δ::Float64 # target accept rate max_depth::Int # maximum tree depth @@ -424,11 +416,10 @@ function NUTS( max_depth::Int, Δ_max::Float64, ϵ::Float64, - ::Type{metricT}, - space::Tuple; + ::Type{metricT}; adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE, ) where {metricT} - return NUTS{typeof(adtype),space,metricT}(n_adapts, δ, max_depth, Δ_max, ϵ, adtype) + return NUTS{typeof(adtype),metricT}(n_adapts, δ, max_depth, Δ_max, ϵ, adtype) end function NUTS(n_adapts::Int, δ::Float64, ::Tuple{}; kwargs...) @@ -437,15 +428,14 @@ end function NUTS( n_adapts::Int, - δ::Float64, - space::Symbol...; + δ::Float64; max_depth::Int=10, Δ_max::Float64=1000.0, init_ϵ::Float64=0.0, metricT=AHMC.DiagEuclideanMetric, adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE, ) - return NUTS(n_adapts, δ, max_depth, Δ_max, init_ϵ, metricT, space; adtype=adtype) + return NUTS(n_adapts, δ, max_depth, Δ_max, init_ϵ, metricT; adtype=adtype) end function NUTS( @@ -463,14 +453,8 @@ function NUTS(; kwargs...) return NUTS(-1, 0.65; kwargs...) end -function drop_space(alg::NUTS{AD,space,metricT}) where {AD,space,metricT} - return NUTS{AD,(),metricT}( - alg.n_adapts, alg.δ, alg.max_depth, alg.Δ_max, alg.ϵ, alg.adtype - ) -end - for alg in (:HMC, :HMCDA, :NUTS) - @eval getmetricT(::$alg{<:Any,<:Any,metricT}) where {metricT} = metricT + @eval getmetricT(::$alg{<:Any,metricT}) where {metricT} = metricT end ##### diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index 65ef24184..35f1711eb 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -139,6 +139,7 @@ struct MH{P} <: InferenceAlgorithm end end +# TODO(mhauru) Remove the below once DPPL no longer relies on it. drop_space(alg::MH) = alg DynamicPPL.getspace(::MH) = () diff --git a/test/dynamicppl/compiler.jl b/test/dynamicppl/compiler.jl index 7939c7beb..e3647a947 100644 --- a/test/dynamicppl/compiler.jl +++ b/test/dynamicppl/compiler.jl @@ -133,9 +133,7 @@ const gdemo_default = gdemo_d() end chain = sample( - newinterface(obs), - HMC(0.75, 3, :p, :x; adtype=AutoForwardDiff(; chunksize=2)), - 100, + newinterface(obs), HMC(0.75, 3; adtype=AutoForwardDiff(; chunksize=2)), 100 ) end @testset "no return" begin diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 9356fbcc1..4b9ccb298 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -368,7 +368,7 @@ using Turing sample( StableRNG(seed), newinterface(obs), - HMC(0.75, 3, :p, :x; adtype=Turing.AutoForwardDiff(; chunksize=2)), + HMC(0.75, 3; adtype=Turing.AutoForwardDiff(; chunksize=2)), 100, ) end diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 43bdcdbb8..41b069708 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -269,25 +269,19 @@ end @info "Starting Gibbs tests with $adbackend" @testset "Deprecated Gibbs constructors" begin N = 10 - @test_deprecated s1 = Gibbs(HMC(0.1, 5, :s, :m; adtype=adbackend)) @test_deprecated s2 = Gibbs(PG(10, :s, :m)) @test_deprecated s3 = Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) @test_deprecated s4 = Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) @test_deprecated s5 = Gibbs(CSMC(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) - @test_deprecated s6 = Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)) - @test_deprecated s7 = Gibbs((HMC(0.1, 5, :s; adtype=adbackend), 2), (ESS(:m), 3)) - for s in (s1, s2, s3, s4, s5, s6, s7) + for s in (s2, s3, s4, s5) @test DynamicPPL.alg_str(Turing.Sampler(s, gdemo_default)) == "Gibbs" end # Check that the samplers work despite using the deprecated constructor. - sample(gdemo_default, s1, N) sample(gdemo_default, s2, N) sample(gdemo_default, s3, N) sample(gdemo_default, s4, N) sample(gdemo_default, s5, N) - sample(gdemo_default, s6, N) - sample(gdemo_default, s7, N) g = Turing.Sampler(s3, gdemo_default) @test sample(gdemo_default, g, N) isa MCMCChains.Chains diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 47ff73b1c..7b0fa90de 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -160,10 +160,6 @@ using Turing sampler = Sampler(alg, gdemo_default) @test DynamicPPL.alg_str(sampler) == "HMCDA" - alg = HMCDA(200, 0.8, 0.75, :s; adtype=adbackend) - sampler = Sampler(alg, gdemo_default) - @test DynamicPPL.alg_str(sampler) == "HMCDA" - @test isa(alg, HMCDA) @test isa(sampler, Sampler{<:Turing.Hamiltonian}) end @@ -182,10 +178,6 @@ using Turing alg = NUTS(0.65; adtype=adbackend) sampler = Sampler(alg, gdemo_default) @test DynamicPPL.alg_str(sampler) == "NUTS" - - alg = NUTS(200, 0.65, :m; adtype=adbackend) - sampler = Sampler(alg, gdemo_default) - @test DynamicPPL.alg_str(sampler) == "NUTS" end @testset "check discard" begin diff --git a/test/mcmc/sghmc.jl b/test/mcmc/sghmc.jl index c1d07d2ce..c878f755d 100644 --- a/test/mcmc/sghmc.jl +++ b/test/mcmc/sghmc.jl @@ -19,12 +19,7 @@ using Turing sampler = Turing.Sampler(alg) @test sampler isa Turing.Sampler{<:SGHMC} - alg = SGHMC(:m; learning_rate=0.01, momentum_decay=0.1, adtype=adbackend) - @test alg isa SGHMC - sampler = Turing.Sampler(alg) - @test sampler isa Turing.Sampler{<:SGHMC} - - alg = SGHMC(:s; learning_rate=0.01, momentum_decay=0.1, adtype=adbackend) + alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=adbackend) @test alg isa SGHMC sampler = Turing.Sampler(alg) @test sampler isa Turing.Sampler{<:SGHMC} @@ -45,12 +40,7 @@ end sampler = Turing.Sampler(alg) @test sampler isa Turing.Sampler{<:SGLD} - alg = SGLD(:m; stepsize=PolynomialStepsize(0.25), adtype=adbackend) - @test alg isa SGLD - sampler = Turing.Sampler(alg) - @test sampler isa Turing.Sampler{<:SGLD} - - alg = SGLD(:s; stepsize=PolynomialStepsize(0.25), adtype=adbackend) + alg = SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adbackend) @test alg isa SGLD sampler = Turing.Sampler(alg) @test sampler isa Turing.Sampler{<:SGLD} From e6e4e56842ffdb2d544f029e923e2881dbea6010 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 8 Jan 2025 16:02:21 +0000 Subject: [PATCH 04/20] Remove selector stuff from Emcee --- src/mcmc/Inference.jl | 2 +- src/mcmc/emcee.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 615330fcf..64190660f 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -592,7 +592,7 @@ include("emcee.jl") # Typing tools # ################ -for alg in (:SMC, :PG, :IS, :Emcee) +for alg in (:SMC, :PG, :IS) @eval DynamicPPL.getspace(::$alg{space}) where {space} = space end diff --git a/src/mcmc/emcee.jl b/src/mcmc/emcee.jl index 816d90578..45e9d2778 100644 --- a/src/mcmc/emcee.jl +++ b/src/mcmc/emcee.jl @@ -13,7 +13,7 @@ Foreman-Mackey, D., Hogg, D. W., Lang, D., & Goodman, J. (2013). emcee: The MCMC Hammer. Publications of the Astronomical Society of the Pacific, 125 (925), 306. https://doi.org/10.1086/670067 """ -struct Emcee{space,E<:AMH.Ensemble} <: InferenceAlgorithm +struct Emcee{E<:AMH.Ensemble} <: InferenceAlgorithm ensemble::E end @@ -26,7 +26,7 @@ function Emcee(n_walkers::Int, stretch_length=2.0) return Emcee{(),typeof(ensemble)}(ensemble) end -drop_space(alg::Emcee{space,E}) where {space,E} = Emcee{(),E}(alg.ensemble) +drop_space(alg::Emcee) = alg struct EmceeState{V<:AbstractVarInfo,S} vi::V From d1840aaa08b32762d152f2b1d344c1ed8c39b69c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 8 Jan 2025 16:03:42 +0000 Subject: [PATCH 05/20] Remove selector stuff from IS --- src/mcmc/Inference.jl | 2 +- src/mcmc/is.jl | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 64190660f..d7da38c99 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -592,7 +592,7 @@ include("emcee.jl") # Typing tools # ################ -for alg in (:SMC, :PG, :IS) +for alg in (:SMC, :PG) @eval DynamicPPL.getspace(::$alg{space}) where {space} = space end diff --git a/src/mcmc/is.jl b/src/mcmc/is.jl index 083bc7bc3..af2be00bb 100644 --- a/src/mcmc/is.jl +++ b/src/mcmc/is.jl @@ -24,11 +24,9 @@ end sample(gdemo([1.5, 2]), IS(), 1000) ``` """ -struct IS{space} <: InferenceAlgorithm end +struct IS <: InferenceAlgorithm end -IS() = IS{()}() - -drop_space(alg::IS) = IS() +drop_space(alg::IS) = alg DynamicPPL.initialsampler(sampler::Sampler{<:IS}) = sampler From 0bebd3e9af11c5febb2a2c12f2744b73c223f464 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 8 Jan 2025 16:08:47 +0000 Subject: [PATCH 06/20] Add missing getspace methods --- src/mcmc/emcee.jl | 1 + src/mcmc/is.jl | 1 + 2 files changed, 2 insertions(+) diff --git a/src/mcmc/emcee.jl b/src/mcmc/emcee.jl index 45e9d2778..2f8c49429 100644 --- a/src/mcmc/emcee.jl +++ b/src/mcmc/emcee.jl @@ -27,6 +27,7 @@ function Emcee(n_walkers::Int, stretch_length=2.0) end drop_space(alg::Emcee) = alg +DynamicPPL.getspace(::Emcee) = () struct EmceeState{V<:AbstractVarInfo,S} vi::V diff --git a/src/mcmc/is.jl b/src/mcmc/is.jl index af2be00bb..7b8b74f32 100644 --- a/src/mcmc/is.jl +++ b/src/mcmc/is.jl @@ -27,6 +27,7 @@ sample(gdemo([1.5, 2]), IS(), 1000) struct IS <: InferenceAlgorithm end drop_space(alg::IS) = alg +DynamicPPL.getspace(::IS) = () DynamicPPL.initialsampler(sampler::Sampler{<:IS}) = sampler From eb90d4cae7b7f684a45b9c7bc02fc1a85308b066 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 8 Jan 2025 16:16:19 +0000 Subject: [PATCH 07/20] Remove selector stuff for particle methods --- src/mcmc/Inference.jl | 4 -- src/mcmc/particle_mcmc.jl | 99 +++++++++++++++----------------------- test/mcmc/gibbs.jl | 19 -------- test/mcmc/particle_mcmc.jl | 63 ------------------------ 4 files changed, 40 insertions(+), 145 deletions(-) diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index d7da38c99..2fde2930b 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -592,10 +592,6 @@ include("emcee.jl") # Typing tools # ################ -for alg in (:SMC, :PG) - @eval DynamicPPL.getspace(::$alg{space}) where {space} = space -end - function DynamicPPL.get_matching_type( spl::Sampler{<:Union{PG,SMC}}, vi, ::Type{TV} ) where {T,N,TV<:Array{T,N}} diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index c5abb56f1..4642985f3 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -15,37 +15,34 @@ Sequential Monte Carlo sampler. $(TYPEDFIELDS) """ -struct SMC{space,R} <: ParticleInference +struct SMC{R} <: ParticleInference resampler::R end """ - SMC(space...) - SMC([resampler = AdvancedPS.ResampleWithESSThreshold(), space = ()]) - SMC([resampler = AdvancedPS.resample_systematic, ]threshold[, space = ()]) + SMC() + SMC([resampler = AdvancedPS.ResampleWithESSThreshold()]) + SMC([resampler = AdvancedPS.resample_systematic, ]threshold) -Create a sequential Monte Carlo sampler of type [`SMC`](@ref) for the variables in `space`. +Create a sequential Monte Carlo sampler of type [`SMC`](@ref). If the algorithm for the resampling step is not specified explicitly, systematic resampling is performed if the estimated effective sample size per particle drops below 0.5. """ -function SMC(resampler=AdvancedPS.ResampleWithESSThreshold(), space::Tuple=()) - return SMC{space,typeof(resampler)}(resampler) +function SMC(resampler=AdvancedPS.ResampleWithESSThreshold()) + return SMC{typeof(resampler)}(resampler) end # Convenient constructors with ESS threshold -function SMC(resampler, threshold::Real, space::Tuple=()) - return SMC(AdvancedPS.ResampleWithESSThreshold(resampler, threshold), space) +function SMC(resampler, threshold::Real) + return SMC(AdvancedPS.ResampleWithESSThreshold(resampler, threshold)) end -function SMC(threshold::Real, space::Tuple=()) - return SMC(AdvancedPS.resample_systematic, threshold, space) +function SMC(threshold::Real) + return SMC(AdvancedPS.resample_systematic, threshold) end -# If only the space is defined -SMC(space::Symbol...) = SMC(space) -SMC(space::Tuple) = SMC(AdvancedPS.ResampleWithESSThreshold(), space) - -drop_space(alg::SMC{space,R}) where {space,R} = SMC{(),R}(alg.resampler) +drop_space(alg::SMC) = alg +DynamicPPL.getspace(::SMC) = () struct SMCTransition{T,F<:AbstractFloat} <: AbstractTransition "The parameters for any given sample." @@ -184,7 +181,7 @@ Particle Gibbs sampler. $(TYPEDFIELDS) """ -struct PG{space,R} <: ParticleInference +struct PG{R} <: ParticleInference """Number of particles.""" nparticles::Int """Resampling algorithm.""" @@ -192,37 +189,31 @@ struct PG{space,R} <: ParticleInference end """ - PG(n, space...) - PG(n, [resampler = AdvancedPS.ResampleWithESSThreshold(), space = ()]) - PG(n, [resampler = AdvancedPS.resample_systematic, ]threshold[, space = ()]) + PG(n) + PG(n, [resampler = AdvancedPS.ResampleWithESSThreshold()]) + PG(n, [resampler = AdvancedPS.resample_systematic, ]threshold) -Create a Particle Gibbs sampler of type [`PG`](@ref) with `n` particles for the variables -in `space`. +Create a Particle Gibbs sampler of type [`PG`](@ref) with `n` particles. If the algorithm for the resampling step is not specified explicitly, systematic resampling is performed if the estimated effective sample size per particle drops below 0.5. """ function PG( - nparticles::Int, resampler=AdvancedPS.ResampleWithESSThreshold(), space::Tuple=() + nparticles::Int, resampler=AdvancedPS.ResampleWithESSThreshold() ) - return PG{space,typeof(resampler)}(nparticles, resampler) + return PG{typeof(resampler)}(nparticles, resampler) end # Convenient constructors with ESS threshold -function PG(nparticles::Int, resampler, threshold::Real, space::Tuple=()) - return PG(nparticles, AdvancedPS.ResampleWithESSThreshold(resampler, threshold), space) -end -function PG(nparticles::Int, threshold::Real, space::Tuple=()) - return PG(nparticles, AdvancedPS.resample_systematic, threshold, space) +function PG(nparticles::Int, resampler, threshold::Real) + return PG(nparticles, AdvancedPS.ResampleWithESSThreshold(resampler, threshold)) end - -# If only the number of particles and the space is defined -PG(nparticles::Int, space::Symbol...) = PG(nparticles, space) -function PG(nparticles::Int, space::Tuple) - return PG(nparticles, AdvancedPS.ResampleWithESSThreshold(), space) +function PG(nparticles::Int, threshold::Real) + return PG(nparticles, AdvancedPS.resample_systematic, threshold) end -drop_space(alg::PG{space,R}) where {space,R} = PG{(),R}(alg.nparticles, alg.resampler) +drop_space(alg::PG) = alg +DynamicPPL.getspace(::PG) = () """ CSMC(...) @@ -384,31 +375,21 @@ function DynamicPPL.assume( vi = trace_local_varinfo_maybe(_vi) trng = trace_local_rng_maybe(rng) - if inspace(vn, spl) - if ~haskey(vi, vn) - r = rand(trng, dist) - push!!(vi, vn, r, dist, spl) - elseif is_flagged(vi, vn, "del") - unset_flag!(vi, vn, "del") # Reference particle parent - r = rand(trng, dist) - vi[vn] = DynamicPPL.tovec(r) - DynamicPPL.setgid!(vi, spl.selector, vn) - setorder!(vi, vn, get_num_produce(vi)) - else - DynamicPPL.updategid!(vi, vn, spl) # Pick data from reference particle - r = vi[vn] - end - # TODO: Should we make this `zero(promote_type(eltype(dist), eltype(r)))` or something? - lp = 0 - else # vn belongs to other sampler <=> conditioning on vn - if haskey(vi, vn) - r = vi[vn] - else - r = rand(rng, dist) - push!!(vi, vn, r, dist, DynamicPPL.Selector(:invalid)) - end - lp = logpdf_with_trans(dist, r, istrans(vi, vn)) + if ~haskey(vi, vn) + r = rand(trng, dist) + push!!(vi, vn, r, dist, spl) + elseif is_flagged(vi, vn, "del") + unset_flag!(vi, vn, "del") # Reference particle parent + r = rand(trng, dist) + vi[vn] = DynamicPPL.tovec(r) + DynamicPPL.setgid!(vi, spl.selector, vn) + setorder!(vi, vn, get_num_produce(vi)) + else + DynamicPPL.updategid!(vi, vn, spl) # Pick data from reference particle + r = vi[vn] end + # TODO: Should we make this `zero(promote_type(eltype(dist), eltype(r)))` or something? + lp = 0 return r, lp, vi end diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 41b069708..b5758b22d 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -267,25 +267,6 @@ end @testset "Testing gibbs.jl with $adbackend" for adbackend in ADUtils.adbackends @info "Starting Gibbs tests with $adbackend" - @testset "Deprecated Gibbs constructors" begin - N = 10 - @test_deprecated s2 = Gibbs(PG(10, :s, :m)) - @test_deprecated s3 = Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) - @test_deprecated s4 = Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) - @test_deprecated s5 = Gibbs(CSMC(3, :s), HMC(0.4, 8, :m; adtype=adbackend)) - for s in (s2, s3, s4, s5) - @test DynamicPPL.alg_str(Turing.Sampler(s, gdemo_default)) == "Gibbs" - end - - # Check that the samplers work despite using the deprecated constructor. - sample(gdemo_default, s2, N) - sample(gdemo_default, s3, N) - sample(gdemo_default, s4, N) - sample(gdemo_default, s5, N) - - g = Turing.Sampler(s3, gdemo_default) - @test sample(gdemo_default, g, N) isa MCMCChains.Chains - end @testset "Gibbs constructors" begin # Create Gibbs samplers with various configurations and ways of passing the diff --git a/test/mcmc/particle_mcmc.jl b/test/mcmc/particle_mcmc.jl index 3378fea32..221244b9d 100644 --- a/test/mcmc/particle_mcmc.jl +++ b/test/mcmc/particle_mcmc.jl @@ -15,45 +15,17 @@ using Turing @test s.resampler == ResampleWithESSThreshold() @test getspace(s) === () - s = SMC(:x) - @test s.resampler == ResampleWithESSThreshold() - @test getspace(s) === (:x,) - - s = SMC((:x,)) - @test s.resampler == ResampleWithESSThreshold() - @test getspace(s) === (:x,) - - s = SMC(:x, :y) - @test s.resampler == ResampleWithESSThreshold() - @test getspace(s) === (:x, :y) - - s = SMC((:x, :y)) - @test s.resampler == ResampleWithESSThreshold() - @test getspace(s) === (:x, :y) - s = SMC(0.6) @test s.resampler === ResampleWithESSThreshold(resample_systematic, 0.6) @test getspace(s) === () - s = SMC(0.6, (:x,)) - @test s.resampler === ResampleWithESSThreshold(resample_systematic, 0.6) - @test getspace(s) === (:x,) - s = SMC(resample_multinomial, 0.6) @test s.resampler === ResampleWithESSThreshold(resample_multinomial, 0.6) @test getspace(s) === () - s = SMC(resample_multinomial, 0.6, (:x,)) - @test s.resampler === ResampleWithESSThreshold(resample_multinomial, 0.6) - @test getspace(s) === (:x,) - s = SMC(resample_systematic) @test s.resampler === resample_systematic @test getspace(s) === () - - s = SMC(resample_systematic, (:x,)) - @test s.resampler === resample_systematic - @test getspace(s) === (:x,) end @testset "models" begin @@ -108,55 +80,20 @@ end @test s.resampler == ResampleWithESSThreshold() @test getspace(s) === () - s = PG(20, :x) - @test s.nparticles == 20 - @test s.resampler == ResampleWithESSThreshold() - @test getspace(s) === (:x,) - - s = PG(30, (:x,)) - @test s.nparticles == 30 - @test s.resampler == ResampleWithESSThreshold() - @test getspace(s) === (:x,) - - s = PG(40, :x, :y) - @test s.nparticles == 40 - @test s.resampler == ResampleWithESSThreshold() - @test getspace(s) === (:x, :y) - - s = PG(50, (:x, :y)) - @test s.nparticles == 50 - @test s.resampler == ResampleWithESSThreshold() - @test getspace(s) === (:x, :y) - s = PG(60, 0.6) @test s.nparticles == 60 @test s.resampler === ResampleWithESSThreshold(resample_systematic, 0.6) @test getspace(s) === () - s = PG(70, 0.6, (:x,)) - @test s.nparticles == 70 - @test s.resampler === ResampleWithESSThreshold(resample_systematic, 0.6) - @test getspace(s) === (:x,) - s = PG(80, resample_multinomial, 0.6) @test s.nparticles == 80 @test s.resampler === ResampleWithESSThreshold(resample_multinomial, 0.6) @test getspace(s) === () - s = PG(90, resample_multinomial, 0.6, (:x,)) - @test s.nparticles == 90 - @test s.resampler === ResampleWithESSThreshold(resample_multinomial, 0.6) - @test getspace(s) === (:x,) - s = PG(100, resample_systematic) @test s.nparticles == 100 @test s.resampler === resample_systematic @test getspace(s) === () - - s = PG(110, resample_systematic, (:x,)) - @test s.nparticles == 110 - @test s.resampler === resample_systematic - @test getspace(s) === (:x,) end @testset "logevidence" begin From 54df711e66e0dd5cbdf202e65f51488b72f71f16 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 8 Jan 2025 16:21:21 +0000 Subject: [PATCH 08/20] Fix an HMC selector bug --- src/mcmc/hmc.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index b27999344..d5d21b275 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -360,7 +360,7 @@ function HMCDA( metricT=AHMC.UnitEuclideanMetric, adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE, ) - return HMCDA(-1, δ, λ, init_ϵ, metricT, (); adtype=adtype) + return HMCDA(-1, δ, λ, init_ϵ, metricT; adtype=adtype) end function HMCDA(n_adapts::Int, δ::Float64, λ::Float64, ::Tuple{}; kwargs...) @@ -446,7 +446,7 @@ function NUTS( metricT=AHMC.DiagEuclideanMetric, adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE, ) - return NUTS(-1, δ, max_depth, Δ_max, init_ϵ, metricT, (); adtype=adtype) + return NUTS(-1, δ, max_depth, Δ_max, init_ϵ, metricT; adtype=adtype) end function NUTS(; kwargs...) From 8af52ea76e5fd3c8f0db298562705816f4dd345d Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 8 Jan 2025 16:25:35 +0000 Subject: [PATCH 09/20] Code style --- src/mcmc/particle_mcmc.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 4642985f3..5b11ab132 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -198,9 +198,7 @@ Create a Particle Gibbs sampler of type [`PG`](@ref) with `n` particles. If the algorithm for the resampling step is not specified explicitly, systematic resampling is performed if the estimated effective sample size per particle drops below 0.5. """ -function PG( - nparticles::Int, resampler=AdvancedPS.ResampleWithESSThreshold() -) +function PG(nparticles::Int, resampler=AdvancedPS.ResampleWithESSThreshold()) return PG{typeof(resampler)}(nparticles, resampler) end From 890bc29cc237355de12e4a4cb87cdc93c19d02e7 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 8 Jan 2025 17:01:10 +0000 Subject: [PATCH 10/20] Fix Emcee selector bug --- src/mcmc/emcee.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcmc/emcee.jl b/src/mcmc/emcee.jl index 2f8c49429..751143c2e 100644 --- a/src/mcmc/emcee.jl +++ b/src/mcmc/emcee.jl @@ -23,7 +23,7 @@ function Emcee(n_walkers::Int, stretch_length=2.0) # ensemble sampling. prop = AMH.StretchProposal(nothing, stretch_length) ensemble = AMH.Ensemble(n_walkers, prop) - return Emcee{(),typeof(ensemble)}(ensemble) + return Emcee{typeof(ensemble)}(ensemble) end drop_space(alg::Emcee) = alg From 8ec521e8fb966ec0a787fb1716301fbe360290a1 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 8 Jan 2025 17:18:14 +0000 Subject: [PATCH 11/20] Fix typo in ESS tests --- test/mcmc/ess.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mcmc/ess.jl b/test/mcmc/ess.jl index 9047e0ee5..4675f61e2 100644 --- a/test/mcmc/ess.jl +++ b/test/mcmc/ess.jl @@ -36,7 +36,7 @@ using Turing c2 = sample(demodot_default, s1, N) s2 = Gibbs(:m => ESS(), :s => MH()) - c3 = sample(gdemo_default, s3, N) + c3 = sample(gdemo_default, s2, N) end @testset "ESS inference" begin From ed3081c2281a3ce8661116b88d6b6a3849b25003 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 8 Jan 2025 17:50:21 +0000 Subject: [PATCH 12/20] Fix some constructor overwrites --- src/mcmc/particle_mcmc.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 5b11ab132..380986be3 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -17,6 +17,10 @@ $(TYPEDFIELDS) """ struct SMC{R} <: ParticleInference resampler::R + + # Specifying this constructor stops the creation of default constructors. We rather + # define our own outer constructors below. + SMC{R}(resampler::R) where {R} = new{R}(resampler) end """ @@ -186,6 +190,10 @@ struct PG{R} <: ParticleInference nparticles::Int """Resampling algorithm.""" resampler::R + + # Specifying this constructor stops the creation of default constructors. We rather + # define our own outer constructors below. + PG{R}(nparticles::Int, resampler::R) where {R} = new{R}(nparticles, resampler) end """ From 934ebefff75acf2bb35afe2795a7f390f80d4795 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 8 Jan 2025 17:52:24 +0000 Subject: [PATCH 13/20] Remove unnecessary tests --- test/mcmc/particle_mcmc.jl | 9 --------- 1 file changed, 9 deletions(-) diff --git a/test/mcmc/particle_mcmc.jl b/test/mcmc/particle_mcmc.jl index 221244b9d..699ee6854 100644 --- a/test/mcmc/particle_mcmc.jl +++ b/test/mcmc/particle_mcmc.jl @@ -4,7 +4,6 @@ using ..Models: gdemo_default #using ..Models: MoGtest, MoGtest_default using AdvancedPS: ResampleWithESSThreshold, resample_systematic, resample_multinomial using Distributions: Bernoulli, Beta, Gamma, Normal, sample -using DynamicPPL: getspace using Random: Random using Test: @test, @test_throws, @testset using Turing @@ -13,19 +12,15 @@ using Turing @testset "constructor" begin s = SMC() @test s.resampler == ResampleWithESSThreshold() - @test getspace(s) === () s = SMC(0.6) @test s.resampler === ResampleWithESSThreshold(resample_systematic, 0.6) - @test getspace(s) === () s = SMC(resample_multinomial, 0.6) @test s.resampler === ResampleWithESSThreshold(resample_multinomial, 0.6) - @test getspace(s) === () s = SMC(resample_systematic) @test s.resampler === resample_systematic - @test getspace(s) === () end @testset "models" begin @@ -78,22 +73,18 @@ end s = PG(10) @test s.nparticles == 10 @test s.resampler == ResampleWithESSThreshold() - @test getspace(s) === () s = PG(60, 0.6) @test s.nparticles == 60 @test s.resampler === ResampleWithESSThreshold(resample_systematic, 0.6) - @test getspace(s) === () s = PG(80, resample_multinomial, 0.6) @test s.nparticles == 80 @test s.resampler === ResampleWithESSThreshold(resample_multinomial, 0.6) - @test getspace(s) === () s = PG(100, resample_systematic) @test s.nparticles == 100 @test s.resampler === resample_systematic - @test getspace(s) === () end @testset "logevidence" begin From 02f251da71799cab8437ee64f29d24d0356b8f35 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 8 Jan 2025 17:57:39 +0000 Subject: [PATCH 14/20] Remove selector stuff from SGHMC --- src/mcmc/sghmc.jl | 35 +++++++++++++---------------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/src/mcmc/sghmc.jl b/src/mcmc/sghmc.jl index c79337c50..367d7dd27 100644 --- a/src/mcmc/sghmc.jl +++ b/src/mcmc/sghmc.jl @@ -1,7 +1,7 @@ """ - SGHMC{AD,space} + SGHMC{AD} -Stochastic Gradient Hamiltonian Monte Carlo (SGHMC) sampler.e +Stochastic Gradient Hamiltonian Monte Carlo (SGHMC) sampler. # Fields $(TYPEDFIELDS) @@ -12,15 +12,14 @@ Tianqi Chen, Emily Fox, & Carlos Guestrin (2014). Stochastic Gradient Hamiltonia Carlo. In: Proceedings of the 31st International Conference on Machine Learning (pp. 1683–1691). """ -struct SGHMC{AD,space,T<:Real} <: StaticHamiltonian +struct SGHMC{AD,T<:Real} <: StaticHamiltonian learning_rate::T momentum_decay::T adtype::AD end """ - SGHMC( - space::Symbol...; + SGHMC(; learning_rate::Real, momentum_decay::Real, adtype::ADTypes.AbstractADType = AutoForwardDiff(), @@ -37,21 +36,18 @@ Tianqi Chen, Emily Fox, & Carlos Guestrin (2014). Stochastic Gradient Hamiltonia Carlo. In: Proceedings of the 31st International Conference on Machine Learning (pp. 1683–1691). """ -function SGHMC( - space::Symbol...; +function SGHMC(; learning_rate::Real, momentum_decay::Real, adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE, ) _learning_rate, _momentum_decay = promote(learning_rate, momentum_decay) - return SGHMC{typeof(adtype),space,typeof(_learning_rate)}( + return SGHMC{typeof(adtype),typeof(_learning_rate)}( _learning_rate, _momentum_decay, adtype ) end -function drop_space(alg::SGHMC{AD,space,T}) where {AD,space,T} - return SGHMC{AD,(),T}(alg.learning_rate, alg.momentum_decay, alg.adtype) -end +drop_space(alg::SGHMC) = alg struct SGHMCState{L,V<:AbstractVarInfo,T<:AbstractVector{<:Real}} logdensity::L @@ -128,15 +124,13 @@ Max Welling & Yee Whye Teh (2011). Bayesian Learning via Stochastic Gradient Lan Dynamics. In: Proceedings of the 28th International Conference on Machine Learning (pp. 681–688). """ -struct SGLD{AD,space,S} <: StaticHamiltonian +struct SGLD{AD,S} <: StaticHamiltonian "Step size function." stepsize::S adtype::AD end -function drop_space(alg::SGLD{AD,space,S}) where {AD,space,S} - return SGLD{AD,(),S}(alg.stepsize, alg.adtype) -end +drop_space(alg::SGLD) = alg struct PolynomialStepsize{T<:Real} "Constant scale factor of the step size." @@ -172,8 +166,7 @@ end (f::PolynomialStepsize)(t::Int) = f.a / (t + f.b)^f.γ """ - SGLD( - space::Symbol...; + SGLD(; stepsize = PolynomialStepsize(0.01), adtype::ADTypes.AbstractADType = AutoForwardDiff(), ) @@ -193,12 +186,10 @@ Dynamics. In: Proceedings of the 28th International Conference on Machine Learni See also: [`PolynomialStepsize`](@ref) """ -function SGLD( - space::Symbol...; - stepsize=PolynomialStepsize(0.01), - adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE, +function SGLD(; + stepsize=PolynomialStepsize(0.01), adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE ) - return SGLD{typeof(adtype),space,typeof(stepsize)}(stepsize, adtype) + return SGLD{typeof(adtype),typeof(stepsize)}(stepsize, adtype) end struct SGLDTransition{T,F<:Real} <: AbstractTransition From 6fde35095b339c47ced2aac6186484321ebbe9ec Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 9 Jan 2025 15:04:05 +0000 Subject: [PATCH 15/20] Remove drop_space and other non-longer-necessary deprecation measures --- src/mcmc/Inference.jl | 21 ++---------------- src/mcmc/emcee.jl | 3 --- src/mcmc/ess.jl | 4 ---- src/mcmc/gibbs.jl | 45 +------------------------------------- src/mcmc/hmc.jl | 4 ---- src/mcmc/is.jl | 3 --- src/mcmc/mh.jl | 4 ---- src/mcmc/particle_mcmc.jl | 6 ----- src/mcmc/repeat_sampler.jl | 3 +-- src/mcmc/sghmc.jl | 4 ---- test/mcmc/gibbs.jl | 3 +-- 11 files changed, 5 insertions(+), 95 deletions(-) diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 2fde2930b..a60afa530 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -91,18 +91,6 @@ abstract type Hamiltonian <: InferenceAlgorithm end abstract type StaticHamiltonian <: Hamiltonian end abstract type AdaptiveHamiltonian <: Hamiltonian end -# TODO(mhauru) Remove the below function once all the space/Selector stuff has been removed. -""" - drop_space(alg::InferenceAlgorithm) - -Return an `InferenceAlgorithm` like `alg`, but with all space information removed. -""" -function drop_space end - -function drop_space(sampler::Sampler) - return Sampler(drop_space(sampler.alg), sampler.selector) -end - include("repeat_sampler.jl") """ @@ -146,11 +134,6 @@ struct ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrain end end -# External samplers don't have notion of space to begin with. -drop_space(x::ExternalSampler) = x - -DynamicPPL.getspace(::ExternalSampler) = () - """ requires_unconstrained_space(sampler::ExternalSampler) @@ -217,8 +200,6 @@ Algorithm for sampling from the prior. """ struct Prior <: InferenceAlgorithm end -drop_space(x::Prior) = x - function AbstractMCMC.step( rng::Random.AbstractRNG, model::DynamicPPL.Model, @@ -602,6 +583,8 @@ end # Utilities # ############## +# TODO(mhauru) Remove this once DynamicPPL has removed all its Selector stuff. +DynamicPPL.getspace(::InferenceAlgorithm) = () DynamicPPL.getspace(spl::Sampler) = getspace(spl.alg) DynamicPPL.inspace(vn::VarName, spl::Sampler) = inspace(vn, getspace(spl.alg)) diff --git a/src/mcmc/emcee.jl b/src/mcmc/emcee.jl index 751143c2e..9d272f5a0 100644 --- a/src/mcmc/emcee.jl +++ b/src/mcmc/emcee.jl @@ -26,9 +26,6 @@ function Emcee(n_walkers::Int, stretch_length=2.0) return Emcee{typeof(ensemble)}(ensemble) end -drop_space(alg::Emcee) = alg -DynamicPPL.getspace(::Emcee) = () - struct EmceeState{V<:AbstractVarInfo,S} vi::V states::S diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index 1b7128605..98d0adabc 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -22,10 +22,6 @@ Mean """ struct ESS <: InferenceAlgorithm end -# TODO(mhauru) Remove the below once DPPL no longer relies on it. -drop_space(alg::ESS) = alg -DynamicPPL.getspace(::Sampler{<:ESS}) = () - # always accept in the first step function DynamicPPL.initialstep( rng::AbstractRNG, model::Model, spl::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs... diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index ada5f611b..04ebe3063 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -345,7 +345,7 @@ struct Gibbs{N,V<:NTuple{N,AbstractVector{<:VarName}},A<:NTuple{N,Any}} <: # Ensure that samplers have the same selector, and that varnames are lists of # VarNames. - samplers = tuple(map(set_selector ∘ drop_space, samplers)...) + samplers = tuple(map(set_selector, samplers)...) varnames = tuple(map(to_varname_list, varnames)...) return new{length(samplers),typeof(varnames),typeof(samplers)}(varnames, samplers) end @@ -355,49 +355,6 @@ function Gibbs(algs::Pair...) return Gibbs(map(first, algs), map(last, algs)) end -# The below two constructors only provide backwards compatibility with the constructor of -# the old Gibbs sampler. They are deprecated and will be removed in the future. -function Gibbs(alg1::InferenceAlgorithm, other_algs::InferenceAlgorithm...) - algs = [alg1, other_algs...] - varnames = map(algs) do alg - space = getspace(alg) - if (space isa VarName) - space - elseif (space isa Symbol) - VarName{space}() - else - tuple((s isa Symbol ? VarName{s}() : s for s in space)...) - end - end - msg = ( - "Specifying which sampler to use with which variable using syntax like " * - "`Gibbs(NUTS(:x), MH(:y))` is deprecated and will be removed in the future. " * - "Please use `Gibbs(; x=NUTS(), y=MH())` instead. If you want different iteration " * - "counts for different subsamplers, use e.g. " * - "`Gibbs(@varname(x) => RepeatSampler(NUTS(), 2), @varname(y) => MH())`" - ) - Base.depwarn(msg, :Gibbs) - return Gibbs(varnames, map(set_selector ∘ drop_space, algs)) -end - -function Gibbs( - alg_with_iters1::Tuple{<:InferenceAlgorithm,Int}, - other_algs_with_iters::Tuple{<:InferenceAlgorithm,Int}..., -) - algs_with_iters = [alg_with_iters1, other_algs_with_iters...] - algs = Iterators.map(first, algs_with_iters) - iters = Iterators.map(last, algs_with_iters) - algs_duplicated = Iterators.flatten(( - Iterators.repeated(alg, iter) for (alg, iter) in zip(algs, iters) - )) - # This calls the other deprecated constructor from above, hence no need for a depwarn - # here. - return Gibbs(algs_duplicated...) -end - -# TODO: Remove when no longer needed. -DynamicPPL.getspace(::Gibbs) = () - struct GibbsState{V<:DynamicPPL.AbstractVarInfo,S} vi::V states::S diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index d5d21b275..98a42482e 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -76,10 +76,6 @@ function HMC( return HMC(ϵ, n_leapfrog, metricT; adtype=adtype) end -# TODO(mhauru) Remove the below once DPPL no longer relies on it. -drop_space(alg::Hamiltonian) = alg -DynamicPPL.getspace(::Hamiltonian) = () - DynamicPPL.initialsampler(::Sampler{<:Hamiltonian}) = SampleFromUniform() # Handle setting `nadapts` and `discard_initial` diff --git a/src/mcmc/is.jl b/src/mcmc/is.jl index 7b8b74f32..23a7bbef1 100644 --- a/src/mcmc/is.jl +++ b/src/mcmc/is.jl @@ -26,9 +26,6 @@ sample(gdemo([1.5, 2]), IS(), 1000) """ struct IS <: InferenceAlgorithm end -drop_space(alg::IS) = alg -DynamicPPL.getspace(::IS) = () - DynamicPPL.initialsampler(sampler::Sampler{<:IS}) = sampler function DynamicPPL.initialstep( diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index 35f1711eb..8dfad8c7b 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -139,10 +139,6 @@ struct MH{P} <: InferenceAlgorithm end end -# TODO(mhauru) Remove the below once DPPL no longer relies on it. -drop_space(alg::MH) = alg -DynamicPPL.getspace(::MH) = () - # Some of the proposals require working in unconstrained space. transform_maybe(proposal::AMH.Proposal) = proposal function transform_maybe(proposal::AMH.RandomWalkProposal) diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 380986be3..58637897a 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -45,9 +45,6 @@ function SMC(threshold::Real) return SMC(AdvancedPS.resample_systematic, threshold) end -drop_space(alg::SMC) = alg -DynamicPPL.getspace(::SMC) = () - struct SMCTransition{T,F<:AbstractFloat} <: AbstractTransition "The parameters for any given sample." θ::T @@ -218,9 +215,6 @@ function PG(nparticles::Int, threshold::Real) return PG(nparticles, AdvancedPS.resample_systematic, threshold) end -drop_space(alg::PG) = alg -DynamicPPL.getspace(::PG) = () - """ CSMC(...) diff --git a/src/mcmc/repeat_sampler.jl b/src/mcmc/repeat_sampler.jl index a3e38f46a..2f8ab8674 100644 --- a/src/mcmc/repeat_sampler.jl +++ b/src/mcmc/repeat_sampler.jl @@ -28,10 +28,9 @@ function RepeatSampler(alg::InferenceAlgorithm, num_repeat::Int) return RepeatSampler(Sampler(alg), num_repeat) end -drop_space(rs::RepeatSampler) = RepeatSampler(drop_space(rs.sampler), rs.num_repeat) getADType(spl::RepeatSampler) = getADType(spl.sampler) DynamicPPL.default_chain_type(sampler::RepeatSampler) = default_chain_type(sampler.sampler) -DynamicPPL.getspace(spl::RepeatSampler) = getspace(spl.sampler) +# TODO(mhauru) Remove the below once DynamicPPL has removed all its Selector stuff. DynamicPPL.inspace(vn::VarName, spl::RepeatSampler) = inspace(vn, spl.sampler) function setparams_varinfo!!(model::DynamicPPL.Model, sampler::RepeatSampler, state, params) diff --git a/src/mcmc/sghmc.jl b/src/mcmc/sghmc.jl index 367d7dd27..20d7bc148 100644 --- a/src/mcmc/sghmc.jl +++ b/src/mcmc/sghmc.jl @@ -47,8 +47,6 @@ function SGHMC(; ) end -drop_space(alg::SGHMC) = alg - struct SGHMCState{L,V<:AbstractVarInfo,T<:AbstractVector{<:Real}} logdensity::L vi::V @@ -130,8 +128,6 @@ struct SGLD{AD,S} <: StaticHamiltonian adtype::AD end -drop_space(alg::SGLD) = alg - struct PolynomialStepsize{T<:Real} "Constant scale factor of the step size." a::T diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 174cc24b0..9b5ba4c2a 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -17,7 +17,7 @@ using Random: Random using ReverseDiff: ReverseDiff import Mooncake using StableRNGs: StableRNG -using Test: @inferred, @test, @test_broken, @test_deprecated, @test_throws, @testset +using Test: @inferred, @test, @test_broken, @test_throws, @testset using Turing using Turing: Inference using Turing.Inference: AdvancedHMC, AdvancedMH @@ -150,7 +150,6 @@ end # Methods we need to define to be able to use AlgWrapper instead of an actual algorithm. # They all just propagate the call to the inner algorithm. Inference.isgibbscomponent(wrap::AlgWrapper) = Inference.isgibbscomponent(wrap.inner) - Inference.drop_space(wrap::AlgWrapper) = AlgWrapper(Inference.drop_space(wrap.inner)) function Inference.setparams_varinfo!!( model::DynamicPPL.Model, sampler::DynamicPPL.Sampler{<:AlgWrapper}, From fb868fe2a045185e3af9d6b53fc5754864537510 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 9 Jan 2025 15:05:20 +0000 Subject: [PATCH 16/20] Bump minor version 0.37. Add a HISTORY.md entry --- HISTORY.md | 6 ++++++ Project.toml | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index 64be69106..82b2fd081 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,9 @@ +# Release 0.37.0 + +## Breaking changes + +0.37 removes the old Gibbs constructors deprecated in 0.36. + # Release 0.36.0 ## Breaking changes diff --git a/Project.toml b/Project.toml index eea0b0b7e..b66edc35b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.36.0" +version = "0.37.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From 091d114dd2721d4ba04f616d347458f84b243e47 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 10 Jan 2025 16:54:01 +0000 Subject: [PATCH 17/20] Apply suggestions from code review Co-authored-by: Penelope Yong --- src/mcmc/particle_mcmc.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 58637897a..04389a771 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -24,7 +24,6 @@ struct SMC{R} <: ParticleInference end """ - SMC() SMC([resampler = AdvancedPS.ResampleWithESSThreshold()]) SMC([resampler = AdvancedPS.resample_systematic, ]threshold) @@ -33,9 +32,7 @@ Create a sequential Monte Carlo sampler of type [`SMC`](@ref). If the algorithm for the resampling step is not specified explicitly, systematic resampling is performed if the estimated effective sample size per particle drops below 0.5. """ -function SMC(resampler=AdvancedPS.ResampleWithESSThreshold()) - return SMC{typeof(resampler)}(resampler) -end +SMC() = SMC(AdvancedPS.ResampleWithESSThreshold()) # Convenient constructors with ESS threshold function SMC(resampler, threshold::Real) @@ -194,7 +191,6 @@ struct PG{R} <: ParticleInference end """ - PG(n) PG(n, [resampler = AdvancedPS.ResampleWithESSThreshold()]) PG(n, [resampler = AdvancedPS.resample_systematic, ]threshold) From 9a0dee6f0d70e3f6a30b9316e81839ad1024583a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 10 Jan 2025 17:12:41 +0000 Subject: [PATCH 18/20] Remove unnecessary type parameters Co-authored-by: Penelope Yong --- src/mcmc/sghmc.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/mcmc/sghmc.jl b/src/mcmc/sghmc.jl index 20d7bc148..7cf5cd6e4 100644 --- a/src/mcmc/sghmc.jl +++ b/src/mcmc/sghmc.jl @@ -42,9 +42,7 @@ function SGHMC(; adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE, ) _learning_rate, _momentum_decay = promote(learning_rate, momentum_decay) - return SGHMC{typeof(adtype),typeof(_learning_rate)}( - _learning_rate, _momentum_decay, adtype - ) + return SGHMC(_learning_rate, _momentum_decay, adtype) end struct SGHMCState{L,V<:AbstractVarInfo,T<:AbstractVector{<:Real}} @@ -185,7 +183,7 @@ See also: [`PolynomialStepsize`](@ref) function SGLD(; stepsize=PolynomialStepsize(0.01), adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE ) - return SGLD{typeof(adtype),typeof(stepsize)}(stepsize, adtype) + return SGLD(stepsize, adtype) end struct SGLDTransition{T,F<:Real} <: AbstractTransition From bd78c89c40ea4d6f41bd439eb45c302c83fc4a6e Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 10 Jan 2025 17:06:33 +0000 Subject: [PATCH 19/20] Simplify constructors in particle_mcmc.jl --- src/mcmc/particle_mcmc.jl | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 04389a771..3e977f216 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -17,10 +17,6 @@ $(TYPEDFIELDS) """ struct SMC{R} <: ParticleInference resampler::R - - # Specifying this constructor stops the creation of default constructors. We rather - # define our own outer constructors below. - SMC{R}(resampler::R) where {R} = new{R}(resampler) end """ @@ -184,10 +180,6 @@ struct PG{R} <: ParticleInference nparticles::Int """Resampling algorithm.""" resampler::R - - # Specifying this constructor stops the creation of default constructors. We rather - # define our own outer constructors below. - PG{R}(nparticles::Int, resampler::R) where {R} = new{R}(nparticles, resampler) end """ @@ -199,8 +191,8 @@ Create a Particle Gibbs sampler of type [`PG`](@ref) with `n` particles. If the algorithm for the resampling step is not specified explicitly, systematic resampling is performed if the estimated effective sample size per particle drops below 0.5. """ -function PG(nparticles::Int, resampler=AdvancedPS.ResampleWithESSThreshold()) - return PG{typeof(resampler)}(nparticles, resampler) +function PG(nparticles::Int) + return PG(nparticles, AdvancedPS.ResampleWithESSThreshold()) end # Convenient constructors with ESS threshold From d6ba522866314d14dd0c69d930e3af21fa0a5d93 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 10 Jan 2025 17:08:56 +0000 Subject: [PATCH 20/20] Remove calls to setgid and updategid --- src/mcmc/hmc.jl | 3 --- src/mcmc/mh.jl | 8 -------- src/mcmc/particle_mcmc.jl | 2 -- 3 files changed, 13 deletions(-) diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index 98a42482e..524d02c6a 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -495,7 +495,6 @@ end function DynamicPPL.assume( rng, spl::Sampler{<:Hamiltonian}, dist::Distribution, vn::VarName, vi ) - DynamicPPL.updategid!(vi, vn, spl) return DynamicPPL.assume(dist, vn, vi) end @@ -507,7 +506,6 @@ function DynamicPPL.dot_assume( var::AbstractMatrix, vi, ) - DynamicPPL.updategid!.(Ref(vi), vns, Ref(spl)) return DynamicPPL.dot_assume(dist, var, vns, vi) end function DynamicPPL.dot_assume( @@ -518,7 +516,6 @@ function DynamicPPL.dot_assume( var::AbstractArray, vi, ) - DynamicPPL.updategid!.(Ref(vi), vns, Ref(spl)) return DynamicPPL.dot_assume(dists, var, vns, vi) end diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index 8dfad8c7b..33587f1e4 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -398,9 +398,6 @@ function DynamicPPL.assume( ) # Just defer to `SampleFromPrior`. retval = DynamicPPL.assume(rng, SampleFromPrior(), dist, vn, vi) - # Update the Gibbs IDs because they might have been assigned in the `SampleFromPrior` call. - DynamicPPL.updategid!(vi, vn, spl) - # Return. return retval end @@ -414,9 +411,6 @@ function DynamicPPL.dot_assume( ) # Just defer to `SampleFromPrior`. retval = DynamicPPL.dot_assume(rng, SampleFromPrior(), dist, vns[1], var, vi) - # Update the Gibbs IDs because they might have been assigned in the `SampleFromPrior` call. - DynamicPPL.updategid!.((vi,), vns, (spl,)) - # Return. return retval end function DynamicPPL.dot_assume( @@ -429,8 +423,6 @@ function DynamicPPL.dot_assume( ) # Just defer to `SampleFromPrior`. retval = DynamicPPL.dot_assume(rng, SampleFromPrior(), dists, vns, var, vi) - # Update the Gibbs IDs because they might have been assigned in the `SampleFromPrior` call. - DynamicPPL.updategid!.((vi,), vns, (spl,)) return retval end diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 3e977f216..733d572c7 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -370,10 +370,8 @@ function DynamicPPL.assume( unset_flag!(vi, vn, "del") # Reference particle parent r = rand(trng, dist) vi[vn] = DynamicPPL.tovec(r) - DynamicPPL.setgid!(vi, spl.selector, vn) setorder!(vi, vn, get_num_produce(vi)) else - DynamicPPL.updategid!(vi, vn, spl) # Pick data from reference particle r = vi[vn] end # TODO: Should we make this `zero(promote_type(eltype(dist), eltype(r)))` or something?