From 9b9607834c2c4285ee254cf9eaa33c5751c74866 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 7 Oct 2024 12:35:08 +0530 Subject: [PATCH 1/2] test: use `split = false` system for remake autodiff tests --- test/downstream/remake_autodiff.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/downstream/remake_autodiff.jl b/test/downstream/remake_autodiff.jl index fd3d4bd85..ce8abce58 100644 --- a/test/downstream/remake_autodiff.jl +++ b/test/downstream/remake_autodiff.jl @@ -14,11 +14,11 @@ function lotka_volterra(; name = name) end @named lotka_volterra_sys = lotka_volterra() -lotka_volterra_sys = structural_simplify(lotka_volterra_sys) +lotka_volterra_sys = structural_simplify(lotka_volterra_sys, split = false) prob = ODEProblem(lotka_volterra_sys, [], (0.0, 10.0), []) sol = solve(prob, Tsit5(), reltol = 1e-6, abstol = 1e-6) -u0 = [1.0 1.0] -p = [1.5 1.0 1.0 1.0] +u0 = [1.0, 1.0] +p = [1.5, 1.0, 1.0, 1.0] function sum_of_solution(u0, p) _prob = remake(prob, u0 = u0, p = p) From 6184c695a77e2aa229c5dca87b081726dd58f911 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 7 Oct 2024 13:36:29 +0530 Subject: [PATCH 2/2] fix: fix observed variable adjoint --- ext/SciMLBaseZygoteExt.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 486d219aa..fa67f0646 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -106,11 +106,12 @@ end if is_observed(VA, sym) f = observed(VA, sym) p = parameter_values(VA) - tunables, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p) + tunables, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p) u = state_values(VA) t = current_time(VA) y, back = Zygote.pullback(u, tunables) do u, tunables - f.(u, Ref(tunables), t) + _p = repack(tunables) + f.(u, Ref(_p), t) end gs = back(Δ) (gs[1], nothing)