diff --git a/DESCRIPTION b/DESCRIPTION index 181d507..abff23e 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -25,9 +25,10 @@ Imports: lifecycle (>= 1.0.3), modelenv (>= 0.1.0), parsnip (>= 1.2.1.9000), + recipes (>= 1.0.10.9000), rlang (>= 1.1.0), tidyselect (>= 1.2.0), - sparsevctrs (>= 0.1.0.9002), + sparsevctrs (>= 0.1.0.9003), vctrs (>= 0.4.1), withr Suggests: @@ -42,7 +43,6 @@ Suggests: methods, modeldata (>= 1.0.0), probably, - recipes (>= 1.0.10.9000), rmarkdown, testthat (>= 3.0.0) VignetteBuilder: diff --git a/R/fit.R b/R/fit.R index 13d1315..8c42940 100644 --- a/R/fit.R +++ b/R/fit.R @@ -71,6 +71,8 @@ fit.workflow <- function(object, data, ..., calibration = NULL, control = contro ) } + object <- toggle_sparsity(object, data) + workflow <- object workflow <- .fit_pre(workflow, data) workflow <- .fit_model(workflow, control) diff --git a/R/sparsevctrs.R b/R/sparsevctrs.R index bbb865a..2111dea 100644 --- a/R/sparsevctrs.R +++ b/R/sparsevctrs.R @@ -1,3 +1,99 @@ is_sparse_matrix <- function(x) { methods::is(x, "sparseMatrix") } + +# This function takes a workflow and its data. If the model supports sparse data +# And there is a recipe, then it uses `should_use_sparsity()` to determine +# whether all the `sparse = "auto"` should be turned to `"yes"` or `"no"` in the +# recipe. +# +# Done using flow chart in https://github.com/tidymodels/workflows/issues/271 +toggle_sparsity <- function(object, data) { + if ( + allow_sparse(object$fit$actions$model$spec) && + has_preprocessor_recipe(object) + ) { + est_sparsity <- recipes::.recipes_estimate_sparsity( + extract_preprocessor(object) + ) + + toggle_sparse <- should_use_sparsity( + est_sparsity, + extract_spec_parsnip(object)$engine, + nrow(data) + ) + + object$pre$actions$recipe$recipe <- recipes::.recipes_toggle_sparse_args( + object$pre$actions$recipe$recipe, + choice = toggle_sparse + ) + } + + object +} + +allow_sparse <- function(x) { + if (inherits(x, "model_fit")) { + x <- x$spec + } + res <- parsnip::get_from_env(paste0(class(x)[1], "_encoding")) + all(res$allow_sparse_x[res$engine == x$engine]) +} + +# This function was created using from the output of a mars model fit on the +# simulation data generated in `analysis/time_analysis.R` +# https://github.com/tidymodels/benchmark-sparsity-threshold +# +# The model was extracted using {tidypredict} and hand-tuned for speed. +# +# The model was fit on `sparsity`, `engine` and `n_rows` and the outcome was +# `log_fold` which is defined as +# `log(time to fit with dense data / time to fit with sparse data)`. +# Meaning that values above above 0 would reflects longer fit times for dense, +# Hence we want to use sparse data. +# +# At this time the only engines that support sparse data are glmnet, LiblineaR, +# ranger, and xgboost. Which is why they are the only ones listed here. +# This is fine as this code will only run if `allow_sparse()` returns `TRUE` +# Which only happens for these engines. +# +# Ranger is hard-coded to always fail since they appear to use the same +# algorithm for sparse and dense data, resulting in identical times. +should_use_sparsity <- function(sparsity, engine, n_rows) { + if (is.null(engine) || engine == "ranger") { + return("no") + } + + log_fold <- -0.599333138645995 + + ifelse(sparsity < 0.836601307189543, 0.836601307189543 - sparsity, 0) * + -0.541581853008009 + + ifelse(n_rows < 16000, 16000 - n_rows, 0) * 3.23980908942813e-05 + + ifelse(n_rows > 16000, n_rows - 16000, 0) * -2.81001152147355e-06 + + ifelse(sparsity > 0.836601307189543, sparsity - 0.836601307189543, 0) * + 9.82444255114058 + + ifelse(sparsity > 0.836601307189543, sparsity - 0.836601307189543, 0) * + ifelse(n_rows > 8000, n_rows - 8000, 0) * + 7.27456967763306e-05 + + ifelse(sparsity > 0.836601307189543, sparsity - 0.836601307189543, 0) * + ifelse(n_rows < 8000, 8000 - n_rows, 0) * + -0.000798307404212627 + + if (engine == "xgboost") { + log_fold <- log_fold + + ifelse(sparsity < 0.984615384615385, 0.984615384615385 - sparsity, 0) * + 0.113098025073806 + + ifelse(n_rows < 8000, 8000 - n_rows, 0) * -9.77914237255269e-05 + + ifelse(n_rows > 8000, n_rows - 8000, 0) * 3.22657666511869e-06 + + ifelse(sparsity > 0.984615384615385, sparsity - 0.984615384615385, 0) * + 41.5180348086939 + + 0.913457808326756 + } + + if (engine == "LiblineaR") { + log_fold <- log_fold + + ifelse(sparsity > 0.836601307189543, sparsity - 0.836601307189543, 0) * + -5.39592564852111 + } + + ifelse(log_fold > 0, "yes", "no") +} diff --git a/tests/testthat/test-sparsevctrs.R b/tests/testthat/test-sparsevctrs.R index 1116118..c627fbb 100644 --- a/tests/testthat/test-sparsevctrs.R +++ b/tests/testthat/test-sparsevctrs.R @@ -191,3 +191,153 @@ test_that("fit() errors if sparse matrix has no colnames", { fit(wf_spec, hotel_data) ) }) + +test_that("toggle_sparsity changes auto to yes", { + skip_if_not_installed("glmnet") + skip_if_not_installed("modeldata") + + data("ames", package = "modeldata") + fcts <- c( + 1L, 2L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L, 16L, + 17L, 20L, 21L, 22L, 23L, 24L, 26L, 27L, 28L, 29L, 30L, 32L, 36L, + 37L, 38L, 39L, 50L, 52L, 53L, 56L, 57L, 64L, 65L, 66L, 70L, 71L + ) + outcome <- 72 + + ames <- ames[c(fcts, outcome)] + ames <- ames[1:100, ] + + tree_spec <- parsnip::linear_reg("regression", "glmnet", penalty = 0) + + rec_spec <- recipes::recipe(Sale_Price ~ ., data = ames) %>% + recipes::step_dummy(recipes::all_nominal_predictors()) + + wf_spec <- workflow(rec_spec, tree_spec) + + res <- toggle_sparsity(wf_spec, ames) + + expect_identical( + extract_preprocessor(res)$steps[[1]]$sparse, + "yes" + ) +}) + +test_that("toggle_sparsity doesn't change no", { + skip_if_not_installed("glmnet") + skip_if_not_installed("modeldata") + + data("ames", package = "modeldata") + fcts <- c( + 1L, 2L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L, 16L, + 17L, 20L, 21L, 22L, 23L, 24L, 26L, 27L, 28L, 29L, 30L, 32L, 36L, + 37L, 38L, 39L, 50L, 52L, 53L, 56L, 57L, 64L, 65L, 66L, 70L, 71L + ) + outcome <- 72 + + ames <- ames[c(fcts, outcome)] + ames <- ames[1:100, ] + + tree_spec <- parsnip::linear_reg("regression", "glmnet", penalty = 0) + + rec_spec <- recipes::recipe(Sale_Price ~ ., data = ames) %>% + recipes::step_dummy(recipes::all_nominal_predictors(), sparse = "no") + + wf_spec <- workflow(rec_spec, tree_spec) + + res <- toggle_sparsity(wf_spec, ames) + + expect_identical( + extract_preprocessor(res)$steps[[1]]$sparse, + "no" + ) +}) + +test_that("toggle_sparsity changes auto to no", { + skip_if_not_installed("glmnet") + skip_if_not_installed("modeldata") + + data("ames", package = "modeldata") + fcts <- c( + 1L, 2L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L, 16L, + 17L, 20L, 21L, 22L, 23L, 24L, 26L, 27L, 28L, 29L, 30L, 32L, 36L, + 37L, 38L, 39L, 50L, 52L, 53L, 56L, 57L, 64L, 65L, 66L, 70L, 71L + ) + outcome <- 72 + + ames <- ames[c(fcts, outcome)] + ames <- ames[1:100, ] + + tree_spec <- parsnip::linear_reg("regression", "glmnet", penalty = 0) + + # if we only dummy 1 variable it doesn't make the data sparse enough + rec_spec <- recipes::recipe(Sale_Price ~ ., data = ames) %>% + recipes::step_dummy(MS_Zoning) + + wf_spec <- workflow(rec_spec, tree_spec) + + res <- toggle_sparsity(wf_spec, ames) + + expect_identical( + extract_preprocessor(res)$steps[[1]]$sparse, + "no" + ) +}) + +test_that("toggle_sparsity doesn't change yes", { + skip_if_not_installed("glmnet") + skip_if_not_installed("modeldata") + + data("ames", package = "modeldata") + fcts <- c( + 1L, 2L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L, 16L, + 17L, 20L, 21L, 22L, 23L, 24L, 26L, 27L, 28L, 29L, 30L, 32L, 36L, + 37L, 38L, 39L, 50L, 52L, 53L, 56L, 57L, 64L, 65L, 66L, 70L, 71L + ) + outcome <- 72 + + ames <- ames[c(fcts, outcome)] + ames <- ames[1:100, ] + + tree_spec <- parsnip::linear_reg("regression", "glmnet", penalty = 0) + + # if we only dummy 1 variable it doesn't make the data sparse enough + rec_spec <- recipes::recipe(Sale_Price ~ ., data = ames) %>% + recipes::step_dummy(MS_Zoning, sparse = "yes") + + wf_spec <- workflow(rec_spec, tree_spec) + + res <- toggle_sparsity(wf_spec, ames) + + expect_identical( + extract_preprocessor(res)$steps[[1]]$sparse, + "yes" + ) +}) + +test_that("toggle_sparsity doesn't break fit", { + skip_if_not_installed("glmnet") + skip_if_not_installed("modeldata") + + data("ames", package = "modeldata") + fcts <- c( + 1L, 2L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L, 16L, + 17L, 20L, 21L, 22L, 23L, 24L, 26L, 27L, 28L, 29L, 30L, 32L, 36L, + 37L, 38L, 39L, 50L, 52L, 53L, 56L, 57L, 64L, 65L, 66L, 70L, 71L + ) + outcome <- 72 + + ames <- ames[c(fcts, outcome)] + ames <- ames[1:100, ] + + tree_spec <- parsnip::linear_reg("regression", "glmnet", penalty = 0) + + rec_spec <- recipes::recipe(Sale_Price ~ ., data = ames) %>% + recipes::step_dummy(recipes::all_nominal_predictors()) + + wf_spec <- workflow(rec_spec, tree_spec) + + expect_no_error( + fit(wf_spec, ames) + ) +}) +