diff --git a/NEWS.md b/NEWS.md index 47ff525..6d4cf27 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # brulee (development version) +* Fixed a bug where SGD always being used as the optimizer (#61). + # brulee 0.2.0 * Several learning rate schedulers were added to the modeling functions (#12). diff --git a/R/linear_reg-fit.R b/R/linear_reg-fit.R index 39af2f1..a470b0e 100644 --- a/R/linear_reg-fit.R +++ b/R/linear_reg-fit.R @@ -452,12 +452,17 @@ linear_reg_fit_imp <- } loss_label <- "\tLoss (scaled):" + if (optimizer == "LBFGS" & !is.null(batch_size)) { + rlang::warn("'batch_size' is only used for the SGD optimizer.") + batch_size <- NULL + } if (is.null(batch_size)) { batch_size <- nrow(x) } else { batch_size <- min(batch_size, nrow(x)) } + ## --------------------------------------------------------------------------- # Convert to index sampler and data loader ds <- brulee::matrix_to_dataset(x, y) @@ -472,17 +477,7 @@ linear_reg_fit_imp <- # Initialize model and optimizer model <- linear_reg_module(ncol(x)) loss_fn <- make_penalized_loss(loss_fn, model, penalty, mixture) - - # Write a optim wrapper - if (optimizer == "LBFGS") { - optimizer <- torch::optim_lbfgs(model$parameters, lr = learn_rate, - history_size = 5) - } else if (optimizer == "SGD") { - optimizer <- - torch::optim_sgd(model$parameters, lr = learn_rate, momentum = momentum) - } else { - rlang::abort(paste0("Unknown optimizer '", optimizer, "'")) - } + optimizer_obj <- set_optimizer(optimizer, model, learn_rate, momentum) ## --------------------------------------------------------------------------- @@ -505,13 +500,13 @@ linear_reg_fit_imp <- coro::loop( for (batch in dl) { cl <- function() { - optimizer$zero_grad() + optimizer_obj$zero_grad() pred <- model(batch$x) loss <- loss_fn(pred, batch$y) loss$backward() loss } - optimizer$step(cl) + optimizer_obj$step(cl) } ) diff --git a/R/logistic_reg-fit.R b/R/logistic_reg-fit.R index 859bd36..f2e0aab 100644 --- a/R/logistic_reg-fit.R +++ b/R/logistic_reg-fit.R @@ -463,6 +463,10 @@ logistic_reg_fit_imp <- y_stats <- list(mean = NA_real_, sd = NA_real_) loss_label <- "\tLoss:" + if (optimizer == "LBFGS" & !is.null(batch_size)) { + rlang::warn("'batch_size' is only used for the SGD optimizer.") + batch_size <- NULL + } if (is.null(batch_size)) { batch_size <- nrow(x) } else { @@ -483,17 +487,7 @@ logistic_reg_fit_imp <- # Initialize model and optimizer model <- logistic_module(ncol(x), y_dim) loss_fn <- make_penalized_loss(loss_fn, model, penalty, mixture) - - # Write a optim wrapper - if (optimizer == "LBFGS") { - optimizer <- torch::optim_lbfgs(model$parameters, lr = learn_rate, - history_size = 5) - } else if (optimizer == "SGD") { - optimizer <- - torch::optim_sgd(model$parameters, lr = learn_rate, momentum = momentum) - } else { - rlang::abort(paste0("Unknown optimizer '", optimizer, "'")) - } + optimizer_obj <- set_optimizer(optimizer, model, learn_rate, momentum) ## --------------------------------------------------------------------------- @@ -517,13 +511,13 @@ logistic_reg_fit_imp <- coro::loop( for (batch in dl) { cl <- function() { - optimizer$zero_grad() + optimizer_obj$zero_grad() pred <- model(batch$x) loss <- loss_fn(pred, batch$y, class_weights) loss$backward() loss } - optimizer$step(cl) + optimizer_obj$step(cl) } ) diff --git a/R/mlp-fit.R b/R/mlp-fit.R index 3ccbd45..e158e71 100644 --- a/R/mlp-fit.R +++ b/R/mlp-fit.R @@ -201,7 +201,7 @@ #' set.seed(2) #' cls_fit <- brulee_mlp(class ~ ., data = parabolic_tr, hidden_units = 2, #' epochs = 200L, learn_rate = 0.1, activation = "elu", -#' penalty = 0.1, batch_size = 2^8) +#' penalty = 0.1, batch_size = 2^8, optimizer = "SGD") #' autoplot(cls_fit) #' #' grid_points <- seq(-4, 4, length.out = 100) @@ -435,9 +435,9 @@ brulee_mlp_bridge <- function(processed, epochs, hidden_units, activation, if (length(hidden_units) != length(activation)) { rlang::abort("'activation' must be a single value or a vector with the same length as 'hidden_units'") } - if (optimizer == "LBFGS" & !is.null(batch_size)) { - rlang::warn("'batch_size' is only use for the SGD optimizer.") + rlang::warn("'batch_size' is only used for the SGD optimizer.") + batch_size <- NULL } check_integer(epochs, single = TRUE, 1, fn = f_nm) @@ -656,16 +656,8 @@ mlp_fit_imp <- model <- mlp_module(ncol(x), hidden_units, activation, dropout, y_dim) loss_fn <- make_penalized_loss(loss_fn, model, penalty, mixture) - # Set the optimizer - if (optimizer == "LBFGS") { - optimizer <- torch::optim_lbfgs(model$parameters, lr = learn_rate, - history_size = 5) - } else if (optimizer == "SGD") { - optimizer <- - torch::optim_sgd(model$parameters, lr = learn_rate, momentum = momentum) - } else { - rlang::abort(paste0("Unknown optimizer '", optimizer, "'")) - } + # Set the optimizer (will be set again below) + optimizer_obj <- set_optimizer(optimizer, model, learn_rate, momentum) ## --------------------------------------------------------------------------- @@ -694,19 +686,19 @@ mlp_fit_imp <- # resetting them can interfere in training." learn_rate <- set_learn_rate(epoch - 1, learn_rate, type = rate_schedule, ...) - optimizer <- torch::optim_sgd(model$parameters, lr = learn_rate, momentum = momentum) + optimizer_obj <- set_optimizer(optimizer, model, learn_rate, momentum) # training loop coro::loop( for (batch in dl) { cl <- function() { - optimizer$zero_grad() + optimizer_obj$zero_grad() pred <- model(batch$x) loss <- loss_fn(pred, batch$y, class_weights) loss$backward() loss } - optimizer$step(cl) + optimizer_obj$step(cl) } ) @@ -874,3 +866,14 @@ get_activation_fn <- function(arg, ...) { } res } + +set_optimizer <- function(optimizer, model, learn_rate, momentum) { + if (optimizer == "LBFGS") { + res <- torch::optim_lbfgs(model$parameters, lr = learn_rate, history_size = 5) + } else if (optimizer == "SGD") { + res <- torch::optim_sgd(model$parameters, lr = learn_rate, momentum = momentum) + } else { + rlang::abort(paste0("Unknown optimizer '", optimizer, "'")) + } + res +} diff --git a/R/multinomial_reg-fit.R b/R/multinomial_reg-fit.R index acc5672..ff7d45e 100644 --- a/R/multinomial_reg-fit.R +++ b/R/multinomial_reg-fit.R @@ -444,6 +444,10 @@ multinomial_reg_fit_imp <- y_stats <- list(mean = NA_real_, sd = NA_real_) loss_label <- "\tLoss:" + if (optimizer == "LBFGS" & !is.null(batch_size)) { + rlang::warn("'batch_size' is only used for the SGD optimizer.") + batch_size <- NULL + } if (is.null(batch_size)) { batch_size <- nrow(x) } else { @@ -464,17 +468,7 @@ multinomial_reg_fit_imp <- # Initialize model and optimizer model <- multinomial_module(ncol(x), y_dim) loss_fn <- make_penalized_loss(loss_fn, model, penalty, mixture) - - # Write a optim wrapper - if (optimizer == "LBFGS") { - optimizer <- torch::optim_lbfgs(model$parameters, lr = learn_rate, - history_size = 5) - } else if (optimizer == "SGD") { - optimizer <- - torch::optim_sgd(model$parameters, lr = learn_rate, momentum = momentum) - } else { - rlang::abort(paste0("Unknown optimizer '", optimizer, "'")) - } + optimizer_obj <- set_optimizer(optimizer, model, learn_rate, momentum) ## --------------------------------------------------------------------------- @@ -498,13 +492,13 @@ multinomial_reg_fit_imp <- coro::loop( for (batch in dl) { cl <- function() { - optimizer$zero_grad() + optimizer_obj$zero_grad() pred <- model(batch$x) loss <- loss_fn(pred, batch$y, class_weights) loss$backward() loss } - optimizer$step(cl) + optimizer_obj$step(cl) } ) diff --git a/inst/WORDLIST b/inst/WORDLIST index ea02565..df7d874 100644 --- a/inst/WORDLIST +++ b/inst/WORDLIST @@ -1,28 +1,16 @@ CMD Codecov -Hoerl -Kennard LBFGS Lifecycle -Multilayer -Nonorthogonal ORCID +PBC SGD -Springer -Technometrics elu extensibility -optimizers -perceptron +funder +mlp +multilayer +perceptrons relu -relu’ tanh -CMD -Lifecycle -LBFGS -SGD -optimizers -ggplot -mlp tibble -multilayer diff --git a/man/brulee_mlp.Rd b/man/brulee_mlp.Rd index d73a870..530488b 100644 --- a/man/brulee_mlp.Rd +++ b/man/brulee_mlp.Rd @@ -319,7 +319,7 @@ if (torch::torch_is_installed()) { set.seed(2) cls_fit <- brulee_mlp(class ~ ., data = parabolic_tr, hidden_units = 2, epochs = 200L, learn_rate = 0.1, activation = "elu", - penalty = 0.1, batch_size = 2^8) + penalty = 0.1, batch_size = 2^8, optimizer = "SGD") autoplot(cls_fit) grid_points <- seq(-4, 4, length.out = 100) diff --git a/tests/testthat/_snaps/logistic_reg-fit.md b/tests/testthat/_snaps/logistic_reg-fit.md index 56e5dad..a225b83 100644 --- a/tests/testthat/_snaps/logistic_reg-fit.md +++ b/tests/testthat/_snaps/logistic_reg-fit.md @@ -3,7 +3,7 @@ Code set.seed(1) fit <- brulee_logistic_reg(y ~ ., df, epochs = 2, verbose = TRUE, penalty = 0) - Message + Message epoch: 1 Loss: 0.173 epoch: 2 Loss: 0.173 x @@ -25,7 +25,7 @@ set.seed(1) fit_imbal <- brulee_logistic_reg(y ~ ., df_imbal, verbose = TRUE, class_weights = 20, optimizer = "SGD", penalty = 0) - Message + Message epoch: 1 Loss: 0.329 epoch: 2 Loss: 0.302 epoch: 3 Loss: 0.282 @@ -53,7 +53,7 @@ set.seed(1) fit <- brulee_logistic_reg(y ~ ., df_imbal, epochs = 2, verbose = TRUE, class_weights = c(a = 12, b = 1), penalty = 0) - Message + Message epoch: 1 Loss: 0.113 epoch: 2 Loss: 0.113 x diff --git a/tests/testthat/_snaps/mlp-binary.md b/tests/testthat/_snaps/mlp-binary.md index 6e21240..7a097f6 100644 --- a/tests/testthat/_snaps/mlp-binary.md +++ b/tests/testthat/_snaps/mlp-binary.md @@ -13,114 +13,71 @@ dropout proportion: 0 batch size: 900 learn rate: 0.01 - validation loss after 10 epochs: 0.543 + validation loss after 10 epochs: 0.294 # class weights - mlp Code set.seed(1) fit_imbal <- brulee_mlp(y ~ ., df_imbal, verbose = TRUE, class_weights = 20) - Message - epoch: 1 learn rate 0.01 Loss: 0.578 - epoch: 2 learn rate 0.01 Loss: 0.574 - epoch: 3 learn rate 0.01 Loss: 0.571 - epoch: 4 learn rate 0.01 Loss: 0.568 - epoch: 5 learn rate 0.01 Loss: 0.565 - epoch: 6 learn rate 0.01 Loss: 0.563 - epoch: 7 learn rate 0.01 Loss: 0.56 - epoch: 8 learn rate 0.01 Loss: 0.557 - epoch: 9 learn rate 0.01 Loss: 0.554 - epoch: 10 learn rate 0.01 Loss: 0.552 - epoch: 11 learn rate 0.01 Loss: 0.549 - epoch: 12 learn rate 0.01 Loss: 0.547 - epoch: 13 learn rate 0.01 Loss: 0.544 - epoch: 14 learn rate 0.01 Loss: 0.542 - epoch: 15 learn rate 0.01 Loss: 0.54 - epoch: 16 learn rate 0.01 Loss: 0.537 - epoch: 17 learn rate 0.01 Loss: 0.535 - epoch: 18 learn rate 0.01 Loss: 0.533 - epoch: 19 learn rate 0.01 Loss: 0.531 - epoch: 20 learn rate 0.01 Loss: 0.529 - epoch: 21 learn rate 0.01 Loss: 0.527 - epoch: 22 learn rate 0.01 Loss: 0.525 - epoch: 23 learn rate 0.01 Loss: 0.523 - epoch: 24 learn rate 0.01 Loss: 0.521 - epoch: 25 learn rate 0.01 Loss: 0.519 - epoch: 26 learn rate 0.01 Loss: 0.517 - epoch: 27 learn rate 0.01 Loss: 0.515 - epoch: 28 learn rate 0.01 Loss: 0.513 - epoch: 29 learn rate 0.01 Loss: 0.511 - epoch: 30 learn rate 0.01 Loss: 0.51 - epoch: 31 learn rate 0.01 Loss: 0.508 - epoch: 32 learn rate 0.01 Loss: 0.506 - epoch: 33 learn rate 0.01 Loss: 0.505 - epoch: 34 learn rate 0.01 Loss: 0.503 - epoch: 35 learn rate 0.01 Loss: 0.502 - epoch: 36 learn rate 0.01 Loss: 0.5 - epoch: 37 learn rate 0.01 Loss: 0.499 - epoch: 38 learn rate 0.01 Loss: 0.497 - epoch: 39 learn rate 0.01 Loss: 0.496 - epoch: 40 learn rate 0.01 Loss: 0.494 - epoch: 41 learn rate 0.01 Loss: 0.493 - epoch: 42 learn rate 0.01 Loss: 0.491 - epoch: 43 learn rate 0.01 Loss: 0.49 - epoch: 44 learn rate 0.01 Loss: 0.489 - epoch: 45 learn rate 0.01 Loss: 0.487 - epoch: 46 learn rate 0.01 Loss: 0.486 - epoch: 47 learn rate 0.01 Loss: 0.485 - epoch: 48 learn rate 0.01 Loss: 0.484 - epoch: 49 learn rate 0.01 Loss: 0.482 - epoch: 50 learn rate 0.01 Loss: 0.481 - epoch: 51 learn rate 0.01 Loss: 0.48 - epoch: 52 learn rate 0.01 Loss: 0.479 - epoch: 53 learn rate 0.01 Loss: 0.478 - epoch: 54 learn rate 0.01 Loss: 0.477 - epoch: 55 learn rate 0.01 Loss: 0.475 - epoch: 56 learn rate 0.01 Loss: 0.474 - epoch: 57 learn rate 0.01 Loss: 0.473 - epoch: 58 learn rate 0.01 Loss: 0.472 - epoch: 59 learn rate 0.01 Loss: 0.471 - epoch: 60 learn rate 0.01 Loss: 0.47 - epoch: 61 learn rate 0.01 Loss: 0.469 - epoch: 62 learn rate 0.01 Loss: 0.468 - epoch: 63 learn rate 0.01 Loss: 0.467 - epoch: 64 learn rate 0.01 Loss: 0.466 - epoch: 65 learn rate 0.01 Loss: 0.465 - epoch: 66 learn rate 0.01 Loss: 0.464 - epoch: 67 learn rate 0.01 Loss: 0.463 - epoch: 68 learn rate 0.01 Loss: 0.462 - epoch: 69 learn rate 0.01 Loss: 0.461 - epoch: 70 learn rate 0.01 Loss: 0.461 - epoch: 71 learn rate 0.01 Loss: 0.46 - epoch: 72 learn rate 0.01 Loss: 0.459 - epoch: 73 learn rate 0.01 Loss: 0.458 - epoch: 74 learn rate 0.01 Loss: 0.457 - epoch: 75 learn rate 0.01 Loss: 0.456 - epoch: 76 learn rate 0.01 Loss: 0.455 - epoch: 77 learn rate 0.01 Loss: 0.455 - epoch: 78 learn rate 0.01 Loss: 0.454 - epoch: 79 learn rate 0.01 Loss: 0.453 - epoch: 80 learn rate 0.01 Loss: 0.452 - epoch: 81 learn rate 0.01 Loss: 0.451 - epoch: 82 learn rate 0.01 Loss: 0.451 - epoch: 83 learn rate 0.01 Loss: 0.45 - epoch: 84 learn rate 0.01 Loss: 0.449 - epoch: 85 learn rate 0.01 Loss: 0.448 - epoch: 86 learn rate 0.01 Loss: 0.448 - epoch: 87 learn rate 0.01 Loss: 0.447 - epoch: 88 learn rate 0.01 Loss: 0.446 - epoch: 89 learn rate 0.01 Loss: 0.446 - epoch: 90 learn rate 0.01 Loss: 0.445 - epoch: 91 learn rate 0.01 Loss: 0.444 - epoch: 92 learn rate 0.01 Loss: 0.443 - epoch: 93 learn rate 0.01 Loss: 0.443 - epoch: 94 learn rate 0.01 Loss: 0.442 - epoch: 95 learn rate 0.01 Loss: 0.441 - epoch: 96 learn rate 0.01 Loss: 0.441 - epoch: 97 learn rate 0.01 Loss: 0.44 - epoch: 98 learn rate 0.01 Loss: 0.44 - epoch: 99 learn rate 0.01 Loss: 0.439 - epoch: 100 learn rate 0.01 Loss: 0.438 + Message + epoch: 1 learn rate 0.01 Loss: 0.351 + epoch: 2 learn rate 0.01 Loss: 0.224 + epoch: 3 learn rate 0.01 Loss: 0.178 + epoch: 4 learn rate 0.01 Loss: 0.143 + epoch: 5 learn rate 0.01 Loss: 0.115 + epoch: 6 learn rate 0.01 Loss: 0.0973 + epoch: 7 learn rate 0.01 Loss: 0.0828 + epoch: 8 learn rate 0.01 Loss: 0.0717 + epoch: 9 learn rate 0.01 Loss: 0.063 + epoch: 10 learn rate 0.01 Loss: 0.0558 + epoch: 11 learn rate 0.01 Loss: 0.0506 + epoch: 12 learn rate 0.01 Loss: 0.0459 + epoch: 13 learn rate 0.01 Loss: 0.0422 + epoch: 14 learn rate 0.01 Loss: 0.0397 + epoch: 15 learn rate 0.01 Loss: 0.037 + epoch: 16 learn rate 0.01 Loss: 0.0352 + epoch: 17 learn rate 0.01 Loss: 0.0345 + epoch: 18 learn rate 0.01 Loss: 0.0334 + epoch: 19 learn rate 0.01 Loss: 0.033 + epoch: 20 learn rate 0.01 Loss: 0.0324 + epoch: 21 learn rate 0.01 Loss: 0.0324 + epoch: 22 learn rate 0.01 Loss: 0.0323 + epoch: 23 learn rate 0.01 Loss: 0.0323 + epoch: 24 learn rate 0.01 Loss: 0.0323 + epoch: 25 learn rate 0.01 Loss: 0.0323 + epoch: 26 learn rate 0.01 Loss: 0.0323 + epoch: 27 learn rate 0.01 Loss: 0.0323 + epoch: 28 learn rate 0.01 Loss: 0.0323 + epoch: 29 learn rate 0.01 Loss: 0.0323 + epoch: 30 learn rate 0.01 Loss: 0.0323 + epoch: 31 learn rate 0.01 Loss: 0.0322 + epoch: 32 learn rate 0.01 Loss: 0.0322 + epoch: 33 learn rate 0.01 Loss: 0.0322 + epoch: 34 learn rate 0.01 Loss: 0.0322 + epoch: 35 learn rate 0.01 Loss: 0.0322 + epoch: 36 learn rate 0.01 Loss: 0.0322 + epoch: 37 learn rate 0.01 Loss: 0.0322 + epoch: 38 learn rate 0.01 Loss: 0.0322 + epoch: 39 learn rate 0.01 Loss: 0.0322 + epoch: 40 learn rate 0.01 Loss: 0.0322 + epoch: 41 learn rate 0.01 Loss: 0.0322 + epoch: 42 learn rate 0.01 Loss: 0.0322 + epoch: 43 learn rate 0.01 Loss: 0.0322 + epoch: 44 learn rate 0.01 Loss: 0.0322 + epoch: 45 learn rate 0.01 Loss: 0.0322 + epoch: 46 learn rate 0.01 Loss: 0.0322 + epoch: 47 learn rate 0.01 Loss: 0.0322 + epoch: 48 learn rate 0.01 Loss: 0.0317 + epoch: 49 learn rate 0.01 Loss: 0.0313 + epoch: 50 learn rate 0.01 Loss: 0.0314 x + epoch: 51 learn rate 0.01 Loss: 0.031 + epoch: 52 learn rate 0.01 Loss: 0.0307 + epoch: 53 learn rate 0.01 Loss: 0.0307 x + epoch: 54 learn rate 0.01 Loss: 0.0307 x + epoch: 55 learn rate 0.01 Loss: 0.0307 x + epoch: 56 learn rate 0.01 Loss: 0.0307 x + epoch: 57 learn rate 0.01 Loss: 0.0308 x --- @@ -128,7 +85,7 @@ set.seed(1) fit <- brulee_mlp(y ~ ., df_imbal, epochs = 2, verbose = TRUE, class_weights = c( a = 12, b = 1)) - Message - epoch: 1 learn rate 0.01 Loss: 0.666 - epoch: 2 learn rate 0.01 Loss: 0.664 + Message + epoch: 1 learn rate 0.01 Loss: 0.372 + epoch: 2 learn rate 0.01 Loss: 0.26 diff --git a/tests/testthat/_snaps/mlp-multinomial.md b/tests/testthat/_snaps/mlp-multinomial.md index 65923df..3be36ec 100644 --- a/tests/testthat/_snaps/mlp-multinomial.md +++ b/tests/testthat/_snaps/mlp-multinomial.md @@ -3,9 +3,9 @@ Code set.seed(1) fit <- brulee_mlp(y ~ ., df, epochs = 2, verbose = TRUE) - Message - epoch: 1 learn rate 0.01 Loss: 1.14 - epoch: 2 learn rate 0.01 Loss: 1.14 + Message + epoch: 1 learn rate 0.01 Loss: 1.08 + epoch: 2 learn rate 0.01 Loss: 1.05 --- @@ -22,114 +22,41 @@ dropout proportion: 0 batch size: 9000 learn rate: 0.01 - validation loss after 2 epochs: 1.14 + validation loss after 2 epochs: 1.05 # class weights - mlp Code set.seed(1) fit_imbal <- brulee_mlp(y ~ ., df, verbose = TRUE, class_weights = 20) - Message - epoch: 1 learn rate 0.01 Loss: 1.36 - epoch: 2 learn rate 0.01 Loss: 1.35 - epoch: 3 learn rate 0.01 Loss: 1.35 - epoch: 4 learn rate 0.01 Loss: 1.34 - epoch: 5 learn rate 0.01 Loss: 1.33 - epoch: 6 learn rate 0.01 Loss: 1.32 - epoch: 7 learn rate 0.01 Loss: 1.31 - epoch: 8 learn rate 0.01 Loss: 1.3 - epoch: 9 learn rate 0.01 Loss: 1.29 - epoch: 10 learn rate 0.01 Loss: 1.29 - epoch: 11 learn rate 0.01 Loss: 1.28 - epoch: 12 learn rate 0.01 Loss: 1.27 - epoch: 13 learn rate 0.01 Loss: 1.26 - epoch: 14 learn rate 0.01 Loss: 1.25 - epoch: 15 learn rate 0.01 Loss: 1.25 - epoch: 16 learn rate 0.01 Loss: 1.24 - epoch: 17 learn rate 0.01 Loss: 1.23 - epoch: 18 learn rate 0.01 Loss: 1.22 - epoch: 19 learn rate 0.01 Loss: 1.22 - epoch: 20 learn rate 0.01 Loss: 1.21 - epoch: 21 learn rate 0.01 Loss: 1.2 - epoch: 22 learn rate 0.01 Loss: 1.2 - epoch: 23 learn rate 0.01 Loss: 1.19 - epoch: 24 learn rate 0.01 Loss: 1.18 - epoch: 25 learn rate 0.01 Loss: 1.17 - epoch: 26 learn rate 0.01 Loss: 1.17 - epoch: 27 learn rate 0.01 Loss: 1.16 - epoch: 28 learn rate 0.01 Loss: 1.15 - epoch: 29 learn rate 0.01 Loss: 1.15 - epoch: 30 learn rate 0.01 Loss: 1.14 - epoch: 31 learn rate 0.01 Loss: 1.13 - epoch: 32 learn rate 0.01 Loss: 1.13 - epoch: 33 learn rate 0.01 Loss: 1.12 - epoch: 34 learn rate 0.01 Loss: 1.12 - epoch: 35 learn rate 0.01 Loss: 1.11 - epoch: 36 learn rate 0.01 Loss: 1.1 - epoch: 37 learn rate 0.01 Loss: 1.1 - epoch: 38 learn rate 0.01 Loss: 1.09 - epoch: 39 learn rate 0.01 Loss: 1.09 - epoch: 40 learn rate 0.01 Loss: 1.08 - epoch: 41 learn rate 0.01 Loss: 1.07 - epoch: 42 learn rate 0.01 Loss: 1.07 - epoch: 43 learn rate 0.01 Loss: 1.06 - epoch: 44 learn rate 0.01 Loss: 1.06 - epoch: 45 learn rate 0.01 Loss: 1.05 - epoch: 46 learn rate 0.01 Loss: 1.05 - epoch: 47 learn rate 0.01 Loss: 1.04 - epoch: 48 learn rate 0.01 Loss: 1.03 - epoch: 49 learn rate 0.01 Loss: 1.03 - epoch: 50 learn rate 0.01 Loss: 1.02 - epoch: 51 learn rate 0.01 Loss: 1.02 - epoch: 52 learn rate 0.01 Loss: 1.01 - epoch: 53 learn rate 0.01 Loss: 1.01 - epoch: 54 learn rate 0.01 Loss: 1 - epoch: 55 learn rate 0.01 Loss: 0.998 - epoch: 56 learn rate 0.01 Loss: 0.993 - epoch: 57 learn rate 0.01 Loss: 0.988 - epoch: 58 learn rate 0.01 Loss: 0.983 - epoch: 59 learn rate 0.01 Loss: 0.978 - epoch: 60 learn rate 0.01 Loss: 0.974 - epoch: 61 learn rate 0.01 Loss: 0.969 - epoch: 62 learn rate 0.01 Loss: 0.964 - epoch: 63 learn rate 0.01 Loss: 0.959 - epoch: 64 learn rate 0.01 Loss: 0.955 - epoch: 65 learn rate 0.01 Loss: 0.95 - epoch: 66 learn rate 0.01 Loss: 0.946 - epoch: 67 learn rate 0.01 Loss: 0.941 - epoch: 68 learn rate 0.01 Loss: 0.937 - epoch: 69 learn rate 0.01 Loss: 0.932 - epoch: 70 learn rate 0.01 Loss: 0.928 - epoch: 71 learn rate 0.01 Loss: 0.924 - epoch: 72 learn rate 0.01 Loss: 0.919 - epoch: 73 learn rate 0.01 Loss: 0.915 - epoch: 74 learn rate 0.01 Loss: 0.911 - epoch: 75 learn rate 0.01 Loss: 0.907 - epoch: 76 learn rate 0.01 Loss: 0.903 - epoch: 77 learn rate 0.01 Loss: 0.899 - epoch: 78 learn rate 0.01 Loss: 0.895 - epoch: 79 learn rate 0.01 Loss: 0.891 - epoch: 80 learn rate 0.01 Loss: 0.887 - epoch: 81 learn rate 0.01 Loss: 0.883 - epoch: 82 learn rate 0.01 Loss: 0.879 - epoch: 83 learn rate 0.01 Loss: 0.875 - epoch: 84 learn rate 0.01 Loss: 0.871 - epoch: 85 learn rate 0.01 Loss: 0.867 - epoch: 86 learn rate 0.01 Loss: 0.864 - epoch: 87 learn rate 0.01 Loss: 0.86 - epoch: 88 learn rate 0.01 Loss: 0.856 - epoch: 89 learn rate 0.01 Loss: 0.853 - epoch: 90 learn rate 0.01 Loss: 0.849 - epoch: 91 learn rate 0.01 Loss: 0.846 - epoch: 92 learn rate 0.01 Loss: 0.842 - epoch: 93 learn rate 0.01 Loss: 0.839 - epoch: 94 learn rate 0.01 Loss: 0.835 - epoch: 95 learn rate 0.01 Loss: 0.832 - epoch: 96 learn rate 0.01 Loss: 0.828 - epoch: 97 learn rate 0.01 Loss: 0.825 - epoch: 98 learn rate 0.01 Loss: 0.822 - epoch: 99 learn rate 0.01 Loss: 0.818 - epoch: 100 learn rate 0.01 Loss: 0.815 + Message + epoch: 1 learn rate 0.01 Loss: 0.675 + epoch: 2 learn rate 0.01 Loss: 0.508 + epoch: 3 learn rate 0.01 Loss: 0.467 + epoch: 4 learn rate 0.01 Loss: 0.453 + epoch: 5 learn rate 0.01 Loss: 0.444 + epoch: 6 learn rate 0.01 Loss: 0.438 + epoch: 7 learn rate 0.01 Loss: 0.434 + epoch: 8 learn rate 0.01 Loss: 0.432 + epoch: 9 learn rate 0.01 Loss: 0.43 + epoch: 10 learn rate 0.01 Loss: 0.429 + epoch: 11 learn rate 0.01 Loss: 0.428 + epoch: 12 learn rate 0.01 Loss: 0.427 + epoch: 13 learn rate 0.01 Loss: 0.427 + epoch: 14 learn rate 0.01 Loss: 0.427 + epoch: 15 learn rate 0.01 Loss: 0.427 + epoch: 16 learn rate 0.01 Loss: 0.427 + epoch: 17 learn rate 0.01 Loss: 0.427 + epoch: 18 learn rate 0.01 Loss: 0.427 + epoch: 19 learn rate 0.01 Loss: 0.427 + epoch: 20 learn rate 0.01 Loss: 0.427 + epoch: 21 learn rate 0.01 Loss: 0.427 + epoch: 22 learn rate 0.01 Loss: 0.426 + epoch: 23 learn rate 0.01 Loss: 0.426 x + epoch: 24 learn rate 0.01 Loss: 0.426 x + epoch: 25 learn rate 0.01 Loss: 0.426 x + epoch: 26 learn rate 0.01 Loss: 0.426 x + epoch: 27 learn rate 0.01 Loss: 0.426 x --- @@ -137,7 +64,7 @@ set.seed(1) fit <- brulee_mlp(y ~ ., df, epochs = 2, verbose = TRUE, class_weights = c(a = 12, b = 1, c = 1)) - Message - epoch: 1 learn rate 0.01 Loss: 0.728 - epoch: 2 learn rate 0.01 Loss: 0.725 + Message + epoch: 1 learn rate 0.01 Loss: 0.479 + epoch: 2 learn rate 0.01 Loss: 0.435 diff --git a/tests/testthat/_snaps/mlp-regression.md b/tests/testthat/_snaps/mlp-regression.md index 496ffce..8a51b5c 100644 --- a/tests/testthat/_snaps/mlp-regression.md +++ b/tests/testthat/_snaps/mlp-regression.md @@ -12,5 +12,5 @@ dropout proportion: 0 batch size: 2637 learn rate: 0.01 - scaled validation loss after 10 epochs: 1.31 + scaled validation loss after 10 epochs: 1.1 diff --git a/tests/testthat/_snaps/multinomial_reg-fit.md b/tests/testthat/_snaps/multinomial_reg-fit.md index b259a77..bf1817b 100644 --- a/tests/testthat/_snaps/multinomial_reg-fit.md +++ b/tests/testthat/_snaps/multinomial_reg-fit.md @@ -3,7 +3,7 @@ Code set.seed(1) fit <- brulee_multinomial_reg(y ~ ., df, epochs = 2, verbose = TRUE, penalty = 0) - Message + Message epoch: 1 Loss: 1.02 epoch: 2 Loss: 1.02 @@ -25,7 +25,7 @@ set.seed(1) fit_imbal <- brulee_multinomial_reg(y ~ ., df, verbose = TRUE, class_weights = 20, optimizer = "SGD") - Message + Message epoch: 1 Loss: 0.921 epoch: 2 Loss: 0.593 epoch: 3 Loss: 0.501 @@ -53,7 +53,7 @@ set.seed(1) fit <- brulee_multinomial_reg(y ~ ., df, epochs = 2, verbose = TRUE, class_weights = c(a = 12, b = 1, c = 1), penalty = 0) - Message + Message epoch: 1 Loss: 0.384 epoch: 2 Loss: 0.384