diff --git a/R/tailor.R b/R/tailor.R index 9abc314..bb69383 100644 --- a/R/tailor.R +++ b/R/tailor.R @@ -19,9 +19,6 @@ #' with the [tidymodels](https://tidymodels.org) framework; for greatest ease #' of use, situate tailors in model workflows with `?workflows::add_tailor()`. #' -#' @param type Character. The model sub-mode. Possible values are -#' `"unknown"`, `"regression"`, `"binary"`, or `"multiclass"`. Only required -#' when used independently of `?workflows::add_tailor()`. #' @param outcome <[`tidy-select`][dplyr::dplyr_tidy_select]> Only required #' when used independently of `?workflows::add_tailor()`, and can also be passed #' at `fit()` time instead. The column name of the outcome variable. @@ -64,18 +61,16 @@ #' # adjust hard class predictions #' predict(tlr_fit, two_class_example) %>% count(predicted) #' @export -tailor <- function(type = "unknown", outcome = NULL, estimate = NULL, - probabilities = NULL) { +tailor <- function(outcome = NULL, estimate = NULL, probabilities = NULL) { columns <- list( outcome = outcome, - type = type, estimate = estimate, probabilities = probabilities ) new_tailor( - type, + "unknown", adjustments = list(), columns = columns, ptype = tibble::new_tibble(list()), @@ -84,8 +79,6 @@ tailor <- function(type = "unknown", outcome = NULL, estimate = NULL, } new_tailor <- function(type, adjustments, columns, ptype, call) { - type <- arg_match0(type, c("unknown", "regression", "binary", "multiclass")) - if (!is.list(adjustments)) { cli_abort("The {.arg adjustments} argument should be a list.", call = call) } @@ -97,8 +90,14 @@ new_tailor <- function(type, adjustments, columns, ptype, call) { {.val adjustment}: {bad_adjustment}.", call = call) } + orderings <- adjustment_orderings(adjustments) + + if (type == "unknown") { + type <- infer_type(orderings) + } + # validate adjustment order and check duplicates - validate_order(adjustments, type, call) + validate_order(orderings, type, call) # check columns res <- list( @@ -233,5 +232,5 @@ set_tailor_type <- function(object, y) { # todo setup eval_time # todo missing methods: # todo tune_args -# todo tidy +# todo tidy (this should probably just be `adjustment_orderings()`) # todo extract_parameter_set_dials diff --git a/R/utils.R b/R/utils.R index 4207cb5..483b4a0 100644 --- a/R/utils.R +++ b/R/utils.R @@ -88,6 +88,19 @@ tailor_adjustment_requires_fit <- function(x) { isTRUE(x$requires_fit) } +# an tidy-esque method for adjustment lists, used in validating +# compatibility of adjustments +adjustment_orderings <- function(adjustments) { + tibble::new_tibble(list( + name = purrr::map_chr(adjustments, ~ class(.x)[1]), + input = purrr::map_chr(adjustments, ~ .x$inputs), + output_numeric = purrr::map_lgl(adjustments, ~ grepl("numeric", .x$outputs)), + output_prob = purrr::map_lgl(adjustments, ~ grepl("probability", .x$outputs)), + output_class = purrr::map_lgl(adjustments, ~ grepl("class", .x$outputs)), + output_all = purrr::map_lgl(adjustments, ~ grepl("everything", .x$outputs)) + )) +} + # ad-hoc checking -------------------------------------------------------------- check_tailor <- function(x, calibration_type = NULL, call = caller_env(), arg = caller_arg(x)) { if (!is_tailor(x)) { diff --git a/R/validation-rules.R b/R/validation-rules.R index fac8129..022f224 100644 --- a/R/validation-rules.R +++ b/R/validation-rules.R @@ -1,21 +1,9 @@ -validate_order <- function(adjustments, type, call = caller_env()) { - orderings <- - tibble::new_tibble(list( - name = purrr::map_chr(adjustments, ~ class(.x)[1]), - input = purrr::map_chr(adjustments, ~ .x$inputs), - output_numeric = purrr::map_lgl(adjustments, ~ grepl("numeric", .x$outputs)), - output_prob = purrr::map_lgl(adjustments, ~ grepl("probability", .x$outputs)), - output_class = purrr::map_lgl(adjustments, ~ grepl("class", .x$outputs)), - output_all = purrr::map_lgl(adjustments, ~ grepl("everything", .x$outputs)) - )) - - if (length(adjustments) < 2) { +validate_order <- function(orderings, type, call = caller_env()) { + if (nrow(orderings) < 2) { return(invisible(orderings)) } - if (type == "unknown") { - type <- infer_type(orderings) - } + check_incompatible_types(orderings, call) switch( type, @@ -27,6 +15,24 @@ validate_order <- function(adjustments, type, call = caller_env()) { invisible(orderings) } +check_incompatible_types <- function(orderings, call) { + if (all(c("numeric", "probability") %in% orderings$input)) { + numeric_adjustments <- orderings$name[which(orderings$input == "numeric")] + probability_adjustments <- orderings$name[which(orderings$input == "probability")] + cli_abort( + c( + "Can't compose adjustments for different prediction types.", + "i" = "{cli::qty(numeric_adjustments)} + Adjustment{?s} {.fn {paste0('adjust_', numeric_adjustments)}} + {cli::qty(numeric_adjustments[-1])} operate{?s} on numerics while + {.fn {paste0('adjust_', probability_adjustments)}} + {cli::qty(probability_adjustments[-1])} operate{?s} on probabilities." + ), + call = call + ) + } +} + check_classification_order <- function(x, call) { cal_ind <- which(grepl("calibration$", x$name)) eq_ind <- which(grepl("equivocal", x$name)) diff --git a/man/tailor.Rd b/man/tailor.Rd index b0c6ac4..41763fb 100644 --- a/man/tailor.Rd +++ b/man/tailor.Rd @@ -4,13 +4,9 @@ \alias{tailor} \title{Declare post-processing for model predictions} \usage{ -tailor(type = "unknown", outcome = NULL, estimate = NULL, probabilities = NULL) +tailor(outcome = NULL, estimate = NULL, probabilities = NULL) } \arguments{ -\item{type}{Character. The model sub-mode. Possible values are -\code{"unknown"}, \code{"regression"}, \code{"binary"}, or \code{"multiclass"}. Only required -when used independently of \code{?workflows::add_tailor()}.} - \item{outcome}{<\code{\link[dplyr:dplyr_tidy_select]{tidy-select}}> Only required when used independently of \code{?workflows::add_tailor()}, and can also be passed at \code{fit()} time instead. The column name of the outcome variable.} diff --git a/tests/testthat/_snaps/adjust-equivocal-zone.md b/tests/testthat/_snaps/adjust-equivocal-zone.md index 314da8c..0427c70 100644 --- a/tests/testthat/_snaps/adjust-equivocal-zone.md +++ b/tests/testthat/_snaps/adjust-equivocal-zone.md @@ -5,7 +5,7 @@ Message -- tailor ---------------------------------------------------------------------- - A postprocessor with 1 adjustment: + A binary postprocessor with 1 adjustment: * Add equivocal zone of size 0.1. @@ -16,7 +16,7 @@ Message -- tailor ---------------------------------------------------------------------- - A postprocessor with 1 adjustment: + A binary postprocessor with 1 adjustment: * Add equivocal zone of optimized size. diff --git a/tests/testthat/_snaps/adjust-numeric-calibration.md b/tests/testthat/_snaps/adjust-numeric-calibration.md index 6f3bd2f..3e9baa9 100644 --- a/tests/testthat/_snaps/adjust-numeric-calibration.md +++ b/tests/testthat/_snaps/adjust-numeric-calibration.md @@ -5,7 +5,7 @@ Message -- tailor ---------------------------------------------------------------------- - A postprocessor with 1 adjustment: + A regression postprocessor with 1 adjustment: * Re-calibrate numeric predictions. @@ -20,15 +20,7 @@ --- Code - tailor("binary") %>% adjust_numeric_calibration("linear") - Condition - Error in `adjust_numeric_calibration()`: - ! A binary tailor is incompatible with the adjustment `adjust_numeric_calibration()`. - ---- - - Code - tailor("regression") %>% adjust_numeric_calibration("binary") + tailor() %>% adjust_numeric_calibration("binary") Condition Error in `adjust_numeric_calibration()`: ! `method` must be one of "linear", "isotonic", or "isotonic_boot", not "binary". diff --git a/tests/testthat/_snaps/adjust-numeric-range.md b/tests/testthat/_snaps/adjust-numeric-range.md index 050da3b..1daaeec 100644 --- a/tests/testthat/_snaps/adjust-numeric-range.md +++ b/tests/testthat/_snaps/adjust-numeric-range.md @@ -5,7 +5,7 @@ Message -- tailor ---------------------------------------------------------------------- - A postprocessor with 1 adjustment: + A regression postprocessor with 1 adjustment: * Constrain numeric predictions to be between [-Inf, Inf]. @@ -16,7 +16,7 @@ Message -- tailor ---------------------------------------------------------------------- - A postprocessor with 1 adjustment: + A regression postprocessor with 1 adjustment: * Constrain numeric predictions to be between [?, Inf]. @@ -27,7 +27,7 @@ Message -- tailor ---------------------------------------------------------------------- - A postprocessor with 1 adjustment: + A regression postprocessor with 1 adjustment: * Constrain numeric predictions to be between [-1, ?]. @@ -38,7 +38,7 @@ Message -- tailor ---------------------------------------------------------------------- - A postprocessor with 1 adjustment: + A regression postprocessor with 1 adjustment: * Constrain numeric predictions to be between [?, 1]. diff --git a/tests/testthat/_snaps/adjust-probability-calibration.md b/tests/testthat/_snaps/adjust-probability-calibration.md index e814d17..432ea81 100644 --- a/tests/testthat/_snaps/adjust-probability-calibration.md +++ b/tests/testthat/_snaps/adjust-probability-calibration.md @@ -5,7 +5,7 @@ Message -- tailor ---------------------------------------------------------------------- - A postprocessor with 1 adjustment: + A binary postprocessor with 1 adjustment: * Re-calibrate classification probabilities. @@ -20,15 +20,7 @@ --- Code - tailor("regression") %>% adjust_probability_calibration("binary") - Condition - Error in `adjust_probability_calibration()`: - ! A regression tailor is incompatible with the adjustment `adjust_probability_calibration()`. - ---- - - Code - tailor("binary") %>% adjust_probability_calibration("linear") + tailor() %>% adjust_probability_calibration("linear") Condition Error in `adjust_probability_calibration()`: ! `method` must be one of "logistic", "multinomial", "beta", "isotonic", or "isotonic_boot", not "linear". diff --git a/tests/testthat/_snaps/adjust-probability-threshold.md b/tests/testthat/_snaps/adjust-probability-threshold.md index d5a1475..0cf3857 100644 --- a/tests/testthat/_snaps/adjust-probability-threshold.md +++ b/tests/testthat/_snaps/adjust-probability-threshold.md @@ -5,7 +5,7 @@ Message -- tailor ---------------------------------------------------------------------- - A postprocessor with 1 adjustment: + A binary postprocessor with 1 adjustment: * Adjust probability threshold to 0.5. @@ -16,7 +16,7 @@ Message -- tailor ---------------------------------------------------------------------- - A postprocessor with 1 adjustment: + A binary postprocessor with 1 adjustment: * Adjust probability threshold to optimized value. diff --git a/tests/testthat/_snaps/tailor.md b/tests/testthat/_snaps/tailor.md index ddf992c..438907b 100644 --- a/tests/testthat/_snaps/tailor.md +++ b/tests/testthat/_snaps/tailor.md @@ -10,16 +10,16 @@ --- Code - tailor(type = "binary") + tailor() Message -- tailor ---------------------------------------------------------------------- - A binary postprocessor with 0 adjustments. + A postprocessor with 0 adjustments. --- Code - tailor(type = "binary") %>% adjust_probability_threshold(0.2) + tailor() %>% adjust_probability_threshold(0.2) Message -- tailor ---------------------------------------------------------------------- @@ -30,8 +30,7 @@ --- Code - tailor(type = "binary") %>% adjust_probability_threshold(0.2) %>% - adjust_equivocal_zone() + tailor() %>% adjust_probability_threshold(0.2) %>% adjust_equivocal_zone() Message -- tailor ---------------------------------------------------------------------- diff --git a/tests/testthat/_snaps/validation-rules.md b/tests/testthat/_snaps/validation-rules.md index 287ca17..cb63223 100644 --- a/tests/testthat/_snaps/validation-rules.md +++ b/tests/testthat/_snaps/validation-rules.md @@ -1,7 +1,7 @@ # validation of adjustments (regression) Code - tailor(type = "regression") %>% adjust_numeric_range(lower_limit = 2) %>% + tailor() %>% adjust_numeric_range(lower_limit = 2) %>% adjust_numeric_calibration() %>% adjust_predictions_custom(squared = .pred^2) Condition Error in `adjust_numeric_calibration()`: @@ -10,7 +10,7 @@ # validation of adjustments (classification) Code - tailor(type = "binary") %>% adjust_probability_threshold(threshold = 0.4) %>% + tailor() %>% adjust_probability_threshold(threshold = 0.4) %>% adjust_probability_calibration() Condition Error in `adjust_probability_calibration()`: @@ -28,7 +28,7 @@ --- Code - tailor(type = "binary") %>% adjust_predictions_custom(veg = "potato") %>% + tailor() %>% adjust_predictions_custom(veg = "potato") %>% adjust_probability_threshold(threshold = 0.4) %>% adjust_probability_calibration() Condition @@ -48,7 +48,7 @@ --- Code - tailor(type = "binary") %>% adjust_predictions_custom(veg = "potato") %>% + tailor() %>% adjust_predictions_custom(veg = "potato") %>% adjust_probability_threshold(threshold = 0.4) %>% adjust_probability_threshold(threshold = 0.5) %>% adjust_probability_calibration() @@ -70,7 +70,7 @@ --- Code - tailor(type = "binary") %>% adjust_equivocal_zone(value = 0.2) %>% + tailor() %>% adjust_equivocal_zone(value = 0.2) %>% adjust_probability_threshold(threshold = 0.4) Condition Error in `adjust_probability_threshold()`: @@ -85,3 +85,65 @@ Error in `adjust_probability_threshold()`: ! Equivocal zone addition should come after adjustments that update the class probability estimates or hard class predictions. +# validation of adjustments (incompatible types) + + Code + tailor() %>% adjust_numeric_calibration() %>% adjust_probability_threshold() + Condition + Error in `adjust_probability_threshold()`: + ! Can't compose adjustments for different prediction types. + i Adjustment `adjust_numeric_calibration()` operates on numerics while `adjust_probability_threshold()` operates on probabilities. + +--- + + Code + tailor() %>% adjust_probability_calibration("logistic") %>% + adjust_probability_threshold(threshold = 0.4) %>% adjust_numeric_range( + lower_limit = 2) + Condition + Error in `adjust_numeric_range()`: + ! Can't compose adjustments for different prediction types. + i Adjustment `adjust_numeric_range()` operates on numerics while `adjust_probability_calibration()` and `adjust_probability_threshold()` operate on probabilities. + +--- + + Code + tailor() %>% adjust_numeric_calibration() %>% adjust_numeric_range(lower_limit = 2) %>% + adjust_probability_threshold(threshold = 0.4) + Condition + Error in `adjust_probability_threshold()`: + ! Can't compose adjustments for different prediction types. + i Adjustments `adjust_numeric_calibration()` and `adjust_numeric_range()` operate on numerics while `adjust_probability_threshold()` operates on probabilities. + +--- + + Code + tailor() %>% adjust_predictions_custom(veg = "potato") %>% + adjust_numeric_calibration() %>% adjust_probability_threshold() + Condition + Error in `adjust_probability_threshold()`: + ! Can't compose adjustments for different prediction types. + i Adjustment `adjust_numeric_calibration()` operates on numerics while `adjust_probability_threshold()` operates on probabilities. + +--- + + Code + tailor() %>% adjust_predictions_custom(veg = "potato") %>% + adjust_probability_calibration("logistic") %>% adjust_probability_threshold( + threshold = 0.4) %>% adjust_numeric_range(lower_limit = 2) + Condition + Error in `adjust_numeric_range()`: + ! Can't compose adjustments for different prediction types. + i Adjustment `adjust_numeric_range()` operates on numerics while `adjust_probability_calibration()` and `adjust_probability_threshold()` operate on probabilities. + +--- + + Code + tailor() %>% adjust_predictions_custom(veg = "potato") %>% + adjust_numeric_calibration() %>% adjust_numeric_range(lower_limit = 2) %>% + adjust_probability_threshold(threshold = 0.4) + Condition + Error in `adjust_probability_threshold()`: + ! Can't compose adjustments for different prediction types. + i Adjustments `adjust_numeric_calibration()` and `adjust_numeric_range()` operate on numerics while `adjust_probability_threshold()` operates on probabilities. + diff --git a/tests/testthat/test-adjust-numeric-calibration.R b/tests/testthat/test-adjust-numeric-calibration.R index 703b201..20e0534 100644 --- a/tests/testthat/test-adjust-numeric-calibration.R +++ b/tests/testthat/test-adjust-numeric-calibration.R @@ -44,11 +44,7 @@ test_that("errors informatively with bad input", { expect_snapshot(error = TRUE, adjust_numeric_calibration(tailor(), "boop")) expect_snapshot( error = TRUE, - tailor("binary") %>% adjust_numeric_calibration("linear") - ) - expect_snapshot( - error = TRUE, - tailor("regression") %>% adjust_numeric_calibration("binary") + tailor() %>% adjust_numeric_calibration("binary") ) expect_no_condition(adjust_numeric_calibration(tailor())) diff --git a/tests/testthat/test-adjust-probability-calibration.R b/tests/testthat/test-adjust-probability-calibration.R index a19cd90..09d82bc 100644 --- a/tests/testthat/test-adjust-probability-calibration.R +++ b/tests/testthat/test-adjust-probability-calibration.R @@ -54,11 +54,7 @@ test_that("errors informatively with bad input", { expect_snapshot(error = TRUE, adjust_probability_calibration(tailor(), "boop")) expect_snapshot( error = TRUE, - tailor("regression") %>% adjust_probability_calibration("binary") - ) - expect_snapshot( - error = TRUE, - tailor("binary") %>% adjust_probability_calibration("linear") + tailor() %>% adjust_probability_calibration("linear") ) expect_no_condition(adjust_numeric_calibration(tailor())) diff --git a/tests/testthat/test-tailor.R b/tests/testthat/test-tailor.R index 0c45150..e951e87 100644 --- a/tests/testthat/test-tailor.R +++ b/tests/testthat/test-tailor.R @@ -1,12 +1,12 @@ test_that("tailor printing", { expect_snapshot(tailor()) - expect_snapshot(tailor(type = "binary")) + expect_snapshot(tailor()) expect_snapshot( - tailor(type = "binary") %>% + tailor() %>% adjust_probability_threshold(.2) ) expect_snapshot( - tailor(type = "binary") %>% + tailor() %>% adjust_probability_threshold(.2) %>% adjust_equivocal_zone() ) diff --git a/tests/testthat/test-validation-rules.R b/tests/testthat/test-validation-rules.R index 4a50c22..77168e2 100644 --- a/tests/testthat/test-validation-rules.R +++ b/tests/testthat/test-validation-rules.R @@ -1,7 +1,7 @@ test_that("validation of adjustments (regression)", { expect_no_condition( reg_tailor <- - tailor(type = "regression") %>% + tailor() %>% adjust_numeric_calibration() %>% adjust_numeric_range(lower_limit = 2) %>% adjust_predictions_custom(squared = .pred^2) @@ -9,7 +9,7 @@ test_that("validation of adjustments (regression)", { expect_snapshot( error = TRUE, - tailor(type = "regression") %>% + tailor() %>% adjust_numeric_range(lower_limit = 2) %>% adjust_numeric_calibration() %>% adjust_predictions_custom(squared = .pred^2) @@ -19,7 +19,7 @@ test_that("validation of adjustments (regression)", { # modifies the prediction? expect_no_condition( reg_tailor <- - tailor(type = "regression") %>% + tailor() %>% adjust_predictions_custom(squared = .pred^2) %>% adjust_numeric_calibration() %>% adjust_numeric_range(lower_limit = 2) @@ -29,14 +29,14 @@ test_that("validation of adjustments (regression)", { test_that("validation of adjustments (classification)", { expect_no_condition( cls_tailor_1 <- - tailor(type = "binary") %>% + tailor() %>% adjust_probability_calibration("logistic") %>% adjust_probability_threshold(threshold = .4) ) expect_no_condition( cls_tailor_2 <- - tailor(type = "binary") %>% + tailor() %>% adjust_predictions_custom(starch = "potato") %>% adjust_predictions_custom(veg = "green beans") %>% adjust_probability_calibration("logistic") %>% @@ -45,7 +45,7 @@ test_that("validation of adjustments (classification)", { expect_snapshot( error = TRUE, - tailor(type = "binary") %>% + tailor() %>% adjust_probability_threshold(threshold = .4) %>% adjust_probability_calibration() ) @@ -59,7 +59,7 @@ test_that("validation of adjustments (classification)", { expect_snapshot( error = TRUE, - tailor(type = "binary") %>% + tailor() %>% adjust_predictions_custom(veg = "potato") %>% adjust_probability_threshold(threshold = .4) %>% adjust_probability_calibration() @@ -75,7 +75,7 @@ test_that("validation of adjustments (classification)", { expect_snapshot( error = TRUE, - tailor(type = "binary") %>% + tailor() %>% adjust_predictions_custom(veg = "potato") %>% adjust_probability_threshold(threshold = .4) %>% adjust_probability_threshold(threshold = .5) %>% @@ -93,7 +93,7 @@ test_that("validation of adjustments (classification)", { expect_snapshot( error = TRUE, - tailor(type = "binary") %>% + tailor() %>% adjust_equivocal_zone(value = .2) %>% adjust_probability_threshold(threshold = .4) ) @@ -116,3 +116,75 @@ test_that("validation of adjustments (ambiguous type)", { expect_equal(ambiguous_tailor$type, "unknown") }) + +test_that("validation of adjustments (incompatible types)", { + # one bad adjustment each + expect_snapshot( + error = TRUE, + tailor() %>% + adjust_numeric_calibration() %>% + adjust_probability_threshold() + ) + + # varying the pluralization... + expect_snapshot( + error = TRUE, + tailor() %>% + adjust_probability_calibration("logistic") %>% + adjust_probability_threshold(threshold = .4) %>% + adjust_numeric_range(lower_limit = 2) + ) + + expect_snapshot( + error = TRUE, + tailor() %>% + adjust_numeric_calibration() %>% + adjust_numeric_range(lower_limit = 2) %>% + adjust_probability_threshold(threshold = .4) + ) + + # ensure that mixing in ambiguous adjustments doesn't cause issues + expect_snapshot( + error = TRUE, + tailor() %>% + adjust_predictions_custom(veg = "potato") %>% + adjust_numeric_calibration() %>% + adjust_probability_threshold() + ) + + expect_snapshot( + error = TRUE, + tailor() %>% + adjust_predictions_custom(veg = "potato") %>% + adjust_probability_calibration("logistic") %>% + adjust_probability_threshold(threshold = .4) %>% + adjust_numeric_range(lower_limit = 2) + ) + + expect_snapshot( + error = TRUE, + tailor() %>% + adjust_predictions_custom(veg = "potato") %>% + adjust_numeric_calibration() %>% + adjust_numeric_range(lower_limit = 2) %>% + adjust_probability_threshold(threshold = .4) + ) + + expect_no_condition( + tailor() %>% + adjust_predictions_custom(veg = "potato") %>% + adjust_numeric_calibration() + ) + + expect_no_condition( + tailor() %>% + adjust_numeric_calibration() %>% + adjust_predictions_custom(veg = "potato") + ) + + expect_no_condition( + tailor() %>% + adjust_probability_threshold(threshold = .4) %>% + adjust_predictions_custom(veg = "potato") + ) +})