diff --git a/Project.toml b/Project.toml index d77e57aaf..9d42a5b82 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "IntervalArithmetic" uuid = "d1acc4aa-44c8-5952-acd4-ba5d80a2a253" repo = "https://github.com/JuliaIntervals/IntervalArithmetic.jl.git" -version = "0.22.7" +version = "0.22.8" [deps] CRlibm_jll = "4e9b3aee-d8a1-5a3d-ad8b-7d824db253f0" @@ -9,15 +9,18 @@ RoundingEmulator = "5eaf0fd0-dfba-4ccb-bf02-d820a40db705" [weakdeps] DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" [extensions] IntervalArithmeticDiffRulesExt = "DiffRules" +IntervalArithmeticForwardDiffExt = "ForwardDiff" IntervalArithmeticRecipesBaseExt = "RecipesBase" [compat] CRlibm_jll = "1" DiffRules = "1" +ForwardDiff = "0.10" RecipesBase = "1" RoundingEmulator = "0.2" julia = "1.9" diff --git a/ext/IntervalArithmeticForwardDiffExt.jl b/ext/IntervalArithmeticForwardDiffExt.jl new file mode 100644 index 000000000..be23e524e --- /dev/null +++ b/ext/IntervalArithmeticForwardDiffExt.jl @@ -0,0 +1,51 @@ +module IntervalArithmeticForwardDiffExt + +using IntervalArithmetic, ForwardDiff +using ForwardDiff: Dual, ≺, value, partials + +function isconstant_interval(x) + all(isthinzero.(values(partials(x)))) +end + +function Base.:(^)(x::Dual{Txy, <:Interval}, y::Dual{Txy, <:Interval}) where Txy + vx, vy = value(x), value(y) + expv = vx^vy + powval = vy * vx^(vy - interval(1)) + if isconstant_interval(y) + logval = one(expv) + elseif isthinzero(vx) && inf(vy) > 0 + logval = zero(vx) + else + logval = expv * log(vx) + end + new_partials = ForwardDiff._mul_partials(partials(x), partials(y), powval, logval) + return Dual{Txy}(expv, new_partials) +end + +function Base.:(^)(x::Dual{Tx, <:Interval}, y::Dual{Ty, <:Interval}) where {Tx, Ty} + if Ty ≺ Tx + return x^value(y) + else + return value(x)^y + end +end + +function Base.:(^)(x::Dual{Tx, <:Interval}, y::Interval) where Tx + v = value(x) + expv = v^y + if isthinzero(y) || isconstant_interval(x) + new_partials = zero(partials(x)) + else + new_partials = partials(x) * y * v^(y - interval(1)) + end + return Dual{Tx}(expv, new_partials) +end + +function Base.:(^)(x::Interval, y::Dual{Ty, <:Interval}) where Ty + v = value(y) + expv = x^v + deriv = (isthinzero(x) && inf(v) > 0) ? zero(expv) : expv*log(x) + return Dual{Ty}(expv, deriv * partials(y)) +end + +end diff --git a/test/interval_tests/forwarddiff.jl b/test/interval_tests/forwarddiff.jl index 7d48f3d0b..c417c6303 100644 --- a/test/interval_tests/forwarddiff.jl +++ b/test/interval_tests/forwarddiff.jl @@ -16,7 +16,7 @@ end @test ForwardDiff.derivative(abs, interval(-2, 2)) === interval(-1, 1, trv) f(x) = abs(x)^interval(2) - @test_broken ForwardDiff.derivative(f, interval(-1, 1)) === interval(-2, 2, trv) + @test ForwardDiff.derivative(f, interval(-1, 1)) === interval(-2, 2, trv) g(x) = abs(x)^2 @test ForwardDiff.derivative(g, interval(-1, 1) ) === interval(convert(Interval{Float64}, -2), convert(Interval{Float64}, 2), trv) @@ -56,12 +56,44 @@ end end @testset "Power" begin - f(x) = interval(2)^x - f′(x) = log(interval(2)) * f(x) - df(t) = ForwardDiff.derivative(f, t) + fxy(xy) = xy[1]^xy[2] - # g(x) = 2^x # not guaranteed + for x in [0.0, 1.1, 2.2] + for y in [-3.3, 0.0, 4.4] + fx(xx) = xx^y + fxi(xx) = xx^interval(y) + fy(yy) = x^yy + fyi(yy) = interval(x)^yy - @test f′(0) === df(0) + dfdx = ForwardDiff.derivative(fxi, interval(x)) + dfdy = ForwardDiff.derivative(fyi, interval(y)) + grad = ForwardDiff.gradient(fxy, [interval(x), interval(y)]) + + @test isguaranteed(dfdx) + @test isguaranteed(dfdy) + @test isguaranteed(grad[1]) + @test isguaranteed(grad[2]) + + if iszero(x) && y < 0 + @test decoration(dfdx) == trv + else + @test in_interval(ForwardDiff.derivative(fx, x), dfdx) + end + + if iszero(x) && y <= 0 + @test decoration(dfdy) == trv + else + @test in_interval(ForwardDiff.derivative(fy, y), dfdy) + end + + if iszero(x) && iszero(y) + @test decoration(grad[1]) == trv + @test decoration(dfdx) == com + else + @test isequal_interval(dfdx, grad[1]) + end + @test isequal_interval(dfdy, grad[2]) + end + end end -end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 98704b6b6..aabd85d91 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,6 @@ using Test + +using ForwardDiff using IntervalArithmetic using InteractiveUtils @@ -22,4 +24,4 @@ for f ∈ readdir("ITF1788_tests"; join = true) @testset "$f" begin include(f) end -end +end \ No newline at end of file