Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add constructor to convert NLLS to OptimizationProblem #538

Merged
merged 3 commits into from
Nov 8, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/problems/basic_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,14 @@
OptimizationProblem{true}(OptimizationFunction{true}(f), args...; kwargs...)
end

function OptimizationProblem(prob::NonlinearLeastSquaresProblem; kwargs...)
if isinplace(prob)
throw(ArgumentError("Converting NonlinearLeastSquaresProblem to OptimizationProblem is not supported with in-place functions yet."))

Check warning on line 693 in src/problems/basic_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/basic_problems.jl#L691-L693

Added lines #L691 - L693 were not covered by tests
end
optf = OptimizationFunction(sum ∘ prob.f, grad = (Jv, u, p) -> prob.f.jvp(Jv, prob.f(u, p), u, p), kwargs...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this assumes jvp is defined? That's a pretty strong assumption?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah..I realise now that I was assuming this to come from being populated by an instatiate_function equivalent but it'll have to be provided by the user to be available here right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sum(abs2, prob.f(u,p)), not just sum

Copy link
Member Author

@Vaibhavdixit02 Vaibhavdixit02 Nov 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the loss function for NLLS already has the square

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@avik-pal wait, I thought the interface would be the same as nonlinearsolve, i.e. returning the vector of residuals. f(du,u,p) and then sum(abs2,du) is the loss. Now I'm looking at the tests and see the user is required to write it. Having that summation be implicit enforces it's actually a least squares problem.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right it is same as NonlinearProblem, sum ∘ prob.f should be sum(abs2,....)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC what Chris meant was the other way around, the NLLS problem should do the squaring implicitly and not ask the user to do that (and then the loss here would just be the sum)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I see what happened. SciML/NonlinearSolve.jl@b8aca89 the original test was just weird, but it was since fixed SciML/NonlinearSolve.jl@883a392. This means it's sum(abs2,...).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, I had seen the previous version and added this

return OptimizationProblem(optf, prob.u0, prob.p; prob.kwargs..., kwargs...)

Check warning on line 696 in src/problems/basic_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/basic_problems.jl#L695-L696

Added lines #L695 - L696 were not covered by tests
end

isinplace(f::OptimizationFunction{iip}) where {iip} = iip
isinplace(f::OptimizationProblem{iip}) where {iip} = iip

Expand Down
Loading