Skip to content

Commit

Permalink
Full check for dev version 0.2.20240109
Browse files Browse the repository at this point in the history
  • Loading branch information
tripartio committed Jan 9, 2024
1 parent 2f1b31a commit 07ad8ff
Show file tree
Hide file tree
Showing 13 changed files with 5,827 additions and 5,558 deletions.
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: ale
Title: Interpretable Machine Learning and Statistical Inference with Accumulated Local Effects (ALE)
Version: 0.2.20240105
Version: 0.2.20240109
Authors@R: c(
person("Chitu", "Okoli", , "[email protected]", role = c("aut", "cre"),
comment = c(ORCID = "0000-0001-5574-7572")),
Expand Down Expand Up @@ -30,12 +30,14 @@ Imports:
ellipsis,
glue,
grDevices,
insight,
labeling,
methods,
purrr,
stats,
stringr,
tidyr,
univariateML,
yaImpute
Depends:
R (>= 3.5.0),
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

export(ale)
export(ale_ixn)
export(create_p_funs)
export(model_bootstrap)
import(assertthat)
import(dplyr)
import(ggplot2)
import(purrr)
importFrom(glue,glue)
importFrom(stats,median)
importFrom(stats,quantile)
importFrom(stats,sd)
3 changes: 3 additions & 0 deletions R/ale_core.R
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@
#'
#' }
#'
#'
#' @importFrom glue glue
#'
ale <- function (
data,
model,
Expand Down
5 changes: 3 additions & 2 deletions R/calc_ale.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,11 @@ calc_ale <- function(
# local variables.
`:=` <- NULL
ale_y <- NULL
Var1 <- NULL
estimate <- NULL
Freq <- NULL
p.value <- NULL
statistic <- NULL
estimate <- NULL
Var1 <- NULL



Expand Down
16 changes: 9 additions & 7 deletions R/model_bootstrap.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@
#' @param model_call_string character string. If NULL, [model_bootstrap()] tries to
#' automatically detect and construct the call for bootstrapped datasets. If it cannot, the
#' function will fail early. In that case, a character string of the full call
#' for the model must be provided that includes `boot_data` as the data argument for the call. See examples
#' for the model must be provided that includes `boot_data` as the data argument for the call.
#' See examples.
# Future arguments:
# * y_col: name of y column in data. This would allow SD and MAD to be calculated.
# * pred_fun,pred_type: allows the prediction function to be called; this would
Expand Down Expand Up @@ -127,7 +128,7 @@
#' attitude,
#' model_call_string = 'mgcv::gam(
#' rating ~ complaints + privileges + s(learning) +
#' raises + s(critical) + advance),
#' raises + s(critical) + advance,
#' data = boot_data
#' )',
#' boot_it = 3
Expand Down Expand Up @@ -255,14 +256,15 @@ model_bootstrap <- function (
ale_y <- NULL
ale_y_mean <- NULL
ale_y_median <- NULL
aled <- NULL
estimate <- NULL
it <- NULL
naler_max <- NULL
name <- NULL
value <- NULL
term <- NULL
estimate <- NULL
p.value <- NULL
statistic <- NULL
aled <- NULL
naler_max <- NULL
term <- NULL
value <- NULL


# Create bootstrap tbl
Expand Down
32 changes: 24 additions & 8 deletions R/stats.R
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@
#' # Plot the ALE data. The horizontal bands in the plots use the p-values.
#' gridExtra::grid.arrange(grobs = ale_gam_diamonds$plots, ncol = 2)
#' }
#'
#'
#' @importFrom glue glue
#'
create_p_funs <- function(
training_data,
test_data,
Expand All @@ -178,6 +182,7 @@ create_p_funs <- function(
silent = FALSE,
.testing_mode = FALSE
) {

# Validate arguments

assert_that(training_data |> inherits('data.frame'))
Expand Down Expand Up @@ -245,6 +250,17 @@ create_p_funs <- function(
assert_that(is.flag(silent))


# Hack to prevent devtools::check from thinking that masked variables are global:
# Make them null local variables within the function with the issues. So,
# when masking applies, the masked variables will be prioritized over these null
# local variables.

# This super-assignment might be problematic. See this tip from ChatGPT to resolve it:
# https://chat.openai.com/c/08b68562-c339-4c37-baab-3c71d2e9fb73
.rand_model <<- NULL
.rand_test <<- NULL
.rand_train <<- NULL



# Determine the closest distribution of the residuals
Expand Down Expand Up @@ -294,7 +310,7 @@ create_p_funs <- function(
else { # use the automatically detected model call
# Update the model to call to add .random_variable and to train on .rand_train
model_call$formula <- model_call$formula |>
update.formula(~ . + .random_variable)
stats::update.formula(~ . + .random_variable)
model_call$data <- .rand_train

.rand_model <<- eval(model_call)
Expand All @@ -314,11 +330,11 @@ create_p_funs <- function(
.rand_ale
})

ale_y_norm <- ale:::create_ale_y_norm_function(test_data[[y_col]])
ale_y_norm <- create_ale_y_norm_function(test_data[[y_col]])

rand_stats <-
map(rand_ales, \(.rand) {
ale:::ale_stats(
ale_stats(
ale_y = .rand$data$.random_variable$ale_y,
ale_n = .rand$data$.random_variable$ale_n,
ale_y_norm_fun = ale_y_norm,
Expand All @@ -336,11 +352,11 @@ create_p_funs <- function(

# For aler_min and naler_min, the p-value is the simple ECDF
if (stringr::str_sub(.name_stat, -4, -1) == '_min') {
ecdf(.stat_vals)(x)
stats::ecdf(.stat_vals)(x)
}
# For other statistics, the p-value is 1 - ECDF.
else {
1 - ecdf(.stat_vals)(x)
1 - stats::ecdf(.stat_vals)(x)
}
}
}
Expand All @@ -361,13 +377,13 @@ create_p_funs <- function(
if (stringr::str_sub(.name_stat, -4, -1) == '_min') {
.stat_vals |>
quantile(probs = p) |>
setNames(p)
stats::setNames(p)
}
# For other statistics, the value is the quantile of 1 - p
else {
.stat_vals |>
quantile(probs = 1 - p) |>
setNames(p)
stats::setNames(p)
}
}
}
Expand All @@ -383,7 +399,7 @@ create_p_funs <- function(

# Set S3 class information for the p_funs object
class(p_funs) <- c('p_funs', 'ale', 'list')
attr(p_funs, 'ale_version') <- packageVersion('ale')
attr(p_funs, 'ale_version') <- utils::packageVersion('ale')

return(p_funs)

Expand Down
3 changes: 3 additions & 0 deletions inst/WORDLIST
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ interpretable
Interpretable
Jingyu
MMDD
Mmmm
mtcars
NALED
NALER
Noam
noncrucial
overfit
PDP
pkgdown
preprocesses
pretrained
programmatically
Expand All @@ -56,6 +58,7 @@ www
x1
x2
x3
yyyy
YYYY
Zhu

Expand Down
41 changes: 28 additions & 13 deletions man/ale.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 07ad8ff

Please sign in to comment.