Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update solve docstring with new callback syntax #590

Merged
merged 1 commit into from
Jan 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,17 @@ The callback function `callback` is a function which is called after every optim
step. Its signature is:

```julia
callback = (u, loss_val, other_args) -> false
callback = (state, loss_val, other_args) -> false
```

where `u` and `loss_val` are the current optimization variables and loss/objective value
in the optimization loop and `other_args` can be the extra things returned from
the optimization `f`. This allows for saving values from the optimization and
where `state` is a `OptimizationState` and stores information for the current
iteration of the solver and `loss_val` is loss/objective value. For more
information about the fields of the `state` look at the `OptimizationState`
documentation. The `other_args` can be the extra things returned from the
optimization `f`. This allows for saving values from the optimization and
using them for plotting and display without recalculating. The callback should
return a Boolean value, and the default should be `false`, such that the optimization
gets stopped if it returns `true`.
return a Boolean value, and the default should be `false`, such that the
optimization gets stopped if it returns `true`.

### Callback Example

Expand All @@ -76,7 +78,7 @@ function loss(u, p)
sum(abs2, batch .- pred), pred
end

callback = function (p, l, pred; doplot = false) #callback function to observe training
callback = function (state, l, pred; doplot = false) #callback function to observe training
display(l)
# plot current prediction against data
if doplot
Expand Down
Loading