From 9a90f51d4a5e19a63e5bf4151eb8f22c311811af Mon Sep 17 00:00:00 2001 From: brianguenter <1brianguenter@gmail.com> Date: Thu, 31 Oct 2024 14:11:25 -0700 Subject: [PATCH] fixed bug in is_constant. Was recursing through the graph for each call. Made code asymptotically slower than it should have been. Now much faster for large graphs. changed simplification rules for unary -. Previously the conditionals in the if statements were creating node, which caused allocations even if the condition proved false. Changed to removed all Node create in conditions. over 2x speed improvement. --- src/ExpressionGraph.jl | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/src/ExpressionGraph.jl b/src/ExpressionGraph.jl index c7b217fd..b15986bd 100644 --- a/src/ExpressionGraph.jl +++ b/src/ExpressionGraph.jl @@ -125,7 +125,9 @@ is_tree(a::Node) = arity(a) >= 1 function is_variable(a::Node) tmp = (value(a) isa Symbol) - tmp2 = (arity(a) == 0) || all(is_variable.(children(a))) #this allows q(t) types. Even allows q(q(t)) types. Now have to figure out what the derivative of this is. + tmp2 = (arity(a) == 0) + #commented out this line because it is causing enormous slowdown. Recurses through the graph for every arity test. For large graphs, ≈ 1e4 nodes, it is x slower than not doing it. For larger graphs it takes so long most people wouldn't be willing to wait. + #|| all(is_variable.(children(a))) #this allows q(t) types. Even allows q(q(t)) types. Now have to figure out what the derivative of this is. return tmp && tmp2 #previously had this all in a single line statement but the compiler generated weird code end @@ -247,9 +249,9 @@ function constant_and_term(a::Node) if constant_product(a) return children(a)[1], children(a)[2] elseif is_negate(a) - return -1, children(a)[1] + return Node(-1), children(a)[1] else - return 1, a + return one(Node), a end end @@ -257,7 +259,9 @@ end function constant_sum_simplification(lchild::Node, rchild::Node) lconstant, lterm = constant_and_term(lchild) rconstant, rterm = constant_and_term(rchild) - if lterm === rterm + if !is_constant(lconstant) || !is_constant(rconstant) + return nothing + elseif lterm === rterm return (lconstant, rconstant, lterm) else return nothing @@ -274,7 +278,13 @@ function simplify_check_cache(::typeof(+), na, nb)::Node return b elseif is_identically_zero(b) return a - elseif a === -b || -a === b + + #the next two cases test for a + -a => 0 or -a + a => 0 form. Previously used the single test, commented out below, but this doubled total evaluation time because of creation of nodes in the test. + # elseif a === -b || -a === b + # return zero(Node) + elseif is_constant(a) && is_constant(b) && value(a) == -value(b) + return zero(Node) + elseif is_negate(a) && !is_negate(b) && children(a)[1] === b || is_negate(b) && !is_negate(a) && children(b)[1] === a return zero(Node) elseif is_constant(a) && is_constant(b) return Node(value(a) + value(b)) @@ -331,7 +341,13 @@ end Special case only for unary -. No simplifications are currently applied to any other unary functions""" function simplify_check_cache(::typeof(-), a)::Node - na = Node(a) #this is safe because Node constructor is idempotent + #this used to be na = Node(a) but this resulted in runtime dispatch and considerable time wasted. + if !isa(a, Node) + na = Node(a) #this is safe because Node constructor is idempotent + else + na = a + end + if arity(na) == 1 && typeof(value(na)) == typeof(-) return children(na)[1] elseif constant_value(na) !== nothing