diff --git a/Project.toml b/Project.toml index 1ed62fd70..76b9359fc 100644 --- a/Project.toml +++ b/Project.toml @@ -27,6 +27,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" +SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" @@ -95,6 +96,7 @@ Reexport = "1.0" ReverseDiff = "1" SciMLBase = "2.28.0" SciMLOperators = "0.3" +SciMLStructures = "1.5" Setfield = "1" SparseArrays = "1.9" Static = "1" @@ -131,4 +133,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] -test = ["Distributed", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "SymbolicIndexingInterface", "ForwardDiff", "SparseArrays", "InteractiveUtils", "Plots", "Pkg", "Random", "StaticArrays", "SafeTestsets", "Statistics", "Test", "Distributions", "Aqua", "BenchmarkTools"] +test = ["Distributed", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "SymbolicIndexingInterface", "ForwardDiff", "SparseArrays", "InteractiveUtils", "Plots", "Pkg", "Random", "ReverseDiff", "StaticArrays", "SafeTestsets", "Statistics", "Test", "Distributions", "Aqua", "BenchmarkTools"] diff --git a/ext/DiffEqBaseReverseDiffExt.jl b/ext/DiffEqBaseReverseDiffExt.jl index 397c234e0..9272bc1ec 100644 --- a/ext/DiffEqBaseReverseDiffExt.jl +++ b/ext/DiffEqBaseReverseDiffExt.jl @@ -4,6 +4,11 @@ using DiffEqBase import DiffEqBase: value import ReverseDiff import DiffEqBase.ArrayInterface +import DiffEqBase.ForwardDiff + +function DiffEqBase.anyeltypedual(::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {V, D, N, VA, DA, T <: ReverseDiff.TrackedArray{V, D, N, VA, DA}} + DiffEqBase.anyeltypedual(V, Val{counter}) +end DiffEqBase.value(x::Type{ReverseDiff.TrackedReal{V, D, O}}) where {V, D, O} = V function DiffEqBase.value(x::Type{ @@ -33,6 +38,7 @@ function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal}, u0 end DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray, t0) = ReverseDiff.track(u0) +DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray{T}, t0) where {T <: ForwardDiff.Dual} = ReverseDiff.track(T.(u0)) DiffEqBase.promote_u0(u0, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = eltype(p).(u0) # Support adaptive with non-tracked time diff --git a/src/DiffEqBase.jl b/src/DiffEqBase.jl index 02d5fa05b..a4dc83a7d 100644 --- a/src/DiffEqBase.jl +++ b/src/DiffEqBase.jl @@ -101,6 +101,8 @@ import SciMLBase: solve, init, step!, solve!, __init, __solve, update_coefficien import SciMLBase: AbstractDiffEqLinearOperator # deprecation path +import SciMLStructures + import Tricks using Reexport diff --git a/src/forwarddiff.jl b/src/forwarddiff.jl index 6eb3e1c7a..1e9cbfcd5 100644 --- a/src/forwarddiff.jl +++ b/src/forwarddiff.jl @@ -352,6 +352,12 @@ DiffEqBase.anyeltypedual(f::SciMLBase.AbstractSciMLFunction, ::Type{Val{counter} @inline promote_u0(::Nothing, p, t0) = nothing @inline function promote_u0(u0, p, t0) + if SciMLStructures.isscimlstructure(p) + _p = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1] + if _p != p + return promote_u0(u0, _p, t0) + end + end Tu = eltype(u0) if Tu <: ForwardDiff.Dual return u0 @@ -373,6 +379,12 @@ DiffEqBase.anyeltypedual(f::SciMLBase.AbstractSciMLFunction, ::Type{Val{counter} end @inline function promote_u0(u0::AbstractArray{<:Complex}, p, t0) + if SciMLStructures.isscimlstructure(p) + _p = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1] + if _p != p + return promote_u0(u0, _p, t0) + end + end Tu = real(eltype(u0)) if Tu <: ForwardDiff.Dual return u0 diff --git a/test/forwarddiff_dual_detection.jl b/test/forwarddiff_dual_detection.jl index 8f269c478..1a2cbad65 100644 --- a/test/forwarddiff_dual_detection.jl +++ b/test/forwarddiff_dual_detection.jl @@ -1,4 +1,5 @@ using DiffEqBase, ForwardDiff, Test, InteractiveUtils +using ReverseDiff, SciMLStructures using Plots u0 = 2.0 @@ -348,3 +349,20 @@ foo = SciMLBase.build_solution( prob, DiffEqBase.InternalEuler.FwdEulerAlg(), [u0, u0], [0.0, 1.0]) DiffEqBase.anyeltypedual((; x = foo)) DiffEqBase.anyeltypedual((; x = foo, y = prob.f)) + +@test DiffEqBase.anyeltypedual(ReverseDiff.track(ones(3))) == Any +@test DiffEqBase.anyeltypedual(typeof(ReverseDiff.track(ones(3)))) == Any +@test DiffEqBase.anyeltypedual(ReverseDiff.track(ones(ForwardDiff.Dual, 3))) == eltype(ones(ForwardDiff.Dual, 3)) +@test DiffEqBase.anyeltypedual(typeof(ReverseDiff.track(ones(ForwardDiff.Dual, 3)))) == eltype(ones(ForwardDiff.Dual, 3)) + +struct FakeParameterObject{T} + tunables::T +end + +SciMLStructures.isscimlstructure(::FakeParameterObject) = true +SciMLStructures.canonicalize(::SciMLStructures.Tunable, f::FakeParameterObject) = f.tunables, x -> FakeParameterObject(x), true + +@test DiffEqBase.promote_u0(ones(3), FakeParameterObject(ReverseDiff.track(ones(3))), 0.0) isa ReverseDiff.TrackedArray +@test DiffEqBase.promote_u0(1.0, FakeParameterObject(ReverseDiff.track(ones(3))), 0.0) isa ReverseDiff.TrackedReal +@test DiffEqBase.promote_u0(ones(3), FakeParameterObject(ReverseDiff.track(ones(ForwardDiff.Dual, 3))), 0.0) isa ReverseDiff.TrackedArray{<:ForwardDiff.Dual} +@test DiffEqBase.promote_u0(1.0, FakeParameterObject(ReverseDiff.track(ones(ForwardDiff.Dual, 3))), 0.0) isa ReverseDiff.TrackedReal{<:ForwardDiff.Dual}