diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 0c6569fda..176e86b11 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -518,7 +518,7 @@ information on generating the SplitFunction from this symbolic engine. struct SplitFunction{ iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TPJ, O, - TCV, SYS} <: AbstractODEFunction{iip} + TCV, SYS, IProb, IProbMap} <: AbstractODEFunction{iip} f1::F1 f2::F2 mass_matrix::TMM @@ -536,6 +536,8 @@ struct SplitFunction{ observed::O colorvec::TCV sys::SYS + initializeprob::IProb + initializeprobmap::IProbMap end @doc doc""" @@ -2599,7 +2601,7 @@ end @add_kwonly function SplitFunction(f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, - observed, colorvec, sys) + observed, colorvec, sys, initializeprob, initializeprobmap) f1 = ODEFunction(f1) f2 = ODEFunction(f2) @@ -2613,8 +2615,10 @@ end typeof(cache), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec), - typeof(sys)}(f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp, vjp, - jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys) + typeof(sys), typeof(initializeprob), typeof(initializeprobmap)}(f1, f2, mass_matrix, + cache, analytic, tgrad, jac, jvp, vjp, + jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys, + initializeprob, initializeprobmap) end function SplitFunction{iip, specialize}(f1, f2; mass_matrix = __has_mass_matrix(f1) ? @@ -2642,18 +2646,24 @@ function SplitFunction{iip, specialize}(f1, f2; DEFAULT_OBSERVED, colorvec = __has_colorvec(f1) ? f1.colorvec : nothing, - sys = __has_sys(f1) ? f1.sys : nothing) where {iip, + sys = __has_sys(f1) ? f1.sys : nothing, + initializeprob = __has_initializeprob(f1) ? f1.initializeprob : nothing, + initializeprobmap = __has_initializeprobmap(f1) ? f1.initializeprobmap : nothing + ) where {iip, specialize } sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym) + @assert typeof(initializeprob) <: + Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem} + if specialize === NoSpecialize SplitFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, - Any, Any, Any}(f1, f2, mass_matrix, _func_cache, + Any, Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, - observed, colorvec, sys) + observed, colorvec, sys, initializeprob, initializeprobmap) else SplitFunction{iip, specialize, typeof(f1), typeof(f2), typeof(mass_matrix), typeof(_func_cache), typeof(analytic), @@ -2661,9 +2671,11 @@ function SplitFunction{iip, specialize}(f1, f2; typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec), - typeof(sys)}(f1, f2, mass_matrix, _func_cache, analytic, tgrad, jac, + typeof(sys), typeof(initializeprob), typeof(initializeprobmap)}(f1, f2, + mass_matrix, _func_cache, analytic, tgrad, jac, jvp, vjp, jac_prototype, - sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys) + sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys, + initializeprob, initializeprobmap) end end