From e114a31af8ce52e21891ae57347336e49cc0a6ec Mon Sep 17 00:00:00 2001 From: wmkouw Date: Tue, 15 Oct 2024 14:27:15 +0200 Subject: [PATCH 1/8] Added rule table lookup and filtered print of existing rules. --- src/rule.jl | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/src/rule.jl b/src/rule.jl index 9f0a51509..98981f6b1 100644 --- a/src/rule.jl +++ b/src/rule.jl @@ -1151,13 +1151,35 @@ function Base.showerror(io::IO, error::RuleMethodError) arguments_spec = join(spec, ", ") meta_spec = rule_method_error_extract_meta(error.meta) + println(io, "\n\nExisting rule(s) for node:\n") + + # Retrieve all rules for this node + all_rules = methods(ReactiveMP.rule) + this_node_fform = error.fform + this_node_rules = all_rules[get_node_from_rule_method.(all_rules) .== "$this_node_fform"] + + for node_rule in this_node_rules + + node_name = get_node_from_rule_method(node_rule) + node_inputs = get_messages_from_rule_method(node_rule) + + if typeof(node_inputs) !== Vector{Any} + + node_rule_string = """ + $node_name($(join(node_inputs, ", "))) + """ + println(io, node_rule_string) + + end + end + possible_fix_definition = """ @rule $(spec_fform)($spec_on, $spec_vconstraint) ($arguments_spec, $meta_spec) = begin return ... end """ - println(io, "\n\nPossible fix, define:\n") + println(io, "\nPossible fix, define:\n") println(io, possible_fix_definition) if !isnothing(error.addons) println(io, "\n\nEnabled addons: ", error.addons, "\n") From 0488672fef310d5ad7de2194aa9131e37c6eb984 Mon Sep 17 00:00:00 2001 From: wmkouw Date: Tue, 15 Oct 2024 15:16:21 +0200 Subject: [PATCH 2/8] Compressed rule lookup and print code. --- src/rule.jl | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/src/rule.jl b/src/rule.jl index 98981f6b1..f94a52a59 100644 --- a/src/rule.jl +++ b/src/rule.jl @@ -1151,25 +1151,15 @@ function Base.showerror(io::IO, error::RuleMethodError) arguments_spec = join(spec, ", ") meta_spec = rule_method_error_extract_meta(error.meta) + # Print list of existing rules for this node println(io, "\n\nExisting rule(s) for node:\n") - - # Retrieve all rules for this node all_rules = methods(ReactiveMP.rule) - this_node_fform = error.fform - this_node_rules = all_rules[get_node_from_rule_method.(all_rules) .== "$this_node_fform"] - + this_node_rules = all_rules[get_node_from_rule_method.(all_rules) .== "$(error.fform)"] for node_rule in this_node_rules - node_name = get_node_from_rule_method(node_rule) node_inputs = get_messages_from_rule_method(node_rule) - if typeof(node_inputs) !== Vector{Any} - - node_rule_string = """ - $node_name($(join(node_inputs, ", "))) - """ - println(io, node_rule_string) - + println(io, "$node_name($(join(node_inputs, ", ")))") end end From 8d3b66ee36448061df3cb5ae92a9ecd56bf54202 Mon Sep 17 00:00:00 2001 From: wmkouw Date: Tue, 15 Oct 2024 15:30:54 +0200 Subject: [PATCH 3/8] Added test for rule look up under rule_method_error testset --- test/rule_tests.jl | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/test/rule_tests.jl b/test/rule_tests.jl index b2f6d83b0..35d2b3591 100644 --- a/test/rule_tests.jl +++ b/test/rule_tests.jl @@ -530,6 +530,29 @@ @test occursin("Possible fix, define", output) @test occursin("(m_out::NormalMeanVariance, m_μ::NormalMeanVariance, q_out_μ::MvNormalMeanPrecision, meta::Float64)", output) end + + let + err = ReactiveMP.RuleMethodError( + Beta, + Val{:a}(), + Marginalisation(), + Val{(:out, :b)}(), + (Message(PointMass, false, false, nothing), Message(PointMass, false, false, nothing)), + Val{(:out_b,)}(), + (Marginal(PointMass, false, false, nothing),), + 1.0, + nothing, + nothing + ) + + io = IOBuffer() + showerror(io, err) + output = String(take!(io)) + + @test occursin("Existing rule(s) for node:", output) + @test occursin("Distributions.Beta", output) + @test occursin("μ(a) :: BayesBase.PointMass, μ(b) :: BayesBase.PointMass", output) + end end @testset "marginalrule_method_error" begin From 8756c0f48c6bee90f9ff528eec0a8ef41db4ff9e Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Mon, 4 Nov 2024 12:50:45 +0100 Subject: [PATCH 4/8] modify tests a bit --- src/rule.jl | 2 +- test/rule_tests.jl | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/rule.jl b/src/rule.jl index f94a52a59..e4847ceed 100644 --- a/src/rule.jl +++ b/src/rule.jl @@ -1159,7 +1159,7 @@ function Base.showerror(io::IO, error::RuleMethodError) node_name = get_node_from_rule_method(node_rule) node_inputs = get_messages_from_rule_method(node_rule) if typeof(node_inputs) !== Vector{Any} - println(io, "$node_name($(join(node_inputs, ", ")))") + println(io, node_name, "(", join(node_inputs, ", "), ")") end end diff --git a/test/rule_tests.jl b/test/rule_tests.jl index 35d2b3591..1d33174b8 100644 --- a/test/rule_tests.jl +++ b/test/rule_tests.jl @@ -551,7 +551,8 @@ @test occursin("Existing rule(s) for node:", output) @test occursin("Distributions.Beta", output) - @test occursin("μ(a) :: BayesBase.PointMass, μ(b) :: BayesBase.PointMass", output) + @test occursin("μ(a) :: BayesBase.PointMass", output) + @test occursin("μ(b) :: BayesBase.PointMass", output) end end From 1ebaae1cb2557316a6b907e1b6993142fe86592c Mon Sep 17 00:00:00 2001 From: wmkouw Date: Tue, 26 Nov 2024 12:50:07 +0100 Subject: [PATCH 5/8] Reworked rule suggestions. Split Bart's get_message/marginal_from_rule_method functions. Formatted to m/q_ instead of m/q(). --- src/rule.jl | 82 +++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 57 insertions(+), 25 deletions(-) diff --git a/src/rule.jl b/src/rule.jl index e4847ceed..a5c63e169 100644 --- a/src/rule.jl +++ b/src/rule.jl @@ -1151,29 +1151,39 @@ function Base.showerror(io::IO, error::RuleMethodError) arguments_spec = join(spec, ", ") meta_spec = rule_method_error_extract_meta(error.meta) - # Print list of existing rules for this node - println(io, "\n\nExisting rule(s) for node:\n") - all_rules = methods(ReactiveMP.rule) - this_node_rules = all_rules[get_node_from_rule_method.(all_rules) .== "$(error.fform)"] - for node_rule in this_node_rules - node_name = get_node_from_rule_method(node_rule) - node_inputs = get_messages_from_rule_method(node_rule) - if typeof(node_inputs) !== Vector{Any} - println(io, node_name, "(", join(node_inputs, ", "), ")") - end - end - possible_fix_definition = """ @rule $(spec_fform)($spec_on, $spec_vconstraint) ($arguments_spec, $meta_spec) = begin return ... end """ - println(io, "\nPossible fix, define:\n") + println(io, "\n\nPossible fix, define:\n") println(io, possible_fix_definition) if !isnothing(error.addons) println(io, "\n\nEnabled addons: ", error.addons, "\n") end + + all_rules = methods(ReactiveMP.rule) + node_rules = all_rules[get_node_from_rule_method.(all_rules) .== spec_fform] + println(io, "Alternatively, consider re-specifying model using an existing rule:\n") + + node_message_names = filter(x -> x != ["Nothing"], get_message_names_from_rule_method.(node_rules)) + node_message_types = filter(!isempty, get_message_types_from_rule_method.(node_rules)) + for (m_name, m_type) in zip(node_message_names, node_message_types) + message_input = [string("m_",n," :: ",t) for (n,t) in zip(m_name, m_type)] + println(io, spec_fform, "(", join(message_input, ", "), ")") + end + + node_marginal_names = filter(x -> x != ["Nothing"], get_marginal_names_from_rule_method.(node_rules)) + node_marginal_types = filter(!isempty, get_marginal_types_from_rule_method.(node_rules)) + for (m_name, m_type) in zip(node_marginal_names, node_marginal_types) + marginal_input = [string("q_",n," :: ",t) for (n,t) in zip(m_name, m_type)] + println(io, spec_fform, "(", join(marginal_input, ", "), ")") + end + if !isempty(node_marginal_names) + println("\nNote that for marginal rules (i.e., involving q_*), the order of input types matters.") + end + else println(io, "\n\n[WARN]: Non-standard rule layout found! Possible fix, define rule with the following arguments:\n") println(io, "rule.fform: ", error.fform) @@ -1273,42 +1283,64 @@ function convert_to_markdown(m::Method) return output end -# Extracts node from rule Method +# Extracts node from rule method function get_node_from_rule_method(m::Method) _, decls, _, _ = Base.arg_decl_parts(m) return decls[2][2][8:(end - 1)] end -# Extracts output from rule Method +# Extracts output from rule method function get_output_from_rule_method(m::Method) _, decls, _, _ = Base.arg_decl_parts(m) return replace(decls[3][2], r"Type|Val|{|}|:|\(|\)|\,|Tuple|Int64" => "") end -# Extracts messages from rule Method -function get_messages_from_rule_method(m::Method) +# Extracts name of message from rule method (e.g, :a, :out) +function get_message_names_from_rule_method(m::Method) + _, decls, _, _ = Base.arg_decl_parts(m) + return split(replace(decls[5][2], r"Type|Val|{|}|:|\(|\)|\," => "")) +end + +# Extracts type of message from rule method (e.g., PointMass, NormalMeanVariance) +function get_message_types_from_rule_method(m::Method) _, decls, _, _ = Base.arg_decl_parts(m) tmp1 = replace(replace(decls[6][2][7:(end - 1)], r"ReactiveMP.ManyOf{<:Tuple{Vararg{" => ""), r"\, N}}}" => "xyz") tmp2 = strip.(strip.(split(tmp1, "Message")[2:end]), ',') tmp3 = map(x -> x == "xyz" ? "{<:ManyOf{<:Tuple{Vararg{Any, N}}}}" : x, tmp2) tmp4 = map(x -> x == r"xyz*" ? "{<:ManyOf{<:Tuple{Vararg{" * x[4:end] * ", N}}}}" : x, tmp3) tmp5 = map(x -> occursin("xyz", x) ? x[1:(end - 3)] : x, tmp4) - interfaces = "μ(" .* split(replace(decls[5][2], r"Type|Val|{|}|:|\(|\)|\," => "")) .* ")" - types = map(x -> isempty(x) ? "Any" : x, map(x -> x[4:(end - 1)], tmp5)) - return interfaces .* " :: " .* types + return map(x -> isempty(x) ? "Any" : x, map(x -> x[4:(end - 1)], tmp5)) end -# Extracts marginals from rule method -function get_marginals_from_rule_method(m::Method) +# Extracts messages from rule method (e.g., "μ(a) :: PointMass") +function get_messages_from_rule_method(m::Method) + interfaces = get_message_names_from_rule_method(m) + types = get_message_types_from_rule_method(m) + return "μ(" .* interfaces .* ")" .* " :: " .* types +end + +# Extracts name of marginal from rule method (e.g, :a, :out) +function get_marginal_names_from_rule_method(m::Method) + _, decls, _, _ = Base.arg_decl_parts(m) + return split(replace(decls[7][2], r"Type|Val|{|}|:|\(|\)|\," => "")) +end + +# Extracts type of marginal from rule method (e.g., PointMass, NormalMeanVariance) +function get_marginal_types_from_rule_method(m::Method) _, decls, _, _ = Base.arg_decl_parts(m) tmp1 = replace(replace(decls[8][2][7:(end - 1)], r"ReactiveMP.ManyOf{<:Tuple{Vararg{" => ""), r"\, N}}}" => "xyz") tmp2 = strip.(strip.(split(tmp1, "Marginal")[2:end]), ',') tmp3 = map(x -> x == "xyz" ? "{<:ManyOf{<:Tuple{Vararg{Any, N}}}}" : x, tmp2) tmp4 = map(x -> x == r"xyz*" ? "{<:ManyOf{<:Tuple{Vararg{" * x[4:end] * ", N}}}}" : x, tmp3) tmp5 = map(x -> occursin("xyz", x) ? x[1:(end - 3)] : x, tmp4) - interfaces = "q(" .* split(replace(decls[7][2], r"Type|Val|{|}|:|\(|\)|\," => "")) .* ")" - types = map(x -> isempty(x) ? "Any" : x, map(x -> x[4:(end - 1)], tmp5)) - return interfaces .* " :: " .* types + return map(x -> isempty(x) ? "Any" : x, map(x -> x[4:(end - 1)], tmp5)) +end + +# Extracts marginals from rule method +function get_marginals_from_rule_method(m::Method) + interfaces = get_marginal_names_from_rule_method(m) + types = get_marginal_types_from_rule_method(m) + return "q(" .* interfaces .* ")" .* " :: " .* types end # Extracts meta from rule method From 86ddd5af8ad2af02fd1827ed19c246afe1bcbfda Mon Sep 17 00:00:00 2001 From: wmkouw Date: Tue, 26 Nov 2024 15:09:24 +0100 Subject: [PATCH 6/8] Replaced node rule index selection with filter on methods iterable. --- src/rule.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/rule.jl b/src/rule.jl index a5c63e169..b77f723cd 100644 --- a/src/rule.jl +++ b/src/rule.jl @@ -1163,25 +1163,24 @@ function Base.showerror(io::IO, error::RuleMethodError) println(io, "\n\nEnabled addons: ", error.addons, "\n") end - all_rules = methods(ReactiveMP.rule) - node_rules = all_rules[get_node_from_rule_method.(all_rules) .== spec_fform] + node_rules = filter(m -> ReactiveMP.get_node_from_rule_method(m) == spec_fform, methods(ReactiveMP.rule)) println(io, "Alternatively, consider re-specifying model using an existing rule:\n") node_message_names = filter(x -> x != ["Nothing"], get_message_names_from_rule_method.(node_rules)) node_message_types = filter(!isempty, get_message_types_from_rule_method.(node_rules)) for (m_name, m_type) in zip(node_message_names, node_message_types) - message_input = [string("m_",n," :: ",t) for (n,t) in zip(m_name, m_type)] + message_input = [string("m_", n, "::", t) for (n, t) in zip(m_name, m_type)] println(io, spec_fform, "(", join(message_input, ", "), ")") end node_marginal_names = filter(x -> x != ["Nothing"], get_marginal_names_from_rule_method.(node_rules)) node_marginal_types = filter(!isempty, get_marginal_types_from_rule_method.(node_rules)) for (m_name, m_type) in zip(node_marginal_names, node_marginal_types) - marginal_input = [string("q_",n," :: ",t) for (n,t) in zip(m_name, m_type)] + marginal_input = [string("q_", n, "::", t) for (n, t) in zip(m_name, m_type)] println(io, spec_fform, "(", join(marginal_input, ", "), ")") end if !isempty(node_marginal_names) - println("\nNote that for marginal rules (i.e., involving q_*), the order of input types matters.") + println(io, "\nNote that for marginal rules (i.e., involving q_*), the order of input types matters.") end else From 56d425e84f14ad0296054abee63a26499e165937 Mon Sep 17 00:00:00 2001 From: wmkouw Date: Tue, 26 Nov 2024 15:10:23 +0100 Subject: [PATCH 7/8] Corrected existing tests. Added tests with Gamma and Dirichlet models. Added testset on get_*_from_rule_method functions. --- test/rule_tests.jl | 88 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 84 insertions(+), 4 deletions(-) diff --git a/test/rule_tests.jl b/test/rule_tests.jl index 1d33174b8..d7af639f2 100644 --- a/test/rule_tests.jl +++ b/test/rule_tests.jl @@ -549,10 +549,63 @@ showerror(io, err) output = String(take!(io)) - @test occursin("Existing rule(s) for node:", output) - @test occursin("Distributions.Beta", output) - @test occursin("μ(a) :: BayesBase.PointMass", output) - @test occursin("μ(b) :: BayesBase.PointMass", output) + @test occursin("Alternatively, consider re-specifying model using an existing rule:", output) + @test occursin("m_a::BayesBase.PointMass", output) + @test occursin("m_b::BayesBase.PointMass", output) + @test occursin("q_a::BayesBase.PointMass", output) + @test occursin("q_b::BayesBase.PointMass", output) + end + + let + err = ReactiveMP.RuleMethodError( + GammaShapeRate, + Val{:a}(), + Marginalisation(), + Val{(:out, :b)}(), + (Message(PointMass, false, false, nothing), Message(PointMass, false, false, nothing)), + Val{(:out_b,)}(), + (Marginal(PointMass, false, false, nothing),), + 1.0, + nothing, + nothing + ) + + io = IOBuffer() + showerror(io, err) + output = String(take!(io)) + + @test occursin("Alternatively, consider re-specifying model using an existing rule:", output) + @test occursin("GammaShapeRate", output) + @test occursin("m_α::BayesBase.PointMass", output) + @test occursin("m_β::BayesBase.PointMass", output) + @test occursin("q_out::activeMP", output) + @test occursin("q_α::Any", output) + @test occursin("q_β::Any", output) + @test occursin("q_β::ExponentialFamily.GammaDistributionsFamily", output) + end + + let + err = ReactiveMP.RuleMethodError( + Dirichlet, + Val{:a}(), + Marginalisation(), + Val{(:out,)}(), + (Message(PointMass, false, false, nothing),), + Val{(:a,)}(), + (Marginal(PointMass, false, false, nothing),), + 1.0, + nothing, + nothing + ) + + io = IOBuffer() + showerror(io, err) + output = String(take!(io)) + + @test occursin("Alternatively, consider re-specifying model using an existing rule:", output) + @test occursin("Dirichlet", output) + @test occursin("m_a::BayesBase.PointMass", output) + @test occursin("q_a::BayesBase.PointMass", output) end end @@ -658,6 +711,33 @@ @test occursin("(m_out::NormalMeanVariance, m_μ::NormalMeanVariance, q_out_μ::MvNormalMeanPrecision, meta::Float64)", output) end end + + @testset "get_from_rule_method" begin + + let + rule1 = methods(ReactiveMP.rule)[1] + + messages_rule1 = ReactiveMP.get_messages_from_rule_method(rule1) + message_names_rule1 = ReactiveMP.get_message_names_from_rule_method(rule1) + message_types_rule1 = ReactiveMP.get_message_types_from_rule_method(rule1) + + marginals_rule1 = ReactiveMP.get_marginals_from_rule_method(rule1) + marginal_names_rule1 = ReactiveMP.get_marginal_names_from_rule_method(rule1) + marginal_types_rule1 = ReactiveMP.get_marginal_types_from_rule_method(rule1) + + @test ReactiveMP.get_node_from_rule_method(rule1) == "*" + @test occursin("μ(A) :: BayesBase.PointMass", messages_rule1[1]) + @test occursin("μ(out) :: Union{ExponentialFamily.NormalDistributionsFamily{T}", messages_rule1[2]) + @test occursin("A", message_names_rule1[1]) + @test occursin("out", message_names_rule1[2]) + @test occursin("PointMass", message_types_rule1[1]) + @test occursin("Union", message_types_rule1[2]) + + @test isempty(marginals_rule1) + @test occursin("Nothing", marginal_names_rule1[1]) + @test isempty(marginal_types_rule1) + end + end end @testset "Check that default meta is `nothing`" begin From aea1dde7a7fa86bd4813866ff3f1f6193a092b3a Mon Sep 17 00:00:00 2001 From: wmkouw Date: Tue, 26 Nov 2024 15:26:44 +0100 Subject: [PATCH 8/8] Ran formatter --- test/rule_tests.jl | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/test/rule_tests.jl b/test/rule_tests.jl index d7af639f2..86f15089d 100644 --- a/test/rule_tests.jl +++ b/test/rule_tests.jl @@ -713,18 +713,17 @@ end @testset "get_from_rule_method" begin - - let + let rule1 = methods(ReactiveMP.rule)[1] - messages_rule1 = ReactiveMP.get_messages_from_rule_method(rule1) - message_names_rule1 = ReactiveMP.get_message_names_from_rule_method(rule1) - message_types_rule1 = ReactiveMP.get_message_types_from_rule_method(rule1) + messages_rule1 = ReactiveMP.get_messages_from_rule_method(rule1) + message_names_rule1 = ReactiveMP.get_message_names_from_rule_method(rule1) + message_types_rule1 = ReactiveMP.get_message_types_from_rule_method(rule1) marginals_rule1 = ReactiveMP.get_marginals_from_rule_method(rule1) marginal_names_rule1 = ReactiveMP.get_marginal_names_from_rule_method(rule1) marginal_types_rule1 = ReactiveMP.get_marginal_types_from_rule_method(rule1) - + @test ReactiveMP.get_node_from_rule_method(rule1) == "*" @test occursin("μ(A) :: BayesBase.PointMass", messages_rule1[1]) @test occursin("μ(out) :: Union{ExponentialFamily.NormalDistributionsFamily{T}", messages_rule1[2])