Skip to content

Commit

Permalink
Make sensitivity interpolations type preserving
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Jan 7, 2024
1 parent 24f0105 commit 5d511a6
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 27 deletions.
56 changes: 36 additions & 20 deletions src/interpolation.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
function enable_interpolation_sensitivitymode end

enable_interpolation_sensitivitymode(interp::Nothing) = nothing

Check warning on line 3 in src/interpolation.jl

View check run for this annotation

Codecov / codecov/patch

src/interpolation.jl#L3

Added line #L3 was not covered by tests

# Pass through should be deprecated in the future, made for backwards compat
enable_interpolation_sensitivitymode(interp::AbstractDiffEqInterpolation) = interp

Check warning on line 6 in src/interpolation.jl

View check run for this annotation

Codecov / codecov/patch

src/interpolation.jl#L6

Added line #L6 was not covered by tests

"""
$(TYPEDEF)
"""
struct HermiteInterpolation{T1, T2, T3} <: AbstractDiffEqInterpolation
t::T1
u::T2
du::T3
sensitivitymode::Bool
end

function HermiteInterpolation(t,u,du; sensitivitymode=false)
HermiteInterpolation(t,u,du,sensitivitymode)

Check warning on line 19 in src/interpolation.jl

View check run for this annotation

Codecov / codecov/patch

src/interpolation.jl#L18-L19

Added lines #L18 - L19 were not covered by tests
end

function enable_interpolation_sensitivitymode(interp::HermiteInterpolation)
HermiteInterpolation(interp.t,interp.u,interp.du,true)

Check warning on line 23 in src/interpolation.jl

View check run for this annotation

Codecov / codecov/patch

src/interpolation.jl#L22-L23

Added lines #L22 - L23 were not covered by tests
end

"""
Expand All @@ -13,6 +29,15 @@ $(TYPEDEF)
struct LinearInterpolation{T1, T2} <: AbstractDiffEqInterpolation
t::T1
u::T2
sensitivitymode::Bool
end

function LinearInterpolation(t,u; sensitivitymode=false)
LinearInterpolation(t,u,sensitivitymode)

Check warning on line 36 in src/interpolation.jl

View check run for this annotation

Codecov / codecov/patch

src/interpolation.jl#L35-L36

Added lines #L35 - L36 were not covered by tests
end

function enable_interpolation_sensitivitymode(interp::LinearInterpolation)
LinearInterpolation(interp.t,interp.u,true)

Check warning on line 40 in src/interpolation.jl

View check run for this annotation

Codecov / codecov/patch

src/interpolation.jl#L39-L40

Added lines #L39 - L40 were not covered by tests
end

"""
Expand All @@ -21,24 +46,22 @@ $(TYPEDEF)
struct ConstantInterpolation{T1, T2} <: AbstractDiffEqInterpolation
t::T1
u::T2
sensitivitymode::Bool
end

"""
$(TYPEDEF)
"""
struct SensitivityInterpolation{T1, T2} <: AbstractDiffEqInterpolation
t::T1
u::T2
function ConstantInterpolation(t,u; sensitivitymode=false)
ConstantInterpolation(t,u,sensitivitymode)

Check warning on line 53 in src/interpolation.jl

View check run for this annotation

Codecov / codecov/patch

src/interpolation.jl#L52-L53

Added lines #L52 - L53 were not covered by tests
end

function enable_interpolation_sensitivitymode(interp::ConstantInterpolation)
ConstantInterpolation(interp.t,interp.u,true)

Check warning on line 57 in src/interpolation.jl

View check run for this annotation

Codecov / codecov/patch

src/interpolation.jl#L56-L57

Added lines #L56 - L57 were not covered by tests
end

interp_summary(::AbstractDiffEqInterpolation) = "Unknown"
interp_summary(::HermiteInterpolation) = "3rd order Hermite"
interp_summary(::LinearInterpolation) = "1st order linear"
interp_summary(::ConstantInterpolation) = "Piecewise constant interpolation"
interp_summary(::Nothing) = "No interpolation"
function interp_summary(::SensitivityInterpolation)
"Interpolation disabled due to sensitivity analysis"
end
interp_summary(sol::AbstractSciMLSolution) = interp_summary(sol.interp)

const SENSITIVITY_INTERP_MESSAGE = """
Expand Down Expand Up @@ -69,13 +92,6 @@ end
function (id::ConstantInterpolation)(val, tvals, idxs, deriv, p, continuity::Symbol = :left)
interpolation!(val, tvals, id, idxs, deriv, p, continuity)
end
function (id::SensitivityInterpolation)(tvals, idxs, deriv, p, continuity::Symbol = :left)
interpolation(tvals, id, idxs, deriv, p, continuity)
end
function (id::SensitivityInterpolation)(val, tvals, idxs, deriv, p,
continuity::Symbol = :left)
interpolation!(val, tvals, id, idxs, deriv, p, continuity)
end

@inline function interpolation(tvals, id::I, idxs, deriv::D, p,
continuity::Symbol = :left) where {I, D}
Expand Down Expand Up @@ -118,7 +134,7 @@ end
vals[j] = u[k][idxs]
end
else
id isa SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE)
id.sensitivitymode && error(SENSITIVITY_INTERP_MESSAGE)

Check warning on line 137 in src/interpolation.jl

View check run for this annotation

Codecov / codecov/patch

src/interpolation.jl#L137

Added line #L137 was not covered by tests
dt = t[i] - t[i - 1]
Θ = (tval - t[i - 1]) / dt
idxs_internal = idxs
Expand Down Expand Up @@ -173,7 +189,7 @@ times t (sorted), with values u and derivatives ks
vals[j] = u[k][idxs]
end
else
id isa SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE)
id.sensitivitymode && error(SENSITIVITY_INTERP_MESSAGE)

Check warning on line 192 in src/interpolation.jl

View check run for this annotation

Codecov / codecov/patch

src/interpolation.jl#L192

Added line #L192 was not covered by tests
dt = t[i] - t[i - 1]
Θ = (tval - t[i - 1]) / dt
idxs_internal = idxs
Expand Down Expand Up @@ -232,7 +248,7 @@ times t (sorted), with values u and derivatives ks
val = u[i - 1][idxs]
end
else
id isa SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE)
id.sensitivitymode && error(SENSITIVITY_INTERP_MESSAGE)

Check warning on line 251 in src/interpolation.jl

View check run for this annotation

Codecov / codecov/patch

src/interpolation.jl#L251

Added line #L251 was not covered by tests
dt = t[i] - t[i - 1]
Θ = (tval - t[i - 1]) / dt
idxs_internal = idxs
Expand Down Expand Up @@ -282,7 +298,7 @@ times t (sorted), with values u and derivatives ks
copy!(out, u[i - 1][idxs])
end
else
id isa SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE)
id.sensitivitymode && error(SENSITIVITY_INTERP_MESSAGE)

Check warning on line 301 in src/interpolation.jl

View check run for this annotation

Codecov / codecov/patch

src/interpolation.jl#L301

Added line #L301 was not covered by tests
dt = t[i] - t[i - 1]
Θ = (tval - t[i - 1]) / dt
idxs_internal = idxs
Expand Down
8 changes: 1 addition & 7 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -396,13 +396,7 @@ function sensitivity_solution(sol::ODESolution, u, t)
N = length((size(u0)..., length(u)))
end

interp = if sol.interp isa LinearInterpolation
LinearInterpolation(t, u)
elseif sol.interp isa ConstantInterpolation
ConstantInterpolation(t, u)
else
SensitivityInterpolation(t, u)
end
interp = enable_interpolation_sensitivitymode(sol.interp)

Check warning on line 399 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L399

Added line #L399 was not covered by tests

ODESolution{T, N}(u, sol.u_analytic, sol.errors, t,
nothing, sol.prob,
Expand Down

0 comments on commit 5d511a6

Please sign in to comment.