From 76e04ce5ab3934e356456d5931c60971a1c5671b Mon Sep 17 00:00:00 2001 From: elaineran Date: Mon, 24 Jun 2024 16:59:01 -0400 Subject: [PATCH] Overload several Base numerical methods --- src/ExpressionGraph.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/ExpressionGraph.jl b/src/ExpressionGraph.jl index 466051f2..ce9ccc2a 100644 --- a/src/ExpressionGraph.jl +++ b/src/ExpressionGraph.jl @@ -95,6 +95,18 @@ Base.zero(::Node) = Node(0) Base.one(::Type{Node}) = Node(1) Base.one(::Node) = Node(1) +# These are essentially copied from Symbolics.jl: +# https://github.com/JuliaSymbolics/Symbolics.jl/blob/e4c328103ece494eaaab2a265524a64bfbe43dbd/src/num.jl#L31-L34 +Base.eps(::Type{Node}) = Node(0) +Base.typemin(::Type{Node}) = Node(-Inf) +Base.typemax(::Type{Node}) = Node(Inf) +Base.float(x::Node) = x + +# This one is needed because julia/base/float.jl only defines `isinf` for `Real`, but `Node +# <: Number`. (See https://github.com/brianguenter/FastDifferentiation.jl/issues/73) +Base.isinf(x::Node) = !isnan(x) & !isfinite(x) + + Broadcast.broadcastable(a::Node) = (a,) value(a::Node) = a.node_value @@ -324,6 +336,7 @@ rules = Any[] Base.convert(::Type{Node}, a::T) where {T<:Real} = Node(a) Base.promote_rule(::Type{<:Real}, ::Type{Node}) = Node +Base.promote_rule(::Type{Bool}, ::Type{Node}) = Node function Base.:-(a::AbstractArray{<:Node,N}) where {N} if length(a) == 0