Skip to content

Commit

Permalink
removed LogExpFunctions from dependencies because can't define diffru…
Browse files Browse the repository at this point in the history
…les for it that will work with FastDifferentiation

added diff rule for mod2pi

moved import DiffRules from DifferentiationRules.jl to FastDifferentiation.jl to make it easier to see at a glance the deps of FD

renamed diadic_non_differentiable,monadic_non_differentiable to special_diadic and special_monadic
  • Loading branch information
brianguenter committed Sep 18, 2024
1 parent bb4e574 commit 37baab5
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 17 deletions.
1 change: 0 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ version = "0.4.1"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down
17 changes: 11 additions & 6 deletions src/DifferentiationRules.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
# Pre-defined derivatives
import DiffRules
#Special case rules for rules in DiffRules that use ?: or if...else, neither of which will work when add conditionals

#Special case rules for diffrentiation rules in DiffRules that use ?: or if...else, neither of which will work when add conditionals
#Some functions use ?: or if...else in the function definition itself; these are not compatible with FastDifferentiation. They can only be made compatible with either a custom derivative rule, a feature which doesn't exist yet, or by being redefined to use FastDifferentiation if_else. The latter is impractical and unlikely to ever happen.
#airybix and airyprimex don't work in FastDifferentiation. airybix(x) where x is a Node causes a stack overflow. So no diffrule defined, although airybix uses if...else
#LogExpFunctions.xlogy uses ?: so doesn't work with FastDifferentiation. Most of the functions in this package don't work with FastDifferentiation.
#Most of the functions in package LogExpFunctions use ?: or if...else so don't work with FastDifferentiation.

DiffRules.@define_diffrule Base.:^(x, y) = :($y * ($x^($y - 1))), :(if_else($x isa Real && $x <= 0, Base.oftype(float($x), NaN), ($x^$y) * log($x)))

DiffRules.@define_diffrule Base.mod2pi(x) = :(if_else(isinteger($x / $DiffRules.twoπ), oftype(float($x), NaN), one(float($x))))

# We provide this hook for special number types like `Interval`
# that need their own special definition of `abs`.
_abs_deriv(x) = signbit(x) ? -one(x) : one(x)

for (modu, fun, arity) DiffRules.diffrules(; filter_modules=(:Base, :SpecialFunctions, :NaNMath))
fun in [:*, :+, :abs, :mod, :rem, :max, :min] && continue # special
Expand Down Expand Up @@ -40,7 +45,7 @@ function_variable_derivative(a::Node, index::Val{i}) where {i} = check_cache((Di

# These functions are primarily used to do error checking on expressions
function derivative(a::Node, index::Val{1})
if is_conditional(a)
if is_unsupported_function(a)
throw(conditional_error(a))
elseif is_variable_function(a)
return function_variable_derivative(a, index)
Expand All @@ -54,7 +59,7 @@ function derivative(a::Node, index::Val{1})
end

function derivative(a::Node, index::Val{2})
if is_conditional(a)
if is_unsupported_function(a)
throw(conditional_error(a))
elseif is_variable_function(a)
return function_variable_derivative(a, index)
Expand All @@ -66,7 +71,7 @@ function derivative(a::Node, index::Val{2})
end

function derivative(a::Node, index::Val{i}) where {i}
if is_conditional(a)
if is_unsupported_function(a)
throw(conditional_error(a))
elseif is_variable_function(a)
return function_variable_derivative(a, index)
Expand Down
17 changes: 12 additions & 5 deletions src/ExpressionGraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ Special if_else to use for conditionals instead of builtin ifelse because the la
during code generation.
"""
function if_else(condition::Node, true_branch=Node(true_branch), false_branch=Node(false_branch))
@assert value(condition) in diadic_non_differentiable || value(condition) in monadic_non_differentiable
@assert value(condition) in special_diadic || value(condition) in special_monadic
check_cache((if_else, condition, true_branch, false_branch))
end
export if_else
Expand Down Expand Up @@ -388,14 +388,21 @@ function create_NoOp(child)
return Node(NoOp(), child)
end



is_NoOp(a::Node) = isa(value(a), NoOp)
is_if_else(a::Node) = value(a) == if_else
is_ifelse(a::Node) = value(a) == ifelse

conditional_error(a::Node) = ErrorException("Your expression contained a $(value(a)) expression. FastDifferentiation does not yet support differentiation through this conditional or any of these $(Tuple(not_currently_differentiable))")
conditional_error(a::Node) = ErrorException("Your expression contained a $(value(a)) expression. FastDifferentiation does not yet support differentiation through this function)")

is_conditional(a::Node) = is_if_else(a) || is_ifelse(a) || value(a) in not_currently_differentiable
function is_unsupported_function(a::Node)
if is_NoOp(a)
return false
elseif is_if_else(a) || is_ifelse(a) || !in(value(a), all_supported_functions)
return true
else
return false
end
end



Expand Down
1 change: 1 addition & 0 deletions src/FastDifferentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import Base: iterate
using UUIDs
using SparseArrays
using DataStructures
import DiffRules

module AutomaticDifferentiation
struct NoDeriv
Expand Down
10 changes: 5 additions & 5 deletions src/Methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ const previously_declared_for = Set([])

const basic_monadic = [-, +]
const basic_diadic = [+, -, *, /, //, \, ^]
const diadic_non_differentiable = [max, min, copysign, &, |, !, , <, >, , , , ==, isless]
const monadic_non_differentiable = [signbit, isreal, iszero, isfinite, isnan, isinf, isinteger, !]
const special_diadic = [max, min, copysign, &, |, !, , <, >, , , , ==, isless]
const special_monadic = [mod2pi, rem2pi, signbit, isreal, iszero, isfinite, isnan, isinf, isinteger, !]

const not_currently_differentiable = vcat(diadic_non_differentiable, monadic_non_differentiable)
const all_supported_functions = vcat(monadic, diadic, basic_monadic, basic_diadic, special_diadic, special_monadic)

# TODO: keep domains tighter than this
function number_methods(T, rhs1, rhs2, options=nothing)
Expand All @@ -45,7 +45,7 @@ function number_methods(T, rhs1, rhs2, options=nothing)
only_basics = options !== nothing ? options == :onlybasics : false
skips = Meta.isexpr(options, [:vcat, :hcat, :vect]) ? Set(options.args) : []

for f in (skip_basics ? diadic : only_basics ? basic_diadic : vcat(basic_diadic, diadic, diadic_non_differentiable))
for f in (skip_basics ? diadic : only_basics ? basic_diadic : vcat(basic_diadic, diadic, special_diadic))
nameof(f) in skips && continue
for S in previously_declared_for
push!(exprs, quote
Expand All @@ -64,7 +64,7 @@ function number_methods(T, rhs1, rhs2, options=nothing)
push!(exprs, expr)
end

for f in (skip_basics ? monadic : only_basics ? basic_monadic : vcat(basic_monadic, monadic, monadic_non_differentiable))
for f in (skip_basics ? monadic : only_basics ? basic_monadic : vcat(basic_monadic, monadic, special_monadic))
nameof(f) in skips && continue
push!(exprs, :((f::$(typeof(f)))(a::$T) = $rhs1))
end
Expand Down

0 comments on commit 37baab5

Please sign in to comment.