diff --git a/Project.toml b/Project.toml index 5ccbbf752..2ef9edc8b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.29.2" +version = "0.29.3" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 1b09668ec..c4fa1eae1 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -98,19 +98,15 @@ Wrap a sampler so it can be used as an inference algorithm. """ externalsampler(sampler::AbstractSampler) = ExternalSampler(sampler) -""" - ESLogDensityFunction - -A log density function for the External sampler. - -""" -const ESLogDensityFunction{M<:Model,S<:Sampler{<:ExternalSampler},V<:AbstractVarInfo} = Turing.LogDensityFunction{V,M,<:DynamicPPL.DefaultContext} -function LogDensityProblems.logdensity(f::ESLogDensityFunction, x::NamedTuple) +function LogDensityProblems.logdensity( + f::Turing.LogDensityFunction{<:AbstractVarInfo,<:Model,<:DynamicPPL.DefaultContext}, + x::NamedTuple +) return DynamicPPL.logjoint(f.model, DynamicPPL.unflatten(f.varinfo, x)) end # TODO: make a nicer `set_namedtuple!` and move these functions to DynamicPPL. -function DynamicPPL.unflatten(vi::TypedVarInfo, θ::NamedTuple) +function DynamicPPL.unflatten(vi::TypedVarInfo, θ::NamedTuple) set_namedtuple!(deepcopy(vi), θ) return vi end diff --git a/src/mcmc/abstractmcmc.jl b/src/mcmc/abstractmcmc.jl index f24d5c02f..00be16e3b 100644 --- a/src/mcmc/abstractmcmc.jl +++ b/src/mcmc/abstractmcmc.jl @@ -7,7 +7,7 @@ state_to_turing(f::DynamicPPL.LogDensityFunction, state) = TuringState(state, f) function transition_to_turing(f::DynamicPPL.LogDensityFunction, transition) # TODO: We should probably rename this `getparams` since it returns something # very different from `Turing.Inference.getparams`. - θ = getparams(transition) + θ = getparams(f.model, transition) varinfo = DynamicPPL.unflatten(f.varinfo, θ) return Transition(f.model, varinfo, transition) end @@ -15,10 +15,10 @@ end # NOTE: Only thing that depends on the underlying sampler. # Something similar should be part of AbstractMCMC at some point: # https://github.com/TuringLang/AbstractMCMC.jl/pull/86 -getparams(transition::AdvancedHMC.Transition) = transition.z.θ +getparams(::DynamicPPL.Model, transition::AdvancedHMC.Transition) = transition.z.θ getstats(transition::AdvancedHMC.Transition) = transition.stat -getparams(transition::AdvancedMH.Transition) = transition.params +getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.params getvarinfo(f::DynamicPPL.LogDensityFunction) = f.varinfo getvarinfo(f::LogDensityProblemsAD.ADGradientWrapper) = getvarinfo(parent(f))