From b24a59e68f5d6efce246b2d114ac042bc4c21cd3 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 3 Sep 2024 14:38:53 +0530 Subject: [PATCH 1/5] fix: fix `anyeltypedual` for `ReverseDiff.TrackedArray` --- ext/DiffEqBaseReverseDiffExt.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ext/DiffEqBaseReverseDiffExt.jl b/ext/DiffEqBaseReverseDiffExt.jl index 397c234e0..6d30ad2fc 100644 --- a/ext/DiffEqBaseReverseDiffExt.jl +++ b/ext/DiffEqBaseReverseDiffExt.jl @@ -5,6 +5,10 @@ import DiffEqBase: value import ReverseDiff import DiffEqBase.ArrayInterface +function DiffEqBase.anyeltypedual(::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {T <: ReverseDiff.TrackedArray} + Any +end + DiffEqBase.value(x::Type{ReverseDiff.TrackedReal{V, D, O}}) where {V, D, O} = V function DiffEqBase.value(x::Type{ ReverseDiff.TrackedArray{V, D, N, VA, DA}, From 036c6699ae7a0e176883e14f82571581de66d1c5 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 3 Sep 2024 14:39:29 +0530 Subject: [PATCH 2/5] feat: use tunable portion of parameter object in `promote_u0` --- Project.toml | 2 ++ src/DiffEqBase.jl | 2 ++ src/forwarddiff.jl | 12 ++++++++++++ 3 files changed, 16 insertions(+) diff --git a/Project.toml b/Project.toml index 1ed62fd70..47372939d 100644 --- a/Project.toml +++ b/Project.toml @@ -27,6 +27,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" +SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" @@ -95,6 +96,7 @@ Reexport = "1.0" ReverseDiff = "1" SciMLBase = "2.28.0" SciMLOperators = "0.3" +SciMLStructures = "1.5" Setfield = "1" SparseArrays = "1.9" Static = "1" diff --git a/src/DiffEqBase.jl b/src/DiffEqBase.jl index 02d5fa05b..a4dc83a7d 100644 --- a/src/DiffEqBase.jl +++ b/src/DiffEqBase.jl @@ -101,6 +101,8 @@ import SciMLBase: solve, init, step!, solve!, __init, __solve, update_coefficien import SciMLBase: AbstractDiffEqLinearOperator # deprecation path +import SciMLStructures + import Tricks using Reexport diff --git a/src/forwarddiff.jl b/src/forwarddiff.jl index 6eb3e1c7a..1e9cbfcd5 100644 --- a/src/forwarddiff.jl +++ b/src/forwarddiff.jl @@ -352,6 +352,12 @@ DiffEqBase.anyeltypedual(f::SciMLBase.AbstractSciMLFunction, ::Type{Val{counter} @inline promote_u0(::Nothing, p, t0) = nothing @inline function promote_u0(u0, p, t0) + if SciMLStructures.isscimlstructure(p) + _p = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1] + if _p != p + return promote_u0(u0, _p, t0) + end + end Tu = eltype(u0) if Tu <: ForwardDiff.Dual return u0 @@ -373,6 +379,12 @@ DiffEqBase.anyeltypedual(f::SciMLBase.AbstractSciMLFunction, ::Type{Val{counter} end @inline function promote_u0(u0::AbstractArray{<:Complex}, p, t0) + if SciMLStructures.isscimlstructure(p) + _p = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1] + if _p != p + return promote_u0(u0, _p, t0) + end + end Tu = real(eltype(u0)) if Tu <: ForwardDiff.Dual return u0 From b81868946bd53f36250ffebda41c97817bbcd866 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 3 Sep 2024 16:01:42 +0530 Subject: [PATCH 3/5] fix: recursively descend `anyeltypedual` for `TrackedArray` Co-authored-by: Christopher Rackauckas --- ext/DiffEqBaseReverseDiffExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/DiffEqBaseReverseDiffExt.jl b/ext/DiffEqBaseReverseDiffExt.jl index 6d30ad2fc..652eca185 100644 --- a/ext/DiffEqBaseReverseDiffExt.jl +++ b/ext/DiffEqBaseReverseDiffExt.jl @@ -5,8 +5,8 @@ import DiffEqBase: value import ReverseDiff import DiffEqBase.ArrayInterface -function DiffEqBase.anyeltypedual(::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {T <: ReverseDiff.TrackedArray} - Any +function DiffEqBase.anyeltypedual(::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {V, D, N, VA, DA, T <: ReverseDiff.TrackedArray{V, D, N, VA, DA}} + DiffEqBase.anyeltypedual(V, Val{counter}) end DiffEqBase.value(x::Type{ReverseDiff.TrackedReal{V, D, O}}) where {V, D, O} = V From eb42bc2a6c2cfb956225a38bbcfddf5955cb0e39 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 3 Sep 2024 16:51:30 +0530 Subject: [PATCH 4/5] fix: fix promote_u0 when `p isa TrackedArray` and `u0` is not --- ext/DiffEqBaseReverseDiffExt.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ext/DiffEqBaseReverseDiffExt.jl b/ext/DiffEqBaseReverseDiffExt.jl index 652eca185..9272bc1ec 100644 --- a/ext/DiffEqBaseReverseDiffExt.jl +++ b/ext/DiffEqBaseReverseDiffExt.jl @@ -4,6 +4,7 @@ using DiffEqBase import DiffEqBase: value import ReverseDiff import DiffEqBase.ArrayInterface +import DiffEqBase.ForwardDiff function DiffEqBase.anyeltypedual(::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {V, D, N, VA, DA, T <: ReverseDiff.TrackedArray{V, D, N, VA, DA}} DiffEqBase.anyeltypedual(V, Val{counter}) @@ -37,6 +38,7 @@ function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal}, u0 end DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray, t0) = ReverseDiff.track(u0) +DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray{T}, t0) where {T <: ForwardDiff.Dual} = ReverseDiff.track(T.(u0)) DiffEqBase.promote_u0(u0, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = eltype(p).(u0) # Support adaptive with non-tracked time From 4e60a8767868f289cbc4e79caaf6841417c69dc7 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 3 Sep 2024 16:51:41 +0530 Subject: [PATCH 5/5] test: add tests for ReverseDiff dual detection and promotion --- Project.toml | 2 +- test/forwarddiff_dual_detection.jl | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 47372939d..76b9359fc 100644 --- a/Project.toml +++ b/Project.toml @@ -133,4 +133,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] -test = ["Distributed", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "SymbolicIndexingInterface", "ForwardDiff", "SparseArrays", "InteractiveUtils", "Plots", "Pkg", "Random", "StaticArrays", "SafeTestsets", "Statistics", "Test", "Distributions", "Aqua", "BenchmarkTools"] +test = ["Distributed", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "SymbolicIndexingInterface", "ForwardDiff", "SparseArrays", "InteractiveUtils", "Plots", "Pkg", "Random", "ReverseDiff", "StaticArrays", "SafeTestsets", "Statistics", "Test", "Distributions", "Aqua", "BenchmarkTools"] diff --git a/test/forwarddiff_dual_detection.jl b/test/forwarddiff_dual_detection.jl index 8f269c478..1a2cbad65 100644 --- a/test/forwarddiff_dual_detection.jl +++ b/test/forwarddiff_dual_detection.jl @@ -1,4 +1,5 @@ using DiffEqBase, ForwardDiff, Test, InteractiveUtils +using ReverseDiff, SciMLStructures using Plots u0 = 2.0 @@ -348,3 +349,20 @@ foo = SciMLBase.build_solution( prob, DiffEqBase.InternalEuler.FwdEulerAlg(), [u0, u0], [0.0, 1.0]) DiffEqBase.anyeltypedual((; x = foo)) DiffEqBase.anyeltypedual((; x = foo, y = prob.f)) + +@test DiffEqBase.anyeltypedual(ReverseDiff.track(ones(3))) == Any +@test DiffEqBase.anyeltypedual(typeof(ReverseDiff.track(ones(3)))) == Any +@test DiffEqBase.anyeltypedual(ReverseDiff.track(ones(ForwardDiff.Dual, 3))) == eltype(ones(ForwardDiff.Dual, 3)) +@test DiffEqBase.anyeltypedual(typeof(ReverseDiff.track(ones(ForwardDiff.Dual, 3)))) == eltype(ones(ForwardDiff.Dual, 3)) + +struct FakeParameterObject{T} + tunables::T +end + +SciMLStructures.isscimlstructure(::FakeParameterObject) = true +SciMLStructures.canonicalize(::SciMLStructures.Tunable, f::FakeParameterObject) = f.tunables, x -> FakeParameterObject(x), true + +@test DiffEqBase.promote_u0(ones(3), FakeParameterObject(ReverseDiff.track(ones(3))), 0.0) isa ReverseDiff.TrackedArray +@test DiffEqBase.promote_u0(1.0, FakeParameterObject(ReverseDiff.track(ones(3))), 0.0) isa ReverseDiff.TrackedReal +@test DiffEqBase.promote_u0(ones(3), FakeParameterObject(ReverseDiff.track(ones(ForwardDiff.Dual, 3))), 0.0) isa ReverseDiff.TrackedArray{<:ForwardDiff.Dual} +@test DiffEqBase.promote_u0(1.0, FakeParameterObject(ReverseDiff.track(ones(ForwardDiff.Dual, 3))), 0.0) isa ReverseDiff.TrackedReal{<:ForwardDiff.Dual}