Skip to content

Commit

Permalink
Restrict values_as_in_model API (#778)
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm authored Jan 16, 2025
1 parent e673b69 commit 938a69d
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 43 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.33.1"
version = "0.34.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
41 changes: 15 additions & 26 deletions src/values_as_in_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ wants to extract the realization of a model in a constrained space.
# Fields
$(TYPEDFIELDS)
"""
struct ValuesAsInModelContext{T,C<:AbstractContext} <: AbstractContext
struct ValuesAsInModelContext{C<:AbstractContext} <: AbstractContext
"values that are extracted from the model"
values::T
values::OrderedDict
"whether to extract variables on the LHS of :="
include_colon_eq::Bool
"child context"
Expand Down Expand Up @@ -114,34 +114,32 @@ function dot_tilde_assume(
end

"""
values_as_in_model(model::Model, include_colon_eq::Bool[, varinfo::AbstractVarInfo, context::AbstractContext])
values_as_in_model(rng::Random.AbstractRNG, model::Model, include_colon_eq::Bool[, varinfo::AbstractVarInfo, context::AbstractContext])
values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo[, context::AbstractContext])
Get the values of `varinfo` as they would be seen in the model.
If no `varinfo` is provided, then this is effectively the same as
[`Base.rand(rng::Random.AbstractRNG, model::Model)`](@ref).
More specifically, this method attempts to extract the realization _as seen in
the model_. For example, `x[1] ~ truncated(Normal(); lower=0)` will result in a
realization that is compatible with `truncated(Normal(); lower=0)` -- i.e. one
where the value of `x[1]` is positive -- regardless of whether `varinfo` is
working in unconstrained space.
More specifically, this method attempts to extract the realization _as seen in the model_.
For example, `x[1] ~ truncated(Normal(); lower=0)` will result in a realization compatible
with `truncated(Normal(); lower=0)` regardless of whether `varinfo` is working in unconstrained
space.
Hence this method is a "safe" way of obtaining realizations in constrained space at the cost
of additional model evaluations.
Hence this method is a "safe" way of obtaining realizations in constrained
space at the cost of additional model evaluations.
# Arguments
- `model::Model`: model to extract realizations from.
- `include_colon_eq::Bool`: whether to also include variables on the LHS of `:=`.
- `varinfo::AbstractVarInfo`: variable information to use for the extraction.
- `context::AbstractContext`: context to use for the extraction. If `rng` is specified, then `context`
will be wrapped in a [`SamplingContext`](@ref) with the provided `rng`.
- `context::AbstractContext`: base context to use for the extraction. Defaults
to `DynamicPPL.DefaultContext()`.
# Examples
## When `VarInfo` fails
The following demonstrates a common pitfall when working with [`VarInfo`](@ref) and constrained variables.
The following demonstrates a common pitfall when working with [`VarInfo`](@ref)
and constrained variables.
```jldoctest
julia> using Distributions, StableRNGs
Expand Down Expand Up @@ -191,19 +189,10 @@ true
function values_as_in_model(
model::Model,
include_colon_eq::Bool,
varinfo::AbstractVarInfo=VarInfo(),
varinfo::AbstractVarInfo,
context::AbstractContext=DefaultContext(),
)
context = ValuesAsInModelContext(include_colon_eq, context)
evaluate!!(model, varinfo, context)
return context.values
end
function values_as_in_model(
rng::Random.AbstractRNG,
model::Model,
include_colon_eq::Bool,
varinfo::AbstractVarInfo=VarInfo(),
context::AbstractContext=DefaultContext(),
)
return values_as_in_model(model, true, varinfo, SamplingContext(rng, context))
end
16 changes: 0 additions & 16 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,22 +429,6 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
end
end
end

@testset "check that sampling obeys rng if passed" begin
@model function f()
x ~ Normal(0)
return y ~ Normal(x)
end
model = f()
# Call values_as_in_model with the rng
values = values_as_in_model(Random.Xoshiro(43), model, false)
# Check that they match the values that would be used if vi was seeded
# with that seed instead
expected_vi = VarInfo(Random.Xoshiro(43), model)
for vn in keys(values)
@test values[vn] == expected_vi[vn]
end
end
end

@testset "Erroneous model call" begin
Expand Down

2 comments on commit 938a69d

@penelopeysm
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

Breaking changes

  • rng argument removed from values_as_in_model, and varinfo made non-optional. This means that the only signatures allowed are

    values_as_in_model(::Model, ::Bool, ::AbstractVarInfo)
    values_as_in_model(::Model, ::Bool, ::AbstractVarInfo, ::AbstractContext)
    

If you aren't using this function (it's probably only used in Turing.jl) then this won't affect you.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/123134

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.34.0 -m "<description of version>" 938a69dfabb7089f52700ef01fa3a9f2d667b7d1
git push origin v0.34.0

Please sign in to comment.