Skip to content

Commit

Permalink
Add OptimizationFunction conversion and adtype arg and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Nov 7, 2023
1 parent c838cf2 commit 5f200fa
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/problems/basic_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -688,11 +688,18 @@ function OptimizationProblem(f, args...; kwargs...)
OptimizationProblem{true}(OptimizationFunction{true}(f), args...; kwargs...)
end

function OptimizationProblem(prob::NonlinearLeastSquaresProblem; kwargs...)
function OptimizationFunction(f::NonlinearFunction, adtype::AbstractADType = NoAD(); kwargs...)
if isinplace(f)
throw(ArgumentError("Converting NonlinearFunction to OptimizationFunction is not supported with in-place functions yet."))

Check warning on line 693 in src/problems/basic_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/basic_problems.jl#L691-L693

Added lines #L691 - L693 were not covered by tests
end
OptimizationFunction((u, p) -> sum(abs2, f(u, p)), adtype; kwargs...)

Check warning on line 695 in src/problems/basic_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/basic_problems.jl#L695

Added line #L695 was not covered by tests
end

function OptimizationProblem(prob::NonlinearLeastSquaresProblem, adtype::AbstractADType = NoAD(); kwargs...)
if isinplace(prob)
throw(ArgumentError("Converting NonlinearLeastSquaresProblem to OptimizationProblem is not supported with in-place functions yet."))

Check warning on line 700 in src/problems/basic_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/basic_problems.jl#L698-L700

Added lines #L698 - L700 were not covered by tests
end
optf = OptimizationFunction(sum prob.f, grad = (Jv, u, p) -> prob.f.jvp(Jv, prob.f(u, p), u, p), kwargs...)
optf = OptimizationFunction(prob.f, adtype; kwargs...)
return OptimizationProblem(optf, prob.u0, prob.p; prob.kwargs..., kwargs...)

Check warning on line 703 in src/problems/basic_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/basic_problems.jl#L702-L703

Added lines #L702 - L703 were not covered by tests
end

Expand Down
29 changes: 29 additions & 0 deletions test/downstream/nllsopt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using NonlinearSolve, Optimization, OptimizationNLopt, ForwardDiff
import FastLevenbergMarquardt, LeastSquaresOptim

true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4])

θ_true = [1.0, 0.1, 2.0, 0.5]

x = [-1.0, -0.5, 0.0, 0.5, 1.0]

y_target = true_function(x, θ_true)

function loss_function(θ, p)
= true_function(p, θ)
return.- y_target
end

θ_init = θ_true .+ randn!(similar(θ_true)) * 0.1
prob_oop = NonlinearLeastSquaresProblem{false}(loss_function, θ_init, x)

solver = LevenbergMarquardt()

@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)

optf = OptimizationFunction(prob_oop.f, AutoForwardDiff())
optprob = OptimizationProblem(optf, prob_oop.u0, prob_oop.p)
@time sol = solve(optprob, NLopt.LD_LBFGS(); maxiters = 10000, abstol = 1e-8)

optprob = OptimizationProblem(prob_oop, AutoForwardDiff())
@time sol = solve(optprob, NLopt.LD_LBFGS(); maxiters = 10000, abstol = 1e-8)

0 comments on commit 5f200fa

Please sign in to comment.