From fd44bd41236d290af517003895ef7b6eadc1bb02 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 13 May 2024 18:13:59 +0530 Subject: [PATCH] fix: correct type params to ODESolution adjoint --- ext/SciMLBaseChainRulesCoreExt.jl | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/ext/SciMLBaseChainRulesCoreExt.jl b/ext/SciMLBaseChainRulesCoreExt.jl index 0e84ccbc3..1deaa9579 100644 --- a/ext/SciMLBaseChainRulesCoreExt.jl +++ b/ext/SciMLBaseChainRulesCoreExt.jl @@ -97,30 +97,30 @@ end function ChainRulesCore.rrule( ::Type{ <:ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, - T11, T12 + T11, T12, T13, T14 }}, u, args...) where {T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, - T12} + T12, T13, T14} function ODESolutionAdjoint(ȳ) (NoTangent(), ȳ, ntuple(_ -> NoTangent(), length(args))...) end - ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12}(u, args...), + ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14}(u, args...), ODESolutionAdjoint end -function ChainRulesCore.rrule( - ::Type{ - <:ODESolution{uType, tType, isinplace, P, NP, F, G, K, - ND - }}, u, - args...) where {uType, tType, isinplace, P, NP, F, G, K, ND} - function SDESolutionAdjoint(ȳ) - (NoTangent(), ȳ, ntuple(_ -> NoTangent(), length(args))...) - end - - SDESolution{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...), SDESolutionAdjoint -end +# function ChainRulesCore.rrule( +# ::Type{ +# <:SDESolution{uType, tType, isinplace, P, NP, F, G, K, +# ND +# }}, u, +# args...) where {uType, tType, isinplace, P, NP, F, G, K, ND} +# function SDESolutionAdjoint(ȳ) +# (NoTangent(), ȳ, ntuple(_ -> NoTangent(), length(args))...) +# end +# +# SDESolution{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...), SDESolutionAdjoint +# end function ChainRulesCore.rrule(::SciMLBase.EnsembleSolution, sim, time, converged) out = EnsembleSolution(sim, time, converged)