Skip to content

Commit

Permalink
fix optimizer bug in #61 (#70)
Browse files Browse the repository at this point in the history
* fix optimizer bug in #61

* regenerate snapshots on intel hardware

* rephrase
  • Loading branch information
topepo authored Nov 2, 2023
1 parent 087129b commit 7f7a421
Show file tree
Hide file tree
Showing 12 changed files with 153 additions and 293 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# brulee (development version)

* Fixed a bug where SGD always being used as the optimizer (#61).

# brulee 0.2.0

* Several learning rate schedulers were added to the modeling functions (#12).
Expand Down
21 changes: 8 additions & 13 deletions R/linear_reg-fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -452,12 +452,17 @@ linear_reg_fit_imp <-
}
loss_label <- "\tLoss (scaled):"

if (optimizer == "LBFGS" & !is.null(batch_size)) {
rlang::warn("'batch_size' is only used for the SGD optimizer.")
batch_size <- NULL
}
if (is.null(batch_size)) {
batch_size <- nrow(x)
} else {
batch_size <- min(batch_size, nrow(x))
}


## ---------------------------------------------------------------------------
# Convert to index sampler and data loader
ds <- brulee::matrix_to_dataset(x, y)
Expand All @@ -472,17 +477,7 @@ linear_reg_fit_imp <-
# Initialize model and optimizer
model <- linear_reg_module(ncol(x))
loss_fn <- make_penalized_loss(loss_fn, model, penalty, mixture)

# Write a optim wrapper
if (optimizer == "LBFGS") {
optimizer <- torch::optim_lbfgs(model$parameters, lr = learn_rate,
history_size = 5)
} else if (optimizer == "SGD") {
optimizer <-
torch::optim_sgd(model$parameters, lr = learn_rate, momentum = momentum)
} else {
rlang::abort(paste0("Unknown optimizer '", optimizer, "'"))
}
optimizer_obj <- set_optimizer(optimizer, model, learn_rate, momentum)

## ---------------------------------------------------------------------------

Expand All @@ -505,13 +500,13 @@ linear_reg_fit_imp <-
coro::loop(
for (batch in dl) {
cl <- function() {
optimizer$zero_grad()
optimizer_obj$zero_grad()
pred <- model(batch$x)
loss <- loss_fn(pred, batch$y)
loss$backward()
loss
}
optimizer$step(cl)
optimizer_obj$step(cl)
}
)

Expand Down
20 changes: 7 additions & 13 deletions R/logistic_reg-fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,10 @@ logistic_reg_fit_imp <-
y_stats <- list(mean = NA_real_, sd = NA_real_)
loss_label <- "\tLoss:"

if (optimizer == "LBFGS" & !is.null(batch_size)) {
rlang::warn("'batch_size' is only used for the SGD optimizer.")
batch_size <- NULL
}
if (is.null(batch_size)) {
batch_size <- nrow(x)
} else {
Expand All @@ -483,17 +487,7 @@ logistic_reg_fit_imp <-
# Initialize model and optimizer
model <- logistic_module(ncol(x), y_dim)
loss_fn <- make_penalized_loss(loss_fn, model, penalty, mixture)

# Write a optim wrapper
if (optimizer == "LBFGS") {
optimizer <- torch::optim_lbfgs(model$parameters, lr = learn_rate,
history_size = 5)
} else if (optimizer == "SGD") {
optimizer <-
torch::optim_sgd(model$parameters, lr = learn_rate, momentum = momentum)
} else {
rlang::abort(paste0("Unknown optimizer '", optimizer, "'"))
}
optimizer_obj <- set_optimizer(optimizer, model, learn_rate, momentum)

## ---------------------------------------------------------------------------

Expand All @@ -517,13 +511,13 @@ logistic_reg_fit_imp <-
coro::loop(
for (batch in dl) {
cl <- function() {
optimizer$zero_grad()
optimizer_obj$zero_grad()
pred <- model(batch$x)
loss <- loss_fn(pred, batch$y, class_weights)
loss$backward()
loss
}
optimizer$step(cl)
optimizer_obj$step(cl)
}
)

Expand Down
35 changes: 19 additions & 16 deletions R/mlp-fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@
#' set.seed(2)
#' cls_fit <- brulee_mlp(class ~ ., data = parabolic_tr, hidden_units = 2,
#' epochs = 200L, learn_rate = 0.1, activation = "elu",
#' penalty = 0.1, batch_size = 2^8)
#' penalty = 0.1, batch_size = 2^8, optimizer = "SGD")
#' autoplot(cls_fit)
#'
#' grid_points <- seq(-4, 4, length.out = 100)
Expand Down Expand Up @@ -435,9 +435,9 @@ brulee_mlp_bridge <- function(processed, epochs, hidden_units, activation,
if (length(hidden_units) != length(activation)) {
rlang::abort("'activation' must be a single value or a vector with the same length as 'hidden_units'")
}

if (optimizer == "LBFGS" & !is.null(batch_size)) {
rlang::warn("'batch_size' is only use for the SGD optimizer.")
rlang::warn("'batch_size' is only used for the SGD optimizer.")
batch_size <- NULL
}

check_integer(epochs, single = TRUE, 1, fn = f_nm)
Expand Down Expand Up @@ -656,16 +656,8 @@ mlp_fit_imp <-
model <- mlp_module(ncol(x), hidden_units, activation, dropout, y_dim)
loss_fn <- make_penalized_loss(loss_fn, model, penalty, mixture)

# Set the optimizer
if (optimizer == "LBFGS") {
optimizer <- torch::optim_lbfgs(model$parameters, lr = learn_rate,
history_size = 5)
} else if (optimizer == "SGD") {
optimizer <-
torch::optim_sgd(model$parameters, lr = learn_rate, momentum = momentum)
} else {
rlang::abort(paste0("Unknown optimizer '", optimizer, "'"))
}
# Set the optimizer (will be set again below)
optimizer_obj <- set_optimizer(optimizer, model, learn_rate, momentum)

## ---------------------------------------------------------------------------

Expand Down Expand Up @@ -694,19 +686,19 @@ mlp_fit_imp <-
# resetting them can interfere in training."

learn_rate <- set_learn_rate(epoch - 1, learn_rate, type = rate_schedule, ...)
optimizer <- torch::optim_sgd(model$parameters, lr = learn_rate, momentum = momentum)
optimizer_obj <- set_optimizer(optimizer, model, learn_rate, momentum)

# training loop
coro::loop(
for (batch in dl) {
cl <- function() {
optimizer$zero_grad()
optimizer_obj$zero_grad()
pred <- model(batch$x)
loss <- loss_fn(pred, batch$y, class_weights)
loss$backward()
loss
}
optimizer$step(cl)
optimizer_obj$step(cl)
}
)

Expand Down Expand Up @@ -874,3 +866,14 @@ get_activation_fn <- function(arg, ...) {
}
res
}

set_optimizer <- function(optimizer, model, learn_rate, momentum) {
if (optimizer == "LBFGS") {
res <- torch::optim_lbfgs(model$parameters, lr = learn_rate, history_size = 5)
} else if (optimizer == "SGD") {
res <- torch::optim_sgd(model$parameters, lr = learn_rate, momentum = momentum)
} else {
rlang::abort(paste0("Unknown optimizer '", optimizer, "'"))
}
res
}
20 changes: 7 additions & 13 deletions R/multinomial_reg-fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,10 @@ multinomial_reg_fit_imp <-
y_stats <- list(mean = NA_real_, sd = NA_real_)
loss_label <- "\tLoss:"

if (optimizer == "LBFGS" & !is.null(batch_size)) {
rlang::warn("'batch_size' is only used for the SGD optimizer.")
batch_size <- NULL
}
if (is.null(batch_size)) {
batch_size <- nrow(x)
} else {
Expand All @@ -464,17 +468,7 @@ multinomial_reg_fit_imp <-
# Initialize model and optimizer
model <- multinomial_module(ncol(x), y_dim)
loss_fn <- make_penalized_loss(loss_fn, model, penalty, mixture)

# Write a optim wrapper
if (optimizer == "LBFGS") {
optimizer <- torch::optim_lbfgs(model$parameters, lr = learn_rate,
history_size = 5)
} else if (optimizer == "SGD") {
optimizer <-
torch::optim_sgd(model$parameters, lr = learn_rate, momentum = momentum)
} else {
rlang::abort(paste0("Unknown optimizer '", optimizer, "'"))
}
optimizer_obj <- set_optimizer(optimizer, model, learn_rate, momentum)

## ---------------------------------------------------------------------------

Expand All @@ -498,13 +492,13 @@ multinomial_reg_fit_imp <-
coro::loop(
for (batch in dl) {
cl <- function() {
optimizer$zero_grad()
optimizer_obj$zero_grad()
pred <- model(batch$x)
loss <- loss_fn(pred, batch$y, class_weights)
loss$backward()
loss
}
optimizer$step(cl)
optimizer_obj$step(cl)
}
)

Expand Down
22 changes: 5 additions & 17 deletions inst/WORDLIST
Original file line number Diff line number Diff line change
@@ -1,28 +1,16 @@
CMD
Codecov
Hoerl
Kennard
LBFGS
Lifecycle
Multilayer
Nonorthogonal
ORCID
PBC
SGD
Springer
Technometrics
elu
extensibility
optimizers
perceptron
funder
mlp
multilayer
perceptrons
relu
relu’
tanh
CMD
Lifecycle
LBFGS
SGD
optimizers
ggplot
mlp
tibble
multilayer
2 changes: 1 addition & 1 deletion man/brulee_mlp.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions tests/testthat/_snaps/logistic_reg-fit.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Code
set.seed(1)
fit <- brulee_logistic_reg(y ~ ., df, epochs = 2, verbose = TRUE, penalty = 0)
Message <rlang_message>
Message
epoch: 1 Loss: 0.173
epoch: 2 Loss: 0.173 x

Expand All @@ -25,7 +25,7 @@
set.seed(1)
fit_imbal <- brulee_logistic_reg(y ~ ., df_imbal, verbose = TRUE,
class_weights = 20, optimizer = "SGD", penalty = 0)
Message <rlang_message>
Message
epoch: 1 Loss: 0.329
epoch: 2 Loss: 0.302
epoch: 3 Loss: 0.282
Expand Down Expand Up @@ -53,7 +53,7 @@
set.seed(1)
fit <- brulee_logistic_reg(y ~ ., df_imbal, epochs = 2, verbose = TRUE,
class_weights = c(a = 12, b = 1), penalty = 0)
Message <rlang_message>
Message
epoch: 1 Loss: 0.113
epoch: 2 Loss: 0.113 x

Loading

0 comments on commit 7f7a421

Please sign in to comment.