Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Breaking] Use NonlinearSolve for all root finding needs #203

Merged
merged 12 commits into from
Feb 22, 2024
12 changes: 6 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DiffEqCallbacks"
uuid = "459566f4-90b8-5000-8ac3-15dfb0a30def"
authors = ["Chris Rackauckas <[email protected]>"]
version = "2.37.0"
version = "2.38.0"

[deps]
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand All @@ -10,7 +10,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Expand All @@ -30,13 +30,13 @@ ForwardDiff = "0.10.19"
Functors = "0.4"
LinearAlgebra = "1.10"
Markdown = "1.10"
NLsolve = "4.5"
NonlinearSolve = "3.6"
ODEProblemLibrary = "0.1.5"
OrdinaryDiffEq = "6.68"
Parameters = "0.12"
QuadGK = "2.4"
RecipesBase = "1.1"
RecursiveArrayTools = "2.38, 3"
RecursiveArrayTools = "3"
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved
SciMLBase = "2.9"
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved
SciMLSensitivity = "7.49"
StaticArrays = "1.8"
Expand All @@ -50,6 +50,7 @@ julia = "1.10"
[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Expand All @@ -61,5 +62,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "DataInterpolations", "OrdinaryDiffEq", "ODEProblemLibrary", "Test", "QuadGK", "SciMLSensitivity", "StaticArrays", "Tracker", "Zygote"]

test = ["Aqua", "DataInterpolations", "OrdinaryDiffEq", "ODEProblemLibrary", "Test", "QuadGK", "SciMLSensitivity", "StaticArrays", "Tracker", "Zygote", "NonlinearSolve"]
2 changes: 1 addition & 1 deletion src/DiffEqCallbacks.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module DiffEqCallbacks

using DiffEqBase, RecursiveArrayTools, DataStructures, RecipesBase, LinearAlgebra,
StaticArraysCore, NLsolve, ForwardDiff, Functors
StaticArraysCore, NonlinearSolve, ForwardDiff, Functors

import Base.Iterators

Expand Down
38 changes: 17 additions & 21 deletions src/domain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,10 @@ end
# callback definitions

"""
```julia
GeneralDomain(g, u = nothing; nlsolve = NLSOLVEJL_SETUP(), save = true,
abstol = nothing, scalefactor = nothing,
autonomous = maximum(SciMLBase.numargs(g)) == 3,
nlopts = Dict(:ftol => 10 * eps()))
```
GeneralDomain(
g, u = nothing; save = true, abstol = nothing, scalefactor = nothing,
autonomous = maximum(SciMLBase.numargs(g)) == 3, nlsolve_kwargs = (;
abstol = 10 * eps()), kwargs...)

A `GeneralDomain` callback in DiffEqCallbacks.jl generalizes the concept of
a `PositiveDomain` callback to arbitrary domains. Domains are specified by
Expand All @@ -242,41 +240,39 @@ preferred.

## Keyword Arguments

- `nlsolve`: A nonlinear solver as defined [in the nlsolve format](https://docs.sciml.ai/DiffEqDocs/stable/features/linear_nonlinear/)
which is passed to a `ManifoldProjection`.
- `save`: Whether to do the standard saving (applied after the callback).
- `abstol`: Tolerance up to, which residuals are accepted. Element-wise tolerances
are allowed. If it is not specified, every application of the callback uses the
current absolute tolerances of the integrator.
- `scalefactor`: Factor by which an unaccepted time step is reduced. If it is not
specified, time steps are halved.
- `autonomous`: Whether `g` is an autonomous function of the form `g(resid, u, p)`.
- `nlopts`: Optional arguments to nonlinear solver of a `ManifoldProjection` which
can be any of the [NLsolve keywords](https://github.com/JuliaNLSolvers/NLsolve.jl#fine-tunings).
The default value of `ftol = 10*eps()` ensures that convergence is only declared
if the infinite norm of residuals is very small and hence the state vector is very
close to the domain.
If it is not specified, it is determined automatically.
- `kwargs`: All other keyword arguments are passed to `ManifoldProjection`.
- `nlsolve_kwargs`: All keyword arguments are passed to the nonlinear solver in
`ManifoldProjection`. The default is `(; abstol = 10 * eps())`.

## References

Shampine, Lawrence F., Skip Thompson, Jacek Kierzenka and G. D. Byrne.
Non-negative solutions of ODEs. Applied Mathematics and Computation 170
(2005): 556-569.
"""
function GeneralDomain(g, u = nothing; nlsolve = NLSOLVEJL_SETUP(), save = true,
abstol = nothing, scalefactor = nothing,
autonomous = maximum(SciMLBase.numargs(g)) == 3,
nlopts = Dict(:ftol => 10 * eps()))
function GeneralDomain(
g, u = nothing; save = true, abstol = nothing, scalefactor = nothing,
autonomous = maximum(SciMLBase.numargs(g)) == 3, nlsolve_kwargs = (;
abstol = 10 * eps()), kwargs...)
_autonomous = SciMLBase._unwrap_val(autonomous)
if u isa Nothing
affect! = GeneralDomainAffect{autonomous}(g, abstol, scalefactor, nothing, nothing)
affect! = GeneralDomainAffect{_autonomous}(g, abstol, scalefactor, nothing, nothing)
else
affect! = GeneralDomainAffect{autonomous}(g, abstol, scalefactor, deepcopy(u),
affect! = GeneralDomainAffect{_autonomous}(g, abstol, scalefactor, deepcopy(u),
deepcopy(u))
end
condition = (u, t, integrator) -> true
CallbackSet(
ManifoldProjection(g; nlsolve = nlsolve, save = false,
autonomous = autonomous, nlopts = nlopts),
ManifoldProjection(
g; save = false, autonomous, isinplace = Val(true), kwargs..., nlsolve_kwargs...),
DiscreteCallback(condition, affect!; save_positions = (false, save)))
end

Expand Down
181 changes: 102 additions & 79 deletions src/manifold.jl
Original file line number Diff line number Diff line change
@@ -1,46 +1,28 @@
Base.@pure function determine_chunksize(u, alg::DiffEqBase.DEAlgorithm)
determine_chunksize(u, get_chunksize(alg))
end
Base.@pure function determine_chunksize(u, CS)
if CS != 0
return CS
else
return ForwardDiff.pickchunksize(length(u))
end
end

struct NLSOLVEJL_SETUP{CS, AD} end
Base.@pure function NLSOLVEJL_SETUP(; chunk_size = 0, autodiff = true)
NLSOLVEJL_SETUP{chunk_size, autodiff}()
end
(::NLSOLVEJL_SETUP)(f, u0; kwargs...) = (res = NLsolve.nlsolve(f, u0; kwargs...); res.zero)
function (p::NLSOLVEJL_SETUP{CS, AD})(::Type{Val{:init}}, f, u0_prototype) where {CS, AD}
AD ? autodiff = :forward : autodiff = :central
OnceDifferentiable(f, u0_prototype, u0_prototype, autodiff,
ForwardDiff.Chunk(determine_chunksize(u0_prototype, CS)))
end

# wrapper for non-autonomous functions
mutable struct NonAutonomousFunction{F, autonomous}
mutable struct NonAutonomousFunction{iip, F, autonomous}
f::F
t::Any
p::Any
end
(p::NonAutonomousFunction{F, true})(res, u) where {F} = p.f(res, u, p.p)
(p::NonAutonomousFunction{F, false})(res, u) where {F} = p.f(res, u, p.p, p.t)

(f::NonAutonomousFunction{true, F, true})(res, u, p) where {F} = f.f(res, u, p)
(f::NonAutonomousFunction{true, F, false})(res, u, p) where {F} = f.f(res, u, p, f.t)

(f::NonAutonomousFunction{false, F, true})(u, p) where {F} = f.f(u, p)
(f::NonAutonomousFunction{false, F, false})(u, p) where {F} = f.f(u, p, f.t)

SciMLBase.isinplace(::NonAutonomousFunction{iip}) where {iip} = iip

"""
```julia
ManifoldProjection(g; nlsolve = NLSOLVEJL_SETUP(), save = true)
```

In many cases, you may want to declare a manifold on which a solution lives.
Mathematically, a manifold `M` is defined by a function `g` as the set of points
where `g(u)=0`. An embedded manifold can be a lower dimensional object which
constrains the solution. For example, `g(u)=E(u)-C` where `E` is the energy
of the system in state `u`, meaning that the energy must be constant (energy
preservation). Thus by defining the manifold the solution should live on, you
can retain desired properties of the solution.
ManifoldProjection(g; nlsolve = missing, save = true, nlls = Val(true),
isinplace = Val(true), autonomous = nothing, resid_prototype = nothing,
kwargs...)

In many cases, you may want to declare a manifold on which a solution lives. Mathematically,
a manifold `M` is defined by a function `g` as the set of points where `g(u) = 0`. An
embedded manifold can be a lower dimensional object which constrains the solution. For
example, `g(u) = E(u) - C` where `E` is the energy of the system in state `u`, meaning that
the energy must be constant (energy preservation). Thus by defining the manifold the
solution should live on, you can retain desired properties of the solution.

`ManifoldProjection` projects the solution of the differential equation to the chosen
manifold `g`, conserving a property while conserving the order. It is a consequence of
Expand All @@ -52,80 +34,121 @@ properties.

## Arguments

- `g`: The residual function for the manifold. This is an inplace function of form
`g(resid, u)` or `g(resid, u, p, t)` which writes to the residual `resid` the
difference from the manifold components. Here, it is assumed that `resid` is of
the same shape as `u`.
- `g`: The residual function for the manifold.

* This is an inplace function of form `g(resid, u, p)` or `g(resid, u, p, t)` which
writes to the residual `resid` the difference from the manifold components. Here, it
is assumed that `resid` is of the same shape as `u`.
* If `isinplace = Val(false)`, then `g` should be a function of the form `g(u, p)` or
`g(u, p, t)` which returns the residual.

## Keyword Arguments

- `nlsolve`: A nonlinear solver as defined [in the nlsolve format](https://docs.sciml.ai/DiffEqDocs/stable/features/linear_nonlinear/)
- `nlsolve`: A nonlinear solver as defined in the
[NonlinearSolve.jl format](https://docs.sciml.ai/NonlinearSolve/stable/basics/solve/)
- `save`: Whether to do the standard saving (applied after the callback)
- `autonomous`: Whether `g` is an autonomous function of the form `g(resid, u)`.
- `nlopts`: Optional arguments to nonlinear solver which can be any of the [NLsolve keywords](https://github.com/JuliaNLSolvers/NLsolve.jl#fine-tunings).
- `nlls`: If the problem is a nonlinear least squares problem. `nlls = Val(false)`
generates a `NonlinearProblem` which is typically faster than
`NonlinearLeastSquaresProblem`, but is only applicable if the residual size is same as
the state size.
- `autonomous`: Whether `g` is an autonomous function of the form `g(resid, u, p)` or
`g(u, p)`. Specify it as `Val(::Bool)` to ensure this function call is type stable.
- `residual_prototype`: This needs to be specified if `nlls = Val(true)` and
`inplace = Val(true)` are specified together, else it is taken to be same as `u`.
- `kwargs`: All other keyword arguments are passed to
[NonlinearSolve.jl](https://docs.sciml.ai/NonlinearSolve/stable/basics/solve/).

### Saveat Warning

Note that the `ManifoldProjection` callback modifies the endpoints of the integration intervals
and thus breaks assumptions of internal interpolations. Because of this, the values for given by
saveat will not be order-matching. However, the interpolation error can be proportional to the
change by the projection, so if the projection is making small changes then one is still safe.
However, if there are large changes from each projection, you should consider only saving at
stopping/projection times. To do this, set `tstops` to the same values as `saveat`. There is a
performance hit by doing so because now the integrator is forced to stop at every saving point,
but this is guerenteed to match the order of the integrator even with the ManifoldProjection.
Note that the `ManifoldProjection` callback modifies the endpoints of the integration
intervals and thus breaks assumptions of internal interpolations. Because of this, the
values for given by saveat will not be order-matching. However, the interpolation error can
be proportional to the change by the projection, so if the projection is making small
changes then one is still safe. However, if there are large changes from each projection,
you should consider only saving at stopping/projection times. To do this, set `tstops` to
the same values as `saveat`. There is a performance hit by doing so because now the
integrator is forced to stop at every saving point, but this is guerenteed to match the
order of the integrator even with the ManifoldProjection.

## References

Ernst Hairer, Christian Lubich, Gerhard Wanner. Geometric Numerical Integration:
Structure-Preserving Algorithms for Ordinary Differential Equations. Berlin ;
New York :Springer, 2002.
"""
mutable struct ManifoldProjection{autonomous, F, NL, O}
mutable struct ManifoldProjection{iip, nlls, autonomous, F, NL, R, K}
g::F
nl_rhs::Any
nlcache::Any
nlsolve::NL
nlopts::O
resid_prototype::R
kwargs::K

function ManifoldProjection{autonomous}(g, nlsolve, nlopts) where {autonomous}
function ManifoldProjection{iip, nlls, autonomous}(
g, nlsolve, resid_prototype, kwargs) where {iip, nlls, autonomous}
# replace residual function if it is time-dependent
# since NLsolve only accepts functions with two arguments
_g = NonAutonomousFunction{typeof(g), autonomous}(g, 0, 0)
new{autonomous, typeof(_g), typeof(nlsolve), typeof(nlopts)}(_g, _g, nlsolve,
nlopts)
_g = NonAutonomousFunction{iip, typeof(g), autonomous}(g, 0)
return new{iip, nlls, autonomous, typeof(_g), typeof(nlsolve),
typeof(resid_prototype), typeof(kwargs)}(
_g, nothing, nlsolve, resid_prototype, kwargs)
end
end

# Now make `affect!` for this:
function (p::ManifoldProjection{autonomous, NL})(integrator) where {autonomous, NL}
function (p::ManifoldProjection{
iip, nlls, autonomous, NL})(integrator) where {iip, nlls,
autonomous, NL}
# update current time if residual function is time-dependent
if !autonomous
p.g.t = integrator.t
end
p.g.p = integrator.p
!autonomous && (p.g.t = integrator.t)

integrator.u .= p.nlsolve(p.nl_rhs, integrator.u; p.nlopts...)
end
# solve the nonlinear problem
reinit!(p.nlcache, integrator.u; p = integrator.p)
sol = solve!(p.nlcache)

function Manifold_initialize(cb, u::Number, t, integrator)
cb.affect!.nl_rhs = cb.affect!.nlsolve(Val{:init}, cb.affect!.g, [u])
u_modified!(integrator, false)
if !SciMLBase.successful_retcode(sol)
SciMLBase.terminate!(integrator, sol.retcode)
return
end

copyto!(integrator.u, sol.u)
end

function Manifold_initialize(cb, u, t, integrator)
cb.affect!.nl_rhs = cb.affect!.nlsolve(Val{:init}, cb.affect!.g, u)
return Manifold_initialize(cb.affect!, u, t, integrator)
end
function Manifold_initialize(
affect!::ManifoldProjection{iip, nlls}, u, t, integrator) where {iip, nlls}
nlfunc = NonlinearFunction{iip}(affect!.g; affect!.resid_prototype)
nlprob = if nlls
NonlinearLeastSquaresProblem(nlfunc, u, integrator.p)
else
NonlinearProblem(nlfunc, u, integrator.p)
end
affect!.nlcache = init(nlprob, affect!.nlsolve; affect!.kwargs...)
u_modified!(integrator, false)
end

function ManifoldProjection(g; nlsolve = NLSOLVEJL_SETUP(), save = true,
autonomous = maximum(SciMLBase.numargs(g)) == 3,
nlopts = Dict{Symbol, Any}())
affect! = ManifoldProjection{autonomous}(g, nlsolve, nlopts)
# Since this is applied to every point, we can reasonably assume that the solution is close
# to the initial guess, so we would want to use NewtonRaphson / RobustMultiNewton instead of
# the default one.
function ManifoldProjection(g; nlsolve = missing, save = true, nlls = Val(true),
isinplace = Val(true), autonomous = nothing, resid_prototype = nothing,
kwargs...)
# `nothing` is a valid solver, so this need to be `missing`
_nlls = SciMLBase._unwrap_val(nlls)
_nlsolve = nlsolve === missing ? (_nlls ? GaussNewton() : NewtonRaphson()) : nlsolve
iip = SciMLBase._unwrap_val(isinplace)
if autonomous === nothing
if iip
autonomous = maximum(SciMLBase.numargs(g)) == 3
else
autonomous = maximum(SciMLBase.numargs(g)) == 2
end
end
affect! = ManifoldProjection{iip, _nlls, SciMLBase._unwrap_val(autonomous)}(
g, _nlsolve, resid_prototype, kwargs)
condition = (u, t, integrator) -> true
save_positions = (false, save)
DiscreteCallback(condition, affect!;
initialize = Manifold_initialize,
save_positions = save_positions)
return DiscreteCallback(condition, affect!; initialize = Manifold_initialize,
save_positions = (false, save))
end

export ManifoldProjection
Loading
Loading