From 911263732ef80a7ea5e0c8d8a9b49738479fb8a3 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Mon, 11 Nov 2024 18:25:36 -0500 Subject: [PATCH 1/4] Wrap ODE functions in a fully transparent wrapper function --- src/problems/ode_problems.jl | 7 ++++++- src/scimlfunctions.jl | 25 ++++++++++++++++++++++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/src/problems/ode_problems.jl b/src/problems/ode_problems.jl index bc9be9400..b36856605 100644 --- a/src/problems/ode_problems.jl +++ b/src/problems/ode_problems.jl @@ -198,7 +198,12 @@ function ODEProblem(f, u0, tspan, p = NullParameters(); kwargs...) iip = isinplace(f, 4) _u0 = prepare_initial_state(u0) _tspan = promote_tspan(tspan) - _f = ODEFunction{iip, DEFAULT_SPECIALIZATION}(f) + _f = if iip + out = copy(u0) # TODO: do this properly + ODEFunction{iip, FullSpecialize}(DUMB_WRAPPER(out, u0, p, first(_tspan), ODE_F_WRAPPER(f))) + else + _f = ODEFunction{iip, DEFAULT_SPECIALIZATION}(f) + end ODEProblem(_f, _u0, _tspan, p; kwargs...) end diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 239791d61..8430a9ded 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2415,8 +2415,31 @@ end (f::BVPFunction)(args...) = f.f(args...) (f::DynamicalBVPFunction)(args...) = f.f(args...) -######### Basic Constructor +#### fun hack to make NOSPECIALIZE fast +mutable struct ODE_F_WRAPPER{F} + const f::F +end +mutable struct DUMB_WRAPPER{OUT, U, P, T} + out::OUT + u::U + p::P + t::T + f::ODE_F_WRAPPER +end +function (w::ODE_F_WRAPPER)(d::DUMB_WRAPPER) + w.f(d.out, d.u, d.p, d.t) + return nothing +end +function (w::DUMB_WRAPPER)(out, u, p, t) + w.out = out + w.u = u + w.p = p + w.t = t + w.f(w) + return nothing +end +######### Basic Constructor function ODEFunction{iip, specialize}(f; mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I, From 67b0cc32c3d37f03181f40f070ae6d4c2e8ac10f Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Tue, 12 Nov 2024 15:59:29 -0500 Subject: [PATCH 2/4] simplify --- src/problems/ode_problems.jl | 2 +- src/scimlfunctions.jl | 12 ++++-------- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/src/problems/ode_problems.jl b/src/problems/ode_problems.jl index b36856605..ada370729 100644 --- a/src/problems/ode_problems.jl +++ b/src/problems/ode_problems.jl @@ -200,7 +200,7 @@ function ODEProblem(f, u0, tspan, p = NullParameters(); kwargs...) _tspan = promote_tspan(tspan) _f = if iip out = copy(u0) # TODO: do this properly - ODEFunction{iip, FullSpecialize}(DUMB_WRAPPER(out, u0, p, first(_tspan), ODE_F_WRAPPER(f))) + ODEFunction{iip, FullSpecialize}(DUMB_WRAPPER(p, first(_tspan), ODE_F_WRAPPER(f))) else _f = ODEFunction{iip, DEFAULT_SPECIALIZATION}(f) end diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 8430a9ded..299a4f25c 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2419,23 +2419,19 @@ end mutable struct ODE_F_WRAPPER{F} const f::F end -mutable struct DUMB_WRAPPER{OUT, U, P, T} - out::OUT - u::U +mutable struct DUMB_WRAPPER{P, T} p::P t::T f::ODE_F_WRAPPER end -function (w::ODE_F_WRAPPER)(d::DUMB_WRAPPER) - w.f(d.out, d.u, d.p, d.t) +function (w::ODE_F_WRAPPER)(out, u, d::DUMB_WRAPPER) + w.f(out, u, d.p, d.t) return nothing end function (w::DUMB_WRAPPER)(out, u, p, t) - w.out = out - w.u = u w.p = p w.t = t - w.f(w) + w.f(out, u, w) return nothing end From d93e431abcd5605be10cce2dfe86417d07ae97ee Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Tue, 12 Nov 2024 15:59:51 -0500 Subject: [PATCH 3/4] fix --- src/problems/ode_problems.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/problems/ode_problems.jl b/src/problems/ode_problems.jl index ada370729..9d867f9cb 100644 --- a/src/problems/ode_problems.jl +++ b/src/problems/ode_problems.jl @@ -199,7 +199,6 @@ function ODEProblem(f, u0, tspan, p = NullParameters(); kwargs...) _u0 = prepare_initial_state(u0) _tspan = promote_tspan(tspan) _f = if iip - out = copy(u0) # TODO: do this properly ODEFunction{iip, FullSpecialize}(DUMB_WRAPPER(p, first(_tspan), ODE_F_WRAPPER(f))) else _f = ODEFunction{iip, DEFAULT_SPECIALIZATION}(f) From 39bd54a4ad603a75982cbad0d89e3b53e1df4dc4 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Wed, 13 Nov 2024 11:08:25 -0500 Subject: [PATCH 4/4] fix --- src/problems/ode_problems.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/problems/ode_problems.jl b/src/problems/ode_problems.jl index 9d867f9cb..1bd81e8a7 100644 --- a/src/problems/ode_problems.jl +++ b/src/problems/ode_problems.jl @@ -199,7 +199,7 @@ function ODEProblem(f, u0, tspan, p = NullParameters(); kwargs...) _u0 = prepare_initial_state(u0) _tspan = promote_tspan(tspan) _f = if iip - ODEFunction{iip, FullSpecialize}(DUMB_WRAPPER(p, first(_tspan), ODE_F_WRAPPER(f))) + ODEFunction{iip, FullSpecialize}(DUMB_WRAPPER(first(_tspan), ODE_F_WRAPPER(f))) else _f = ODEFunction{iip, DEFAULT_SPECIALIZATION}(f) end