-
-
Notifications
You must be signed in to change notification settings - Fork 211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reversediff with in optimization problem #2968
Comments
As I understand this is still work in progress. See SciML/SciMLSensitivity.jl#1085 for more details. |
This should be good with todays release which handles the SciMLStructures interface for forward mode. There's no MWE here to double check, so I'll close under the assumption that it's just the ForwardDiffSensitivity handling of MTK as the stack trace alludes to. |
This is still happening, but there seems to be some other errors along the way as well with the existing example it gets a StackOverflow error
with
some how changing the sensealg to zygote leads to
Also not the failure on master in Optimization.jl https://github.com/SciML/Optimization.jl/actions/runs/10550106197/job/29225903838#step:7:1169 seems to be related, it goes away on using |
This constructor already exists? https://github.com/SciML/ModelingToolkit.jl/blob/master/ext/MTKChainRulesCoreExt.jl#L7 Are you on the latest MTK? |
yeah v9.33.1, I guess that's not the big issue the main ones are the last two. Though it probably is an issue for SciMLSensitivity |
That really looks like an environment issue. It's not seeing that |
There's a few things that needed to be released. Let me try that first 😅 |
I think everything in SciMLSensitivity.jl is updated, so this should be fine. @AayushSabharwal can you check, and can you add reverse mode to our docs/tests in the right spots? |
ReverseDiff.jl doesn't work because DiffEqBase doesn't like it ( |
JuliaSymbolics/SymbolicUtils.jl#646 and SciML/DiffEqBase.jl#1078 are necessary to use |
Is there anything left to do here? |
ReverseDiff and Zygote are not working for me at the moment. julia> using ModelingToolkit
julia> using ModelingToolkit: t_nounits as t, D_nounits as D
julia>
julia> @parameters α β γ δ
4-element Vector{Num}:
α
β
γ
δ
julia> @variables x(t) y(t)
2-element Vector{Num}:
x(t)
y(t)
julia> eqs = [D(x) ~ (α - β * y) * x
D(y) ~ (δ * x - γ) * y]
2-element Vector{Equation}:
Differential(t)(x(t)) ~ x(t)*(α - y(t)*β)
Differential(t)(y(t)) ~ (-γ + x(t)*δ)*y(t)
julia> @mtkbuild odesys = ODESystem(eqs, t)
┌ Warning: solve_for is deprecated, please use symbolic_linear_solve instead.
│ caller = tearing_reassemble(state::TearingState{ODESystem}, var_eq_matching::ModelingToolkit.BipartiteGraphs.Matching{Union{ModelingToolkit.BipartiteGraphs.Unassigned, ModelingToolki
t.StructuralTransformations.SelectedState}, Vector{Union{ModelingToolkit.BipartiteGraphs.Unassigned, ModelingToolkit.StructuralTransformations.SelectedState, Int64}}}, full_var_eq_matchi
ng::Nothing; simplify::Bool, mm::ModelingToolkit.SparseMatrixCLIL{Int64, Int64}) at symbolics_tearing.jl:459
└ @ ModelingToolkit.StructuralTransformations ~/.julia/packages/ModelingToolkit/2KZCu/src/structural_transformation/symbolics_tearing.jl:459
Model odesys with 2 equations
Unknowns (2):
x(t)
y(t)
Parameters (4):
α
β
δ
γ
Incidence matrix:2×4 SparseArrays.SparseMatrixCSC{Num, Int64} with 6 stored entries:
× × × ⋅
× × ⋅ ×
julia>
julia> using OrdinaryDiffEq
julia>
julia> odeprob = ODEProblem(
odesys, [x => 1.0, y => 1.0], (0.0, 10.0), [α => 1.5, β => 1.0, γ => 3.0, δ => 1.0])
ODEProblem with uType Vector{Float64} and tType Float64. In-place: true
timespan: (0.0, 10.0)
u0: 2-element Vector{Float64}:
1.0
1.0
julia>
julia> timesteps = 0.0:0.1:10.0
0.0:0.1:10.0
julia> sol = solve(odeprob, Tsit5(); saveat=timesteps)
retcode: Success
Interpolation: 1st order linear
t: 101-element Vector{Float64}:
0.0
0.1
0.2
0.3
0.4
0.5
0.6
0.7
0.8
0.9
1.0
1.1
1.2
1.3
1.4
1.5
1.6
1.7
1.8
1.9
2.0
2.1
2.2
2.3
2.4
2.5
2.6
2.7
2.8
2.9
3.0
3.1
3.2
3.3
3.4
3.5
3.6
3.7
3.8
3.9
⋮
6.1
6.2
6.3
6.4
6.5
6.6
6.7
6.8
6.9
7.0
7.1
7.2
7.3
7.4
7.5
7.6
7.7
7.8
7.9
8.0
8.1
8.2
8.3
8.4
8.5
8.6
8.7
8.8
8.9
9.0
9.1
9.2
9.3
9.4
9.5
9.6
9.7
9.8
9.9
10.0
u: 101-element Vector{Vector{Float64}}:
[1.0, 1.0]
[1.0610780673356455, 0.8210842775886171]
[1.1440276717257598, 0.6790526689784503]
[1.2491712125724483, 0.5668931465841179]
[1.3776445705636384, 0.47881295137951546]
[1.5312308177480134, 0.4101564670866144]
[1.7122697558187643, 0.35726544879948385]
[1.923578275830157, 0.31734720616177164]
[2.1683910896994067, 0.288388843787324]
[2.4502506671402524, 0.2690537093960071]
[2.7728223025987977, 0.25872441605303803]
[3.1397329894440613, 0.2574966932231203]
[3.553901355478236, 0.26645005590756804]
[4.016473052558066, 0.28817137186254316]
[4.525663272806641, 0.32730043827773214]
[5.074397765405284, 0.39143450178869377]
[5.641916476196301, 0.49553075376759187]
[6.192472161132441, 0.6621277791342365]
[6.648914999451691, 0.9354326153529998]
[6.891031980057643, 1.3703704468848408]
[6.773495154522366, 2.0153567192692066]
[6.16957697373435, 2.8703702051675317]
[5.150805370096711, 3.744898507982017]
[3.973483928044951, 4.377341076689389]
[2.935261043812171, 4.5779802856931875]
[2.1768955119119644, 4.364209539476808]
[1.6703212546949273, 3.91149854259555]
[1.3477004416010134, 3.3657947039986973]
[1.1491565951639178, 2.8230086295254524]
[1.0323069911195162, 2.3311220268787345]
[0.9706713993069102, 1.9081445902225405]
[0.9488493052093983, 1.5556308529856293]
[0.9577818215432641, 1.267363865724042]
[0.9922794958077016, 1.0348218681752417]
[1.0495828373263223, 0.8488784188403882]
[1.1288176929270866, 0.701111055792832]
[1.230210616609076, 0.5842710897329086]
[1.354684071606731, 0.492417079724541]
[1.5039050790725055, 0.4207861727158718]
[1.6801514655730152, 0.3654763255815079]
⋮
[1.1806510314183558, 2.9121763370943823]
[1.0512640300755585, 2.410953991987525]
[0.9813575934995471, 1.9763328963688114]
[0.9534959024195566, 1.612257955268154]
[0.957455228555967, 1.313908676077288]
[0.9876865454680123, 1.0725479040433192]
[1.0413016959542074, 0.8791865264865037]
[1.1169174187486561, 0.7254047577193622]
[1.2145060700479002, 0.6037649755582781]
[1.3350343536024947, 0.5080112218132274]
[1.4800614990225558, 0.43307075757976243]
[1.6517069024700077, 0.3750865476687141]
[1.852649743032995, 0.3310364165707537]
[2.0859783340597544, 0.2985534421000465]
[2.355082256708884, 0.2760570725478811]
[2.6636389478100773, 0.26273370570565524]
[3.01519682870142, 0.25853390704805435]
[3.413275447780306, 0.2639719454831945]
[3.8595573551795153, 0.2811283353912037]
[4.353034859443118, 0.3140433891601936]
[4.886514887071523, 0.3705992629672006]
[5.451076255258948, 0.4600903837582948]
[6.012316518139551, 0.6035694318576293]
[6.5041826615226, 0.8366901639767437]
[6.828968757688667, 1.2096283905662835]
[6.844507339381619, 1.7822733396782149]
[6.406175705545386, 2.571295274395615]
[5.501331190187849, 3.4676511853435175]
[4.346171370391272, 4.204698030188888]
[3.246586407744981, 4.5469283684689055]
[2.395666281592049, 4.457765583126276]
[1.8172823219555856, 4.064946595043356]
[1.442761298838369, 3.5397375780465627]
[1.208908107884453, 2.9914550030314535]
[1.0685925969627899, 2.4820729201626373]
[0.991022962327608, 2.037244570196892]
[0.957421348475827, 1.6632055724974297]
[0.9569793912886565, 1.3555870283301439]
[0.9835609063200599, 1.1062868199420042]
[1.033758125602055, 0.9063703842886213]
julia> data = Array(sol)
2×101 Matrix{Float64}:
1.0 1.06108 1.14403 1.24917 1.37764 1.53123 1.71227 1.92358 2.16839 2.45025 … 1.81728 1.44276 1.20891 1.06859 0.991023 0.957421 0.956979 0.983561 1.03376
1.0 0.821084 0.679053 0.566893 0.478813 0.410156 0.357265 0.317347 0.288389 0.269054 4.06495 3.53974 2.99146 2.48207 2.03724 1.66321 1.35559 1.10629 0.90637
julia> # add some random noise
julia> data = data + 0.01 * randn(size(data))
2×101 Matrix{Float64}:
1.01614 1.06674 1.12347 1.24773 1.37758 1.53041 1.7112 1.92646 2.18836 … 1.82871 1.44904 1.19395 1.04878 0.984745 0.956703 0.967545 0.970734 1.02028
0.98386 0.812587 0.68637 0.563933 0.481978 0.409203 0.354941 0.315103 0.299003 4.07288 3.53542 2.99219 2.47814 2.02978 1.64381 1.3585 1.10354 0.908083
julia>
julia> using SymbolicIndexingInterface: parameter_values, state_values
julia> using SciMLStructures: Tunable, replace, replace!
julia>
julia> function loss(x, p)
odeprob = p[1] # ODEProblem stored as parameters to avoid using global variables
ps = parameter_values(odeprob) # obtain the parameter object from the problem
ps = replace(Tunable(), ps, x) # create a copy with the values passed to the loss function
# remake the problem, passing in our new parameter object
newprob = remake(odeprob; p=ps)
timesteps = p[2]
sol = solve(newprob, AutoTsit5(Rosenbrock23()); saveat=timesteps)
truth = p[3]
data = Array(sol)
return sum((truth .- data) .^ 2) / length(truth)
end
loss (generic function with 1 method)
julia>
julia> using Optimization
julia> using OptimizationOptimJL
julia>
julia> # manually create an OptimizationFunction to ensure usage of `ForwardDiff`, which will
julia> # require changing the types of parameters from `Float64` to `ForwardDiff.Dual`
julia> optfn = OptimizationFunction(loss, Optimization.AutoForwardDiff())
(::OptimizationFunction{true, AutoForwardDiff{nothing, Nothing}, typeof(loss), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing,
Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}) (generic function with 1 method)
julia> # parameter object is a tuple, to store differently typed objects together
julia> optprob = OptimizationProblem(
optfn, rand(4), (odeprob, timesteps, data), lb=0.1zeros(4), ub=3ones(4))
OptimizationProblem. In-place: true
u0: 4-element Vector{Float64}:
0.1874913445775207
0.1696330068794284
0.5501872523515791
0.05346430654034351
julia> sol = solve(optprob, BFGS())
retcode: Success
u: 4-element Vector{Float64}:
0.8981332489799965
0.590120993334407
0.027383608655195707
5.197063407075763e-12
julia>
julia> using Zygote
julia> using ReverseDiff
julia> using SciMLSensitivity
julia>
julia> optfn = OptimizationFunction(loss, Optimization.AutoReverseDiff())
(::OptimizationFunction{true, AutoReverseDiff{false}, typeof(loss), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, t
ypeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}) (generic function with 1 method)
julia> optprob = OptimizationProblem(
optfn, rand(4), (odeprob, timesteps, data), lb=0.1zeros(4), ub=3ones(4))
OptimizationProblem. In-place: true
u0: 4-element Vector{Float64}:
0.25273647439013414
0.3479945472068383
0.5370159213045103
0.5160481890509427
julia> sol = solve(optprob, BFGS())
ERROR: TrackedArrays do not support setindex!
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] setindex!(::ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, ::ReverseDiff.TrackedReal{Float64, Float64, Nothing}, ::Int64)
@ ReverseDiff ~/.julia/packages/ReverseDiff/p1MzG/src/tracked.jl:390
[3] macro expansion
@ ~/.julia/packages/SymbolicUtils/ij6YM/src/code.jl:430 [inlined]
[4] macro expansion
@ ~/.julia/packages/Symbolics/b4I7P/src/build_function.jl:546 [inlined]
[5] macro expansion
@ ~/.julia/packages/SymbolicUtils/ij6YM/src/code.jl:387 [inlined]
[6] macro expansion
@ ~/.julia/packages/RuntimeGeneratedFunctions/M9ZX8/src/RuntimeGeneratedFunctions.jl:163 [inlined]
[7] macro expansion
@ ./none:0 [inlined]
[8] generated_callfunc
@ ./none:0 [inlined]
[9] (::RuntimeGeneratedFunctions.RuntimeGeneratedFunction{…})(::ReverseDiff.TrackedArray{…}, ::ReverseDiff.TrackedArray{…}, ::ReverseDiff.TrackedArray{…}, ::Float64)
@ RuntimeGeneratedFunctions ~/.julia/packages/RuntimeGeneratedFunctions/M9ZX8/src/RuntimeGeneratedFunctions.jl:150
[10] (::ModelingToolkit.var"#f#777"{…})(du::ReverseDiff.TrackedArray{…}, u::ReverseDiff.TrackedArray{…}, p::MTKParameters{…}, t::Float64)
@ ModelingToolkit ~/.julia/packages/ModelingToolkit/2KZCu/src/systems/diffeqs/abstractodesystem.jl:351
[11] (::ODEFunction{…})(::ReverseDiff.TrackedArray{…}, ::Vararg{…})
@ SciMLBase ~/.julia/packages/SciMLBase/DXnzJ/src/scimlfunctions.jl:2335
[12] initialize!(integrator::OrdinaryDiffEq.ODEIntegrator{…}, cache::OrdinaryDiffEq.Tsit5Cache{…})
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/s27pa/src/perform_step/low_order_rk_perform_step.jl:799
[13] initialize!(integrator::OrdinaryDiffEq.ODEIntegrator{…}, cache::OrdinaryDiffEq.CompositeCache{…})
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/s27pa/src/perform_step/composite_perform_step.jl:79
[14] __init(prob::ODEProblem{…}, alg::CompositeAlgorithm{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::StepRangeLen{…}, tstops::Tuple{},
d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::Float64, dtmin:
:Float64, dtmax::Float64, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Nothing, reltol::Nothing, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64,
beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), i
nternalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, ti
meseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress
_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, in
itializealg::OrdinaryDiffEq.DefaultInit, kwargs::@Kwargs{})
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/s27pa/src/solve.jl:524
[15] __init (repeats 5 times)
@ ~/.julia/packages/OrdinaryDiffEq/s27pa/src/solve.jl:11 [inlined]
[16] __solve(::ODEProblem{…}, ::CompositeAlgorithm{…}; kwargs::@Kwargs{…})
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/s27pa/src/solve.jl:6
[17] __solve
@ ~/.julia/packages/OrdinaryDiffEq/s27pa/src/solve.jl:1 [inlined]
[18] #solve_call#44
@ ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:612 [inlined]
[19] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::MTKParameters{…}, args::CompositeAlgorithm{…}; kwargs::@Kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:1080
[20] solve_up
@ ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:1066 [inlined]
[21] #solve#51
@ ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:1003 [inlined]
[22] solve
@ ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:993 [inlined]
[23] loss(x::ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{…}, Vector{…}}, p::Tuple{ODEProblem{…}, StepRangeLen{…}, Matrix{…}})
@ Main ./REPL[16]:8
[24] (::OptimizationFunction{…})(::ReverseDiff.TrackedArray{…}, ::Vararg{…})
@ SciMLBase ~/.julia/packages/SciMLBase/DXnzJ/src/scimlfunctions.jl:3812
[25] (::OptimizationBase.var"#_f#22"{OptimizationFunction{…}})(θ::ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{…}, Vector{…}})
@ OptimizationBase ~/.julia/packages/OptimizationBase/3r1wm/src/OptimizationDIExt.jl:36
[26] ReverseDiff.GradientTape(f::OptimizationBase.var"#_f#22"{OptimizationFunction{…}}, input::Vector{Float64}, cfg::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{…}})
@ ReverseDiff ~/.julia/packages/ReverseDiff/p1MzG/src/api/tape.jl:199
[27] ReverseDiff.GradientTape(f::Function, input::Vector{Float64})
@ ReverseDiff ~/.julia/packages/ReverseDiff/p1MzG/src/api/tape.jl:198
[28] prepare_gradient(f::Function, ::AutoReverseDiff{false}, x::Vector{Float64})
@ DifferentiationInterfaceReverseDiffExt ~/.julia/packages/DifferentiationInterface/FTGtS/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl:53
[29] instantiate_function(f::OptimizationFunction{…}, x::Vector{…}, adtype::AutoReverseDiff{…}, p::Tuple{…}, num_cons::Int64; g::Bool, h::Bool, hv::Bool, fg::Bool, fgh::Bool, cons_j::Bo
ol, cons_vjp::Bool, cons_jvp::Bool, cons_h::Bool, lag_h::Bool)
@ OptimizationBase ~/.julia/packages/OptimizationBase/3r1wm/src/OptimizationDIExt.jl:42
[30] instantiate_function
@ ~/.julia/packages/OptimizationBase/3r1wm/src/OptimizationDIExt.jl:28 [inlined]
[31] #instantiate_function#48
@ ~/.julia/packages/OptimizationBase/3r1wm/src/OptimizationDIExt.jl:294 [inlined]
[32] instantiate_function
@ ~/.julia/packages/OptimizationBase/3r1wm/src/OptimizationDIExt.jl:287 [inlined]
[33] OptimizationCache(prob::OptimizationProblem{…}, opt::Fminbox{…}; callback::Function, maxiters::Nothing, maxtime::Nothing, abstol::Nothing, reltol::Nothing, progress::Bool, structur
al_analysis::Bool, manifold::Nothing, kwargs::@Kwargs{})
@ OptimizationBase ~/.julia/packages/OptimizationBase/3r1wm/src/cache.jl:37
[34] OptimizationCache
@ ~/.julia/packages/OptimizationBase/3r1wm/src/cache.jl:25 [inlined]
[35] #__init#2
@ ~/.julia/packages/OptimizationOptimJL/BIkTp/src/OptimizationOptimJL.jl:109 [inlined]
[36] __init
@ ~/.julia/packages/OptimizationOptimJL/BIkTp/src/OptimizationOptimJL.jl:80 [inlined]
[37] #init#657
@ ~/.julia/packages/SciMLBase/DXnzJ/src/solve.jl:174 [inlined]
[38] init
@ ~/.julia/packages/SciMLBase/DXnzJ/src/solve.jl:172 [inlined]
[39] solve(::OptimizationProblem{…}, ::BFGS{…}; kwargs::@Kwargs{})
@ SciMLBase ~/.julia/packages/SciMLBase/DXnzJ/src/solve.jl:96
[40] solve(::OptimizationProblem{…}, ::BFGS{…})
@ SciMLBase ~/.julia/packages/SciMLBase/DXnzJ/src/solve.jl:93
[41] top-level scope
@ REPL[30]:1
Some type information was truncated. Use `show(err)` to see complete types.
julia>
julia> optfn = OptimizationFunction(loss, Optimization.AutoZygote())
(::OptimizationFunction{true, AutoZygote, typeof(loss), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLB
ase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}) (generic function with 1 method)
julia> optprob = OptimizationProblem(
optfn, rand(4), (odeprob, timesteps, data), lb=0.1zeros(4), ub=3ones(4))
OptimizationProblem. In-place: true
u0: 4-element Vector{Float64}:
0.7369076868555151
0.5526940053454873
0.8645171650000937
0.6673653969086624
julia> sol = solve(optprob, BFGS())
ERROR: Need an adjoint for constructor MTKParameters{Vector{Float64}, StaticArraysCore.Siz
64}, StaticArraysCore.SizedVector{0, Any, Vector{Any}}, Tuple{}, Tuple{}}
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] (::Zygote.Jnew{MTKParameters{…}, Nothing, false})(Δ::MTKParameters{Vector{…}, Static
@ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/lib/lib.jl:330
[3] (::Zygote.var"#2210#back#316"{Zygote.Jnew{MTKParameters{…}, Nothing, false}})(Δ::MTK
@ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
[4] MTKParameters
@ ~/.julia/packages/ModelingToolkit/2KZCu/src/systems/parameter_buffer.jl:7 [inlined]
[5] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::MTKParameters{Vector{…}, StaticArraysCore
@ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
[6] setfields_object
@ ~/.julia/packages/ConstructionBase/lUKuV/src/ConstructionBase.jl:195 [inlined]
[7] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::MTKParameters{Vector{…}, StaticArraysCore
@ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
[8] setproperties_object
@ ~/.julia/packages/ConstructionBase/lUKuV/src/ConstructionBase.jl:208 [inlined]
[9] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::MTKParameters{Vector{…}, StaticArraysCore
@ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
[10] setproperties
@ ~/.julia/packages/ConstructionBase/lUKuV/src/ConstructionBase.jl:136 [inlined]
[11] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::MTKParameters{Vector{…}, StaticArraysCore
@ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
[12] set
@ ~/.julia/packages/Setfield/PdKfV/src/lens.jl:122 [inlined]
[13] replace
@ ~/.julia/packages/Setfield/PdKfV/src/sugar.jl:197 [inlined]
[14] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::MTKParameters{Vector{…}, StaticArraysCore
@ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
[15] loss
@ ./REPL[16]:4 [inlined]
[16] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
[17] #294
@ ~/.julia/packages/Zygote/Tt5Gx/src/lib/lib.jl:206 [inlined]
[18] #2169#back
@ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
[19] OptimizationFunction
@ ~/.julia/packages/SciMLBase/DXnzJ/src/scimlfunctions.jl:3812 [inlined]
[20] (::Zygote.Pullback{Tuple{OptimizationFunction{…}, Vector{…}, Tuple{…}}, Tuple{Zygote
@ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
[21] _f
@ ~/.julia/packages/OptimizationBase/3r1wm/ext/OptimizationZygoteExt.jl:29 [inlined]
[22] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
[23] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:91
[24] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:148
[25] gradient
@ ~/.julia/packages/DifferentiationInterface/FTGtS/ext/DifferentiationInterfaceZygoteE
[26] gradient!(f::Function, grad::Vector{Float64}, backend::AutoZygote, x::Vector{Float64
@ DifferentiationInterfaceZygoteExt ~/.julia/packages/DifferentiationInterface/FTGtS/e
[27] (::OptimizationZygoteExt.var"#grad#15"{AutoZygote, OptimizationZygoteExt.var"#_f#14"
@ OptimizationZygoteExt ~/.julia/packages/OptimizationBase/3r1wm/ext/OptimizationZygot
[28] (::OptimizationOptimJL.var"#19#23"{OptimizationCache{…}, OptimizationOptimJL.var"#18
@ OptimizationOptimJL ~/.julia/packages/OptimizationOptimJL/BIkTp/src/OptimizationOpti
[29] value_gradient!!(obj::OnceDifferentiable{Float64, Vector{Float64}, Vector{Float64}},
@ NLSolversBase ~/.julia/packages/NLSolversBase/kavn7/src/interface.jl:82
[30] value_gradient!!(bw::Optim.BarrierWrapper{OnceDifferentiable{…}, Optim.BoxBarrier{…}
@ Optim ~/.julia/packages/Optim/ZhuZN/src/multivariate/solvers/constrained/fminbox.jl:
[31] initial_state(method::BFGS{…}, options::Optim.Options{…}, d::Optim.BarrierWrapper{…}
@ Optim ~/.julia/packages/Optim/ZhuZN/src/multivariate/solvers/first_order/bfgs.jl:94
[32] optimize(df::OnceDifferentiable{…}, l::Vector{…}, u::Vector{…}, initial_x::Vector{…}
@ Optim ~/.julia/packages/Optim/ZhuZN/src/multivariate/solvers/constrained/fminbox.jl:
[33] __solve(cache::OptimizationCache{…})
@ OptimizationOptimJL ~/.julia/packages/OptimizationOptimJL/BIkTp/src/OptimizationOpti
[34] solve!(cache::OptimizationCache{…})
@ SciMLBase ~/.julia/packages/SciMLBase/DXnzJ/src/solve.jl:188
[35] solve(::OptimizationProblem{…}, ::BFGS{…}; kwargs::@Kwargs{})
@ SciMLBase ~/.julia/packages/SciMLBase/DXnzJ/src/solve.jl:96
[36] solve(::OptimizationProblem{…}, ::BFGS{…})
@ SciMLBase ~/.julia/packages/SciMLBase/DXnzJ/src/solve.jl:93
[37] top-level scope
@ REPL[33]:1
Some type information was truncated. Use `show(err)` to see complete types.
julia> |
@SebastianM-C @AayushSabharwal what's the current state of this? |
One thing to note is that currently replacing the entire parameter vector is required as For (and for Enzyme I get compilation failures) |
I think it's because the odeprob = ODEProblem{false}(
odesys, [x => 1.0, y => 1.0], (0.0, 10.0), [α => 1.5, β => 1.0, γ => 3.0, δ => 1.0]) And re-create the
MTK tests Zygote AD through |
using ModelingToolkit
import ModelingToolkit.t_nounits as t
import ModelingToolkit.D_nounits as D
using SciMLStructures: SciMLStructures, canonicalize, Tunable
using SciMLSensitivity
using SymbolicIndexingInterface
using OrdinaryDiffEq, Plots
using Optimization
using OptimizationOptimJL
using BenchmarkTools
using Test
function mwe()
@parameters p1 = 0.5 [tunable = true] (p23[1:2] = [1, 3.0]) [tunable = true] p4 = 3 [tunable = false] y0 = 1.2 [tunable = false]
@variables x(t) = 2 y(t) = y0
eqs = [
D(x) ~ p1 * x - p23[1] * x * y
D(y) ~ -p23[2] * y + p4 * x * y
]
sys = ODESystem(eqs, t, tspan=(0, 3.0), name=:sys)
structural_simplify(sys)
end
sys = mwe()
prob = ODEProblem{true,SciMLBase.FullSpecialize}(sys)
ref_prob = ODEProblem(sys, [], (0.0, 1,), [sys.p1 => 1.55, sys.p23 => [4.0, 3.44]])
sol_ref = solve(ref_prob, Tsit5(), sensealg=GaussAdjoint())
tp = reduce(vcat, Symbolics.scalarize(tunable_parameters(sys)))
x0 = getp(prob, tp)(prob)
oop_update = setp_oop(prob, tp)
cost = function (x, p)
prob, oop_update, sol_ref = p
new_p = oop_update(prob, x)
new_prob = remake(prob, p=new_p)
ts = sol_ref.t
new_sol = solve(new_prob, Tsit5(), saveat=ts, sensealg=GaussAdjoint())
loss = zero(eltype(x))
for (i, sol_i) in enumerate(new_sol.u)
for (j, sol_ij) in enumerate(sol_i)
loss += sqrt(abs2(sol_ij - sol_ref.u[i][j]))
end
end
loss
end
of = OptimizationFunction{true}(cost, AutoZygote())
opt_ps = (prob, oop_update, sol_ref)
@btime $of([1.5, 1.0, 0.0], $opt_ps)
op = OptimizationProblem(of, x0, opt_ps)
res = solve(op, LBFGS(); maxtime=60)
@test SciMLBase.successful_retcode(res.retcode) errors with
|
Where are we at on this one? |
It's a SciMLSensitivity bug. Specifically, it's trying to build a |
Question❓
If I change the differentiation method in the optimization example in Optimizing through an ODE solve and re-creating MTK Problems to AutoReverseDiff() I obtain the following error when solving the problem:
This has to do with changing the parameter values of the ODE problem using Tunable() and remake.
Is there a way to use reverse differentiation in this case?
The text was updated successfully, but these errors were encountered: