Skip to content

Commit

Permalink
Merge pull request #281 from tidymodels/toggle-sparsity
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilHvitfeldt authored Jan 17, 2025
2 parents 2b31e8f + 5d1c4f4 commit d963de9
Show file tree
Hide file tree
Showing 4 changed files with 250 additions and 2 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ Imports:
lifecycle (>= 1.0.3),
modelenv (>= 0.1.0),
parsnip (>= 1.2.1.9000),
recipes (>= 1.0.10.9000),
rlang (>= 1.1.0),
tidyselect (>= 1.2.0),
sparsevctrs (>= 0.1.0.9002),
sparsevctrs (>= 0.1.0.9003),
vctrs (>= 0.4.1),
withr
Suggests:
Expand All @@ -42,7 +43,6 @@ Suggests:
methods,
modeldata (>= 1.0.0),
probably,
recipes (>= 1.0.10.9000),
rmarkdown,
testthat (>= 3.0.0)
VignetteBuilder:
Expand Down
2 changes: 2 additions & 0 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ fit.workflow <- function(object, data, ..., calibration = NULL, control = contro
)
}

object <- toggle_sparsity(object, data)

workflow <- object
workflow <- .fit_pre(workflow, data)
workflow <- .fit_model(workflow, control)
Expand Down
96 changes: 96 additions & 0 deletions R/sparsevctrs.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,99 @@
is_sparse_matrix <- function(x) {
methods::is(x, "sparseMatrix")
}

# This function takes a workflow and its data. If the model supports sparse data
# And there is a recipe, then it uses `should_use_sparsity()` to determine
# whether all the `sparse = "auto"` should be turned to `"yes"` or `"no"` in the
# recipe.
#
# Done using flow chart in https://github.com/tidymodels/workflows/issues/271
toggle_sparsity <- function(object, data) {
if (
allow_sparse(object$fit$actions$model$spec) &&
has_preprocessor_recipe(object)
) {
est_sparsity <- recipes::.recipes_estimate_sparsity(
extract_preprocessor(object)
)

toggle_sparse <- should_use_sparsity(
est_sparsity,
extract_spec_parsnip(object)$engine,
nrow(data)
)

object$pre$actions$recipe$recipe <- recipes::.recipes_toggle_sparse_args(
object$pre$actions$recipe$recipe,
choice = toggle_sparse
)
}

object
}

allow_sparse <- function(x) {
if (inherits(x, "model_fit")) {
x <- x$spec
}
res <- parsnip::get_from_env(paste0(class(x)[1], "_encoding"))
all(res$allow_sparse_x[res$engine == x$engine])
}

# This function was created using from the output of a mars model fit on the
# simulation data generated in `analysis/time_analysis.R`
# https://github.com/tidymodels/benchmark-sparsity-threshold
#
# The model was extracted using {tidypredict} and hand-tuned for speed.
#
# The model was fit on `sparsity`, `engine` and `n_rows` and the outcome was
# `log_fold` which is defined as
# `log(time to fit with dense data / time to fit with sparse data)`.
# Meaning that values above above 0 would reflects longer fit times for dense,
# Hence we want to use sparse data.
#
# At this time the only engines that support sparse data are glmnet, LiblineaR,
# ranger, and xgboost. Which is why they are the only ones listed here.
# This is fine as this code will only run if `allow_sparse()` returns `TRUE`
# Which only happens for these engines.
#
# Ranger is hard-coded to always fail since they appear to use the same
# algorithm for sparse and dense data, resulting in identical times.
should_use_sparsity <- function(sparsity, engine, n_rows) {
if (is.null(engine) || engine == "ranger") {
return("no")
}

log_fold <- -0.599333138645995 +
ifelse(sparsity < 0.836601307189543, 0.836601307189543 - sparsity, 0) *
-0.541581853008009 +
ifelse(n_rows < 16000, 16000 - n_rows, 0) * 3.23980908942813e-05 +
ifelse(n_rows > 16000, n_rows - 16000, 0) * -2.81001152147355e-06 +
ifelse(sparsity > 0.836601307189543, sparsity - 0.836601307189543, 0) *
9.82444255114058 +
ifelse(sparsity > 0.836601307189543, sparsity - 0.836601307189543, 0) *
ifelse(n_rows > 8000, n_rows - 8000, 0) *
7.27456967763306e-05 +
ifelse(sparsity > 0.836601307189543, sparsity - 0.836601307189543, 0) *
ifelse(n_rows < 8000, 8000 - n_rows, 0) *
-0.000798307404212627

if (engine == "xgboost") {
log_fold <- log_fold +
ifelse(sparsity < 0.984615384615385, 0.984615384615385 - sparsity, 0) *
0.113098025073806 +
ifelse(n_rows < 8000, 8000 - n_rows, 0) * -9.77914237255269e-05 +
ifelse(n_rows > 8000, n_rows - 8000, 0) * 3.22657666511869e-06 +
ifelse(sparsity > 0.984615384615385, sparsity - 0.984615384615385, 0) *
41.5180348086939 +
0.913457808326756
}

if (engine == "LiblineaR") {
log_fold <- log_fold +
ifelse(sparsity > 0.836601307189543, sparsity - 0.836601307189543, 0) *
-5.39592564852111
}

ifelse(log_fold > 0, "yes", "no")
}
150 changes: 150 additions & 0 deletions tests/testthat/test-sparsevctrs.R
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,153 @@ test_that("fit() errors if sparse matrix has no colnames", {
fit(wf_spec, hotel_data)
)
})

test_that("toggle_sparsity changes auto to yes", {
skip_if_not_installed("glmnet")
skip_if_not_installed("modeldata")

data("ames", package = "modeldata")
fcts <- c(
1L, 2L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L, 16L,
17L, 20L, 21L, 22L, 23L, 24L, 26L, 27L, 28L, 29L, 30L, 32L, 36L,
37L, 38L, 39L, 50L, 52L, 53L, 56L, 57L, 64L, 65L, 66L, 70L, 71L
)
outcome <- 72

ames <- ames[c(fcts, outcome)]
ames <- ames[1:100, ]

tree_spec <- parsnip::linear_reg("regression", "glmnet", penalty = 0)

rec_spec <- recipes::recipe(Sale_Price ~ ., data = ames) %>%
recipes::step_dummy(recipes::all_nominal_predictors())

wf_spec <- workflow(rec_spec, tree_spec)

res <- toggle_sparsity(wf_spec, ames)

expect_identical(
extract_preprocessor(res)$steps[[1]]$sparse,
"yes"
)
})

test_that("toggle_sparsity doesn't change no", {
skip_if_not_installed("glmnet")
skip_if_not_installed("modeldata")

data("ames", package = "modeldata")
fcts <- c(
1L, 2L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L, 16L,
17L, 20L, 21L, 22L, 23L, 24L, 26L, 27L, 28L, 29L, 30L, 32L, 36L,
37L, 38L, 39L, 50L, 52L, 53L, 56L, 57L, 64L, 65L, 66L, 70L, 71L
)
outcome <- 72

ames <- ames[c(fcts, outcome)]
ames <- ames[1:100, ]

tree_spec <- parsnip::linear_reg("regression", "glmnet", penalty = 0)

rec_spec <- recipes::recipe(Sale_Price ~ ., data = ames) %>%
recipes::step_dummy(recipes::all_nominal_predictors(), sparse = "no")

wf_spec <- workflow(rec_spec, tree_spec)

res <- toggle_sparsity(wf_spec, ames)

expect_identical(
extract_preprocessor(res)$steps[[1]]$sparse,
"no"
)
})

test_that("toggle_sparsity changes auto to no", {
skip_if_not_installed("glmnet")
skip_if_not_installed("modeldata")

data("ames", package = "modeldata")
fcts <- c(
1L, 2L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L, 16L,
17L, 20L, 21L, 22L, 23L, 24L, 26L, 27L, 28L, 29L, 30L, 32L, 36L,
37L, 38L, 39L, 50L, 52L, 53L, 56L, 57L, 64L, 65L, 66L, 70L, 71L
)
outcome <- 72

ames <- ames[c(fcts, outcome)]
ames <- ames[1:100, ]

tree_spec <- parsnip::linear_reg("regression", "glmnet", penalty = 0)

# if we only dummy 1 variable it doesn't make the data sparse enough
rec_spec <- recipes::recipe(Sale_Price ~ ., data = ames) %>%
recipes::step_dummy(MS_Zoning)

wf_spec <- workflow(rec_spec, tree_spec)

res <- toggle_sparsity(wf_spec, ames)

expect_identical(
extract_preprocessor(res)$steps[[1]]$sparse,
"no"
)
})

test_that("toggle_sparsity doesn't change yes", {
skip_if_not_installed("glmnet")
skip_if_not_installed("modeldata")

data("ames", package = "modeldata")
fcts <- c(
1L, 2L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L, 16L,
17L, 20L, 21L, 22L, 23L, 24L, 26L, 27L, 28L, 29L, 30L, 32L, 36L,
37L, 38L, 39L, 50L, 52L, 53L, 56L, 57L, 64L, 65L, 66L, 70L, 71L
)
outcome <- 72

ames <- ames[c(fcts, outcome)]
ames <- ames[1:100, ]

tree_spec <- parsnip::linear_reg("regression", "glmnet", penalty = 0)

# if we only dummy 1 variable it doesn't make the data sparse enough
rec_spec <- recipes::recipe(Sale_Price ~ ., data = ames) %>%
recipes::step_dummy(MS_Zoning, sparse = "yes")

wf_spec <- workflow(rec_spec, tree_spec)

res <- toggle_sparsity(wf_spec, ames)

expect_identical(
extract_preprocessor(res)$steps[[1]]$sparse,
"yes"
)
})

test_that("toggle_sparsity doesn't break fit", {
skip_if_not_installed("glmnet")
skip_if_not_installed("modeldata")

data("ames", package = "modeldata")
fcts <- c(
1L, 2L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L, 16L,
17L, 20L, 21L, 22L, 23L, 24L, 26L, 27L, 28L, 29L, 30L, 32L, 36L,
37L, 38L, 39L, 50L, 52L, 53L, 56L, 57L, 64L, 65L, 66L, 70L, 71L
)
outcome <- 72

ames <- ames[c(fcts, outcome)]
ames <- ames[1:100, ]

tree_spec <- parsnip::linear_reg("regression", "glmnet", penalty = 0)

rec_spec <- recipes::recipe(Sale_Price ~ ., data = ames) %>%
recipes::step_dummy(recipes::all_nominal_predictors())

wf_spec <- workflow(rec_spec, tree_spec)

expect_no_error(
fit(wf_spec, ames)
)
})

0 comments on commit d963de9

Please sign in to comment.