diff --git a/src/problems/ode_problems.jl b/src/problems/ode_problems.jl index bc9be9400..1bd81e8a7 100644 --- a/src/problems/ode_problems.jl +++ b/src/problems/ode_problems.jl @@ -198,7 +198,11 @@ function ODEProblem(f, u0, tspan, p = NullParameters(); kwargs...) iip = isinplace(f, 4) _u0 = prepare_initial_state(u0) _tspan = promote_tspan(tspan) - _f = ODEFunction{iip, DEFAULT_SPECIALIZATION}(f) + _f = if iip + ODEFunction{iip, FullSpecialize}(DUMB_WRAPPER(first(_tspan), ODE_F_WRAPPER(f))) + else + _f = ODEFunction{iip, DEFAULT_SPECIALIZATION}(f) + end ODEProblem(_f, _u0, _tspan, p; kwargs...) end diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 239791d61..299a4f25c 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2415,8 +2415,27 @@ end (f::BVPFunction)(args...) = f.f(args...) (f::DynamicalBVPFunction)(args...) = f.f(args...) -######### Basic Constructor +#### fun hack to make NOSPECIALIZE fast +mutable struct ODE_F_WRAPPER{F} + const f::F +end +mutable struct DUMB_WRAPPER{P, T} + p::P + t::T + f::ODE_F_WRAPPER +end +function (w::ODE_F_WRAPPER)(out, u, d::DUMB_WRAPPER) + w.f(out, u, d.p, d.t) + return nothing +end +function (w::DUMB_WRAPPER)(out, u, p, t) + w.p = p + w.t = t + w.f(out, u, w) + return nothing +end +######### Basic Constructor function ODEFunction{iip, specialize}(f; mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I,