Skip to content

Commit

Permalink
Merge pull request #423 from ReactiveBayes/ux-rule-suggestions
Browse files Browse the repository at this point in the history
Ux rule suggestions
  • Loading branch information
bvdmitri authored Nov 27, 2024
2 parents 05f0394 + aea1dde commit 7f9e27c
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 12 deletions.
67 changes: 55 additions & 12 deletions src/rule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1162,6 +1162,27 @@ function Base.showerror(io::IO, error::RuleMethodError)
if !isnothing(error.addons)
println(io, "\n\nEnabled addons: ", error.addons, "\n")
end

node_rules = filter(m -> ReactiveMP.get_node_from_rule_method(m) == spec_fform, methods(ReactiveMP.rule))
println(io, "Alternatively, consider re-specifying model using an existing rule:\n")

node_message_names = filter(x -> x != ["Nothing"], get_message_names_from_rule_method.(node_rules))
node_message_types = filter(!isempty, get_message_types_from_rule_method.(node_rules))
for (m_name, m_type) in zip(node_message_names, node_message_types)
message_input = [string("m_", n, "::", t) for (n, t) in zip(m_name, m_type)]
println(io, spec_fform, "(", join(message_input, ", "), ")")
end

node_marginal_names = filter(x -> x != ["Nothing"], get_marginal_names_from_rule_method.(node_rules))
node_marginal_types = filter(!isempty, get_marginal_types_from_rule_method.(node_rules))
for (m_name, m_type) in zip(node_marginal_names, node_marginal_types)
marginal_input = [string("q_", n, "::", t) for (n, t) in zip(m_name, m_type)]
println(io, spec_fform, "(", join(marginal_input, ", "), ")")
end
if !isempty(node_marginal_names)
println(io, "\nNote that for marginal rules (i.e., involving q_*), the order of input types matters.")
end

else
println(io, "\n\n[WARN]: Non-standard rule layout found! Possible fix, define rule with the following arguments:\n")
println(io, "rule.fform: ", error.fform)
Expand Down Expand Up @@ -1261,42 +1282,64 @@ function convert_to_markdown(m::Method)
return output
end

# Extracts node from rule Method
# Extracts node from rule method
function get_node_from_rule_method(m::Method)
_, decls, _, _ = Base.arg_decl_parts(m)
return decls[2][2][8:(end - 1)]
end

# Extracts output from rule Method
# Extracts output from rule method
function get_output_from_rule_method(m::Method)
_, decls, _, _ = Base.arg_decl_parts(m)
return replace(decls[3][2], r"Type|Val|{|}|:|\(|\)|\,|Tuple|Int64" => "")
end

# Extracts messages from rule Method
function get_messages_from_rule_method(m::Method)
# Extracts name of message from rule method (e.g, :a, :out)
function get_message_names_from_rule_method(m::Method)
_, decls, _, _ = Base.arg_decl_parts(m)
return split(replace(decls[5][2], r"Type|Val|{|}|:|\(|\)|\," => ""))
end

# Extracts type of message from rule method (e.g., PointMass, NormalMeanVariance)
function get_message_types_from_rule_method(m::Method)
_, decls, _, _ = Base.arg_decl_parts(m)
tmp1 = replace(replace(decls[6][2][7:(end - 1)], r"ReactiveMP.ManyOf{<:Tuple{Vararg{" => ""), r"\, N}}}" => "xyz")
tmp2 = strip.(strip.(split(tmp1, "Message")[2:end]), ',')
tmp3 = map(x -> x == "xyz" ? "{<:ManyOf{<:Tuple{Vararg{Any, N}}}}" : x, tmp2)
tmp4 = map(x -> x == r"xyz*" ? "{<:ManyOf{<:Tuple{Vararg{" * x[4:end] * ", N}}}}" : x, tmp3)
tmp5 = map(x -> occursin("xyz", x) ? x[1:(end - 3)] : x, tmp4)
interfaces = "μ(" .* split(replace(decls[5][2], r"Type|Val|{|}|:|\(|\)|\," => "")) .* ")"
types = map(x -> isempty(x) ? "Any" : x, map(x -> x[4:(end - 1)], tmp5))
return interfaces .* " :: " .* types
return map(x -> isempty(x) ? "Any" : x, map(x -> x[4:(end - 1)], tmp5))
end

# Extracts marginals from rule method
function get_marginals_from_rule_method(m::Method)
# Extracts messages from rule method (e.g., "μ(a) :: PointMass")
function get_messages_from_rule_method(m::Method)
interfaces = get_message_names_from_rule_method(m)
types = get_message_types_from_rule_method(m)
return "μ(" .* interfaces .* ")" .* " :: " .* types
end

# Extracts name of marginal from rule method (e.g, :a, :out)
function get_marginal_names_from_rule_method(m::Method)
_, decls, _, _ = Base.arg_decl_parts(m)
return split(replace(decls[7][2], r"Type|Val|{|}|:|\(|\)|\," => ""))
end

# Extracts type of marginal from rule method (e.g., PointMass, NormalMeanVariance)
function get_marginal_types_from_rule_method(m::Method)
_, decls, _, _ = Base.arg_decl_parts(m)
tmp1 = replace(replace(decls[8][2][7:(end - 1)], r"ReactiveMP.ManyOf{<:Tuple{Vararg{" => ""), r"\, N}}}" => "xyz")
tmp2 = strip.(strip.(split(tmp1, "Marginal")[2:end]), ',')
tmp3 = map(x -> x == "xyz" ? "{<:ManyOf{<:Tuple{Vararg{Any, N}}}}" : x, tmp2)
tmp4 = map(x -> x == r"xyz*" ? "{<:ManyOf{<:Tuple{Vararg{" * x[4:end] * ", N}}}}" : x, tmp3)
tmp5 = map(x -> occursin("xyz", x) ? x[1:(end - 3)] : x, tmp4)
interfaces = "q(" .* split(replace(decls[7][2], r"Type|Val|{|}|:|\(|\)|\," => "")) .* ")"
types = map(x -> isempty(x) ? "Any" : x, map(x -> x[4:(end - 1)], tmp5))
return interfaces .* " :: " .* types
return map(x -> isempty(x) ? "Any" : x, map(x -> x[4:(end - 1)], tmp5))
end

# Extracts marginals from rule method
function get_marginals_from_rule_method(m::Method)
interfaces = get_marginal_names_from_rule_method(m)
types = get_marginal_types_from_rule_method(m)
return "q(" .* interfaces .* ")" .* " :: " .* types
end

# Extracts meta from rule method
Expand Down
103 changes: 103 additions & 0 deletions test/rule_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,83 @@
@test occursin("Possible fix, define", output)
@test occursin("(m_out::NormalMeanVariance, m_μ::NormalMeanVariance, q_out_μ::MvNormalMeanPrecision, meta::Float64)", output)
end

let
err = ReactiveMP.RuleMethodError(
Beta,
Val{:a}(),
Marginalisation(),
Val{(:out, :b)}(),
(Message(PointMass, false, false, nothing), Message(PointMass, false, false, nothing)),
Val{(:out_b,)}(),
(Marginal(PointMass, false, false, nothing),),
1.0,
nothing,
nothing
)

io = IOBuffer()
showerror(io, err)
output = String(take!(io))

@test occursin("Alternatively, consider re-specifying model using an existing rule:", output)
@test occursin("m_a::BayesBase.PointMass", output)
@test occursin("m_b::BayesBase.PointMass", output)
@test occursin("q_a::BayesBase.PointMass", output)
@test occursin("q_b::BayesBase.PointMass", output)
end

let
err = ReactiveMP.RuleMethodError(
GammaShapeRate,
Val{:a}(),
Marginalisation(),
Val{(:out, :b)}(),
(Message(PointMass, false, false, nothing), Message(PointMass, false, false, nothing)),
Val{(:out_b,)}(),
(Marginal(PointMass, false, false, nothing),),
1.0,
nothing,
nothing
)

io = IOBuffer()
showerror(io, err)
output = String(take!(io))

@test occursin("Alternatively, consider re-specifying model using an existing rule:", output)
@test occursin("GammaShapeRate", output)
@test occursin("m_α::BayesBase.PointMass", output)
@test occursin("m_β::BayesBase.PointMass", output)
@test occursin("q_out::activeMP", output)
@test occursin("q_α::Any", output)
@test occursin("q_β::Any", output)
@test occursin("q_β::ExponentialFamily.GammaDistributionsFamily", output)
end

let
err = ReactiveMP.RuleMethodError(
Dirichlet,
Val{:a}(),
Marginalisation(),
Val{(:out,)}(),
(Message(PointMass, false, false, nothing),),
Val{(:a,)}(),
(Marginal(PointMass, false, false, nothing),),
1.0,
nothing,
nothing
)

io = IOBuffer()
showerror(io, err)
output = String(take!(io))

@test occursin("Alternatively, consider re-specifying model using an existing rule:", output)
@test occursin("Dirichlet", output)
@test occursin("m_a::BayesBase.PointMass", output)
@test occursin("q_a::BayesBase.PointMass", output)
end
end

@testset "marginalrule_method_error" begin
Expand Down Expand Up @@ -634,6 +711,32 @@
@test occursin("(m_out::NormalMeanVariance, m_μ::NormalMeanVariance, q_out_μ::MvNormalMeanPrecision, meta::Float64)", output)
end
end

@testset "get_from_rule_method" begin
let
rule1 = methods(ReactiveMP.rule)[1]

messages_rule1 = ReactiveMP.get_messages_from_rule_method(rule1)
message_names_rule1 = ReactiveMP.get_message_names_from_rule_method(rule1)
message_types_rule1 = ReactiveMP.get_message_types_from_rule_method(rule1)

marginals_rule1 = ReactiveMP.get_marginals_from_rule_method(rule1)
marginal_names_rule1 = ReactiveMP.get_marginal_names_from_rule_method(rule1)
marginal_types_rule1 = ReactiveMP.get_marginal_types_from_rule_method(rule1)

@test ReactiveMP.get_node_from_rule_method(rule1) == "*"
@test occursin("μ(A) :: BayesBase.PointMass", messages_rule1[1])
@test occursin("μ(out) :: Union{ExponentialFamily.NormalDistributionsFamily{T}", messages_rule1[2])
@test occursin("A", message_names_rule1[1])
@test occursin("out", message_names_rule1[2])
@test occursin("PointMass", message_types_rule1[1])
@test occursin("Union", message_types_rule1[2])

@test isempty(marginals_rule1)
@test occursin("Nothing", marginal_names_rule1[1])
@test isempty(marginal_types_rule1)
end
end
end

@testset "Check that default meta is `nothing`" begin
Expand Down

0 comments on commit 7f9e27c

Please sign in to comment.