Skip to content

Commit

Permalink
Propagate IIP information in the Wrapper Functions
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 15, 2023
1 parent ece1966 commit 9017479
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 20 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SciMLBase"
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
authors = ["Chris Rackauckas <[email protected]> and contributors"]
version = "2.8.1"
version = "2.8.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
72 changes: 53 additions & 19 deletions src/function_wrappers.jl
Original file line number Diff line number Diff line change
@@ -1,56 +1,90 @@
mutable struct TimeGradientWrapper{fType, uType, P} <: Function
mutable struct TimeGradientWrapper{iip, fType, uType, P} <: AbstractSciMLFunction{iip}
f::fType
uprev::uType
p::P
end
(ff::TimeGradientWrapper)(t) = (du2 = similar(ff.uprev); ff.f(du2, ff.uprev, ff.p, t); du2)
(ff::TimeGradientWrapper)(du2, t) = ff.f(du2, ff.uprev, ff.p, t)

mutable struct UJacobianWrapper{fType, tType, P} <: Function
function TimeGradientWrapper(f::F, uprev, p) where {F}

Check warning on line 7 in src/function_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/function_wrappers.jl#L7

Added line #L7 was not covered by tests
return TimeGradientWrapper{isinplace(f), F, typeof(uprev), typeof(p)}(f, uprev, p)
end

(ff::TimeGradientWrapper{true})(t) = (du2 = similar(ff.uprev); ff.f(du2, ff.uprev, ff.p, t); du2)

Check warning on line 11 in src/function_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/function_wrappers.jl#L11

Added line #L11 was not covered by tests
(ff::TimeGradientWrapper{true})(du2, t) = ff.f(du2, ff.uprev, ff.p, t)

(ff::TimeGradientWrapper{false})(t) = ff.f(ff.uprev, ff.p, t)

Check warning on line 14 in src/function_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/function_wrappers.jl#L14

Added line #L14 was not covered by tests

mutable struct UJacobianWrapper{iip, fType, tType, P} <: AbstractSciMLFunction{iip}
f::fType
t::tType
p::P
end

(ff::UJacobianWrapper)(du1, uprev) = ff.f(du1, uprev, ff.p, ff.t)
(ff::UJacobianWrapper)(uprev) = (du1 = similar(uprev); ff.f(du1, uprev, ff.p, ff.t); du1)
(ff::UJacobianWrapper)(du1, uprev, p, t) = ff.f(du1, uprev, p, t)
(ff::UJacobianWrapper)(uprev, p, t) = (du1 = similar(uprev); ff.f(du1, uprev, p, t); du1)
function UJacobianWrapper(f::F, t, p) where {F}

Check warning on line 22 in src/function_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/function_wrappers.jl#L22

Added line #L22 was not covered by tests
return UJacobianWrapper{isinplace(f), F, typeof(t), typeof(p)}(f, t, p)
end

(ff::UJacobianWrapper{true})(du1, uprev) = ff.f(du1, uprev, ff.p, ff.t)
(ff::UJacobianWrapper{true})(uprev) = (du1 = similar(uprev); ff.f(du1, uprev, ff.p, ff.t); du1)
(ff::UJacobianWrapper{true})(du1, uprev, p, t) = ff.f(du1, uprev, p, t)
(ff::UJacobianWrapper{true})(uprev, p, t) = (du1 = similar(uprev); ff.f(du1, uprev, p, t); du1)

Check warning on line 29 in src/function_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/function_wrappers.jl#L27-L29

Added lines #L27 - L29 were not covered by tests

mutable struct TimeDerivativeWrapper{F, uType, P} <: Function
(ff::UJacobianWrapper{false})(uprev) = ff.f(uprev, ff.p, ff.t)
(ff::UJacobianWrapper{false})(uprev, p, t) = ff.f(uprev, p, t)

Check warning on line 32 in src/function_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/function_wrappers.jl#L31-L32

Added lines #L31 - L32 were not covered by tests

mutable struct TimeDerivativeWrapper{iip, F, uType, P} <: AbstractSciMLFunction{iip}
f::F
u::uType
p::P
end
(ff::TimeDerivativeWrapper)(t) = ff.f(ff.u, ff.p, t)

mutable struct UDerivativeWrapper{F, tType, P} <: Function
function TimeDerivativeWrapper(f::F, u, p) where {F}
return TimeDerivativeWrapper{isinplace(f), F, typeof(u), typeof(p)}(f, u, p)

Check warning on line 41 in src/function_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/function_wrappers.jl#L40-L41

Added lines #L40 - L41 were not covered by tests
end

(ff::TimeDerivativeWrapper{false})(t) = ff.f(ff.u, ff.p, t)
(ff::TimeDerivativeWrapper{true})(du1, t) = ff.f(du1, ff.u, ff.p, t)
(ff::TimeDerivativeWrapper{true})(t) = (du1 = similar(ff.u); ff.f(du1, ff.u, ff.p, t); du1)

Check warning on line 46 in src/function_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/function_wrappers.jl#L44-L46

Added lines #L44 - L46 were not covered by tests

mutable struct UDerivativeWrapper{iip, F, tType, P} <: AbstractSciMLFunction{iip}
f::F
t::tType
p::P
end
(ff::UDerivativeWrapper)(u) = ff.f(u, ff.p, ff.t)

mutable struct ParamJacobianWrapper{fType, tType, uType} <: Function
function UDerivativeWrapper(f::F, t, p) where {F}
return UDerivativeWrapper{isinplace(f), F, typeof(t), typeof(p)}(f, t, p)

Check warning on line 55 in src/function_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/function_wrappers.jl#L54-L55

Added lines #L54 - L55 were not covered by tests
end

(ff::UDerivativeWrapper{false})(u) = ff.f(u, ff.p, ff.t)
(ff::UDerivativeWrapper{true})(du1, u) = ff.f(du1, u, ff.p, ff.t)
(ff::UDerivativeWrapper{true})(u) = (du1 = similar(u); ff.f(du1, u, ff.p, ff.t); du1)

Check warning on line 60 in src/function_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/function_wrappers.jl#L58-L60

Added lines #L58 - L60 were not covered by tests

mutable struct ParamJacobianWrapper{iip, fType, tType, uType} <: AbstractSciMLFunction{iip}
f::fType
t::tType
u::uType
end

function (ff::ParamJacobianWrapper)(du1, p)
ff.f(du1, ff.u, p, ff.t)
function ParamJacobianWrapper(f::F, t, u) where {F}
return ParamJacobianWrapper{isinplace(f), F, typeof(t), typeof(u)}(f, t, u)

Check warning on line 69 in src/function_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/function_wrappers.jl#L68-L69

Added lines #L68 - L69 were not covered by tests
end

function (ff::ParamJacobianWrapper)(p)
(ff::ParamJacobianWrapper{true})(du1, p) = ff.f(du1, ff.u, p, ff.t)
function (ff::ParamJacobianWrapper{true})(p)

Check warning on line 73 in src/function_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/function_wrappers.jl#L72-L73

Added lines #L72 - L73 were not covered by tests
du1 = similar(p, size(ff.u))
ff.f(du1, ff.u, p, ff.t)
return du1
end
(ff::ParamJacobianWrapper{false})(p) = ff.f(ff.u, p, ff.t)

Check warning on line 78 in src/function_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/function_wrappers.jl#L78

Added line #L78 was not covered by tests

mutable struct JacobianWrapper{fType, pType}
mutable struct JacobianWrapper{iip, fType, pType} <: AbstractSciMLFunction{iip}
f::fType
p::pType
end

(uf::JacobianWrapper)(u) = uf.f(u, uf.p)
(uf::JacobianWrapper)(res, u) = uf.f(res, u, uf.p)
function JacobianWrapper(f::F, p) where {F}
return JacobianWrapper{isinplace(f), F, typeof(p)}(f, p)

Check warning on line 86 in src/function_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/function_wrappers.jl#L85-L86

Added lines #L85 - L86 were not covered by tests
end

(uf::JacobianWrapper{false})(u) = uf.f(u, uf.p)
(uf::JacobianWrapper{true})(res, u) = uf.f(res, u, uf.p)

Check warning on line 90 in src/function_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/function_wrappers.jl#L89-L90

Added lines #L89 - L90 were not covered by tests

0 comments on commit 9017479

Please sign in to comment.