From c9e3b4b4b40e09da4d23a9c29029f2d380560b29 Mon Sep 17 00:00:00 2001 From: brianguenter <1brianguenter@gmail.com> Date: Mon, 5 Aug 2024 10:51:18 -0700 Subject: [PATCH 01/29] Add conditionals Fixes #89 added function bool_methods to create overloads for conditionals and ifelse. All tests pass. Now need to change derivative calculation and code generation. --- src/ExpressionGraph.jl | 4 ++-- src/Methods.jl | 17 +++++++++++++++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/ExpressionGraph.jl b/src/ExpressionGraph.jl index 4d21c42c..986ebdd4 100644 --- a/src/ExpressionGraph.jl +++ b/src/ExpressionGraph.jl @@ -37,8 +37,8 @@ struct Node <: Real Node(a::T) where {T<:Real} = new(a, nothing) #convert numbers to Node Node(a::T) where {T<:Node} = a #if a is already a special node leave it alone - function Node(operation, args::MVector{N,T}) where {T<:Node,N} #use MVector rather than Vector. 40x faster. - return new(operation, args) + function Node(operation, args::MVector) #use MVector rather than Vector. 40x faster. + return new(operation, Node.(args)) end Node(a::S) where {S<:Symbol} = new(a, nothing) diff --git a/src/Methods.jl b/src/Methods.jl index f9c35945..8622744a 100644 --- a/src/Methods.jl +++ b/src/Methods.jl @@ -54,8 +54,6 @@ function number_methods(T, rhs1, rhs2, options=nothing) (f::$(typeof(f)))(a::$T, b::$T) = $rhs2 (f::$(typeof(f)))(a::$T, b::Real) = $rhs2 (f::$(typeof(f)))(a::Real, b::$T) = $rhs2 - # (f::$(typeof(f)))(a::$T, b::Number) = $rhs2 - # (f::$(typeof(f)))(a::Number, b::$T) = $rhs2 end push!(exprs, expr) @@ -73,6 +71,21 @@ function number_methods(T, rhs1, rhs2, options=nothing) Expr(:block, exprs...) end +function bool_methods() + for func in (<, >, ≤, ≥, ≠, ==) + eval(:(Base.$(Symbol(func))(a::Node, b::Node) = Node($func, a, b); + Base.$(Symbol(func))(a::Node, b::Real) = Node(func, a, Node(b)); + Base.$(Symbol(func))(a::Real, b::Node) = Node(func, Node(a), b) + )) + end + + eval(:(Base.ifelse(a::Node, b::Node, c::Real) = Node(ifelse, MVector(a, b, Node(c))); + Base.ifelse(a::Node, b::Real, c::Node) = Node(ifelse, MVector(a, Node(b), c)); + Base.ifelse(a::Node, b::Node, c::Node) = Node(ifelse, MVector(a, b, c)) + )) +end +export bool_methods + macro number_methods(T, rhs1, rhs2, options=nothing) number_methods(T, rhs1, rhs2, options) |> esc end From f3e2f0912ce346c0f257d67e58cceb2ac89a3b56 Mon Sep 17 00:00:00 2001 From: brianguenter <1brianguenter@gmail.com> Date: Tue, 6 Aug 2024 10:10:02 -0700 Subject: [PATCH 02/29] Add conditionals Fixes #89 turned boolean_methods into a macro and called it on package load added tests for booleans and ifelse --- src/ExpressionGraph.jl | 2 ++ src/Methods.jl | 17 +++++++++-------- test/FDTests.jl | 37 +++++++++++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 8 deletions(-) diff --git a/src/ExpressionGraph.jl b/src/ExpressionGraph.jl index 986ebdd4..02ab83bc 100644 --- a/src/ExpressionGraph.jl +++ b/src/ExpressionGraph.jl @@ -657,3 +657,5 @@ export make_variables #create methods that accept Node arguments for all mathematical functions. @number_methods(Node, simplify_check_cache(f, a, EXPRESSION_CACHE), simplify_check_cache(f, a, b, EXPRESSION_CACHE)) #create methods for standard functions that take Node instead of Number arguments. Check cache to see if these arguments have been seen before. + +@boolean_methods(Node) diff --git a/src/Methods.jl b/src/Methods.jl index 8622744a..6ea868ca 100644 --- a/src/Methods.jl +++ b/src/Methods.jl @@ -71,20 +71,21 @@ function number_methods(T, rhs1, rhs2, options=nothing) Expr(:block, exprs...) end -function bool_methods() +#Define boolean methods +"""T is the type you want to define the boolean methods for. In this case Node""" +macro boolean_methods(T) for func in (<, >, ≤, ≥, ≠, ==) - eval(:(Base.$(Symbol(func))(a::Node, b::Node) = Node($func, a, b); - Base.$(Symbol(func))(a::Node, b::Real) = Node(func, a, Node(b)); - Base.$(Symbol(func))(a::Real, b::Node) = Node(func, Node(a), b) + eval(:(Base.$(Symbol(func))(a::$T, b::$T) = $T($func, a, b); + Base.$(Symbol(func))(a::$T, b::Real) = $T(func, a, $T(b)); + Base.$(Symbol(func))(a::Real, b::$T) = Node(func, $T(a), b) )) end - eval(:(Base.ifelse(a::Node, b::Node, c::Real) = Node(ifelse, MVector(a, b, Node(c))); - Base.ifelse(a::Node, b::Real, c::Node) = Node(ifelse, MVector(a, Node(b), c)); - Base.ifelse(a::Node, b::Node, c::Node) = Node(ifelse, MVector(a, b, c)) + eval(:(Base.ifelse(a::$T, b, c) = $T(ifelse, MVector(a, b, c)) )) end -export bool_methods + + macro number_methods(T, rhs1, rhs2, options=nothing) number_methods(T, rhs1, rhs2, options) |> esc diff --git a/test/FDTests.jl b/test/FDTests.jl index 7dbbaeb6..cdbafaa8 100644 --- a/test/FDTests.jl +++ b/test/FDTests.jl @@ -2062,3 +2062,40 @@ end end end +@testitem "conditional tests" begin + @variables x y + + #boolean operators + f = x < y + @test ==(FastDifferentiation.value(f), <) + @test FastDifferentiation.children(f)[1] === x + @test FastDifferentiation.children(f)[2] === y + f = x > y + @test ==(FastDifferentiation.value(f), >) + @test FastDifferentiation.children(f)[1] === x + @test FastDifferentiation.children(f)[2] === y + f = x == y + @test ==(FastDifferentiation.value(f), ==) + @test FastDifferentiation.children(f)[1] === x + @test FastDifferentiation.children(f)[2] === y + f = x ≠ y + @test ==(FastDifferentiation.value(f), ≠) + @test FastDifferentiation.children(f)[1] === x + @test FastDifferentiation.children(f)[2] === y + f = x ≤ y + @test ==(FastDifferentiation.value(f), ≤) + @test FastDifferentiation.children(f)[1] === x + @test FastDifferentiation.children(f)[2] === y + f = x ≥ y + @test ==(FastDifferentiation.value(f), ≥) + @test FastDifferentiation.children(f)[1] === x + @test FastDifferentiation.children(f)[2] === y + + #conditional + expr = x < y + f = ifelse(expr, x, y) + @test ==(FastDifferentiation.value(f), ifelse) + @test ===(FastDifferentiation.children(f)[1], expr) + @test ===(FastDifferentiation.children(f)[2], x) + @test ===(FastDifferentiation.children(f)[3], y) +end From 7bc7fd385b0107e53660a7f14bb5edaea5a72d82 Mon Sep 17 00:00:00 2001 From: brianguenter <1brianguenter@gmail.com> Date: Tue, 6 Aug 2024 11:03:56 -0700 Subject: [PATCH 03/29] adding boolean operatiors and builtin functions like isinf, sign, etc., that return boolean values. --- src/Methods.jl | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/Methods.jl b/src/Methods.jl index 6ea868ca..b01912ac 100644 --- a/src/Methods.jl +++ b/src/Methods.jl @@ -71,21 +71,35 @@ function number_methods(T, rhs1, rhs2, options=nothing) Expr(:block, exprs...) end +const comparison_operators = (<, >, ≤, ≥, ≠, ==) +const boolean_operators = (&, !, |, ⊻) +const boolean_like_operators = (Base.sign, Base.signbit, Base.isreal, Base.isfinite, Base.iszero, Base.isnan, Base.isinf, Base.isinteger) #Define boolean methods """T is the type you want to define the boolean methods for. In this case Node""" macro boolean_methods(T) - for func in (<, >, ≤, ≥, ≠, ==) + for func in comparison_operators eval(:(Base.$(Symbol(func))(a::$T, b::$T) = $T($func, a, b); Base.$(Symbol(func))(a::$T, b::Real) = $T(func, a, $T(b)); Base.$(Symbol(func))(a::Real, b::$T) = Node(func, $T(a), b) )) end + # for boolean_op in boolean_operators + #want to have tests in this method to ensure that the Node values are boolean in nature, i.e., one of + eval(:(Base.ifelse(a::$T, b, c) = $T(ifelse, MVector(a, b, c)) )) end - +#methods may need to add to get good compatibility with the rest of Julia +# Base.sign +# Base.signbit +# Base.isreal +# Base.isfinite +# Base.iszero +# Base.isnan +# Base.isinf +# Base.isinteger macro number_methods(T, rhs1, rhs2, options=nothing) number_methods(T, rhs1, rhs2, options) |> esc From bf466bf4954d6d7d3f918555e58c1d9a1e8e4c09 Mon Sep 17 00:00:00 2001 From: brianguenter <1brianguenter@gmail.com> Date: Thu, 8 Aug 2024 11:07:10 -0700 Subject: [PATCH 04/29] Add conditionals Fixes #89 added tests for comparison operators and ifelse fixed bug in boolean_methods, func variable wasn't interpolated into Expr changed make_function to use === to test for variable inclusion --- src/CodeGeneration.jl | 15 ++++++++++++++- src/Methods.jl | 4 ++-- test/FDTests.jl | 39 +++++++++++++++++++++++++-------------- 3 files changed, 41 insertions(+), 17 deletions(-) diff --git a/src/CodeGeneration.jl b/src/CodeGeneration.jl index 6cd50d7d..da97802b 100644 --- a/src/CodeGeneration.jl +++ b/src/CodeGeneration.jl @@ -316,7 +316,20 @@ function make_function(func_array::AbstractArray{T}, input_variables::AbstractVe vars = variables(func_array) #all unique variables in func_array all_input_vars = vcat(input_variables...) - @assert vars ⊆ all_input_vars "Some of the variables in your function (the func_array argument) were not in the input_variables argument. Every variable that is used in your function must have a corresponding entry in the input_variables argument." + #Because FD defines == operator for Node, which does not return a boolean, many builtin Julia functions will not work as expected. For example: + # vars ⊆ all_input_vars errors because internally issubset tests for equality between the node values using ==, not ===. == returns a Node value but the issubset function expects a Bool. + + temp = Vector{eltype(vars)}(undef, 0) + + input_dict = IdDict(zip(all_input_vars, all_input_vars)) + for one_var in vars + value = get(input_dict, one_var, nothing) + if value === nothing + push!(temp, one_var) + end + end + + @assert length(temp) == 0 "The variables $temp were not in the input_variables argument to make_function. Every variable that is used in your function must have a corresponding entry in the input_variables argument." @RuntimeGeneratedFunction(make_Expr(func_array, all_input_vars, in_place, init_with_zeros)) end diff --git a/src/Methods.jl b/src/Methods.jl index b01912ac..babadf54 100644 --- a/src/Methods.jl +++ b/src/Methods.jl @@ -79,8 +79,8 @@ const boolean_like_operators = (Base.sign, Base.signbit, Base.isreal, Base.isfin macro boolean_methods(T) for func in comparison_operators eval(:(Base.$(Symbol(func))(a::$T, b::$T) = $T($func, a, b); - Base.$(Symbol(func))(a::$T, b::Real) = $T(func, a, $T(b)); - Base.$(Symbol(func))(a::Real, b::$T) = Node(func, $T(a), b) + Base.$(Symbol(func))(a::$T, b::Real) = $T($func, a, $T(b)); + Base.$(Symbol(func))(a::Real, b::$T) = Node($func, $T(a), b) )) end diff --git a/test/FDTests.jl b/test/FDTests.jl index cdbafaa8..82e73e1b 100644 --- a/test/FDTests.jl +++ b/test/FDTests.jl @@ -77,7 +77,7 @@ end function edge_fields_equal(edge1, edge2) return edge1.top_vertex == edge2.top_vertex && edge1.bott_vertex == edge2.bott_vertex && - edge1.edge_value == edge2.edge_value && + edge1.edge_value === edge2.edge_value && #must compare Node types with === because have now defined conditionals for Node edge1.reachable_variables == edge2.reachable_variables && edge1.reachable_roots == edge2.reachable_roots end @@ -207,8 +207,8 @@ end a = x * y - @test derivative(a, Val(1)) == y - @test derivative(a, Val(2)) == x + @test derivative(a, Val(1)) === y + @test derivative(a, Val(2)) === x end @testitem "FD.compute_factorable_subgraphs test order" begin @@ -979,7 +979,7 @@ end #first verify all nodes have the postorder numbers we expect for (i, nd) in pairs(gnodes) - @test FD.node(graph, i) == nd + @test FD.node(graph, i) === nd end sub_heap = FD.compute_factorable_subgraphs(graph) @@ -1151,10 +1151,10 @@ end result = FD._symbolic_jacobian!(graph, [nx1, ny2]) #symbolic equality will work here because of common subexpression caching. - @test result[1, 1] == cos(nx1 * ny2) * ny2 - @test result[1, 2] == cos(nx1 * ny2) * nx1 - @test result[2, 1] == -sin(nx1 * ny2) * ny2 - @test result[2, 2] == (-sin(nx1 * ny2)) * nx1 + @test result[1, 1] === cos(nx1 * ny2) * ny2 + @test result[1, 2] === cos(nx1 * ny2) * nx1 + @test result[2, 1] === -sin(nx1 * ny2) * ny2 + @test result[2, 2] === (-sin(nx1 * ny2)) * nx1 end @@ -1290,7 +1290,7 @@ end copy_jac = FD._symbolic_jacobian(graph, [x, y]) jac = FD._symbolic_jacobian!(graph, [x, y]) - @test all(copy_jac .== jac) #make sure the jacobian computed by copying the graph has the same FD.variables as the one computed by destructively modifying the graph + @test all(copy_jac .=== jac) #make sure the jacobian computed by copying the graph has the same FD.variables as the one computed by destructively modifying the graph computed_jacobian = FD.make_function(jac, [x, y]) @@ -1434,7 +1434,7 @@ end ] @test isapprox(zeros(2, 2), FD.value.(derivative(A, nq2))) #taking derivative wrt variable not present in the graph returns all zero matrix - @test DA == derivative(A, nq1) + @test all(DA .=== derivative(A, nq1)) end @testitem "jacobian_times_v" begin @@ -1607,7 +1607,7 @@ end #test to make sure sparse_hessian/evaluate_path bug is not reintroduced (ref commit 4b4aeeb1990a15443ca87c15638dcaf7bd9d34d1) a = hessian(x * y, [x, y]) b = sparse_hessian(x * y, [x, y]) - @test all(a .== b) + @test all(a .=== b) end @testitem "hessian_times_v" begin @@ -1783,9 +1783,9 @@ end f = x + x @test f === 2 * x f = (-1 * x) + x - @test f == 0 + @test FastDifferentiation.value(f) == 0 f = x + (-1 * x) - @test f == 0 + @test FastDifferentiation.value(f) == 0 f = 2x + 3x @test f === 5 * x f2 = -x + x @@ -2062,7 +2062,7 @@ end end end -@testitem "conditional tests" begin +@testitem "conditionals" begin @variables x y #boolean operators @@ -2099,3 +2099,14 @@ end @test ===(FastDifferentiation.children(f)[2], x) @test ===(FastDifferentiation.children(f)[3], y) end + +@testitem "conditional code generation" begin + @variables x y + f = ifelse(x < y, cos(x), sin(x)) + + input = [π / 2.0, 20.0] + exe = make_function([f], [x, y]) + @test isapprox(cos(input[1]), exe(input)[1]) + input = [π / 2.0, 1.0] + @test isapprox(sin(input[1]), exe(input)[1]) +end From d215350bc2c41c4926286891cf83f2c8c320706b Mon Sep 17 00:00:00 2001 From: brianguenter <1brianguenter@gmail.com> Date: Thu, 8 Aug 2024 11:12:17 -0700 Subject: [PATCH 05/29] renamed boolean_methods to comparison_methods --- src/ExpressionGraph.jl | 2 +- src/Methods.jl | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/ExpressionGraph.jl b/src/ExpressionGraph.jl index 02ab83bc..0529f4e3 100644 --- a/src/ExpressionGraph.jl +++ b/src/ExpressionGraph.jl @@ -658,4 +658,4 @@ export make_variables #create methods that accept Node arguments for all mathematical functions. @number_methods(Node, simplify_check_cache(f, a, EXPRESSION_CACHE), simplify_check_cache(f, a, b, EXPRESSION_CACHE)) #create methods for standard functions that take Node instead of Number arguments. Check cache to see if these arguments have been seen before. -@boolean_methods(Node) +@comparison_methods(Node) diff --git a/src/Methods.jl b/src/Methods.jl index babadf54..bc09f064 100644 --- a/src/Methods.jl +++ b/src/Methods.jl @@ -76,7 +76,7 @@ const boolean_operators = (&, !, |, ⊻) const boolean_like_operators = (Base.sign, Base.signbit, Base.isreal, Base.isfinite, Base.iszero, Base.isnan, Base.isinf, Base.isinteger) #Define boolean methods """T is the type you want to define the boolean methods for. In this case Node""" -macro boolean_methods(T) +macro comparison_methods(T) for func in comparison_operators eval(:(Base.$(Symbol(func))(a::$T, b::$T) = $T($func, a, b); Base.$(Symbol(func))(a::$T, b::Real) = $T($func, a, $T(b)); @@ -91,6 +91,10 @@ macro boolean_methods(T) )) end +macro boolean_methods(T) + # +end + #methods may need to add to get good compatibility with the rest of Julia # Base.sign # Base.signbit From 1ba92c5e76bee7a980f7a6659b6978877b545831 Mon Sep 17 00:00:00 2001 From: brianguenter <1brianguenter@gmail.com> Date: Thu, 8 Aug 2024 16:30:47 -0700 Subject: [PATCH 06/29] Add conditionals Fixes #89 added ternary Node constructor to handle ifelse replace iszero and isinf methods with ones that return Node objects instead of using node_value. removed special case isfinite from number_methods. Doesn't cause tests to fail but might cause problems. Didn't document why it was there in the first place. added special eval case for ifelse to number_methods macro two sparse tests fail for obscure reasons. --- src/ExpressionGraph.jl | 10 ++++---- src/Methods.jl | 53 +++++++++++------------------------------- 2 files changed, 17 insertions(+), 46 deletions(-) diff --git a/src/ExpressionGraph.jl b/src/ExpressionGraph.jl index 0529f4e3..6e6a7464 100644 --- a/src/ExpressionGraph.jl +++ b/src/ExpressionGraph.jl @@ -33,6 +33,7 @@ struct Node <: Real Node(f::S, a) where {S} = new(f, MVector{1,Node}(Node(a))) Node(f::S, a, b) where {S} = new(f, MVector{2,Node}(Node(a), Node(b))) #if a,b not a Node convert them. + Node(f::S, a, b, c) where {S} = new(f, MVector{3,Node}(Node(a), Node(b), Node(c))) #if a,b not a Node convert them. Node(a::T) where {T<:Real} = new(a, nothing) #convert numbers to Node Node(a::T) where {T<:Node} = a #if a is already a special node leave it alone @@ -102,9 +103,7 @@ Base.typemin(::Type{Node}) = Node(-Inf) Base.typemax(::Type{Node}) = Node(Inf) Base.float(x::Node) = x -# This one is needed because julia/base/float.jl only defines `isinf` for `Real`, but `Node -# <: Number`. (See https://github.com/brianguenter/FastDifferentiation.jl/issues/73) -Base.isinf(x::Node) = !isnan(x) & !isfinite(x) + Broadcast.broadcastable(a::Node) = (a,) @@ -144,7 +143,7 @@ Base.isless(::Node, ::Number) = error_message() Base.isless(::Number, ::Node) = error_message() Base.isless(::Node, ::Node) = error_message() -Base.iszero(a::Node) = value(a) == 0 #need this because sparse matrix and other code in linear algebra may call it. If it is not defined get a type promotion error. + function is_zero(a::Node) #this: value(a) == 0 would work but when add conditionals to the language if a is not a constant this will generate an expression graph instead of returning a bool value. @@ -173,6 +172,7 @@ end #Simple algebraic simplification rules for *,+,-,/. These are mostly safe, i.e., they will return exactly the same results as IEEE arithmetic. However multiplication by 0 always simplifies to 0, which is not true for IEEE arithmetic: 0*NaN=NaN, 0*Inf = NaN, for example. This should be a good tradeoff, since zeros are common in derivative expressions and can result in considerable expression simplification. Maybe later make this opt-out. simplify_check_cache(a, b, c, cache) = check_cache((a, b, c), cache) +simplify_check_cache(a, b, c, d, cache) = check_cache((a, b, c, d), cache) #this version handles ifelse is_nary(a::Node) = arity(a) > 2 is_times(a::Node) = value(a) == * @@ -657,5 +657,3 @@ export make_variables #create methods that accept Node arguments for all mathematical functions. @number_methods(Node, simplify_check_cache(f, a, EXPRESSION_CACHE), simplify_check_cache(f, a, b, EXPRESSION_CACHE)) #create methods for standard functions that take Node instead of Number arguments. Check cache to see if these arguments have been seen before. - -@comparison_methods(Node) diff --git a/src/Methods.jl b/src/Methods.jl index bc09f064..081da1c7 100644 --- a/src/Methods.jl +++ b/src/Methods.jl @@ -22,7 +22,7 @@ const monadic = [deg2rad, rad2deg, asind, log1p, acsch, atand, sec, acscd, cot, exp2, expm1, atanh, gamma, loggamma, erf, erfc, erfcinv, erfi, erfcx, dawson, digamma, trigamma, invdigamma, polygamma, airyai, airyaiprime, airybi, - airybiprime, besselj0, besselj1, bessely0, bessely1, isfinite] + airybiprime, besselj0, besselj1, bessely0, bessely1, signbit, isreal, isfinite, iszero, isnan, isinf, isinteger, !] const diadic = [max, min, hypot, atan, mod, rem, copysign, besselj, bessely, besseli, besselk, hankelh1, hankelh2, @@ -30,7 +30,8 @@ const diadic = [max, min, hypot, atan, mod, rem, copysign, const previously_declared_for = Set([]) const basic_monadic = [-, +] -const basic_diadic = [+, -, *, /, //, \, ^] +const basic_diadic = [+, -, *, /, //, \, ^, &, |, ⊻, <, >, ≤, ≥, ≠, ==] + # TODO: keep domains tighter than this function number_methods(T, rhs1, rhs2, options=nothing) @@ -61,51 +62,23 @@ function number_methods(T, rhs1, rhs2, options=nothing) for f in (skip_basics ? monadic : only_basics ? basic_monadic : vcat(basic_monadic, monadic)) nameof(f) in skips && continue - if f === isfinite - push!(exprs, :((f::$(typeof(f)))(a::$T) = true)) - else - push!(exprs, :((f::$(typeof(f)))(a::$T) = $rhs1)) - end + push!(exprs, :((f::$(typeof(f)))(a::$T) = $rhs1)) end + println(exprs) push!(exprs, :(push!($previously_declared_for, $T))) Expr(:block, exprs...) end -const comparison_operators = (<, >, ≤, ≥, ≠, ==) -const boolean_operators = (&, !, |, ⊻) -const boolean_like_operators = (Base.sign, Base.signbit, Base.isreal, Base.isfinite, Base.iszero, Base.isnan, Base.isinf, Base.isinteger) -#Define boolean methods -"""T is the type you want to define the boolean methods for. In this case Node""" -macro comparison_methods(T) - for func in comparison_operators - eval(:(Base.$(Symbol(func))(a::$T, b::$T) = $T($func, a, b); - Base.$(Symbol(func))(a::$T, b::Real) = $T($func, a, $T(b)); - Base.$(Symbol(func))(a::Real, b::$T) = Node($func, $T(a), b) - )) - end - - # for boolean_op in boolean_operators - #want to have tests in this method to ensure that the Node values are boolean in nature, i.e., one of - - eval(:(Base.ifelse(a::$T, b, c) = $T(ifelse, MVector(a, b, c)) - )) -end - -macro boolean_methods(T) - # -end - -#methods may need to add to get good compatibility with the rest of Julia -# Base.sign -# Base.signbit -# Base.isreal -# Base.isfinite -# Base.iszero -# Base.isnan -# Base.isinf -# Base.isinteger +# """if the node value is 0 then can evaluate this at compile time. Otherwise have to return an expression which will be evaluated when executing function created by make_function""" +# Base.iszero(a::Node) = node_value(a) == 0 ? true : simplify_check_cache() +# const special_cases = (signbit, isreal, isfinite, iszero, isnan, isinf, isinteger) #iszero must be defined or linear algebra routines, for example in sparse matrix will give type promotion error +# This one is needed because julia/base/float.jl only defines `isinf` for `Real`, but `Node +# <: Number`. (See https://github.com/brianguenter/FastDifferentiation.jl/issues/73) +# Base.isinf(x::Node) = !isnan(x) & !isfinite(x) macro number_methods(T, rhs1, rhs2, options=nothing) + eval(:(Base.ifelse(a::$T, b, c) = simplify_check_cache(Base.ifelse, a, b, c, EXPRESSION_CACHE))) + number_methods(T, rhs1, rhs2, options) |> esc end From 41f40e1eacfdac7ea37548a67c9aaf1792ae739f Mon Sep 17 00:00:00 2001 From: brianguenter <1brianguenter@gmail.com> Date: Fri, 9 Aug 2024 09:56:41 -0700 Subject: [PATCH 07/29] Add conditionals Fixes #89 SparseArrays.jl needs iszero(a::Node) to return a boolean value or it's matrix contructors don't work. Can't have an iszero that returns an expression so iszero is special cased to return a boolean. --- src/ExpressionGraph.jl | 4 ++++ src/Methods.jl | 8 +++----- test/FDTests.jl | 4 +++- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/ExpressionGraph.jl b/src/ExpressionGraph.jl index 6e6a7464..be7bcafe 100644 --- a/src/ExpressionGraph.jl +++ b/src/ExpressionGraph.jl @@ -96,6 +96,10 @@ Base.zero(::Node) = Node(0) Base.one(::Type{Node}) = Node(1) Base.one(::Node) = Node(1) +#special case for iszero. Some other libraries need iszero(a::Node) to return a boolean value. Other code may be okay with iszero return the expression Node(iszero,a), which would then be evaluated at run time in generated function. Unfortunately, SparseArrays requires the former so the sparse Jacobian/Hessian code won't work unless iszero returns Bool. Maybe there is a way to fix this but I suspect not easily. + +Base.iszero(a::Node) = is_constant(a) && iszero(value(a)) ? true : false + # These are essentially copied from Symbolics.jl: # https://github.com/JuliaSymbolics/Symbolics.jl/blob/e4c328103ece494eaaab2a265524a64bfbe43dbd/src/num.jl#L31-L34 Base.eps(::Type{Node}) = Node(0) diff --git a/src/Methods.jl b/src/Methods.jl index 081da1c7..f0d76e5a 100644 --- a/src/Methods.jl +++ b/src/Methods.jl @@ -22,7 +22,8 @@ const monadic = [deg2rad, rad2deg, asind, log1p, acsch, atand, sec, acscd, cot, exp2, expm1, atanh, gamma, loggamma, erf, erfc, erfcinv, erfi, erfcx, dawson, digamma, trigamma, invdigamma, polygamma, airyai, airyaiprime, airybi, - airybiprime, besselj0, besselj1, bessely0, bessely1, signbit, isreal, isfinite, iszero, isnan, isinf, isinteger, !] + airybiprime, besselj0, besselj1, bessely0, bessely1, signbit, isreal, isfinite, isnan, isinf, isinteger, !] +#ideally would have iszero in this list but this interferes with SparseArrays, which calls iszero to allocate space. Some other functions may use iszero in a way more compatible with symbolic numbers, meaning they will not crash when iszero returns an expression rather than a boolean. No easy way around this. const diadic = [max, min, hypot, atan, mod, rem, copysign, besselj, bessely, besseli, besselk, hankelh1, hankelh2, @@ -64,7 +65,7 @@ function number_methods(T, rhs1, rhs2, options=nothing) nameof(f) in skips && continue push!(exprs, :((f::$(typeof(f)))(a::$T) = $rhs1)) end - println(exprs) + push!(exprs, :(push!($previously_declared_for, $T))) Expr(:block, exprs...) end @@ -82,6 +83,3 @@ macro number_methods(T, rhs1, rhs2, options=nothing) number_methods(T, rhs1, rhs2, options) |> esc end - - - diff --git a/test/FDTests.jl b/test/FDTests.jl index 82e73e1b..17954c0d 100644 --- a/test/FDTests.jl +++ b/test/FDTests.jl @@ -1607,7 +1607,9 @@ end #test to make sure sparse_hessian/evaluate_path bug is not reintroduced (ref commit 4b4aeeb1990a15443ca87c15638dcaf7bd9d34d1) a = hessian(x * y, [x, y]) b = sparse_hessian(x * y, [x, y]) - @test all(a .=== b) + for index in eachindex(a) + @test FD.value(a[index]) == FD.value(b[index]) + end end @testitem "hessian_times_v" begin From 31fb3ce76dad014db4a204b1ffceec6d4fe58e44 Mon Sep 17 00:00:00 2001 From: brianguenter <1brianguenter@gmail.com> Date: Fri, 9 Aug 2024 10:27:45 -0700 Subject: [PATCH 08/29] split out non-differentiable new functions (signbit, isnan, >, etc.) into separate constants. This will make it easier to filter these nodes out of derivative calculations. --- src/Methods.jl | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/Methods.jl b/src/Methods.jl index f0d76e5a..26742e43 100644 --- a/src/Methods.jl +++ b/src/Methods.jl @@ -22,17 +22,18 @@ const monadic = [deg2rad, rad2deg, asind, log1p, acsch, atand, sec, acscd, cot, exp2, expm1, atanh, gamma, loggamma, erf, erfc, erfcinv, erfi, erfcx, dawson, digamma, trigamma, invdigamma, polygamma, airyai, airyaiprime, airybi, - airybiprime, besselj0, besselj1, bessely0, bessely1, signbit, isreal, isfinite, isnan, isinf, isinteger, !] + airybiprime, besselj0, besselj1, bessely0, bessely1] #ideally would have iszero in this list but this interferes with SparseArrays, which calls iszero to allocate space. Some other functions may use iszero in a way more compatible with symbolic numbers, meaning they will not crash when iszero returns an expression rather than a boolean. No easy way around this. -const diadic = [max, min, hypot, atan, mod, rem, copysign, +const diadic = [hypot, atan, mod, rem, besselj, bessely, besseli, besselk, hankelh1, hankelh2, polygamma, beta, logbeta] const previously_declared_for = Set([]) const basic_monadic = [-, +] -const basic_diadic = [+, -, *, /, //, \, ^, &, |, ⊻, <, >, ≤, ≥, ≠, ==] - +const basic_diadic = [+, -, *, /, //, \, ^] +const diadic_non_differentiable = [max, min, copysign, &, |, ⊻, <, >, ≤, ≥, ≠, ==] +const monadic_non_differentiable = [signbit, isreal, isfinite, isnan, isinf, isinteger, !] # TODO: keep domains tighter than this function number_methods(T, rhs1, rhs2, options=nothing) @@ -42,7 +43,7 @@ function number_methods(T, rhs1, rhs2, options=nothing) only_basics = options !== nothing ? options == :onlybasics : false skips = Meta.isexpr(options, [:vcat, :hcat, :vect]) ? Set(options.args) : [] - for f in (skip_basics ? diadic : only_basics ? basic_diadic : vcat(basic_diadic, diadic)) + for f in (skip_basics ? diadic : only_basics ? basic_diadic : vcat(basic_diadic, diadic, diadic_non_differentiable)) nameof(f) in skips && continue for S in previously_declared_for push!(exprs, quote @@ -61,7 +62,7 @@ function number_methods(T, rhs1, rhs2, options=nothing) push!(exprs, expr) end - for f in (skip_basics ? monadic : only_basics ? basic_monadic : vcat(basic_monadic, monadic)) + for f in (skip_basics ? monadic : only_basics ? basic_monadic : vcat(basic_monadic, monadic, monadic_non_differentiable)) nameof(f) in skips && continue push!(exprs, :((f::$(typeof(f)))(a::$T) = $rhs1)) end From e6f6c07b1173e62dafbf9a6eb712dac1abf089fb Mon Sep 17 00:00:00 2001 From: brianguenter <1brianguenter@gmail.com> Date: Sun, 11 Aug 2024 11:25:08 -0700 Subject: [PATCH 09/29] added comments to describe what the different src files contain --- src/FastDifferentiation.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/FastDifferentiation.jl b/src/FastDifferentiation.jl index 1d23b2fe..52c4dffd 100644 --- a/src/FastDifferentiation.jl +++ b/src/FastDifferentiation.jl @@ -36,18 +36,18 @@ RuntimeGeneratedFunctions.init(@__MODULE__) const DefaultNodeIndexType = Int64 -include("Methods.jl") +include("Methods.jl") #functions and macros to generate Node specialized methods for all the common arithmetic, trigonometric, etc., operations. include("Utilities.jl") include("BitVectorFunctions.jl") -include("ExpressionGraph.jl") -include("PathEdge.jl") -include("DerivativeGraph.jl") -include("Reverse.jl") +include("ExpressionGraph.jl") #definition of Node type from which FD expression graphs are created +include("PathEdge.jl") #functions to create and manipulate edges in derivative graphs +include("DerivativeGraph.jl") #functions to compute derivative graph from an expression graph of Node +include("Reverse.jl") #symbolic implementation of conventional reverse automatic differentiation include("GraphProcessing.jl") include("FactorableSubgraph.jl") include("Factoring.jl") -include("Jacobian.jl") -include("CodeGeneration.jl") +include("Jacobian.jl") #functions to compute jacobians, gradients, hessians, etc. +include("CodeGeneration.jl") #functions to convert expression graphs of Node to executable functions # FastDifferentiationVisualizationExt overloads them function make_dot_file end From 1299b47993c4a9f35e70d50ffca0676e645c1ca5 Mon Sep 17 00:00:00 2001 From: brianguenter <1brianguenter@gmail.com> Date: Fri, 16 Aug 2024 16:43:21 -0700 Subject: [PATCH 10/29] Add conditionals Fixes #89 added a new function, is_identically_zero, that replaces old iszero. iszero now returns a node expression and is_identically_zero returns a boolean. --- src/CodeGeneration.jl | 16 ++++++++-------- src/ExpressionGraph.jl | 21 +++++++++++---------- src/Jacobian.jl | 3 +-- src/Methods.jl | 7 ++++--- test/FDTests.jl | 26 +++++++++++++++++++++++++- 5 files changed, 49 insertions(+), 24 deletions(-) diff --git a/src/CodeGeneration.jl b/src/CodeGeneration.jl index da97802b..ce8d14a4 100644 --- a/src/CodeGeneration.jl +++ b/src/CodeGeneration.jl @@ -8,7 +8,7 @@ then `sparsity = (nelts-nzeros)/nelts`. Frequently used in combination with a call to `make_function` to determine whether to set keyword argument `init_with_zeros` to false.""" function sparsity(sym_func::AbstractArray{<:Node}) - zeros = mapreduce(x -> is_zero(x) ? 1 : 0, +, sym_func) + zeros = mapreduce(x -> is_identically_zero(x) ? 1 : 0, +, sym_func) tot = prod(size(sym_func)) return zeros == 0 ? 1.0 : (tot - zeros) / tot end @@ -124,8 +124,8 @@ function make_Expr(func_array::AbstractArray{T}, input_variables::AbstractVector node_to_var = IdDict{Node,Union{Symbol,Real,Expr}}() body = Expr(:block) - num_zeros = count(is_zero, (func_array)) - num_const = count((x) -> is_constant(x) && !is_zero(x), func_array) + num_zeros = count(is_identically_zero, (func_array)) + num_const = count((x) -> is_constant(x) && !is_identically_zero(x), func_array) zero_threshold = 0.5 @@ -166,7 +166,7 @@ function make_Expr(func_array::AbstractArray{T}, input_variables::AbstractVector for (i, node) in pairs(func_array) # skip all terms that we have computed above during construction if is_constant(node) && initialization_strategy === :const || # already initialized as constant above - is_zero(node) && (initialization_strategy === :zero || !init_with_zeros) # was already initialized as zero above or we don't want to initialize with zeros + is_identically_zero(node) && (initialization_strategy === :zero || !init_with_zeros) # was already initialized as zero above or we don't want to initialize with zeros continue end node_body, variable = function_body!(node, node_to_index, node_to_var) @@ -185,11 +185,11 @@ function make_Expr(func_array::AbstractArray{T}, input_variables::AbstractVector # wrap in function body if in_place - return :((result, input_variables) -> @inbounds begin + return :((result, input_variables::AbstractArray) -> @inbounds begin $body end) else - return :((input_variables) -> @inbounds begin + return :((input_variables::AbstractArray) -> @inbounds begin $body end) end @@ -248,9 +248,9 @@ function make_Expr(A::SparseMatrixCSC{T,Ti}, input_variables::AbstractVector{S}, push!(body.args, :(return result)) if in_place - return :((result, input_variables) -> $body) + return :((result, input_variables::AbstractArray) -> $body) else - return :((input_variables) -> $body) + return :((input_variables::AbstractArray) -> $body) end end end diff --git a/src/ExpressionGraph.jl b/src/ExpressionGraph.jl index be7bcafe..a8fc0a6f 100644 --- a/src/ExpressionGraph.jl +++ b/src/ExpressionGraph.jl @@ -98,7 +98,8 @@ Base.one(::Node) = Node(1) #special case for iszero. Some other libraries need iszero(a::Node) to return a boolean value. Other code may be okay with iszero return the expression Node(iszero,a), which would then be evaluated at run time in generated function. Unfortunately, SparseArrays requires the former so the sparse Jacobian/Hessian code won't work unless iszero returns Bool. Maybe there is a way to fix this but I suspect not easily. -Base.iszero(a::Node) = is_constant(a) && iszero(value(a)) ? true : false + + # These are essentially copied from Symbolics.jl: # https://github.com/JuliaSymbolics/Symbolics.jl/blob/e4c328103ece494eaaab2a265524a64bfbe43dbd/src/num.jl#L31-L34 @@ -148,8 +149,8 @@ Base.isless(::Number, ::Node) = error_message() Base.isless(::Node, ::Node) = error_message() - -function is_zero(a::Node) +"""True if the value of a variable is identically zero and false otherwise.""" +function is_identically_zero(a::Node) #this: value(a) == 0 would work but when add conditionals to the language if a is not a constant this will generate an expression graph instead of returning a bool value. if is_tree(a) || is_variable(a) return false @@ -205,11 +206,11 @@ function simplify_check_cache(::typeof(*), na, nb, cache)::Node #TODO sort variables so if y < x then x*y => y*x. The will automatically get commutativity. #c1*c2 = c3, (c1*x)*(c2*x) = c3*x - if is_zero(a) && is_zero(b) + if is_identically_zero(a) && is_identically_zero(b) return Node(value(a) + value(b)) #user may have mixed types for numbers so use automatic promotion to widen the type. - elseif is_zero(a) #b is not zero + elseif is_identically_zero(a) #b is not zero return a #use this node rather than creating a zero since a has the type encoded in it - elseif is_zero(b) #a is not zero + elseif is_identically_zero(b) #a is not zero return b #use this node rather than creating a zero since b has the type encoded in it elseif is_one(a) return b #At this point in processing the type of b may be impossible to determine, for example if b = sin(x) and the value of x won't be known till the expression is evaluated. No easy way to promote the type of b here if a has a wider type than b will eventually be determined to have. Example: a = BigFloat(1.0), b = sin(x). If the value of x is Float32 when the function is evaluated then would expect the type of the result to be BigFloat. But it will be Float32. Need to figure out a type of Node that will eventually generate code something like this: b = promote_type(a,b)(b) where the types of a,b will be known because this will be called in the generated Julia function for the derivative. @@ -261,9 +262,9 @@ function simplify_check_cache(::typeof(+), na, nb, cache)::Node #TODO sort variables so if y < x then x*y => y*x. The will automatically get commutativity. - if is_zero(a) + if is_identically_zero(a) return b - elseif is_zero(b) + elseif is_identically_zero(b) return a elseif a === -b || -a === b return zero(Node) @@ -285,9 +286,9 @@ function simplify_check_cache(::typeof(-), na, nb, cache)::Node b = Node(nb) if a === b return zero(Node) - elseif is_zero(b) + elseif is_identically_zero(b) return a - elseif is_zero(a) + elseif is_identically_zero(a) return -b elseif is_negate(b) return a + children(b)[1] diff --git a/src/Jacobian.jl b/src/Jacobian.jl index 03351049..e1580ca8 100644 --- a/src/Jacobian.jl +++ b/src/Jacobian.jl @@ -90,7 +90,6 @@ julia> jacobian([x*y,y*x],[x]) jacobian(terms::AbstractVector{T}, partial_variables::AbstractVector{S}) where {T<:Node,S<:Node} = _symbolic_jacobian(DerivativeGraph(terms), partial_variables) export jacobian - """ _sparse_symbolic_jacobian!( graph::DerivativeGraph, @@ -114,7 +113,7 @@ function _sparse_symbolic_jacobian!(graph::DerivativeGraph, partial_variables::A partial_index = variable_node_to_index(graph, partial_var) #make sure variable is in the domain of the graph if partial_index !== nothing tmp = evaluate_path(graph, root, partial_index) - if !is_zero(tmp) + if !is_identically_zero(tmp) push!(row_indices, root) push!(col_indices, i) push!(values, tmp) diff --git a/src/Methods.jl b/src/Methods.jl index 26742e43..6d33ed79 100644 --- a/src/Methods.jl +++ b/src/Methods.jl @@ -33,7 +33,7 @@ const previously_declared_for = Set([]) const basic_monadic = [-, +] const basic_diadic = [+, -, *, /, //, \, ^] const diadic_non_differentiable = [max, min, copysign, &, |, ⊻, <, >, ≤, ≥, ≠, ==] -const monadic_non_differentiable = [signbit, isreal, isfinite, isnan, isinf, isinteger, !] +const monadic_non_differentiable = [signbit, isreal, iszero, isfinite, isnan, isinf, isinteger, !] # TODO: keep domains tighter than this function number_methods(T, rhs1, rhs2, options=nothing) @@ -68,9 +68,10 @@ function number_methods(T, rhs1, rhs2, options=nothing) end push!(exprs, :(push!($previously_declared_for, $T))) - Expr(:block, exprs...) + return Expr(:block, exprs...) end + # """if the node value is 0 then can evaluate this at compile time. Otherwise have to return an expression which will be evaluated when executing function created by make_function""" # Base.iszero(a::Node) = node_value(a) == 0 ? true : simplify_check_cache() # const special_cases = (signbit, isreal, isfinite, iszero, isnan, isinf, isinteger) #iszero must be defined or linear algebra routines, for example in sparse matrix will give type promotion error @@ -79,7 +80,7 @@ end # Base.isinf(x::Node) = !isnan(x) & !isfinite(x) macro number_methods(T, rhs1, rhs2, options=nothing) - eval(:(Base.ifelse(a::$T, b, c) = simplify_check_cache(Base.ifelse, a, b, c, EXPRESSION_CACHE))) + eval(:(Base.ifelse(a::$T, b::$T, c::$T) = simplify_check_cache(Base.ifelse, a, b, c, EXPRESSION_CACHE))) number_methods(T, rhs1, rhs2, options) |> esc end diff --git a/test/FDTests.jl b/test/FDTests.jl index 17954c0d..9f969096 100644 --- a/test/FDTests.jl +++ b/test/FDTests.jl @@ -1643,10 +1643,34 @@ end using SparseArrays import FastDifferentiation as FD + + function FD_sparse(a::AbstractArray{T}) where {T<:FD.Node} + + values = FD.Node[] + + inds = findall(!FD.is_identically_zero, a) + row_indices = getindex.(inds, 1) + col_indices = getindex.(inds, 2) + values = getindex.(Ref(a), inds) + + return sparse(row_indices, col_indices, values, size(a)[1], size(a)[2]) + end + + m1 = FastDifferentiation.Node.([1 0 0; 0 2 0; 3 0 0]) + m1_non_zeros = ([1, 2, 3], [1, 2, 1], FastDifferentiation.Node.([1, 2, 3])) + + values = m1_non_zeros[3] + indices = collect(zip(m1_non_zeros[1], m1_non_zeros[2])) + sp = FD_sparse(m1) + for (i, index) in pairs(indices) + @test values[i] === getindex(m1, index...) + end + + FD.@variables a11 a12 a13 a21 a22 a23 a31 a32 a33 vars = vec([a11 a12 a13 a21 a22 a23 a31 a32 a33]) - spmat = sparse([a11 a12 a13; a21 a22 a23; a31 a32 a33]) + spmat = FD_sparse([a11 a12 a13; a21 a22 a23; a31 a32 a33]) f1 = FD.make_function(spmat, vars) inputs = [1 2 3 4 5 6 7 8 9] correct = [1 2 3; 4 5 6; 7 8 9] From dfaa1021f41359230ba8a0a765db829e5b2717de Mon Sep 17 00:00:00 2001 From: brianguenter <1brianguenter@gmail.com> Date: Wed, 21 Aug 2024 17:03:59 -0700 Subject: [PATCH 11/29] Add conditionals Fixes #89 added Contitionals.jl and iterator for conditional bit values. --- src/Conditionals.jl | 30 ++++++++++++++++++++++++++++++ src/FastDifferentiation.jl | 1 + 2 files changed, 31 insertions(+) create mode 100644 src/Conditionals.jl diff --git a/src/Conditionals.jl b/src/Conditionals.jl new file mode 100644 index 00000000..9dd1f669 --- /dev/null +++ b/src/Conditionals.jl @@ -0,0 +1,30 @@ +function all_combinations(n::Integer) + num = 2^n + + result = Vector{Vector{Bool}}(undef, num) + + for i in 0:num-1 + result[i+1] = Bool.(digits(i, base=2, pad=n)) + end + return result +end +export all_combinations + +struct Combinations{T<:Integer} + n::T +end + +iterate(a::Combinations) = (a, 0) + +function iterate(a::Combinations{T}, state::T) where {T<:Integer} + if state == 2^a.n + return nothing + else + return (BitVector(Bool.(digits(state, base=2, pad=a.n))), state + 1) + end +end + +# Base.getindex(a::Combinations{T}, b::T) where {T<:Integer} = b +# length(a::Combinations) = 2^a.n +# eltype(a::Combinations) = Combinations +# size(a::Combinations, dims...) = length(a) \ No newline at end of file diff --git a/src/FastDifferentiation.jl b/src/FastDifferentiation.jl index 52c4dffd..c9399b5d 100644 --- a/src/FastDifferentiation.jl +++ b/src/FastDifferentiation.jl @@ -41,6 +41,7 @@ include("Utilities.jl") include("BitVectorFunctions.jl") include("ExpressionGraph.jl") #definition of Node type from which FD expression graphs are created include("PathEdge.jl") #functions to create and manipulate edges in derivative graphs +include("Conditionals.jl") include("DerivativeGraph.jl") #functions to compute derivative graph from an expression graph of Node include("Reverse.jl") #symbolic implementation of conventional reverse automatic differentiation include("GraphProcessing.jl") From ea573afb18476650b3233c9093f3cf3ceae77e60 Mon Sep 17 00:00:00 2001 From: brianguenter <1brianguenter@gmail.com> Date: Mon, 26 Aug 2024 14:04:11 -0700 Subject: [PATCH 12/29] Add conditionals Fixes #89 added error message so that users attempting to compute derivatives through conditionals will know this isn't yet supported. --- src/ExpressionGraph.jl | 18 +++++++++++++++--- src/Methods.jl | 2 ++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/ExpressionGraph.jl b/src/ExpressionGraph.jl index a8fc0a6f..a0131f88 100644 --- a/src/ExpressionGraph.jl +++ b/src/ExpressionGraph.jl @@ -414,10 +414,18 @@ end derivative(::NoOp, arg::Tuple{T}, ::Val{1}) where {T} = 1.0 +is_ifelse(a::Node) = value(a) == ifelse + +conditional_error(a::Node) = ErrorException("Your expression contained $(value(a)). FastDifferentiation does not yet support differentiation through ifelse or any of these conditionals $(Tuple(not_currently_differentiable))") + +is_conditional(a::Node) = is_ifelse(a) || value(a) in not_currently_differentiable + function derivative(a::Node, index::Val{1}) # if is_variable(a) # if arity(a) == 0 - if is_variable_function(a) + if is_conditional(a) + throw(conditional_error(a)) + elseif is_variable_function(a) return function_variable_derivative(a, index) elseif arity(a) == 1 return derivative(value(a), (children(a)[1],), index) @@ -429,7 +437,9 @@ function derivative(a::Node, index::Val{1}) end function derivative(a::Node, index::Val{2}) - if is_variable_function(a) + if is_conditional(a) + throw(conditional_error(a)) + elseif is_variable_function(a) return function_variable_derivative(a, index) elseif arity(a) == 2 return derivative(value(a), (children(a)[1], children(a)[2]), index) @@ -439,7 +449,9 @@ function derivative(a::Node, index::Val{2}) end function derivative(a::Node, index::Val{i}) where {i} - if is_variable_function(a) + if is_conditional(a) + throw(conditional_error(a)) + elseif is_variable_function(a) return function_variable_derivative(a, index) else return derivative(value(a), (children(a)...,), index) diff --git a/src/Methods.jl b/src/Methods.jl index 6d33ed79..e95d3bb1 100644 --- a/src/Methods.jl +++ b/src/Methods.jl @@ -35,6 +35,8 @@ const basic_diadic = [+, -, *, /, //, \, ^] const diadic_non_differentiable = [max, min, copysign, &, |, ⊻, <, >, ≤, ≥, ≠, ==] const monadic_non_differentiable = [signbit, isreal, iszero, isfinite, isnan, isinf, isinteger, !] +const not_currently_differentiable = vcat(diadic_non_differentiable, monadic_non_differentiable) + # TODO: keep domains tighter than this function number_methods(T, rhs1, rhs2, options=nothing) exprs = [] From 738e9bea09a6d290b1fe3f965941422617fc6bb1 Mon Sep 17 00:00:00 2001 From: brianguenter <1brianguenter@gmail.com> Date: Tue, 27 Aug 2024 11:09:34 -0700 Subject: [PATCH 13/29] new feature release conditionals --- Project.toml | 2 +- docs/src/index.md | 33 +++++++++++++++++--- docs/src/limitations.md | 69 +---------------------------------------- 3 files changed, 31 insertions(+), 73 deletions(-) diff --git a/Project.toml b/Project.toml index a6dbe310..f81c0dbd 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "FastDifferentiation" uuid = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be" authors = ["BrianGuenter"] -version = "0.3.17" +version = "0.4.0" [deps] DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" diff --git a/docs/src/index.md b/docs/src/index.md index 01214e75..29d6ed70 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -7,12 +7,38 @@ CurrentModule = FastDifferentiation [![Build Status](https://github.com/brianguenter/FastDifferentiation.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/brianguenter/FastDifferentiation.jl/actions/workflows/CI.yml?query=branch%3Amain) -FastDifferentiation (**FD**) is a package for generating efficient executables to evaluate derivatives of Julia functions. It can also generate efficient true symbolic derivatives for symbolic analysis. Unlike forward and reverse mode automatic differentiation **FD** automatically generates efficient derivatives for arbitrary function types: ℝ¹->ℝ¹, ℝ¹->ℝᵐ, ℝⁿ->ℝ¹, and ℝⁿ->ℝᵐ, m≠1,n≠1. +FastDifferentiation (**FD**) is a package for generating efficient executables to evaluate derivatives of Julia functions. It can also generate efficient true symbolic derivatives for symbolic analysis. Unlike forward and reverse mode automatic differentiation **FD** automatically generates efficient derivatives for arbitrary function types: ℝ¹->ℝ¹, ℝ¹->ℝᵐ, ℝⁿ->ℝ¹, and ℝⁿ->ℝᵐ. -For f:ℝⁿ->ℝᵐ with n,m large FD may have better performance than conventional AD algorithms because the **FD** algorithm finds expressions shared between partials and computes them only once. In some cases **FD** derivatives can be as efficient as manually coded derivatives (see the Lagrangian dynamics example in the [D*](https://www.microsoft.com/en-us/research/publication/the-d-symbolic-differentiation-algorithm/) paper or the [Benchmarks](@ref) section of the documentation for another example). +For f:ℝⁿ->ℝᵐ with n,m large **FD** may have better performance than conventional AD algorithms because the **FD** algorithm finds expressions shared between partials and computes them only once. In some cases **FD** derivatives can be as efficient as manually coded derivatives (see the Lagrangian dynamics example in the [D*](https://www.microsoft.com/en-us/research/publication/the-d-symbolic-differentiation-algorithm/) paper or the [Benchmarks](@ref) section of the documentation for another example). - **FD** may take much less time to compute symbolic derivatives than Symbolics.jl even in the ℝ¹->ℝ¹ case. The executables generated by **FD** may also be much faster (see [Symbolic Processing](@ref)[^1]. + **FD** may take much less time to compute symbolic derivatives than Symbolics.jl even in the ℝ¹->ℝ¹ case. The executables generated by **FD** may also be much faster (see [Symbolic Processing](@ref)). + +As of version 0.4.0 **FD** allows you to create expressions with conditionals: +```julia + +julia> @variables x y +y + +julia> f = ifelse(x a = make_function([f],[x,y]) + +julia> a(1.0,2.0) +1-element Vector{Float64}: + 1.0 + +julia> a(2.0,1.0) +1-element Vector{Float64}: + 2.0 +``` +Howver, you cannot yet compute derivatives of expressions that contain conditionals: +```julia +julia> jacobian([f],[x,y]) +ERROR: Your expression contained ifelse. FastDifferentiation does not yet support differentiation through ifelse or any of these conditionals (max, min, copysign, &, |, xor, <, >, <=, >=, !=, ==, signbit, isreal, iszero, isfinite, isnan, isinf, isinteger, !) +``` +A future PR will add support for differentiating through conditionals. You should consider using FastDifferentiation when you need: * a fast executable for evaluating the derivative of a function and the overhead of the preprocessing/compilation time is swamped by evaluation time. @@ -37,7 +63,6 @@ This is **beta** software being modified on a daily basis. Expect bugs and frequ The derivative of `|u|` is `u/|u|` which is NaN when `u==0`. This is not a bug. The derivative of the absolute value function is undefined at 0 and the way **FD** signals this is by returning NaN. -[^1]: I am working with the SciML team to see if it is possible to integrate **FD** differentiation directly into Symbolics.jl. diff --git a/docs/src/limitations.md b/docs/src/limitations.md index 1239b34f..9cc490f0 100644 --- a/docs/src/limitations.md +++ b/docs/src/limitations.md @@ -1,72 +1,5 @@ # Limitations -**FD** does not support expressions with conditionals on **FD** variables. For example, you can do this: -```julia -julia> f(a,b,c) = a< 1.0 ? cos(b) : sin(c) -f (generic function with 2 methods) - -julia> f(0.0,x,y) -cos(x) - -julia> f(1.0,x,y) -sin(y) -``` -but you can't do this: -```julia -julia> f(a,b) = a < b ? cos(a) : sin(b) -f (generic function with 2 methods) - -julia> f(x,y) -ERROR: MethodError: no method matching isless(::FastDifferentiation.Node{Symbol, 0}, ::FastDifferentiation.Node{Symbol, 0}) - -Closest candidates are: - isless(::Any, ::DataValues.DataValue{Union{}}) - @ DataValues ~/.julia/packages/DataValues/N7oeL/src/scalar/core.jl:291 - isless(::S, ::DataValues.DataValue{T}) where {S, T} - @ DataValues ~/.julia/packages/DataValues/N7oeL/src/scalar/core.jl:285 - isless(::DataValues.DataValue{Union{}}, ::Any) - @ DataValues ~/.julia/packages/DataValues/N7oeL/src/scalar/core.jl:293 - ... -``` -This is because the call `f(x,y)` creates an expression graph. At graph creation time the **FD** variables `x,y` are unevaluated variables with no specific value so they cannot be compared with any other value. - -The algorithm can be extended to work with conditionals applied to **FD** variables but the processing time and graph size may grow exponentially with conditional nesting depth. A future version may allow for limited conditional nesting. See [Future Work](@ref) for a potential long term solution to this problem. **FD** does not support looping internally. All operations with loops, such as matrix vector multiplication, are unrolled into scalar operations. The corresponding executable functions generated by `make_function` have size proportional to the number of operations. -Expressions with ≈10⁵ scalar operations have reasonable symbolic preprocessing and compilation times. Beyond this size LLVM compilation time can become extremely long and eventually the executables become so large that their caching behavior is not good and performance declines. - -A possible solution to this problem is to do what is called rerolling: detecting repreating indexing patterns in the **FD** expressions and automatically generating loops to replace inlined code. This rerolling step would be performed on the **FD** expressions graphs before function compliation. - -It is not necessary to completely undo the unrolling back to the original expresssion, just to reduce code size enough to get reasonable compilation times and better caching behavior. - -For example, in this matrix vector multiplication -```julia - -julia> a = make_variables(:a,3,3) -3×3 Matrix{FastDifferentiation.Node}: - a1_1 a1_2 a1_3 - a2_1 a2_2 a2_3 - a3_1 a3_2 a3_3 - -julia> b = make_variables(:b,3) -3-element Vector{FastDifferentiation.Node}: - b1 - b2 - b3 - -julia> a*b -3-element Vector{Any}: - (((a1_1 * b1) + (a1_2 * b2)) + (a1_3 * b3)) - (((a2_1 * b1) + (a2_2 * b2)) + (a2_3 * b3)) - (((a3_1 * b1) + (a3_2 * b2)) + (a3_3 * b3)) -``` -the goal is to replace concrete index numbers with symbolic index variables that represent offsets rather than absolute indices -```julia - -[ - a[i,j]*b[j] + a[i, j+1]*b[j+1] + a[i,j+2]*b[j+2] - a[i+1,j]*b[j] + a[i+1, j+1]*b[j+1] + a[i+1,j+2]*b[j+2] - a[i+2,j]*b[j] + a[i+2, j+1]*b[j+1] + a[i+2,j+2]*b[j+2] -] -``` -and then to extract looping structure from these indices. +Expressions with ≈10⁵ scalar operations have reasonable symbolic preprocessing and compilation times. Beyond this size LLVM compilation time can become extremely long and eventually the executables become so large that their caching behavior is not good and performance declines. \ No newline at end of file From 7c005d045d40d96acdabfd82415e7170245e3817 Mon Sep 17 00:00:00 2001 From: brianguenter <1brianguenter@gmail.com> Date: Tue, 27 Aug 2024 11:27:21 -0700 Subject: [PATCH 14/29] Add conditionals Fixes #89 commment change explaining consequences of adding conditionals --- docs/src/index.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/src/index.md b/docs/src/index.md index 29d6ed70..e1441abf 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -38,8 +38,12 @@ Howver, you cannot yet compute derivatives of expressions that contain condition julia> jacobian([f],[x,y]) ERROR: Your expression contained ifelse. FastDifferentiation does not yet support differentiation through ifelse or any of these conditionals (max, min, copysign, &, |, xor, <, >, <=, >=, !=, ==, signbit, isreal, iszero, isfinite, isnan, isinf, isinteger, !) ``` +This may be a breaking change for some users. In previous versions this threw an the expression `x==y` returned a `Bool`. Some data structures, such as `Dict` use `==` by default to determine if two entries are the same. This will no longer work since `x==y` will now return an expression graph. Use an `IDict` instead since this uses `===`. + A future PR will add support for differentiating through conditionals. + + You should consider using FastDifferentiation when you need: * a fast executable for evaluating the derivative of a function and the overhead of the preprocessing/compilation time is swamped by evaluation time. * to do additional symbolic processing on your derivative. **FD** can generate a true symbolic derivative to be processed further in Symbolics.jl or another computer algebra system. From 6a51005dbc7cd439402f5f59f7cc4592606f03c0 Mon Sep 17 00:00:00 2001 From: brianguenter <1brianguenter@gmail.com> Date: Tue, 27 Aug 2024 11:29:39 -0700 Subject: [PATCH 15/29] fix to introd docs --- docs/src/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/index.md b/docs/src/index.md index e1441abf..7020dc0f 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -38,7 +38,7 @@ Howver, you cannot yet compute derivatives of expressions that contain condition julia> jacobian([f],[x,y]) ERROR: Your expression contained ifelse. FastDifferentiation does not yet support differentiation through ifelse or any of these conditionals (max, min, copysign, &, |, xor, <, >, <=, >=, !=, ==, signbit, isreal, iszero, isfinite, isnan, isinf, isinteger, !) ``` -This may be a breaking change for some users. In previous versions this threw an the expression `x==y` returned a `Bool`. Some data structures, such as `Dict` use `==` by default to determine if two entries are the same. This will no longer work since `x==y` will now return an expression graph. Use an `IDict` instead since this uses `===`. +This may be a breaking change for some users. In previous versions `x==y` returned a `Bool` whereas in 0.4.0 and up it returns an **FD** expression. Some data structures, such as `Dict` use `==` by default to determine if two entries are the same. This will no longer work since `x==y` will not return a `Bool`. Use an `IDict` instead since this uses `===`. A future PR will add support for differentiating through conditionals. From aa1a8c074b2fd4f26f28bd78d82ee4ecc0c507e2 Mon Sep 17 00:00:00 2001 From: brianguenter <1brianguenter@gmail.com> Date: Tue, 27 Aug 2024 11:33:08 -0700 Subject: [PATCH 16/29] fix to intro doc --- docs/src/index.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 7020dc0f..1b42c0b7 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -61,8 +61,6 @@ Hv computes the Hessian times a vector without explicitly forming the Hessian ma If you use FD in your work please share the functions you differentiate with me. I'll add them to the benchmarks. The more functions available to test the easier it is for others to determine if FD will help with their problem. -This is **beta** software being modified on a daily basis. Expect bugs and frequent, possibly breaking changes, over the next month or so. Documentation is frequently updated so check the latest docs before filing an issue. Your problem may have been fixed and documented. - ## Notes about special derivatives The derivative of `|u|` is `u/|u|` which is NaN when `u==0`. This is not a bug. The derivative of the absolute value function is undefined at 0 and the way **FD** signals this is by returning NaN. From cde852edb3e5d0057878262a2187b872248b34af Mon Sep 17 00:00:00 2001 From: brianguenter <1brianguenter@gmail.com> Date: Tue, 27 Aug 2024 15:17:36 -0700 Subject: [PATCH 17/29] doc readme change --- README.md | 22 +--------------------- 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/README.md b/README.md index 079467cc..a91d2e67 100644 --- a/README.md +++ b/README.md @@ -71,27 +71,7 @@ If you use FD in your work please share the functions you differentiate with me. **A**: If you multiply a matrix of **FD** variables times a vector of **FD** variables the matrix vector multiplication loop is effectively unrolled into scalar expressions. Matrix operations on large matrices will generate large executables and long preprocessing time. **FD** functions with up 10⁵ operations should still have reasonable preprocessing/compilation times (approximately 1 minute on a modern laptop) and good run time performance. **Q**: Does **FD** support conditionals? -**A**: **FD** does not yet support conditionals that involve the variables you are differentiating with respect to. You can do this: -```julia -@variables x y #create FD variables - -julia> f(a,b,c) = a< 1.0 ? cos(b) : sin(c) -f (generic function with 2 methods) - -julia> f(0.0,x,y) -cos(x) - -julia> f(1.0,x,y) -sin(y) -``` -but you can't do this: -```julia -julia> f(a,b) = a < b ? cos(a) : sin(b) -f (generic function with 2 methods) - -julia> f(x,y) -ERROR: MethodError: no method matching isless(::FastDifferentiation.Node{Symbol, 0}, ::FastDifferentiation.Node{Symbol, 0}) -``` +**A**: Yes, but see the documentation for limitations. Full functionality will come in a future release. # Release Notes
From bb40f585803431a550a31b4b6a1833544557a539 Mon Sep 17 00:00:00 2001 From: brianguenter <1brianguenter@gmail.com> Date: Mon, 9 Sep 2024 12:18:06 -0700 Subject: [PATCH 18/29] pre-merge Add Conditionals pr #90 --- README.md | 22 +++++++++++++++++++++- docs/src/index.md | 4 +++- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a91d2e67..079467cc 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,27 @@ If you use FD in your work please share the functions you differentiate with me. **A**: If you multiply a matrix of **FD** variables times a vector of **FD** variables the matrix vector multiplication loop is effectively unrolled into scalar expressions. Matrix operations on large matrices will generate large executables and long preprocessing time. **FD** functions with up 10⁵ operations should still have reasonable preprocessing/compilation times (approximately 1 minute on a modern laptop) and good run time performance. **Q**: Does **FD** support conditionals? -**A**: Yes, but see the documentation for limitations. Full functionality will come in a future release. +**A**: **FD** does not yet support conditionals that involve the variables you are differentiating with respect to. You can do this: +```julia +@variables x y #create FD variables + +julia> f(a,b,c) = a< 1.0 ? cos(b) : sin(c) +f (generic function with 2 methods) + +julia> f(0.0,x,y) +cos(x) + +julia> f(1.0,x,y) +sin(y) +``` +but you can't do this: +```julia +julia> f(a,b) = a < b ? cos(a) : sin(b) +f (generic function with 2 methods) + +julia> f(x,y) +ERROR: MethodError: no method matching isless(::FastDifferentiation.Node{Symbol, 0}, ::FastDifferentiation.Node{Symbol, 0}) +``` # Release Notes
diff --git a/docs/src/index.md b/docs/src/index.md index 1b42c0b7..e1441abf 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -38,7 +38,7 @@ Howver, you cannot yet compute derivatives of expressions that contain condition julia> jacobian([f],[x,y]) ERROR: Your expression contained ifelse. FastDifferentiation does not yet support differentiation through ifelse or any of these conditionals (max, min, copysign, &, |, xor, <, >, <=, >=, !=, ==, signbit, isreal, iszero, isfinite, isnan, isinf, isinteger, !) ``` -This may be a breaking change for some users. In previous versions `x==y` returned a `Bool` whereas in 0.4.0 and up it returns an **FD** expression. Some data structures, such as `Dict` use `==` by default to determine if two entries are the same. This will no longer work since `x==y` will not return a `Bool`. Use an `IDict` instead since this uses `===`. +This may be a breaking change for some users. In previous versions this threw an the expression `x==y` returned a `Bool`. Some data structures, such as `Dict` use `==` by default to determine if two entries are the same. This will no longer work since `x==y` will now return an expression graph. Use an `IDict` instead since this uses `===`. A future PR will add support for differentiating through conditionals. @@ -61,6 +61,8 @@ Hv computes the Hessian times a vector without explicitly forming the Hessian ma If you use FD in your work please share the functions you differentiate with me. I'll add them to the benchmarks. The more functions available to test the easier it is for others to determine if FD will help with their problem. +This is **beta** software being modified on a daily basis. Expect bugs and frequent, possibly breaking changes, over the next month or so. Documentation is frequently updated so check the latest docs before filing an issue. Your problem may have been fixed and documented. + ## Notes about special derivatives The derivative of `|u|` is `u/|u|` which is NaN when `u==0`. This is not a bug. The derivative of the absolute value function is undefined at 0 and the way **FD** signals this is by returning NaN. From e794945c93c82c19e10351bcdcffe2e9bc965345 Mon Sep 17 00:00:00 2001 From: brianguenter <1brianguenter@gmail.com> Date: Thu, 12 Sep 2024 10:29:55 -0700 Subject: [PATCH 19/29] moved differentiation rules into a separate file. All tests pass except. Need to replace ifelse with my own ifelse that generates correct if...else code. Too many functions execute illegal stuff in one of the ifelse branches which causes an exception. --- Project.toml | 1 + src/DifferentiationRules.jl | 76 +++++++++++++++++++++++++++++++++++++ src/ExpressionGraph.jl | 72 +---------------------------------- src/FastDifferentiation.jl | 2 +- 4 files changed, 80 insertions(+), 71 deletions(-) create mode 100644 src/DifferentiationRules.jl diff --git a/Project.toml b/Project.toml index f81c0dbd..da5d9d3b 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "0.4.0" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" diff --git a/src/DifferentiationRules.jl b/src/DifferentiationRules.jl new file mode 100644 index 00000000..f27ea130 --- /dev/null +++ b/src/DifferentiationRules.jl @@ -0,0 +1,76 @@ +# Pre-defined derivatives +import DiffRules +#Special case rules for rules in DiffRules that use ?: or if...else, neither of which will work when add conditionals + +#airybix and airyprimex don't work in FastDifferentiation. airybix(x) where x is a Node causes a stack overflow. So no diffrule defined, although airybix uses if...else +#LogExpFunctions.xlogy uses ?: so doesn't work with FastDifferentiation. Most of the functions in this package don't work with FastDifferentiation. + +DiffRules.@define_diffrule Base.:^(x, y) = :($y * ($x^($y - 1))), :(ifelse($x isa Real && $x <= 0, Base.oftype(float($x), NaN), ($x^$y) * log($x))) + + +for (modu, fun, arity) ∈ DiffRules.diffrules(; filter_modules=(:Base, :SpecialFunctions, :NaNMath)) + fun in [:*, :+, :abs, :mod, :rem, :max, :min] && continue # special + for i ∈ 1:arity + + expr = if arity == 1 + DiffRules.diffrule(modu, fun, :(args[1])) + else + DiffRules.diffrule(modu, fun, ntuple(k -> :(args[$k]), arity)...)[i] + end + + @eval derivative(::typeof($modu.$fun), args::NTuple{$arity,Any}, ::Val{$i}) = $expr + end +end + +derivative(::typeof(abs), arg::Tuple{T}, ::Val{1}) where {T} = arg[1] / abs(arg[1]) + +function derivative(::typeof(*), args::NTuple{N,Any}, ::Val{I}) where {N,I} + if N == 2 + return I == 1 ? args[2] : args[1] + else + return Node(*, deleteat!(collect(args), I)...) #TODO: simplify_check_cache will only be called for 2 arguments or less. Need to extend to nary *, n> 2, if this is necessary. + end +end + +derivative(::typeof(+), args::NTuple{N,Any}, ::Val{I}) where {I,N} = Node(1) +derivative(::NoOp, arg::Tuple{T}, ::Val{1}) where {T} = 1.0 + + +function_variable_derivative(a::Node, index::Val{i}) where {i} = check_cache((Differential, children(a)[i]), EXPRESSION_CACHE) +function derivative(a::Node, index::Val{1}) + # if is_variable(a) + # if arity(a) == 0 + if is_conditional(a) + throw(conditional_error(a)) + elseif is_variable_function(a) + return function_variable_derivative(a, index) + elseif arity(a) == 1 + return derivative(value(a), (children(a)[1],), index) + elseif arity(a) == 2 + return derivative(value(a), (children(a)[1], children(a)[2]), index) + else + throw(ErrorException("should never get here")) + end +end + +function derivative(a::Node, index::Val{2}) + if is_conditional(a) + throw(conditional_error(a)) + elseif is_variable_function(a) + return function_variable_derivative(a, index) + elseif arity(a) == 2 + return derivative(value(a), (children(a)[1], children(a)[2]), index) + else + throw(ErrorException("should never get here")) + end +end + +function derivative(a::Node, index::Val{i}) where {i} + if is_conditional(a) + throw(conditional_error(a)) + elseif is_variable_function(a) + return function_variable_derivative(a, index) + else + return derivative(value(a), (children(a)...,), index) + end +end \ No newline at end of file diff --git a/src/ExpressionGraph.jl b/src/ExpressionGraph.jl index a0131f88..3dee7104 100644 --- a/src/ExpressionGraph.jl +++ b/src/ExpressionGraph.jl @@ -337,8 +337,6 @@ end Base.:^(a::FastDifferentiation.Node, b::Integer) = simplify_check_cache(^, a, b, EXPRESSION_CACHE) -rules = Any[] - Base.convert(::Type{Node}, a::T) where {T<:Real} = Node(a) Base.promote_rule(::Type{<:Real}, ::Type{Node}) = Node Base.promote_rule(::Type{Bool}, ::Type{Node}) = Node @@ -355,23 +353,6 @@ Base.conj(a::Node) = a #need to define this because dot and probably other linea Base.adjoint(a::Node) = a Base.transpose(a::Node) = a -# Pre-defined derivatives -import DiffRules -for (modu, fun, arity) ∈ DiffRules.diffrules(; filter_modules=(:Base, :SpecialFunctions, :NaNMath)) - fun in [:*, :+, :abs, :mod, :rem, :max, :min] && continue # special - for i ∈ 1:arity - - expr = if arity == 1 - DiffRules.diffrule(modu, fun, :(args[1])) - else - DiffRules.diffrule(modu, fun, ntuple(k -> :(args[$k]), arity)...)[i] - end - - push!(rules, expr) - @eval derivative(::typeof($modu.$fun), args::NTuple{$arity,Any}, ::Val{$i}) = $expr - end -end - function Base.inv(a::Node) if typeof(value(a)) === / return children(a)[2] / children(a)[1] @@ -380,23 +361,10 @@ function Base.inv(a::Node) end end - #need special case for sincos because it returns a 2 tuple. Also Diffrules.jl does not define a differentiation rule for sincos. Base.sincos(x::Node) = (sin(x), cos(x)) #this will be less efficient than sincos. TODO figure out a better way. -derivative(::typeof(abs), arg::Tuple{T}, ::Val{1}) where {T} = arg[1] / abs(arg[1]) - -function derivative(::typeof(*), args::NTuple{N,Any}, ::Val{I}) where {N,I} - if N == 2 - return I == 1 ? args[2] : args[1] - else - return Node(*, deleteat!(collect(args), I)...) #TODO: simplify_check_cache will only be called for 2 arguments or less. Need to extend to nary *, n> 2, if this is necessary. - end -end -derivative(::typeof(+), args::NTuple{N,Any}, ::Val{I}) where {I,N} = Node(1) - -function_variable_derivative(a::Node, index::Val{i}) where {i} = check_cache((Differential, children(a)[i]), EXPRESSION_CACHE) """When constructing `DerivativeGraph` with repeated values in roots, e.g., ```julia @@ -412,51 +380,15 @@ function create_NoOp(child) return Node(NoOp(), child) end -derivative(::NoOp, arg::Tuple{T}, ::Val{1}) where {T} = 1.0 + is_ifelse(a::Node) = value(a) == ifelse -conditional_error(a::Node) = ErrorException("Your expression contained $(value(a)). FastDifferentiation does not yet support differentiation through ifelse or any of these conditionals $(Tuple(not_currently_differentiable))") +conditional_error(a::Node) = ErrorException("Your expression contained $(value(a)). FastDifferentiation does not yet support differentiation through this conditional or any of these $(Tuple(not_currently_differentiable))") is_conditional(a::Node) = is_ifelse(a) || value(a) in not_currently_differentiable -function derivative(a::Node, index::Val{1}) - # if is_variable(a) - # if arity(a) == 0 - if is_conditional(a) - throw(conditional_error(a)) - elseif is_variable_function(a) - return function_variable_derivative(a, index) - elseif arity(a) == 1 - return derivative(value(a), (children(a)[1],), index) - elseif arity(a) == 2 - return derivative(value(a), (children(a)[1], children(a)[2]), index) - else - throw(ErrorException("should never get here")) - end -end - -function derivative(a::Node, index::Val{2}) - if is_conditional(a) - throw(conditional_error(a)) - elseif is_variable_function(a) - return function_variable_derivative(a, index) - elseif arity(a) == 2 - return derivative(value(a), (children(a)[1], children(a)[2]), index) - else - throw(ErrorException("should never get here")) - end -end -function derivative(a::Node, index::Val{i}) where {i} - if is_conditional(a) - throw(conditional_error(a)) - elseif is_variable_function(a) - return function_variable_derivative(a, index) - else - return derivative(value(a), (children(a)...,), index) - end -end """ variables(node::Node) diff --git a/src/FastDifferentiation.jl b/src/FastDifferentiation.jl index c9399b5d..9f90f745 100644 --- a/src/FastDifferentiation.jl +++ b/src/FastDifferentiation.jl @@ -11,7 +11,6 @@ using UUIDs using SparseArrays using DataStructures - module AutomaticDifferentiation struct NoDeriv end @@ -40,6 +39,7 @@ include("Methods.jl") #functions and macros to generate Node specialized methods include("Utilities.jl") include("BitVectorFunctions.jl") include("ExpressionGraph.jl") #definition of Node type from which FD expression graphs are created +include("DifferentiationRules.jl") include("PathEdge.jl") #functions to create and manipulate edges in derivative graphs include("Conditionals.jl") include("DerivativeGraph.jl") #functions to compute derivative graph from an expression graph of Node From e7c1996281a390a16a21ddf2c0039d82680781b7 Mon Sep 17 00:00:00 2001 From: brianguenter <1brianguenter@gmail.com> Date: Thu, 12 Sep 2024 13:55:30 -0700 Subject: [PATCH 20/29] changed code generation to use new if_else instead of ifelse because the latter evaluates all its arguments even if doing so would cause an exception changed simplify_check_cache and check_cache functions to not take a cache argument added ! to the list of supported bool expressions --- src/CodeGeneration.jl | 8 +++- src/DifferentiationRules.jl | 10 ++--- src/ExpressionGraph.jl | 80 ++++++++++++++++++++----------------- src/Methods.jl | 7 ++-- src/UnspecifiedFunction.jl | 3 +- 5 files changed, 59 insertions(+), 49 deletions(-) diff --git a/src/CodeGeneration.jl b/src/CodeGeneration.jl index ce8d14a4..b0995372 100644 --- a/src/CodeGeneration.jl +++ b/src/CodeGeneration.jl @@ -53,7 +53,13 @@ function function_body!(dag::Node, variable_to_index::IdDict{Node,Int64}, node_t if is_tree(node) args = _dag_to_function.(children(node)) - statement = :($(node_to_var[node]) = $(Symbol(value(node)))($(args...))) + + if value(node) === if_else + println(args) + statement = :($(node_to_var[node]) = $(args[1]) ? $(args[2]) : $(args[3])) + else + statement = :($(node_to_var[node]) = $(Symbol(value(node)))($(args...))) + end push!(body.args, statement) end end diff --git a/src/DifferentiationRules.jl b/src/DifferentiationRules.jl index f27ea130..a43c0300 100644 --- a/src/DifferentiationRules.jl +++ b/src/DifferentiationRules.jl @@ -5,7 +5,7 @@ import DiffRules #airybix and airyprimex don't work in FastDifferentiation. airybix(x) where x is a Node causes a stack overflow. So no diffrule defined, although airybix uses if...else #LogExpFunctions.xlogy uses ?: so doesn't work with FastDifferentiation. Most of the functions in this package don't work with FastDifferentiation. -DiffRules.@define_diffrule Base.:^(x, y) = :($y * ($x^($y - 1))), :(ifelse($x isa Real && $x <= 0, Base.oftype(float($x), NaN), ($x^$y) * log($x))) +DiffRules.@define_diffrule Base.:^(x, y) = :($y * ($x^($y - 1))), :(if_else($x isa Real && $x <= 0, Base.oftype(float($x), NaN), ($x^$y) * log($x))) for (modu, fun, arity) ∈ DiffRules.diffrules(; filter_modules=(:Base, :SpecialFunctions, :NaNMath)) @@ -36,13 +36,9 @@ derivative(::typeof(+), args::NTuple{N,Any}, ::Val{I}) where {I,N} = Node(1) derivative(::NoOp, arg::Tuple{T}, ::Val{1}) where {T} = 1.0 -function_variable_derivative(a::Node, index::Val{i}) where {i} = check_cache((Differential, children(a)[i]), EXPRESSION_CACHE) +function_variable_derivative(a::Node, index::Val{i}) where {i} = check_cache((Differential, children(a)[i])) function derivative(a::Node, index::Val{1}) - # if is_variable(a) - # if arity(a) == 0 - if is_conditional(a) - throw(conditional_error(a)) - elseif is_variable_function(a) + if is_variable_function(a) return function_variable_derivative(a, index) elseif arity(a) == 1 return derivative(value(a), (children(a)[1],), index) diff --git a/src/ExpressionGraph.jl b/src/ExpressionGraph.jl index 3dee7104..023f6266 100644 --- a/src/ExpressionGraph.jl +++ b/src/ExpressionGraph.jl @@ -48,13 +48,13 @@ end #until I can think of a better way of structuring the caching operation it will be a single global expression cache. This precludes multithreading, unfortunately. Many other parts of the algorithm are difficult to multithread. Processing is also so quick that only large graphs would benefit from multithreading. Don't know how common these will be. const EXPRESSION_CACHE = IdDict{Any,Node}() -function check_cache(a::Tuple{Vararg}, cache::IdDict{Any,Node})::Node - cache_val = get(cache, a, nothing) +function check_cache(a::Tuple{Vararg})::Node + cache_val = get(EXPRESSION_CACHE, a, nothing) if cache_val === nothing - cache[a] = Node(a[1], a[2:end]...) #this should wrap everything, including basic numbers, in a Node object + EXPRESSION_CACHE[a] = Node(a[1], a[2:end]...) #this should wrap everything, including basic numbers, in a Node object end - return cache[a] + return EXPRESSION_CACHE[a] end """ @@ -74,11 +74,11 @@ end num_derivatives(a::Differential) = length(a.variables_wrt) #convenience function to extract the fields from Node object to check cache -function check_cache(a::Node, cache) +function check_cache(a::Node) if children(a) !== nothing - check_cache((value(a), children(a)...), cache) + check_cache((value(a), children(a)...)) else - check_cache((a,), cache) + check_cache((a,)) end end @@ -143,13 +143,8 @@ function constant_value(a::Node) end end -error_message() = throw(ErrorException("FastDifferentiation.jl does not currently support comparison operations on FastDifferentiation expressions. Your code, or libraries called by your code, had a statement with a comparison operator such as x 2 is_times(a::Node) = value(a) == * is_nary_times(a::Node) = is_nary(a) && value(a) == typeof(*) - - -function simplify_check_cache(::typeof(^), a, b, cache) +function simplify_check_cache(::typeof(^), a, b) na = Node(a) nb = Node(b) if constant_value(na) !== nothing && constant_value(nb) !== nothing @@ -196,11 +204,11 @@ function simplify_check_cache(::typeof(^), a, b, cache) elseif value(nb) == 1 return a else - return check_cache((^, na, nb), cache) + return check_cache((^, na, nb)) end end -function simplify_check_cache(::typeof(*), na, nb, cache)::Node +function simplify_check_cache(::typeof(*), na, nb)::Node a = Node(na) b = Node(nb) @@ -229,7 +237,7 @@ function simplify_check_cache(::typeof(*), na, nb, cache)::Node elseif typeof(*) == typeof(value(a)) && typeof(*) == typeof(value(b)) && is_constant(children(b)[1]) && is_constant(children(a)[1]) return Node(value(children(a)[1]) * value(children(b)[1])) * (children(b)[2] * children(a)[2]) else - return check_cache((*, a, b), cache) + return check_cache((*, a, b)) end end @@ -256,7 +264,7 @@ function constant_sum_simplification(lchild::Node, rchild::Node) end end -function simplify_check_cache(::typeof(+), na, nb, cache)::Node +function simplify_check_cache(::typeof(+), na, nb)::Node a = Node(na) b = Node(nb) @@ -277,11 +285,11 @@ function simplify_check_cache(::typeof(+), na, nb, cache)::Node elseif (tmp = constant_sum_simplification(a, b)) !== nothing #simplify c1*a + c2*a => (c1+c2)*a where c1,c2 are constants return (tmp[1] + tmp[2]) * tmp[3] else - return check_cache((+, a, b), cache) + return check_cache((+, a, b)) end end -function simplify_check_cache(::typeof(-), na, nb, cache)::Node +function simplify_check_cache(::typeof(-), na, nb)::Node a = Node(na) b = Node(nb) if a === b @@ -297,11 +305,11 @@ function simplify_check_cache(::typeof(-), na, nb, cache)::Node elseif (tmp = constant_sum_simplification(a, b)) !== nothing #simplify c1*a - c2*a => (c1-c2)*a where c1,c2 are constants return (tmp[1] - tmp[2]) * tmp[3] else - return check_cache((-, a, b), cache) + return check_cache((-, a, b)) end end -function simplify_check_cache(::typeof(/), na, nb, cache)::Node +function simplify_check_cache(::typeof(/), na, nb)::Node a = Node(na) b = Node(nb) @@ -310,19 +318,19 @@ function simplify_check_cache(::typeof(/), na, nb, cache)::Node elseif is_constant(a) && is_constant(b) return Node(value(a) / value(b)) else - return check_cache((/, a, b), cache) + return check_cache((/, a, b)) end end -simplify_check_cache(f::Any, na, cache) = check_cache((f, na), cache)::Node + """ - simplify_check_cache(::typeof(-), a, cache) + simplify_check_cache(::typeof(-), a) Special case only for unary -. No simplifications are currently applied to any other unary functions""" -function simplify_check_cache(::typeof(-), a, cache)::Node +function simplify_check_cache(::typeof(-), a)::Node na = Node(a) #this is safe because Node constructor is idempotent if arity(na) == 1 && typeof(value(na)) == typeof(-) return children(na)[1] @@ -331,11 +339,11 @@ function simplify_check_cache(::typeof(-), a, cache)::Node elseif typeof(*) == typeof(value(na)) && constant_value(children(na)[1]) !== nothing return Node(-value(children(na)[1])) * children(na)[2] else - return check_cache((-, na), cache) + return check_cache((-, na)) end end -Base.:^(a::FastDifferentiation.Node, b::Integer) = simplify_check_cache(^, a, b, EXPRESSION_CACHE) +Base.:^(a::FastDifferentiation.Node, b::Integer) = simplify_check_cache(^, a, b) Base.convert(::Type{Node}, a::T) where {T<:Real} = Node(a) Base.promote_rule(::Type{<:Real}, ::Type{Node}) = Node @@ -382,11 +390,11 @@ end -is_ifelse(a::Node) = value(a) == ifelse +is_if_else(a::Node) = value(a) == if_else conditional_error(a::Node) = ErrorException("Your expression contained $(value(a)). FastDifferentiation does not yet support differentiation through this conditional or any of these $(Tuple(not_currently_differentiable))") -is_conditional(a::Node) = is_ifelse(a) || value(a) in not_currently_differentiable +is_conditional(a::Node) = is_if_else(a) || value(a) in not_currently_differentiable @@ -605,4 +613,4 @@ export make_variables #create methods that accept Node arguments for all mathematical functions. -@number_methods(Node, simplify_check_cache(f, a, EXPRESSION_CACHE), simplify_check_cache(f, a, b, EXPRESSION_CACHE)) #create methods for standard functions that take Node instead of Number arguments. Check cache to see if these arguments have been seen before. +@number_methods(Node, simplify_check_cache(f, a), simplify_check_cache(f, a, b)) #create methods for standard functions that take Node instead of Number arguments. Check cache to see if these arguments have been seen before. diff --git a/src/Methods.jl b/src/Methods.jl index e95d3bb1..4e95e632 100644 --- a/src/Methods.jl +++ b/src/Methods.jl @@ -32,7 +32,7 @@ const previously_declared_for = Set([]) const basic_monadic = [-, +] const basic_diadic = [+, -, *, /, //, \, ^] -const diadic_non_differentiable = [max, min, copysign, &, |, ⊻, <, >, ≤, ≥, ≠, ==] +const diadic_non_differentiable = [max, min, copysign, &, |, !, ⊻, <, >, ≤, ≥, ≠, ==, isless] const monadic_non_differentiable = [signbit, isreal, iszero, isfinite, isnan, isinf, isinteger, !] const not_currently_differentiable = vcat(diadic_non_differentiable, monadic_non_differentiable) @@ -60,7 +60,7 @@ function number_methods(T, rhs1, rhs2, options=nothing) (f::$(typeof(f)))(a::$T, b::Real) = $rhs2 (f::$(typeof(f)))(a::Real, b::$T) = $rhs2 end - + println(expr) push!(exprs, expr) end @@ -82,7 +82,8 @@ end # Base.isinf(x::Node) = !isnan(x) & !isfinite(x) macro number_methods(T, rhs1, rhs2, options=nothing) - eval(:(Base.ifelse(a::$T, b::$T, c::$T) = simplify_check_cache(Base.ifelse, a, b, c, EXPRESSION_CACHE))) + #special case for ifelse because it takes three arguments + eval(:(Base.ifelse(a::$T, b::$T, c::$T) = simplify_check_cache(Base.ifelse, a, b, c))) number_methods(T, rhs1, rhs2, options) |> esc end diff --git a/src/UnspecifiedFunction.jl b/src/UnspecifiedFunction.jl index 3bf80e34..9f7b443c 100644 --- a/src/UnspecifiedFunction.jl +++ b/src/UnspecifiedFunction.jl @@ -36,8 +36,7 @@ export function_of function derivative(uf::UnspecifiedFunction{V,D}, args::NTuple{N,Node{SymbolicUtils.BasicSymbolic{Real},0}}, ::Val{I}) where {I,V,D,N} new_derivs = SVector{D + 1,Node{SymbolicUtils.BasicSymbolic{Real},0}}(uf.derivatives..., args[I]) return check_cache( - (UnspecifiedFunction(uf.name, uf.variables, new_derivs), uf.variables...), - EXPRESSION_CACHE) + (UnspecifiedFunction(uf.name, uf.variables, new_derivs), uf.variables...)) end function Base.show(io::IO, a::UnspecifiedFunction) From 6a6025db66e99abe3e313c3b6e9326794879f764 Mon Sep 17 00:00:00 2001 From: brianguenter <1brianguenter@gmail.com> Date: Sat, 14 Sep 2024 14:36:55 -0700 Subject: [PATCH 21/29] code generation for if_else and conditionals seems to work. Needs test functions written. --- src/CodeGeneration.jl | 25 +++++++++++++++---------- src/Methods.jl | 2 +- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/src/CodeGeneration.jl b/src/CodeGeneration.jl index b0995372..aa09d841 100644 --- a/src/CodeGeneration.jl +++ b/src/CodeGeneration.jl @@ -37,14 +37,12 @@ end ``` and the second return value will be the constant value. """ -function function_body!(dag::Node, variable_to_index::IdDict{Node,Int64}, node_to_var::Union{Nothing,IdDict{Node,Union{Symbol,Real,Expr}}}=nothing) +function function_body!(dag::Node, variable_to_index::IdDict{Node,Int64}, node_to_var::Union{Nothing,IdDict{Node,Union{Symbol,Real,Expr}}}=nothing, body::Expr=Expr(:block)) if node_to_var === nothing node_to_var = IdDict{Node,Union{Symbol,Real,Expr}}() end - body = Expr(:block) - - function _dag_to_function(node) + function _dag_to_function!(node, local_body) tmp = get(node_to_var, node, nothing) @@ -52,22 +50,29 @@ function function_body!(dag::Node, variable_to_index::IdDict{Node,Int64}, node_t node_to_var[node] = node_symbol(node, variable_to_index) if is_tree(node) - args = _dag_to_function.(children(node)) - if value(node) === if_else - println(args) - statement = :($(node_to_var[node]) = $(args[1]) ? $(args[2]) : $(args[3])) + true_body = Expr(:block) + false_body = Expr(:block) + if_cond_var = _dag_to_function!(children(node)[1], local_body) + _dag_to_function!(children(node)[2], true_body) + _dag_to_function!(children(node)[3], false_body) + statement = :($(node_to_var[node]) = if $(if_cond_var) + $(true_body) + else + $(false_body) + end) else + args = _dag_to_function!.(children(node), Ref(local_body)) statement = :($(node_to_var[node]) = $(Symbol(value(node)))($(args...))) end - push!(body.args, statement) + push!(local_body.args, statement) end end return node_to_var[node] end - return body, _dag_to_function(dag) + return body, _dag_to_function!(dag, body) end function zero_array_declaration(array::StaticArray{S,<:Any,N}) where {S,N} diff --git a/src/Methods.jl b/src/Methods.jl index 4e95e632..f1300e67 100644 --- a/src/Methods.jl +++ b/src/Methods.jl @@ -60,7 +60,7 @@ function number_methods(T, rhs1, rhs2, options=nothing) (f::$(typeof(f)))(a::$T, b::Real) = $rhs2 (f::$(typeof(f)))(a::Real, b::$T) = $rhs2 end - println(expr) + push!(exprs, expr) end From 86cfb4d1078e314d3a82700d5a80fa44c20b3ccc Mon Sep 17 00:00:00 2001 From: brianguenter <1brianguenter@gmail.com> Date: Sat, 14 Sep 2024 14:45:56 -0700 Subject: [PATCH 22/29] added comment to function_body! --- src/CodeGeneration.jl | 2 +- src/Methods.jl | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/CodeGeneration.jl b/src/CodeGeneration.jl index aa09d841..0459d1c9 100644 --- a/src/CodeGeneration.jl +++ b/src/CodeGeneration.jl @@ -50,7 +50,7 @@ function function_body!(dag::Node, variable_to_index::IdDict{Node,Int64}, node_t node_to_var[node] = node_symbol(node, variable_to_index) if is_tree(node) - if value(node) === if_else + if value(node) === if_else #special case code generation for if...else. Need to generate nested code so only the statements in the true or false branch will be executed. true_body = Expr(:block) false_body = Expr(:block) if_cond_var = _dag_to_function!(children(node)[1], local_body) diff --git a/src/Methods.jl b/src/Methods.jl index f1300e67..86d9842a 100644 --- a/src/Methods.jl +++ b/src/Methods.jl @@ -70,6 +70,7 @@ function number_methods(T, rhs1, rhs2, options=nothing) end push!(exprs, :(push!($previously_declared_for, $T))) + return Expr(:block, exprs...) end From 2382a1c6a8aa5d97e1e3b0d0c99b141383944416 Mon Sep 17 00:00:00 2001 From: brianguenter <1brianguenter@gmail.com> Date: Sat, 14 Sep 2024 15:03:29 -0700 Subject: [PATCH 23/29] pulled dag_to_function! out of function_body! --- src/CodeGeneration.jl | 63 ++++++++++++++++++++++--------------------- 1 file changed, 32 insertions(+), 31 deletions(-) diff --git a/src/CodeGeneration.jl b/src/CodeGeneration.jl index 0459d1c9..9d7b9d6b 100644 --- a/src/CodeGeneration.jl +++ b/src/CodeGeneration.jl @@ -14,6 +14,37 @@ function sparsity(sym_func::AbstractArray{<:Node}) end export sparsity + +function _dag_to_function!(node, local_body, variable_to_index, node_to_var) + + tmp = get(node_to_var, node, nothing) + + if tmp === nothing #if node not in node_to_var then it hasn't been visited. Otherwise it has so don't recurse. + node_to_var[node] = node_symbol(node, variable_to_index) + + if is_tree(node) + if value(node) === if_else #special case code generation for if...else. Need to generate nested code so only the statements in the true or false branch will be executed. + true_body = Expr(:block) + false_body = Expr(:block) + if_cond_var = _dag_to_function!(children(node)[1], local_body, variable_to_index, node_to_var) + _dag_to_function!(children(node)[2], true_body, variable_to_index, node_to_var) + _dag_to_function!(children(node)[3], false_body, variable_to_index, node_to_var) + statement = :($(node_to_var[node]) = if $(if_cond_var) + $(true_body) + else + $(false_body) + end) + else + args = _dag_to_function!.(children(node), Ref(local_body), Ref(variable_to_index), Ref(node_to_var)) + statement = :($(node_to_var[node]) = $(Symbol(value(node)))($(args...))) + end + push!(local_body.args, statement) + end + end + + return node_to_var[node] +end + """ function_body!( dag::Node, @@ -42,37 +73,7 @@ function function_body!(dag::Node, variable_to_index::IdDict{Node,Int64}, node_t node_to_var = IdDict{Node,Union{Symbol,Real,Expr}}() end - function _dag_to_function!(node, local_body) - - tmp = get(node_to_var, node, nothing) - - if tmp === nothing #if node not in node_to_var then it hasn't been visited. Otherwise it has so don't recurse. - node_to_var[node] = node_symbol(node, variable_to_index) - - if is_tree(node) - if value(node) === if_else #special case code generation for if...else. Need to generate nested code so only the statements in the true or false branch will be executed. - true_body = Expr(:block) - false_body = Expr(:block) - if_cond_var = _dag_to_function!(children(node)[1], local_body) - _dag_to_function!(children(node)[2], true_body) - _dag_to_function!(children(node)[3], false_body) - statement = :($(node_to_var[node]) = if $(if_cond_var) - $(true_body) - else - $(false_body) - end) - else - args = _dag_to_function!.(children(node), Ref(local_body)) - statement = :($(node_to_var[node]) = $(Symbol(value(node)))($(args...))) - end - push!(local_body.args, statement) - end - end - - return node_to_var[node] - end - - return body, _dag_to_function!(dag, body) + return body, _dag_to_function!(dag, body, variable_to_index, node_to_var) end function zero_array_declaration(array::StaticArray{S,<:Any,N}) where {S,N} From 94fecf894bc2e794e8bd4a784889dded8cc2d21c Mon Sep 17 00:00:00 2001 From: brianguenter <1brianguenter@gmail.com> Date: Tue, 17 Sep 2024 15:28:58 -0700 Subject: [PATCH 24/29] conditionals seem to work now. derivative(x^y,y) generates proper code. --- src/CodeGeneration.jl | 29 +++++++++++++++++++++++++++-- test/FDTests.jl | 30 +++++++++++++++++++++++++++++- 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/src/CodeGeneration.jl b/src/CodeGeneration.jl index 9d7b9d6b..b0f4cc18 100644 --- a/src/CodeGeneration.jl +++ b/src/CodeGeneration.jl @@ -27,8 +27,33 @@ function _dag_to_function!(node, local_body, variable_to_index, node_to_var) true_body = Expr(:block) false_body = Expr(:block) if_cond_var = _dag_to_function!(children(node)[1], local_body, variable_to_index, node_to_var) - _dag_to_function!(children(node)[2], true_body, variable_to_index, node_to_var) - _dag_to_function!(children(node)[3], false_body, variable_to_index, node_to_var) + + true_node = children(node)[2] + false_node = children(node)[3] + + if is_leaf(true_node) #handle leaf nodes properly + if is_constant(true_node) + temp_val = value(true_node) + else + temp_val = node_to_var[true_node] + end + + push!(true_body.args, :($(gensym(:s)) = $(temp_val))) #seems roundabout to use an assignment when really just want the value of the node but couldn't figure out how to make this work with Expr + else + _dag_to_function!(children(node)[2], true_body, variable_to_index, node_to_var) + end + + if is_leaf(false_node) + if is_constant(false_node) + temp_val = value(false_node) + else + temp_val = node_to_var[false_node] + end + push!(false_body.args, :($(gensym(:s)) = $(temp_val))) #seems roundabout to use an assignment when really just want the value of the node but couldn't figure out how to make this work with Expr + else + _dag_to_function!(children(node)[3], false_body, variable_to_index, node_to_var) + end + statement = :($(node_to_var[node]) = if $(if_cond_var) $(true_body) else diff --git a/test/FDTests.jl b/test/FDTests.jl index 9f969096..8cb773da 100644 --- a/test/FDTests.jl +++ b/test/FDTests.jl @@ -2128,11 +2128,39 @@ end @testitem "conditional code generation" begin @variables x y - f = ifelse(x < y, cos(x), sin(x)) + f = if_else(x < y, cos(x), sin(x)) input = [π / 2.0, 20.0] exe = make_function([f], [x, y]) @test isapprox(cos(input[1]), exe(input)[1]) input = [π / 2.0, 1.0] @test isapprox(sin(input[1]), exe(input)[1]) + + f = if_else(x < y, x, 2.0) + + exe = make_function([f], [x, y]) + @test exe([1.0, 2.0])[1] == 1.0 + @test exe([2.0, 1.0])[1] == 2.0 + + f = if_else(x < y, 1.0, 2.0) + exe = make_function([f], [x, y]) + @test exe([1.0, 2.0])[1] == 1.0 + @test exe([2.0, 1.0])[1] == 2.0 + + f = if_else(x < y, x, y) + exe = make_function([f], [x, y]) + @test exe([1.5, 3.0])[1] == 1.5 + @test exe([2.5, 1.2])[1] == 1.2 + + f = if_else(x < y, 2.7, y) + exe = make_function([f], [x, y]) + @test exe([1.5, 3.0])[1] == 2.7 + @test exe([2.5, 1.2])[1] == 1.2 + + #make sure derivative of x^y works because this derivative has a conditional statement. + f = x^y + g = derivative([f], y) + h = make_function(g, [x, y]) + @test isnan(h([0.0, 2.0])[1]) + @test isapprox(h([1.1, 2.0])[1], 0.11532531756323319) end From 6ea048a2f850ca1247a470366cfa80d09acf875c Mon Sep 17 00:00:00 2001 From: brianguenter <1brianguenter@gmail.com> Date: Tue, 17 Sep 2024 15:55:00 -0700 Subject: [PATCH 25/29] change version to 0.4.1 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index da5d9d3b..ca3dbb4b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "FastDifferentiation" uuid = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be" authors = ["BrianGuenter"] -version = "0.4.0" +version = "0.4.1" [deps] DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" From 743cc32836a4602f6cc0ed7870c967aeeeeb31ea Mon Sep 17 00:00:00 2001 From: brianguenter <1brianguenter@gmail.com> Date: Tue, 17 Sep 2024 15:56:46 -0700 Subject: [PATCH 26/29] updates index.md to have new version 0.4.1 so nobody tries to use 0.4.0 --- docs/src/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/index.md b/docs/src/index.md index e1441abf..4b75b15a 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -14,7 +14,7 @@ For f:ℝⁿ->ℝᵐ with n,m large **FD** may have better performance than conv **FD** may take much less time to compute symbolic derivatives than Symbolics.jl even in the ℝ¹->ℝ¹ case. The executables generated by **FD** may also be much faster (see [Symbolic Processing](@ref)). -As of version 0.4.0 **FD** allows you to create expressions with conditionals: +As of version 0.4.1 **FD** allows you to create expressions with conditionals: ```julia julia> @variables x y From 428284cb2ce3cb27316cc4f8200bb74c666a0056 Mon Sep 17 00:00:00 2001 From: brianguenter <1brianguenter@gmail.com> Date: Tue, 17 Sep 2024 16:18:47 -0700 Subject: [PATCH 27/29] fixed conditional test that was using ifelse instead of if_else added 2 new methods for ifelse so it can take inputs that are Real instead of just Node updated documentation --- docs/src/index.md | 67 +++++++++++++++++++++++++++++------------------ src/Methods.jl | 2 ++ test/FDTests.jl | 4 +-- 3 files changed, 45 insertions(+), 28 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 4b75b15a..b9786cf0 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -14,33 +14,7 @@ For f:ℝⁿ->ℝᵐ with n,m large **FD** may have better performance than conv **FD** may take much less time to compute symbolic derivatives than Symbolics.jl even in the ℝ¹->ℝ¹ case. The executables generated by **FD** may also be much faster (see [Symbolic Processing](@ref)). -As of version 0.4.1 **FD** allows you to create expressions with conditionals: -```julia - -julia> @variables x y -y -julia> f = ifelse(x a = make_function([f],[x,y]) - -julia> a(1.0,2.0) -1-element Vector{Float64}: - 1.0 - -julia> a(2.0,1.0) -1-element Vector{Float64}: - 2.0 -``` -Howver, you cannot yet compute derivatives of expressions that contain conditionals: -```julia -julia> jacobian([f],[x,y]) -ERROR: Your expression contained ifelse. FastDifferentiation does not yet support differentiation through ifelse or any of these conditionals (max, min, copysign, &, |, xor, <, >, <=, >=, !=, ==, signbit, isreal, iszero, isfinite, isnan, isinf, isinteger, !) -``` -This may be a breaking change for some users. In previous versions this threw an the expression `x==y` returned a `Bool`. Some data structures, such as `Dict` use `==` by default to determine if two entries are the same. This will no longer work since `x==y` will now return an expression graph. Use an `IDict` instead since this uses `===`. - -A future PR will add support for differentiating through conditionals. @@ -66,6 +40,47 @@ This is **beta** software being modified on a daily basis. Expect bugs and frequ ## Notes about special derivatives The derivative of `|u|` is `u/|u|` which is NaN when `u==0`. This is not a bug. The derivative of the absolute value function is undefined at 0 and the way **FD** signals this is by returning NaN. +## Conditionals + +As of version 0.4.1 **FD** allows you to create expressions with conditionals using either the builtin `ifelse` function or a new function `if_else`. `ifelse` will evaluate both inputs. By contrast `if_else` has the semantics of `if...else...end`; only the true or false branch will be executed. This is useful when your conditional is used to prevent exceptions because of illegal input values: +```julia +julia> f = if_else(x<0,NaN,sqrt(x)) +(if_else (x < 0) NaN sqrt(x)) + +julia> g = make_function([f],[x]) + + +julia> g([-1]) +1-element Vector{Float64}: + NaN + +julia> g([2.0]) +1-element Vector{Float64}: + 1.4142135623730951 +end +``` +In this case you wouldn't want to use `ifelse` because it evaluates both the true and false branches and causes a runtime exception: +```julia +julia> f = ifelse(x<0,NaN,sqrt(x)) +(ifelse (x < 0) NaN sqrt(x)) + +julia> g = make_function([f],[x]) +... + +julia> g([-1]) +ERROR: DomainError with -1.0: +sqrt was called with a negative real argument but will only return a complex result if called with a complex argument. Try sqrt(Complex(x)). +``` + +Howver, you cannot yet compute derivatives of expressions that contain conditionals: +```julia +julia> jacobian([f],[x,y]) +ERROR: Your expression contained ifelse. FastDifferentiation does not yet support differentiation through ifelse or any of these conditionals (max, min, copysign, &, |, xor, <, >, <=, >=, !=, ==, signbit, isreal, iszero, isfinite, isnan, isinf, isinteger, !) +``` +This may be a breaking change for some users. In previous versions the expression `x==y` returned a `Bool`. Some data structures, such as `Dict` use `==` by default to determine if two entries are the same. This will no longer work since `x==y` will now return an expression graph. Use an `IdDict` instead since this uses `===`. + +A future PR will add support for differentiating through conditionals. + diff --git a/src/Methods.jl b/src/Methods.jl index 86d9842a..8708b412 100644 --- a/src/Methods.jl +++ b/src/Methods.jl @@ -85,6 +85,8 @@ end macro number_methods(T, rhs1, rhs2, options=nothing) #special case for ifelse because it takes three arguments eval(:(Base.ifelse(a::$T, b::$T, c::$T) = simplify_check_cache(Base.ifelse, a, b, c))) + eval(:(Base.ifelse(a::$T, b::$Real, c::$T) = simplify_check_cache(Base.ifelse, a, Node(b), c))) + eval(:(Base.ifelse(a::$T, b::$T, c::$Real) = simplify_check_cache(Base.ifelse, a, b, Node(c)))) number_methods(T, rhs1, rhs2, options) |> esc end diff --git a/test/FDTests.jl b/test/FDTests.jl index 8cb773da..05bec783 100644 --- a/test/FDTests.jl +++ b/test/FDTests.jl @@ -2119,8 +2119,8 @@ end #conditional expr = x < y - f = ifelse(expr, x, y) - @test ==(FastDifferentiation.value(f), ifelse) + f = if_else(expr, x, y) + @test ==(FastDifferentiation.value(f), if_else) @test ===(FastDifferentiation.children(f)[1], expr) @test ===(FastDifferentiation.children(f)[2], x) @test ===(FastDifferentiation.children(f)[3], y) From bb4e574c625b89926ba4ef4c80c82866b3297219 Mon Sep 17 00:00:00 2001 From: brianguenter <1brianguenter@gmail.com> Date: Wed, 18 Sep 2024 09:52:46 -0700 Subject: [PATCH 28/29] changed is_conditional to detect both ifelse and if_else changedderivative(a::Node, index::Val{1}) in DifferentiationRules.jl to throw correct conditional error if expression is a conditional updated README and docs about new conditional feature. --- README.md | 44 ++++++++++++++++++++++++++----------- docs/src/index.md | 2 +- src/DifferentiationRules.jl | 6 ++++- src/ExpressionGraph.jl | 5 +++-- 4 files changed, 40 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 079467cc..3333927b 100644 --- a/README.md +++ b/README.md @@ -71,26 +71,44 @@ If you use FD in your work please share the functions you differentiate with me. **A**: If you multiply a matrix of **FD** variables times a vector of **FD** variables the matrix vector multiplication loop is effectively unrolled into scalar expressions. Matrix operations on large matrices will generate large executables and long preprocessing time. **FD** functions with up 10⁵ operations should still have reasonable preprocessing/compilation times (approximately 1 minute on a modern laptop) and good run time performance. **Q**: Does **FD** support conditionals? -**A**: **FD** does not yet support conditionals that involve the variables you are differentiating with respect to. You can do this: +**A**: As of version 0.4.1 **FD** expressions may contain conditionals which involve variables. However, you cannot yet differentiate an expression containing conditionals. A future PR will allow you to differentiate conditional expressions. + +You can use either the builtin `ifelse` function or a new function `if_else`. `ifelse` will evaluate both the true and false branches. By contrast `if_else` has the semantics of `if...else...end`; only one of the true or false branches will be executed. + +This is useful when your conditional is used to prevent exceptions because of illegal input values: ```julia -@variables x y #create FD variables +julia> f = if_else(x<0,NaN,sqrt(x)) +(if_else (x < 0) NaN sqrt(x)) + +julia> g = make_function([f],[x]) -julia> f(a,b,c) = a< 1.0 ? cos(b) : sin(c) -f (generic function with 2 methods) -julia> f(0.0,x,y) -cos(x) +julia> g([-1]) +1-element Vector{Float64}: + NaN -julia> f(1.0,x,y) -sin(y) +julia> g([2.0]) +1-element Vector{Float64}: + 1.4142135623730951 +end ``` -but you can't do this: +In this case you wouldn't want to use `ifelse` because it evaluates both the true and false branches and causes a runtime exception: ```julia -julia> f(a,b) = a < b ? cos(a) : sin(b) -f (generic function with 2 methods) +julia> f = ifelse(x<0,NaN,sqrt(x)) +(ifelse (x < 0) NaN sqrt(x)) -julia> f(x,y) -ERROR: MethodError: no method matching isless(::FastDifferentiation.Node{Symbol, 0}, ::FastDifferentiation.Node{Symbol, 0}) +julia> g = make_function([f],[x]) +... + +julia> g([-1]) +ERROR: DomainError with -1.0: +sqrt was called with a negative real argument but will only return a complex result if called with a complex argument. Try sqrt(Complex(x)). +``` + +However, you cannot yet compute derivatives of expressions that contain conditionals: +```julia +julia> jacobian([f],[x,y]) +ERROR: Your expression contained ifelse. FastDifferentiation does not yet support differentiation through ifelse or any of these conditionals (max, min, copysign, &, |, xor, <, >, <=, >=, !=, ==, signbit, isreal, iszero, isfinite, isnan, isinf, isinteger, !) ``` # Release Notes diff --git a/docs/src/index.md b/docs/src/index.md index b9786cf0..17e45e3b 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -72,7 +72,7 @@ ERROR: DomainError with -1.0: sqrt was called with a negative real argument but will only return a complex result if called with a complex argument. Try sqrt(Complex(x)). ``` -Howver, you cannot yet compute derivatives of expressions that contain conditionals: +However, you cannot yet compute derivatives of expressions that contain conditionals: ```julia julia> jacobian([f],[x,y]) ERROR: Your expression contained ifelse. FastDifferentiation does not yet support differentiation through ifelse or any of these conditionals (max, min, copysign, &, |, xor, <, >, <=, >=, !=, ==, signbit, isreal, iszero, isfinite, isnan, isinf, isinteger, !) diff --git a/src/DifferentiationRules.jl b/src/DifferentiationRules.jl index a43c0300..f360fa3e 100644 --- a/src/DifferentiationRules.jl +++ b/src/DifferentiationRules.jl @@ -37,8 +37,12 @@ derivative(::NoOp, arg::Tuple{T}, ::Val{1}) where {T} = 1.0 function_variable_derivative(a::Node, index::Val{i}) where {i} = check_cache((Differential, children(a)[i])) + +# These functions are primarily used to do error checking on expressions function derivative(a::Node, index::Val{1}) - if is_variable_function(a) + if is_conditional(a) + throw(conditional_error(a)) + elseif is_variable_function(a) return function_variable_derivative(a, index) elseif arity(a) == 1 return derivative(value(a), (children(a)[1],), index) diff --git a/src/ExpressionGraph.jl b/src/ExpressionGraph.jl index 023f6266..b4e03840 100644 --- a/src/ExpressionGraph.jl +++ b/src/ExpressionGraph.jl @@ -391,10 +391,11 @@ end is_if_else(a::Node) = value(a) == if_else +is_ifelse(a::Node) = value(a) == ifelse -conditional_error(a::Node) = ErrorException("Your expression contained $(value(a)). FastDifferentiation does not yet support differentiation through this conditional or any of these $(Tuple(not_currently_differentiable))") +conditional_error(a::Node) = ErrorException("Your expression contained a $(value(a)) expression. FastDifferentiation does not yet support differentiation through this conditional or any of these $(Tuple(not_currently_differentiable))") -is_conditional(a::Node) = is_if_else(a) || value(a) in not_currently_differentiable +is_conditional(a::Node) = is_if_else(a) || is_ifelse(a) || value(a) in not_currently_differentiable From 37baab57b677a549f2184c115b00218ba426a6e1 Mon Sep 17 00:00:00 2001 From: brianguenter <1brianguenter@gmail.com> Date: Wed, 18 Sep 2024 10:54:58 -0700 Subject: [PATCH 29/29] removed LogExpFunctions from dependencies because can't define diffrules for it that will work with FastDifferentiation added diff rule for mod2pi moved import DiffRules from DifferentiationRules.jl to FastDifferentiation.jl to make it easier to see at a glance the deps of FD renamed diadic_non_differentiable,monadic_non_differentiable to special_diadic and special_monadic --- Project.toml | 1 - src/DifferentiationRules.jl | 17 +++++++++++------ src/ExpressionGraph.jl | 17 ++++++++++++----- src/FastDifferentiation.jl | 1 + src/Methods.jl | 10 +++++----- 5 files changed, 29 insertions(+), 17 deletions(-) diff --git a/Project.toml b/Project.toml index ca3dbb4b..0310acae 100644 --- a/Project.toml +++ b/Project.toml @@ -7,7 +7,6 @@ version = "0.4.1" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" diff --git a/src/DifferentiationRules.jl b/src/DifferentiationRules.jl index f360fa3e..b56506e1 100644 --- a/src/DifferentiationRules.jl +++ b/src/DifferentiationRules.jl @@ -1,12 +1,17 @@ # Pre-defined derivatives -import DiffRules -#Special case rules for rules in DiffRules that use ?: or if...else, neither of which will work when add conditionals +#Special case rules for diffrentiation rules in DiffRules that use ?: or if...else, neither of which will work when add conditionals +#Some functions use ?: or if...else in the function definition itself; these are not compatible with FastDifferentiation. They can only be made compatible with either a custom derivative rule, a feature which doesn't exist yet, or by being redefined to use FastDifferentiation if_else. The latter is impractical and unlikely to ever happen. #airybix and airyprimex don't work in FastDifferentiation. airybix(x) where x is a Node causes a stack overflow. So no diffrule defined, although airybix uses if...else -#LogExpFunctions.xlogy uses ?: so doesn't work with FastDifferentiation. Most of the functions in this package don't work with FastDifferentiation. +#Most of the functions in package LogExpFunctions use ?: or if...else so don't work with FastDifferentiation. DiffRules.@define_diffrule Base.:^(x, y) = :($y * ($x^($y - 1))), :(if_else($x isa Real && $x <= 0, Base.oftype(float($x), NaN), ($x^$y) * log($x))) +DiffRules.@define_diffrule Base.mod2pi(x) = :(if_else(isinteger($x / $DiffRules.twoπ), oftype(float($x), NaN), one(float($x)))) + +# We provide this hook for special number types like `Interval` +# that need their own special definition of `abs`. +_abs_deriv(x) = signbit(x) ? -one(x) : one(x) for (modu, fun, arity) ∈ DiffRules.diffrules(; filter_modules=(:Base, :SpecialFunctions, :NaNMath)) fun in [:*, :+, :abs, :mod, :rem, :max, :min] && continue # special @@ -40,7 +45,7 @@ function_variable_derivative(a::Node, index::Val{i}) where {i} = check_cache((Di # These functions are primarily used to do error checking on expressions function derivative(a::Node, index::Val{1}) - if is_conditional(a) + if is_unsupported_function(a) throw(conditional_error(a)) elseif is_variable_function(a) return function_variable_derivative(a, index) @@ -54,7 +59,7 @@ function derivative(a::Node, index::Val{1}) end function derivative(a::Node, index::Val{2}) - if is_conditional(a) + if is_unsupported_function(a) throw(conditional_error(a)) elseif is_variable_function(a) return function_variable_derivative(a, index) @@ -66,7 +71,7 @@ function derivative(a::Node, index::Val{2}) end function derivative(a::Node, index::Val{i}) where {i} - if is_conditional(a) + if is_unsupported_function(a) throw(conditional_error(a)) elseif is_variable_function(a) return function_variable_derivative(a, index) diff --git a/src/ExpressionGraph.jl b/src/ExpressionGraph.jl index b4e03840..c7b217fd 100644 --- a/src/ExpressionGraph.jl +++ b/src/ExpressionGraph.jl @@ -176,7 +176,7 @@ Special if_else to use for conditionals instead of builtin ifelse because the la during code generation. """ function if_else(condition::Node, true_branch=Node(true_branch), false_branch=Node(false_branch)) - @assert value(condition) in diadic_non_differentiable || value(condition) in monadic_non_differentiable + @assert value(condition) in special_diadic || value(condition) in special_monadic check_cache((if_else, condition, true_branch, false_branch)) end export if_else @@ -388,14 +388,21 @@ function create_NoOp(child) return Node(NoOp(), child) end - - +is_NoOp(a::Node) = isa(value(a), NoOp) is_if_else(a::Node) = value(a) == if_else is_ifelse(a::Node) = value(a) == ifelse -conditional_error(a::Node) = ErrorException("Your expression contained a $(value(a)) expression. FastDifferentiation does not yet support differentiation through this conditional or any of these $(Tuple(not_currently_differentiable))") +conditional_error(a::Node) = ErrorException("Your expression contained a $(value(a)) expression. FastDifferentiation does not yet support differentiation through this function)") -is_conditional(a::Node) = is_if_else(a) || is_ifelse(a) || value(a) in not_currently_differentiable +function is_unsupported_function(a::Node) + if is_NoOp(a) + return false + elseif is_if_else(a) || is_ifelse(a) || !in(value(a), all_supported_functions) + return true + else + return false + end +end diff --git a/src/FastDifferentiation.jl b/src/FastDifferentiation.jl index 9f90f745..6705b03c 100644 --- a/src/FastDifferentiation.jl +++ b/src/FastDifferentiation.jl @@ -10,6 +10,7 @@ import Base: iterate using UUIDs using SparseArrays using DataStructures +import DiffRules module AutomaticDifferentiation struct NoDeriv diff --git a/src/Methods.jl b/src/Methods.jl index 8708b412..7019b042 100644 --- a/src/Methods.jl +++ b/src/Methods.jl @@ -32,10 +32,10 @@ const previously_declared_for = Set([]) const basic_monadic = [-, +] const basic_diadic = [+, -, *, /, //, \, ^] -const diadic_non_differentiable = [max, min, copysign, &, |, !, ⊻, <, >, ≤, ≥, ≠, ==, isless] -const monadic_non_differentiable = [signbit, isreal, iszero, isfinite, isnan, isinf, isinteger, !] +const special_diadic = [max, min, copysign, &, |, !, ⊻, <, >, ≤, ≥, ≠, ==, isless] +const special_monadic = [mod2pi, rem2pi, signbit, isreal, iszero, isfinite, isnan, isinf, isinteger, !] -const not_currently_differentiable = vcat(diadic_non_differentiable, monadic_non_differentiable) +const all_supported_functions = vcat(monadic, diadic, basic_monadic, basic_diadic, special_diadic, special_monadic) # TODO: keep domains tighter than this function number_methods(T, rhs1, rhs2, options=nothing) @@ -45,7 +45,7 @@ function number_methods(T, rhs1, rhs2, options=nothing) only_basics = options !== nothing ? options == :onlybasics : false skips = Meta.isexpr(options, [:vcat, :hcat, :vect]) ? Set(options.args) : [] - for f in (skip_basics ? diadic : only_basics ? basic_diadic : vcat(basic_diadic, diadic, diadic_non_differentiable)) + for f in (skip_basics ? diadic : only_basics ? basic_diadic : vcat(basic_diadic, diadic, special_diadic)) nameof(f) in skips && continue for S in previously_declared_for push!(exprs, quote @@ -64,7 +64,7 @@ function number_methods(T, rhs1, rhs2, options=nothing) push!(exprs, expr) end - for f in (skip_basics ? monadic : only_basics ? basic_monadic : vcat(basic_monadic, monadic, monadic_non_differentiable)) + for f in (skip_basics ? monadic : only_basics ? basic_monadic : vcat(basic_monadic, monadic, special_monadic)) nameof(f) in skips && continue push!(exprs, :((f::$(typeof(f)))(a::$T) = $rhs1)) end