Skip to content

Commit

Permalink
Merge pull request #567 from SciML/ap/nlls_bvp
Browse files Browse the repository at this point in the history
Add a nlls trait to BVProblem
  • Loading branch information
ChrisRackauckas authored Mar 25, 2024
2 parents 643114c + 7ff95fc commit 9baab96
Show file tree
Hide file tree
Showing 8 changed files with 170 additions and 83 deletions.
2 changes: 1 addition & 1 deletion src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ $(TYPEDEF)
Base for types which define BVP problems.
"""
abstract type AbstractBVProblem{uType, tType, isinplace} <:
abstract type AbstractBVProblem{uType, tType, isinplace, nlls} <:
AbstractODEProblem{uType, tType, isinplace} end

"""
Expand Down
66 changes: 56 additions & 10 deletions src/problems/bvp_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ struct TwoPointBVProblem{iip} end # The iip is needed to make type stable constr
@doc doc"""
Defines an BVP problem.
Documentation Page: https://docs.sciml.ai/DiffEqDocs/stable/types/bvp_types/
Documentation Page: [https://docs.sciml.ai/DiffEqDocs/stable/types/bvp_types/](https://docs.sciml.ai/DiffEqDocs/stable/types/bvp_types/)
## Mathematical Specification of a BVP Problem
Expand Down Expand Up @@ -41,16 +41,16 @@ u(t_f) = b
### Constructors
```julia
TwoPointBVProblem{isinplace}(f,bc,u0,tspan,p=NullParameters();kwargs...)
BVProblem{isinplace}(f,bc,u0,tspan,p=NullParameters();kwargs...)
TwoPointBVProblem{isinplace}(f, bc, u0, tspan, p=NullParameters(); kwargs...)
BVProblem{isinplace}(f, bc, u0, tspan, p=NullParameters(); kwargs...)
```
or if we have an initial guess function `initialGuess(t)` for the given BVP,
or if we have an initial guess function `initialGuess(p, t)` for the given BVP,
we can pass the initial guess to the problem constructors:
```julia
TwoPointBVProblem{isinplace}(f,bc,initialGuess,tspan,p=NullParameters();kwargs...)
BVProblem{isinplace}(f,bc,initialGuess,tspan,p=NullParameters();kwargs...)
TwoPointBVProblem{isinplace}(f, bc, initialGuess, tspan, p=NullParameters(); kwargs...)
BVProblem{isinplace}(f, bc, initialGuess, tspan, p=NullParameters(); kwargs...)
```
For any BVP problem type, `bc` must be inplace if `f` is inplace. Otherwise it must be
Expand Down Expand Up @@ -104,9 +104,17 @@ every solve call.
* `tspan`: The timespan for the problem.
* `p`: The parameters for the problem. Defaults to `NullParameters`
* `kwargs`: The keyword arguments passed onto the solves.
### Special Keyword Arguments
- `nlls`: Specify that the BVP is a nonlinear least squares problem. Use `Val(true)` or
`Val(false)` for type stability. By default this is automatically inferred based on the
size of the input and outputs, however this is type unstable for any array type that
doesn't store array size as part of type information. If we can't reliably infer this,
we set it to `Nothing`. Downstreams solvers must be setup to deal with this case.
"""
struct BVProblem{uType, tType, isinplace, P, F, PT, K} <:
AbstractBVProblem{uType, tType, isinplace}
struct BVProblem{uType, tType, isinplace, nlls, P, F, PT, K} <:
AbstractBVProblem{uType, tType, isinplace, nlls}
f::F
u0::uType
tspan::tType
Expand All @@ -115,18 +123,56 @@ struct BVProblem{uType, tType, isinplace, P, F, PT, K} <:
kwargs::K

@add_kwonly function BVProblem{iip}(f::AbstractBVPFunction{iip, TP}, u0, tspan,
p = NullParameters(); problem_type = nothing, kwargs...) where {iip, TP}
p = NullParameters(); problem_type = nothing, nlls = nothing,
kwargs...) where {iip, TP}
_u0 = prepare_initial_state(u0)
_tspan = promote_tspan(tspan)
warn_paramtype(p)
prob_type = TP ? TwoPointBVProblem{iip}() : StandardBVProblem()

# Needed to ensure that `problem_type` doesn't get passed in kwargs
if problem_type === nothing
problem_type = prob_type
else
@assert prob_type===problem_type "This indicates incorrect problem type specification! Users should never pass in `problem_type` kwarg, this exists exclusively for internal use."
end
return new{typeof(_u0), typeof(_tspan), iip, typeof(p), typeof(f),

if nlls === nothing
if !hasmethod(length, Tuple{typeof(_u0)})
# If _u0 is a function for initial guess we won't be able to infer
__u0 = _u0 isa Function ?
(hasmethod(_u0, Tuple{typeof(p), typeof(first(_tspan))}) ?
_u0(p, first(_tspan)) : _u0(first(_tspan))) : nothing
else
__u0 = _u0
end
# Try to infer it
if __u0 isa Nothing
_nlls = Nothing
elseif problem_type isa TwoPointBVProblem
if f.bcresid_prototype !== nothing
l1, l2 = length(f.bcresid_prototype[1]), length(f.bcresid_prototype[2])
_nlls = l1 + l2 != length(__u0)
else
# iip without bcresid_prototype is not possible
if !iip
l1 = length(f.bc[1](u0, p))
l2 = length(f.bc[2](u0, p))
_nlls = l1 + l2 != length(__u0)
end
end
else
if f.bcresid_prototype !== nothing
_nlls = length(f.bcresid_prototype) != length(__u0)
else
_nlls = Nothing # Cannot reliably infer
end
end
else
_nlls = _unwrap_val(nlls)
end

return new{typeof(_u0), typeof(_tspan), iip, _nlls, typeof(p), typeof(f),
typeof(problem_type), typeof(kwargs)}(f, _u0, _tspan, p, problem_type, kwargs)
end

Expand Down
29 changes: 16 additions & 13 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ function isrecompile(prob::ODEProblem{iip}) where {iip}
(prob.f isa ODEFunction) ? !isfunctionwrapper(prob.f.f) : true
end

function remake(prob::AbstractSciMLProblem; u0 = missing, p = missing, interpret_symbolicmap = true, kwargs...)
function remake(prob::AbstractSciMLProblem; u0 = missing,
p = missing, interpret_symbolicmap = true, kwargs...)
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)
_remake_internal(prob; kwargs..., u0, p)
end
Expand All @@ -54,7 +55,8 @@ function remake(prob::AbstractNoiseProblem; kwargs...)
_remake_internal(prob; kwargs...)
end

function remake(prob::AbstractIntegralProblem; p = missing, interpret_symbolicmap = true, kwargs...)
function remake(
prob::AbstractIntegralProblem; p = missing, interpret_symbolicmap = true, kwargs...)
p = updated_p(prob, p; interpret_symbolicmap)
_remake_internal(prob; kwargs..., p)
end
Expand Down Expand Up @@ -128,16 +130,15 @@ end
Remake the given `BVProblem`.
"""
function remake(prob::BVProblem; f = missing, bc = missing, u0 = missing, tspan = missing,
p = missing, kwargs = missing, problem_type = missing, interpret_symbolicmap = true, _kwargs...)
function remake(prob::BVProblem{uType, tType, iip, nlls}; f = missing, bc = missing,
u0 = missing, tspan = missing, p = missing, kwargs = missing, problem_type = missing,
interpret_symbolicmap = true, _kwargs...) where {uType, tType, iip, nlls}
if tspan === missing
tspan = prob.tspan
end

u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)

iip = isinplace(prob)

if problem_type === missing
problem_type = prob.problem_type
end
Expand Down Expand Up @@ -170,9 +171,11 @@ function remake(prob::BVProblem; f = missing, bc = missing, u0 = missing, tspan
end

if kwargs === missing
BVProblem{iip}(_f, bc, u0, tspan, p; problem_type, prob.kwargs..., _kwargs...)
BVProblem{iip}(
_f, bc, u0, tspan, p; problem_type, nlls = Val(nlls), prob.kwargs...,
_kwargs...)
else
BVProblem{iip}(_f, bc, u0, tspan, p; problem_type, kwargs...)
BVProblem{iip}(_f, bc, u0, tspan, p; problem_type, nlls = Val(nlls), kwargs...)
end
end

Expand Down Expand Up @@ -254,7 +257,6 @@ function remake(prob::OptimizationProblem;
kwargs = missing,
interpret_symbolicmap = true,
_kwargs...)

u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)
if f === missing
f = prob.f
Expand Down Expand Up @@ -393,10 +395,11 @@ function updated_p(prob, p; interpret_symbolicmap = true)
end
if eltype(p) <: Pair
if interpret_symbolicmap
has_sys(prob.f) || throw(ArgumentError("This problem does not support symbolic maps with " *
"`remake`, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `p` keyword argument as a vector of values (paying attention to" *
"parameter order) or pass `interpret_symbolicmap = false` as a keyword argument"))
has_sys(prob.f) ||
throw(ArgumentError("This problem does not support symbolic maps with " *
"`remake`, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `p` keyword argument as a vector of values (paying attention to" *
"parameter order) or pass `interpret_symbolicmap = false` as a keyword argument"))
else
return p
end
Expand Down
64 changes: 41 additions & 23 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1929,35 +1929,46 @@ $(TYPEDEF)
A representation of a BVP function `f`, defined by:
```math
\frac{du}{dt}=f(u,p,t)
\frac{du}{dt} = f(u, p, t)
```
and the constraints:
```math
\frac{du}{dt}=g(u,p,t)
g(u, p, t) = 0
```
If the size of `g(u, p, t)` is different from the size of `u`, then the constraints are
interpreted as a least squares problem, i.e. the objective function is:
```math
\min_{u} \| g_i(u, p, t) \|^2
```
and all of its related functions, such as the Jacobian of `f`, its gradient
with respect to time, and more. For all cases, `u0` is the initial condition,
`p` are the parameters, and `t` is the independent variable.
```julia
BVPFunction{iip,specialize}(f, bc;
mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I,
analytic = __has_analytic(f) ? f.analytic : nothing,
tgrad= __has_tgrad(f) ? f.tgrad : nothing,
jac = __has_jac(f) ? f.jac : nothing,
bcjac = __has_jac(bc) ? bc.jac : nothing,
jvp = __has_jvp(f) ? f.jvp : nothing,
vjp = __has_vjp(f) ? f.vjp : nothing,
jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing,
bcjac_prototype = __has_jac_prototype(bc) ? bc.jac_prototype : nothing,
sparsity = __has_sparsity(f) ? f.sparsity : jac_prototype,
paramjac = __has_paramjac(f) ? f.paramjac : nothing,
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
bccolorvec = __has_colorvec(f) ? bc.colorvec : nothing,
sys = __has_sys(f) ? f.sys : nothing)
BVPFunction{iip, specialize}(f, bc;
mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I,
analytic = __has_analytic(f) ? f.analytic : nothing,
tgrad= __has_tgrad(f) ? f.tgrad : nothing,
jac = __has_jac(f) ? f.jac : nothing,
bcjac = __has_jac(bc) ? bc.jac : nothing,
jvp = __has_jvp(f) ? f.jvp : nothing,
vjp = __has_vjp(f) ? f.vjp : nothing,
jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing,
bcjac_prototype = __has_jac_prototype(bc) ? bc.jac_prototype : nothing,
sparsity = __has_sparsity(f) ? f.sparsity : jac_prototype,
paramjac = __has_paramjac(f) ? f.paramjac : nothing,
syms = nothing,
indepsym= nothing,
paramsyms = nothing,
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
bccolorvec = __has_colorvec(f) ? bc.colorvec : nothing,
sys = __has_sys(f) ? f.sys : nothing,
twopoint::Union{Val, Bool} = Val(false)
```
Note that both the function `f` and boundary condition `bc` are required. `f` should
Expand Down Expand Up @@ -1985,7 +1996,7 @@ the usage of `f` and `bc`. These include:
sparsity patterns should use a `SparseMatrixCSC` with a correct sparsity pattern for the Jacobian.
The default is `nothing`, which means a dense Jacobian.
- `bcjac_prototype`: a prototype matrix matching the type that matches the Jacobian. For example,
if the Jacobian is tridiagonal, then an appropriately sized `Tridiagonal` matrix can be used
if the Jacobian is tridiagonal, then an appropriately sized `Tridiagonal` matrix can be used
as the prototype and integrators will specialize on this structure where possible. Non-structured
sparsity patterns should use a `SparseMatrixCSC` with a correct sparsity pattern for the Jacobian.
The default is `nothing`, which means a dense Jacobian.
Expand All @@ -2003,6 +2014,11 @@ the usage of `f` and `bc`. These include:
internally computed on demand when required. The cost of this operation is highly dependent
on the sparsity pattern.
Additional Options:
- `twopoint`: Specify that the BVP is a two-point boundary value problem. Use `Val(true)` or
`Val(false)` for type stability.
## iip: In-Place vs Out-Of-Place
For more details on this argument, see the ODEFunction documentation.
Expand Down Expand Up @@ -3801,7 +3817,7 @@ function BVPFunction{iip, specialize, twopoint}(f, bc;

_f = prepare_function(f)

sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym)
sys = something(sys, SymbolCache(syms, paramsyms, indepsym))

if specialize === NoSpecialize
BVPFunction{iip, specialize, twopoint, Any, Any, Any, Any, Any,
Expand All @@ -3813,9 +3829,9 @@ function BVPFunction{iip, specialize, twopoint}(f, bc;
sparsity, Wfact, Wfact_t, paramjac, observed,
_colorvec, _bccolorvec, sys)
else
BVPFunction{iip, specialize, twopoint, typeof(_f), typeof(bc), typeof(mass_matrix),
typeof(analytic), typeof(tgrad), typeof(jac), typeof(bcjac), typeof(jvp),
typeof(vjp), typeof(jac_prototype),
BVPFunction{iip, specialize, twopoint, typeof(_f), typeof(bc),
typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac),
typeof(bcjac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
typeof(bcjac_prototype), typeof(bcresid_prototype), typeof(sparsity),
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed),
typeof(_colorvec), typeof(_bccolorvec), typeof(sys)}(
Expand Down Expand Up @@ -3897,7 +3913,9 @@ end
function sys_or_symbolcache(sys, syms, paramsyms, indepsym = nothing)
if sys === nothing &&
(syms !== nothing || paramsyms !== nothing || indepsym !== nothing)
Base.depwarn("The use of keyword arguments `syms`, `paramsyms` and `indepsym` for `SciMLFunction`s is deprecated. Pass `sys = SymbolCache(syms, paramsyms, indepsym)` instead.", :syms)
Base.depwarn(
"The use of keyword arguments `syms`, `paramsyms` and `indepsym` for `SciMLFunction`s is deprecated. Pass `sys = SymbolCache(syms, paramsyms, indepsym)` instead.",
:syms)
sys = SymbolCache(syms, paramsyms, indepsym)
end
return sys
Expand Down
Loading

0 comments on commit 9baab96

Please sign in to comment.