Skip to content

Commit

Permalink
merge pr #41: remove tailor(type)
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch authored Sep 17, 2024
2 parents 6704eb6 + a9bb433 commit dfd9b0e
Show file tree
Hide file tree
Showing 15 changed files with 214 additions and 91 deletions.
21 changes: 10 additions & 11 deletions R/tailor.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
#' with the [tidymodels](https://tidymodels.org) framework; for greatest ease
#' of use, situate tailors in model workflows with `?workflows::add_tailor()`.
#'
#' @param type Character. The model sub-mode. Possible values are
#' `"unknown"`, `"regression"`, `"binary"`, or `"multiclass"`. Only required
#' when used independently of `?workflows::add_tailor()`.
#' @param outcome <[`tidy-select`][dplyr::dplyr_tidy_select]> Only required
#' when used independently of `?workflows::add_tailor()`, and can also be passed
#' at `fit()` time instead. The column name of the outcome variable.
Expand Down Expand Up @@ -64,18 +61,16 @@
#' # adjust hard class predictions
#' predict(tlr_fit, two_class_example) %>% count(predicted)
#' @export
tailor <- function(type = "unknown", outcome = NULL, estimate = NULL,
probabilities = NULL) {
tailor <- function(outcome = NULL, estimate = NULL, probabilities = NULL) {
columns <-
list(
outcome = outcome,
type = type,
estimate = estimate,
probabilities = probabilities
)

new_tailor(
type,
"unknown",
adjustments = list(),
columns = columns,
ptype = tibble::new_tibble(list()),
Expand All @@ -84,8 +79,6 @@ tailor <- function(type = "unknown", outcome = NULL, estimate = NULL,
}

new_tailor <- function(type, adjustments, columns, ptype, call) {
type <- arg_match0(type, c("unknown", "regression", "binary", "multiclass"))

if (!is.list(adjustments)) {
cli_abort("The {.arg adjustments} argument should be a list.", call = call)
}
Expand All @@ -97,8 +90,14 @@ new_tailor <- function(type, adjustments, columns, ptype, call) {
{.val adjustment}: {bad_adjustment}.", call = call)
}

orderings <- adjustment_orderings(adjustments)

if (type == "unknown") {
type <- infer_type(orderings)
}

# validate adjustment order and check duplicates
validate_order(adjustments, type, call)
validate_order(orderings, type, call)

# check columns
res <- list(
Expand Down Expand Up @@ -233,5 +232,5 @@ set_tailor_type <- function(object, y) {
# todo setup eval_time
# todo missing methods:
# todo tune_args
# todo tidy
# todo tidy (this should probably just be `adjustment_orderings()`)
# todo extract_parameter_set_dials
13 changes: 13 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,19 @@ tailor_adjustment_requires_fit <- function(x) {
isTRUE(x$requires_fit)
}

# an tidy-esque method for adjustment lists, used in validating
# compatibility of adjustments
adjustment_orderings <- function(adjustments) {
tibble::new_tibble(list(
name = purrr::map_chr(adjustments, ~ class(.x)[1]),
input = purrr::map_chr(adjustments, ~ .x$inputs),
output_numeric = purrr::map_lgl(adjustments, ~ grepl("numeric", .x$outputs)),
output_prob = purrr::map_lgl(adjustments, ~ grepl("probability", .x$outputs)),
output_class = purrr::map_lgl(adjustments, ~ grepl("class", .x$outputs)),
output_all = purrr::map_lgl(adjustments, ~ grepl("everything", .x$outputs))
))
}

# ad-hoc checking --------------------------------------------------------------
check_tailor <- function(x, calibration_type = NULL, call = caller_env(), arg = caller_arg(x)) {
if (!is_tailor(x)) {
Expand Down
36 changes: 21 additions & 15 deletions R/validation-rules.R
Original file line number Diff line number Diff line change
@@ -1,21 +1,9 @@
validate_order <- function(adjustments, type, call = caller_env()) {
orderings <-
tibble::new_tibble(list(
name = purrr::map_chr(adjustments, ~ class(.x)[1]),
input = purrr::map_chr(adjustments, ~ .x$inputs),
output_numeric = purrr::map_lgl(adjustments, ~ grepl("numeric", .x$outputs)),
output_prob = purrr::map_lgl(adjustments, ~ grepl("probability", .x$outputs)),
output_class = purrr::map_lgl(adjustments, ~ grepl("class", .x$outputs)),
output_all = purrr::map_lgl(adjustments, ~ grepl("everything", .x$outputs))
))

if (length(adjustments) < 2) {
validate_order <- function(orderings, type, call = caller_env()) {
if (nrow(orderings) < 2) {
return(invisible(orderings))
}

if (type == "unknown") {
type <- infer_type(orderings)
}
check_incompatible_types(orderings, call)

switch(
type,
Expand All @@ -27,6 +15,24 @@ validate_order <- function(adjustments, type, call = caller_env()) {
invisible(orderings)
}

check_incompatible_types <- function(orderings, call) {
if (all(c("numeric", "probability") %in% orderings$input)) {
numeric_adjustments <- orderings$name[which(orderings$input == "numeric")]
probability_adjustments <- orderings$name[which(orderings$input == "probability")]
cli_abort(
c(
"Can't compose adjustments for different prediction types.",
"i" = "{cli::qty(numeric_adjustments)}
Adjustment{?s} {.fn {paste0('adjust_', numeric_adjustments)}}
{cli::qty(numeric_adjustments[-1])} operate{?s} on numerics while
{.fn {paste0('adjust_', probability_adjustments)}}
{cli::qty(probability_adjustments[-1])} operate{?s} on probabilities."
),
call = call
)
}
}

check_classification_order <- function(x, call) {
cal_ind <- which(grepl("calibration$", x$name))
eq_ind <- which(grepl("equivocal", x$name))
Expand Down
6 changes: 1 addition & 5 deletions man/tailor.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions tests/testthat/_snaps/adjust-equivocal-zone.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Message
-- tailor ----------------------------------------------------------------------
A postprocessor with 1 adjustment:
A binary postprocessor with 1 adjustment:
* Add equivocal zone of size 0.1.

Expand All @@ -16,7 +16,7 @@
Message
-- tailor ----------------------------------------------------------------------
A postprocessor with 1 adjustment:
A binary postprocessor with 1 adjustment:
* Add equivocal zone of optimized size.

12 changes: 2 additions & 10 deletions tests/testthat/_snaps/adjust-numeric-calibration.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Message
-- tailor ----------------------------------------------------------------------
A postprocessor with 1 adjustment:
A regression postprocessor with 1 adjustment:
* Re-calibrate numeric predictions.

Expand All @@ -20,15 +20,7 @@
---

Code
tailor("binary") %>% adjust_numeric_calibration("linear")
Condition
Error in `adjust_numeric_calibration()`:
! A binary tailor is incompatible with the adjustment `adjust_numeric_calibration()`.

---

Code
tailor("regression") %>% adjust_numeric_calibration("binary")
tailor() %>% adjust_numeric_calibration("binary")
Condition
Error in `adjust_numeric_calibration()`:
! `method` must be one of "linear", "isotonic", or "isotonic_boot", not "binary".
Expand Down
8 changes: 4 additions & 4 deletions tests/testthat/_snaps/adjust-numeric-range.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Message
-- tailor ----------------------------------------------------------------------
A postprocessor with 1 adjustment:
A regression postprocessor with 1 adjustment:
* Constrain numeric predictions to be between [-Inf, Inf].

Expand All @@ -16,7 +16,7 @@
Message
-- tailor ----------------------------------------------------------------------
A postprocessor with 1 adjustment:
A regression postprocessor with 1 adjustment:
* Constrain numeric predictions to be between [?, Inf].

Expand All @@ -27,7 +27,7 @@
Message
-- tailor ----------------------------------------------------------------------
A postprocessor with 1 adjustment:
A regression postprocessor with 1 adjustment:
* Constrain numeric predictions to be between [-1, ?].

Expand All @@ -38,7 +38,7 @@
Message
-- tailor ----------------------------------------------------------------------
A postprocessor with 1 adjustment:
A regression postprocessor with 1 adjustment:
* Constrain numeric predictions to be between [?, 1].

12 changes: 2 additions & 10 deletions tests/testthat/_snaps/adjust-probability-calibration.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Message
-- tailor ----------------------------------------------------------------------
A postprocessor with 1 adjustment:
A binary postprocessor with 1 adjustment:
* Re-calibrate classification probabilities.

Expand All @@ -20,15 +20,7 @@
---

Code
tailor("regression") %>% adjust_probability_calibration("binary")
Condition
Error in `adjust_probability_calibration()`:
! A regression tailor is incompatible with the adjustment `adjust_probability_calibration()`.

---

Code
tailor("binary") %>% adjust_probability_calibration("linear")
tailor() %>% adjust_probability_calibration("linear")
Condition
Error in `adjust_probability_calibration()`:
! `method` must be one of "logistic", "multinomial", "beta", "isotonic", or "isotonic_boot", not "linear".
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/_snaps/adjust-probability-threshold.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Message
-- tailor ----------------------------------------------------------------------
A postprocessor with 1 adjustment:
A binary postprocessor with 1 adjustment:
* Adjust probability threshold to 0.5.

Expand All @@ -16,7 +16,7 @@
Message
-- tailor ----------------------------------------------------------------------
A postprocessor with 1 adjustment:
A binary postprocessor with 1 adjustment:
* Adjust probability threshold to optimized value.

9 changes: 4 additions & 5 deletions tests/testthat/_snaps/tailor.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@
---

Code
tailor(type = "binary")
tailor()
Message
-- tailor ----------------------------------------------------------------------
A binary postprocessor with 0 adjustments.
A postprocessor with 0 adjustments.

---

Code
tailor(type = "binary") %>% adjust_probability_threshold(0.2)
tailor() %>% adjust_probability_threshold(0.2)
Message
-- tailor ----------------------------------------------------------------------
Expand All @@ -30,8 +30,7 @@
---

Code
tailor(type = "binary") %>% adjust_probability_threshold(0.2) %>%
adjust_equivocal_zone()
tailor() %>% adjust_probability_threshold(0.2) %>% adjust_equivocal_zone()
Message
-- tailor ----------------------------------------------------------------------
Expand Down
Loading

0 comments on commit dfd9b0e

Please sign in to comment.