Skip to content

Commit

Permalink
Replaced gridextra with patchwork. Relegated ggplot2 from Depends to …
Browse files Browse the repository at this point in the history
…Import.
  • Loading branch information
tripartio committed Feb 8, 2024
1 parent df40044 commit 087e09f
Show file tree
Hide file tree
Showing 16 changed files with 125 additions and 95 deletions.
8 changes: 4 additions & 4 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: ale
Title: Interpretable Machine Learning and Statistical Inference with Accumulated Local Effects (ALE)
Version: 0.2.0.20240207
Version: 0.2.0.20240208
Authors@R: c(
person("Chitu", "Okoli", , "[email protected]", role = c("aut", "cre"),
comment = c(ORCID = "0000-0001-5574-7572")),
Expand All @@ -14,9 +14,9 @@ Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.1
Suggests:
ALEPlot,
gridExtra,
knitr,
mgcv,
patchwork,
readr,
rmarkdown,
testthat (>= 3.0.0)
Expand All @@ -28,6 +28,7 @@ Imports:
ellipsis,
furrr,
future,
ggplot2,
glue,
grDevices,
insight,
Expand All @@ -42,8 +43,7 @@ Imports:
univariateML,
yaImpute
Depends:
R (>= 3.5.0),
ggplot2
R (>= 3.5.0)
Remotes:
tidyverse/ggplot2#5592
URL: https://github.com/tripartio/ale, https://tripartio.github.io/ale/
Expand Down
26 changes: 15 additions & 11 deletions R/ale_core.R
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,9 @@
#' arguments to understand how these are determined.
#'
#' @examples
# Sample 1000 rows from the diamonds dataset (for a simple example)
#' diamonds
# Sample 1000 rows from the ggplot2::diamonds dataset (for a simple example)
#' set.seed(0)
#' diamonds_sample <- diamonds[sample(nrow(diamonds), 1000), ]
#' diamonds_sample <- ggplot2::diamonds[sample(nrow(ggplot2::diamonds), 1000), ]
#'
#' # Split the dataset into training and test sets
#' # https://stackoverflow.com/a/54892459/2449926
Expand Down Expand Up @@ -309,7 +308,8 @@
#' )
#'
#' # Plot the ALE data
#' gridExtra::grid.arrange(grobs = ale_gam_diamonds$plots, ncol = 2)
#' ale_gam_diamonds$plots |>
#' patchwork::wrap_plots()
#'
#' # Bootstrapped ALE
#' # This can be slow, since bootstrapping runs the algorithm boot_it times
Expand All @@ -322,7 +322,8 @@
#' )
#'
#' # Bootstrapped ALEs print with confidence intervals
#' gridExtra::grid.arrange(grobs = ale_gam_diamonds_boot$plots, ncol = 2)
#' ale_gam_diamonds_boot$plots |>
#' patchwork::wrap_plots()
#'
#'
#' # If the predict function you want is non-standard, you may define a
Expand All @@ -339,7 +340,8 @@
#' )
#'
#' # Plot the ALE data
#' gridExtra::grid.arrange(grobs = ale_gam_diamonds_custom$plots, ncol = 2)
#' ale_gam_diamonds_custom$plots |>
#' patchwork::wrap_plots()
#'
#' }
#'
Expand Down Expand Up @@ -449,10 +451,9 @@ ale <- function (
#'
#' @examples
#'
# Sample 1000 rows from the diamonds dataset (for a simple example)
#' diamonds
# Sample 1000 rows from the ggplot2::diamonds dataset (for a simple example)
#' set.seed(0)
#' diamonds_sample <- diamonds[sample(nrow(diamonds), 1000), ]
#' diamonds_sample <- ggplot2::diamonds[sample(nrow(ggplot2::diamonds), 1000), ]
#'
#' # Split the dataset into training and test sets
#' # https://stackoverflow.com/a/54892459/2449926
Expand Down Expand Up @@ -484,8 +485,11 @@ ale <- function (
#'
#' # Print interaction plots
#' ale_ixn_gam_diamonds$plots |>
#' purrr::walk(\(.x1) { # extract list of x1 ALE outputs
#' gridExtra::grid.arrange(grobs = .x1, ncol = 2) # plot all x1 plots
#' # extract list of x1 ALE outputs
#' purrr::walk(\(.x1) {
#' # plot all x2 plots in each .x1 element
#' patchwork::wrap_plots(.x1) |>
#' print()
#' })
#' }
#'
Expand Down
3 changes: 2 additions & 1 deletion R/model_bootstrap.R
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@
#' mb_gam$model_coefs
#'
#' # Plot ALE
#' gridExtra::grid.arrange(grobs = mb_gam$ale$plots, ncol = 2)
#' mb_gam$ale$plots |>
#' patchwork::wrap_plots()
#' }
#'
#'
Expand Down
7 changes: 4 additions & 3 deletions R/stats.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@
#'
#' @examples
#' \donttest{
#' # Sample 1000 rows from the diamonds dataset (for a simple example)diamonds
#' # Sample 1000 rows from the ggplot2::diamonds dataset (for a simple example)
#' set.seed(0)
#' diamonds_sample <- diamonds[sample(nrow(diamonds), 1000), ]
#' diamonds_sample <- ggplot2::diamonds[sample(nrow(ggplot2::diamonds), 1000), ]
#'
#' # Split the dataset into training and test sets
#' # https://stackoverflow.com/a/54892459/2449926
Expand Down Expand Up @@ -168,7 +168,8 @@
#' )
#'
#' # Plot the ALE data. The horizontal bands in the plots use the p-values.
#' gridExtra::grid.arrange(grobs = ale_gam_diamonds$plots, ncol = 2)
#' ale_gam_diamonds$plots |>
#' patchwork::wrap_plots()
#'
#' }
#'
Expand Down
13 changes: 7 additions & 6 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,9 @@ Here is a simple example that demonstrates the usage of the model. First, we tra
```{r gam-and-ale, fig.width=7, fig.height=11}
library(ale)
# Sample 1000 rows from the diamonds dataset (for a simple example).
# diamonds is included with ggplot2, which is imported by the ale package.
# Sample 1000 rows from the ggplot2::diamonds dataset (for a simple example).
set.seed(0)
diamonds_sample <- diamonds[sample(nrow(diamonds), 1000), ]
diamonds_sample <- ggplot2::diamonds[sample(nrow(ggplot2::diamonds), 1000), ]
# Split the dataset into training and test sets
# https://stackoverflow.com/a/54892459/2449926
Expand All @@ -97,17 +96,19 @@ gam_diamonds <- mgcv::gam(
data = diamonds_train
)
# Create ALE data and plot it
# Create ALE data
ale_gam_diamonds <- ale(
diamonds_test, gam_diamonds,
model_packages = 'mgcv' # required for parallel processing
)
gridExtra::grid.arrange(grobs = ale_gam_diamonds$plots, ncol = 2)
# Plot the ALE data
ale_gam_diamonds$plots |>
patchwork::wrap_plots(ncol = 2)
```

## Getting help

If you find a bug, please report it on [GitHub](https://github.com/tripartio/ale/issues). If you have a question about how to use the package, you can post it on [Stack Overflow with the “ale” tag](https://stackoverflow.com/questions/tagged/ale). I will follow that tag, so I will try my best to respond quickly. However, be sure to always include a minimal reproducible example for your usage requests. If you cannot include your own dataset in the question, then use one of the built-in datasets to frame your help request: `var_cars`, `census`, or `diamonds`.
If you find a bug, please report it on [GitHub](https://github.com/tripartio/ale/issues). If you have a question about how to use the package, you can post it on [Stack Overflow with the “ale” tag](https://stackoverflow.com/questions/tagged/ale). I will follow that tag, so I will try my best to respond quickly. However, be sure to always include a minimal reproducible example for your usage requests. If you cannot include your own dataset in the question, then use one of the built-in datasets to frame your help request: `var_cars` or `census`. You may also use `ggplot2::diamonds` for a larger sample.


16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,10 @@ the `ggplot` plot objects.

``` r
library(ale)
#> Loading required package: ggplot2

# Sample 1000 rows from the diamonds dataset (for a simple example).
# diamonds is included with ggplot2, which is imported by the ale package.
# Sample 1000 rows from the ggplot2::diamonds dataset (for a simple example).
set.seed(0)
diamonds_sample <- diamonds[sample(nrow(diamonds), 1000), ]
diamonds_sample <- ggplot2::diamonds[sample(nrow(ggplot2::diamonds), 1000), ]

# Split the dataset into training and test sets
# https://stackoverflow.com/a/54892459/2449926
Expand All @@ -132,13 +130,15 @@ gam_diamonds <- mgcv::gam(
data = diamonds_train
)

# Create ALE data and plot it
# Create ALE data
ale_gam_diamonds <- ale(
diamonds_test, gam_diamonds,
model_packages = 'mgcv' # required for parallel processing
)

gridExtra::grid.arrange(grobs = ale_gam_diamonds$plots, ncol = 2)
# Plot the ALE data
ale_gam_diamonds$plots |>
patchwork::wrap_plots(ncol = 2)
```

<img src="man/figures/README-gam-and-ale-1.png" width="100%" />
Expand All @@ -153,5 +153,5 @@ tag](https://stackoverflow.com/questions/tagged/ale). I will follow that
tag, so I will try my best to respond quickly. However, be sure to
always include a minimal reproducible example for your usage requests.
If you cannot include your own dataset in the question, then use one of
the built-in datasets to frame your help request: `var_cars`, `census`,
or `diamonds`.
the built-in datasets to frame your help request: `var_cars` or
`census`. You may also use `ggplot2::diamonds` for a larger sample.
12 changes: 7 additions & 5 deletions man/ale.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 6 additions & 4 deletions man/ale_ixn.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions man/create_p_funs.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file modified man/figures/README-gam-and-ale-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion man/model_bootstrap.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 14 additions & 11 deletions vignettes/ale-intro.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ library(dplyr)

## diamonds dataset

For this introduction, we use the `diamonds` dataset, built-in with the `ggplot2` graphics system. We cleaned the original version by [removing duplicates](https://lorentzen.ch/index.php/2021/04/16/a-curious-fact-on-the-diamonds-dataset/ "errors in the diamonds dataset") and invalid entries where the length (x), width (y), or depth (z) is 0.
For this introduction, we use the `diamonds` dataset, included with the `{ggplot2}` graphics system. We cleaned the original version by [removing duplicates](https://lorentzen.ch/index.php/2021/04/16/a-curious-fact-on-the-diamonds-dataset/ "errors in the diamonds dataset") and invalid entries where the length (x), width (y), or depth (z) is 0.

```{r diamonds_print}
# Clean up some invalid entries
Expand Down Expand Up @@ -130,7 +130,6 @@ The `ale()` function returns a list with various elements. The two main ones are
ale_gam_diamonds <- ale(
diamonds_test, gam_diamonds,
model_packages = 'mgcv', # required for parallel processing
silent = TRUE, # progress bars disabled for the vignette
parallel = 2 # CRAN limit (delete this line on your own computer)
)
```
Expand All @@ -144,11 +143,12 @@ To access the plot for a specific variable, we can call it by its variable name
ale_gam_diamonds$plots$carat
```

To iterate the list and plot all the ALE plots, we provide here some demonstration code using the `gridExtra` package for arranging multiple plots in a common plot grid using `gridExtra::grid.arrange`. We need to pass the list of plots to the `grobs` argument and we can specify that we want two plots per row with the `ncol` argument.
To iterate the list and plot all the ALE plots, we provide here some demonstration code using the `patchwork` package for arranging multiple plots in a common plot grid using `patchwork::wrap_plots()`. We need to pass the list of plots to the `grobs` argument and we can specify that we want two plots per row with the `ncol` argument.

```{r print ale_simple, fig.width=7, fig.height=11}
# Print all plots
gridExtra::grid.arrange(grobs = ale_gam_diamonds$plots, ncol = 2)
ale_gam_diamonds$plots |>
patchwork::wrap_plots(ncol = 2)
```

## Bootstrapped ALE
Expand All @@ -173,12 +173,12 @@ ale_gam_diamonds_boot <- ale(
# Normally boot_it should be set to 100, but just 10 here for a faster demonstration
boot_it = 10,
model_packages = 'mgcv', # required for parallel processing
silent = TRUE, # progress bars disabled for the vignette
parallel = 2 # CRAN limit (delete this line on your own computer)
)
# Bootstrapping produces confidence intervals
gridExtra::grid.arrange(grobs = ale_gam_diamonds_boot$plots, ncol = 2)
ale_gam_diamonds_boot$plots |>
patchwork::wrap_plots(ncol = 2)
```

In this case, the bootstrapped results are mostly similar to single (non-bootstrapped) ALE result. In principle, we should always bootstrap the results and trust only in bootstrapped results. The most unusual result is that values of `x_length` (the length of the diamond) from 6.2 mm or so and higher are associated with lower diamond prices. When we compare this with the `y_width` value (width of the diamond), we suspect that when both the length and width (that is, the size) of a diamond become increasingly large, the price increases so much more rapidly with the width than with the length that the width has an inordinately high effect that is tempered by a decreased effect of the length at those high values. This would be worth further exploration for real analysis, but here we are just introducing the key features of the package.
Expand All @@ -192,25 +192,28 @@ Another advantage of ALE is that it provides data for two-way interactions betwe
ale_ixn_gam_diamonds <- ale_ixn(
diamonds_test, gam_diamonds,
model_packages = 'mgcv', # required for parallel processing
silent = TRUE, # progress bars disabled for the vignette
parallel = 2 # CRAN limit (delete this line on your own computer)
)
```

Like the `ale()` function, the `ale_ixn()` returns a list with one element per input x1 variable, as well as a `.common_data` element with details about the outcome (y) variable. However, in this case, each variable's element consists of a list of all the x2 variables for which the x1 interaction is calculated. Each x2 element then has two elements: the ALE data for that variable and a `ggplot` plot object that plots that ALE data. In the interaction plots, the x1 variable is always shown on the x axis and the x2 variable on the y axis.

Again, we provide here some demonstration code to plot all the ALE plots. It is a little more complex this time because of the two levels of interacting variables in the output data, so we use the `purrr` package to iterate the list structure. `purrr::walk` takes a list as its first argument and then we specify an anonymous function for what we want to do with each element of the list. We specify the anonymous function as `\(.x1) {...}` where `.x1` in our case represents each individual element of `ale_ixn_gam_diamonds$plots` in turn, that is, a sublist of plots with which the x1 variable interacts. We print the plots of all the x1 interactions as a combined grid of plots with `gridExtra::grid.arrange`, as before.
Again, we provide here some demonstration code to plot all the ALE plots. It is a little more complex this time because of the two levels of interacting variables in the output data, so we use the `purrr` package to iterate the list structure. `purrr::walk()` takes a list as its first argument and then we specify an anonymous function for what we want to do with each element of the list. We specify the anonymous function as `\(.x1) {...}` where `.x1` in our case represents each individual element of `ale_ixn_gam_diamonds$plots` in turn, that is, a sublist of plots with which the x1 variable interacts. We print the plots of all the x1 interactions as a combined grid of plots with `patchwork::wrap_plots()`, as before.

```{r print all ale_ixn, fig.width=7, fig.height=7}
# Print all interaction plots
ale_ixn_gam_diamonds$plots |>
purrr::walk(\(.x1) { # extract list of x1 ALE outputs
gridExtra::grid.arrange(grobs = .x1, ncol = 2) # plot each x1 plot
# extract list of x1 ALE outputs
purrr::walk(\(.x1) {
# plot all x2 plots in each .x1 element
patchwork::wrap_plots(.x1, ncol = 2) |>
print()
})
```

Because we are printing all plots together with the same `gridExtra::grid.arrange` statement, some of them might appear vertically distorted because each plot is forced to be of the same height. For more fine-tuned presentation, we would need to refer to a specific plot. For example, we can print the interaction plot between carat and depth by referring to it thus: `ale_ixn_gam_diamonds$plots$carat$depth`.
Because we are printing all plots together with the same `patchwork::wrap_plots()` statement, some of them might appear vertically distorted because each plot is forced to be of the same height. For more fine-tuned presentation, we would need to refer to a specific plot. For example, we can print the interaction plot between carat and depth by referring to it thus: `ale_ixn_gam_diamonds$plots$carat$depth`.

```{r print specific ixn, fig.width=5, fig.height=3}
ale_ixn_gam_diamonds$plots$carat$depth
Expand Down
Loading

0 comments on commit 087e09f

Please sign in to comment.