Skip to content

Commit

Permalink
Tests and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Kolaru committed Feb 23, 2024
1 parent c452cc5 commit a3a00b9
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 17 deletions.
33 changes: 24 additions & 9 deletions ext/IntervalArithmeticForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,23 @@ module IntervalArithmeticForwardDiffExt
using IntervalArithmetic, ForwardDiff
using ForwardDiff: Dual, , value, partials

function isconstant_interval(x)
all(isthinzero.(values(partials(x))))
end

function Base.:(^)(x::Dual{Txy, <:Interval}, y::Dual{Txy, <:Interval}) where Txy
vx, vy = value(x), value(y)
primal = vx^vy
expv = vx^vy
powval = vy * vx^(vy - interval(1))
logval = primal * log(vx)
new_partials = _mul_partials(partials(x), partials(y), powval, logval)
return Dual{Txy}(primal, new_partials)
if isconstant_interval(y)
logval = one(expv)
elseif isthinzero(vx) && inf(vy) > 0
logval = zero(vx)
else
logval = expv * log(vx)
end
new_partials = ForwardDiff._mul_partials(partials(x), partials(y), powval, logval)
return Dual{Txy}(expv, new_partials)
end

function Base.:(^)(x::Dual{Tx, <:Interval}, y::Dual{Ty, <:Interval}) where {Tx, Ty}
Expand All @@ -22,15 +32,20 @@ end

function Base.:(^)(x::Dual{Tx, <:Interval}, y::Interval) where Tx
v = value(x)
new_partials = partials(x) * y * v^(y - interval(1))
return Dual{Tx}(v^y, new_partials)
expv = v^y
if isthinzero(y) || isconstant_interval(x)
new_partials = zero(partials(x))
else
new_partials = partials(x) * y * v^(y - interval(1))
end
return Dual{Tx}(expv, new_partials)
end

function Base.:(^)(x::Interval, y::Dual{Ty, <:Interval}) where Ty
v = value(y)
primal = x^v
deriv = primal*log(x)
return Dual{Ty}(primal, deriv * partials(y))
expv = x^v
deriv = (isthinzero(x) && inf(v) > 0) ? zero(expv) : expv*log(x)
return Dual{Ty}(expv, deriv * partials(y))
end

end
46 changes: 39 additions & 7 deletions test/interval_tests/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ end
@test ForwardDiff.derivative(abs, interval(-2, 2)) === interval(-1, 1, trv)

f(x) = abs(x)^interval(2)
@test_broken ForwardDiff.derivative(f, interval(-1, 1)) === interval(-2, 2, trv)
@test ForwardDiff.derivative(f, interval(-1, 1)) === interval(-2, 2, trv)

g(x) = abs(x)^2
@test ForwardDiff.derivative(g, interval(-1, 1) ) === interval(convert(Interval{Float64}, -2), convert(Interval{Float64}, 2), trv)
Expand Down Expand Up @@ -56,12 +56,44 @@ end
end

@testset "Power" begin
f(x) = interval(2)^x
f′(x) = log(interval(2)) * f(x)
df(t) = ForwardDiff.derivative(f, t)
fxy(xy) = xy[1]^xy[2]

# g(x) = 2^x # not guaranteed
for x in [0.0, 1.1, 2.2]
for y in [-3.3, 0.0, 4.4]
fx(xx) = xx^y
fxi(xx) = xx^interval(y)
fy(yy) = x^yy
fyi(yy) = interval(x)^yy

@test f′(0) === df(0)
dfdx = ForwardDiff.derivative(fxi, interval(x))
dfdy = ForwardDiff.derivative(fyi, interval(y))
grad = ForwardDiff.gradient(fxy, [interval(x), interval(y)])

@test isguaranteed(dfdx)
@test isguaranteed(dfdy)
@test isguaranteed(grad[1])
@test isguaranteed(grad[2])

if iszero(x) && y < 0
@test decoration(dfdx) == trv
else
@test in_interval(ForwardDiff.derivative(fx, x), dfdx)
end

if iszero(x) && y <= 0
@test decoration(dfdy) == trv
else
@test in_interval(ForwardDiff.derivative(fy, y), dfdy)
end

if iszero(x) && iszero(y)
@test decoration(grad[1]) == trv
@test decoration(dfdx) == com
else
@test isequal_interval(dfdx, grad[1])
end
@test isequal_interval(dfdy, grad[2])
end
end
end
end
end
4 changes: 3 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using Test

using ForwardDiff
using IntervalArithmetic
using InteractiveUtils

Expand All @@ -22,4 +24,4 @@ for f ∈ readdir("ITF1788_tests"; join = true)
@testset "$f" begin
include(f)
end
end
end

0 comments on commit a3a00b9

Please sign in to comment.