-
-
Notifications
You must be signed in to change notification settings - Fork 105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add constructor to convert NLLS to OptimizationProblem #538
Conversation
Codecov Report
@@ Coverage Diff @@
## master #538 +/- ##
==========================================
- Coverage 41.07% 40.98% -0.10%
==========================================
Files 53 53
Lines 4051 4060 +9
==========================================
Hits 1664 1664
- Misses 2387 2396 +9
📣 Codecov offers a browser extension for seamless coverage viewing on GitHub. Try it in Chrome or Firefox today! |
src/problems/basic_problems.jl
Outdated
if isinplace(prob) | ||
throw(ArgumentError("Converting NonlinearLeastSquaresProblem to OptimizationProblem is not supported with in-place functions yet.")) | ||
end | ||
optf = OptimizationFunction(sum ∘ prob.f, grad = (Jv, u, p) -> prob.f.jvp(Jv, prob.f(u, p), u, p), kwargs...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this assumes jvp is defined? That's a pretty strong assumption?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah..I realise now that I was assuming this to come from being populated by an instatiate_function equivalent but it'll have to be provided by the user to be available here right?
src/problems/basic_problems.jl
Outdated
if isinplace(prob) | ||
throw(ArgumentError("Converting NonlinearLeastSquaresProblem to OptimizationProblem is not supported with in-place functions yet.")) | ||
end | ||
optf = OptimizationFunction(sum ∘ prob.f, grad = (Jv, u, p) -> prob.f.jvp(Jv, prob.f(u, p), u, p), kwargs...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sum(abs2, prob.f(u,p))
, not just sum
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the loss function for NLLS already has the square
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@avik-pal wait, I thought the interface would be the same as nonlinearsolve, i.e. returning the vector of residuals. f(du,u,p)
and then sum(abs2,du)
is the loss. Now I'm looking at the tests and see the user is required to write it. Having that summation be implicit enforces it's actually a least squares problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right it is same as NonlinearProblem, sum ∘ prob.f
should be sum(abs2,....)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC what Chris meant was the other way around, the NLLS problem should do the squaring implicitly and not ask the user to do that (and then the loss here would just be the sum)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay I see what happened. SciML/NonlinearSolve.jl@b8aca89 the original test was just weird, but it was since fixed SciML/NonlinearSolve.jl@883a392. This means it's sum(abs2,...)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes, I had seen the previous version and added this
5f200fa
to
4856a9d
Compare
4856a9d
to
169d5c2
Compare
Looks ready to merge? In-place is going to be much more common than out of place, so that would be good to support, but a proper error message is fine for now. |
test/downstream/nllsopt.jl
Outdated
using NonlinearSolve, Optimization, OptimizationNLopt, ForwardDiff | ||
import FastLevenbergMarquardt, LeastSquaresOptim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These downstream deps would be required. But just remove FastLevenbergMarquardt
and LeastSquaresOptim
?
Co-authored-by: Christopher Rackauckas <[email protected]>
Yeah should be ready. But I don't completely like the interface, it needs the adtype as positional arg for the OptimizationProblem constructor for the "automatic" conversion but can't see any other way to do it for now. |
Fixes SciML/Optimization.jl#622