-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
994bb0f
commit 48bda12
Showing
1 changed file
with
205 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
# Design doc for improving MCMC sampling | ||
|
||
This design doc proposes several updates to the `rbmi` package to improve the options for MCMC sampling. The main goal is to provide more flexibility to the user in terms of the sampling behavior, while still keeping the package easy to use. | ||
|
||
## Wish list | ||
|
||
1. Stan sampling argument as input parameters (`rstan::sampling` control should be controllable) | ||
1. Keep current option (single chain; MMRM output as initial value as default) | ||
1. Add option to have multiple chains with random initial values | ||
1. Always have default prior and don't make prior the input parameter | ||
|
||
## Sampling controls | ||
|
||
Currently the `rstan` sampling is controlled by a hard coded list in [`fit_mcmc`](https://github.com/insightsengineering/rbmi/blob/main/R/mcmc.R#L94). | ||
|
||
The `sampling()` function is described [here](https://mc-stan.org/rstan/reference/stanmodel-method-sampling.html). In particular, the `control` argument allows to further specify the sampling behavior, which is documented [here](https://mc-stan.org/rstan/reference/stan.html). | ||
|
||
Idea: | ||
|
||
- Provide a `control_bayes()` function that returns a list of control parameters for `rstan::stan()` | ||
- This function should have a default behavior that is the same as the current behavior | ||
- The user can override the default behavior by providing a list of control parameters | ||
- Takes additional arguments which are then included in the returned list | ||
- In order to simplify the user interface, the previous `method` parts `burn_in`, `burn_between`, `n_samples` are moved to the `control_bayes()` function, and will properly be deprecated from `method_bayes()`. | ||
- The `fit_mcmc` function should take the `control_bayes()` output as an additional argument and passes it internally to `rstan::sampling()` | ||
|
||
## Multiple chains | ||
|
||
Currently just a single chain is used. | ||
|
||
Idea: | ||
|
||
- As part of the `control_bayes()` function, we could add an argument to specify the number of chains. | ||
- If the number of chains is greater than 1, then the `init` argument of `rstan::sampling()` should be set to `random` (see [here](https://mc-stan.org/rstan/reference/stanmodel-method-sampling.html)). | ||
- We can make some experiments to ensure that chains converge to the stationary distribution within reasonable time frame if started from random values. If the Stan random initial values are not good enough, we could later have the option to use initial values sampled from the MLE distribution as the initial values for the chains. | ||
|
||
We note that due to the definition of the prior distribution for the covariance matrix, we still need to fit the frequentist MMRM model in any case to get the required prior parameter. | ||
|
||
Question: What do we need to adapt downstream when using multiple chains? | ||
|
||
## Default prior | ||
|
||
We would want to keep using a default prior, as it is currently implemented for the unstructured covariance structure already. | ||
|
||
That implies that we will need to come up with such default priors for the other covariance structures. This will be handled in the separate design doc. | ||
|
||
## Prototype | ||
|
||
This is how it could look like: | ||
|
||
```r | ||
control_bayes <- function( | ||
burn_in = 200, | ||
burn_between = 50, | ||
n_samples = 20, | ||
chains = 1, | ||
init = ife(chains > 1, "random", "mmrm"), | ||
quiet = FALSE, | ||
refresh = ife( | ||
quiet, | ||
0, | ||
iter / 10 | ||
), | ||
seed = sample.int(.Machine$integer.max, 1), | ||
... | ||
) { | ||
list( | ||
warmup = burn_in, | ||
thin = burn_between, | ||
iter = burn_in + burn_between * n_samples, | ||
chains = chains, | ||
warmup = warmup, | ||
thin = thin, | ||
iter = iter, | ||
quiet = quiet, | ||
refresh = refresh, | ||
init = init, | ||
seed = seed, | ||
additional = list(...) | ||
) | ||
} | ||
``` | ||
|
||
This function's output `control` can then be used in `fit_mcmc()` like this (copied and adapted function code): | ||
|
||
```r | ||
fit_mcmc <- function( | ||
designmat, | ||
outcome, | ||
group, | ||
subjid, | ||
visit, | ||
method, | ||
control) { | ||
|
||
# Fit MMRM (needed for Sigma prior parameter and possibly initial values). | ||
mmrm_initial <- fit_mmrm( | ||
designmat = designmat, | ||
outcome = outcome, | ||
subjid = subjid, | ||
visit = visit, | ||
group = group, | ||
cov_struct = "us", | ||
REML = TRUE, | ||
same_cov = method$same_cov | ||
) | ||
|
||
if (mmrm_initial$failed) { | ||
stop("Fitting MMRM to original dataset failed") | ||
} | ||
|
||
stan_data <- prepare_stan_data( | ||
ddat = designmat, | ||
subjid = subjid, | ||
visit = visit, | ||
outcome = outcome, | ||
group = ife(method$same_cov == TRUE, rep(1, length(group)), group) | ||
) | ||
|
||
stan_data$Sigma_init <- ife( | ||
same_cov == TRUE, | ||
list(mmrm_initial$sigma[[1]]), | ||
mmrm_initial$sigma | ||
) | ||
|
||
sampling_args <- c( | ||
list( | ||
object = get_stan_model(), | ||
data = stan_data, | ||
pars = c("beta", "Sigma"), | ||
chains = control$chains, | ||
warmup = control$warmup, | ||
thin = control$thin, | ||
iter = control$iter, | ||
init = ife( | ||
control$init == "mmrm", | ||
list(list( | ||
theta = as.vector(stan_data$R %*% mmrm_initial$beta), | ||
sigma = mmrm_initial$sigma | ||
)), | ||
control$init | ||
), | ||
refresh = control$refresh, | ||
seed = control$seed | ||
), | ||
control$additional | ||
) | ||
|
||
stan_fit <- record({ | ||
do.call(rstan::sampling, sampling_args) | ||
}) | ||
|
||
if (!is.null(stan_fit$errors)) { | ||
stop(stan_fit$errors) | ||
} | ||
|
||
ignorable_warnings <- c( | ||
"Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.\nRunning the chains for more iterations may help. See\nhttps://mc-stan.org/misc/warnings.html#bulk-ess", | ||
"Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.\nRunning the chains for more iterations may help. See\nhttps://mc-stan.org/misc/warnings.html#tail-ess" | ||
) | ||
|
||
# handle warning: display only warnings if | ||
# 1) the warning is not in ignorable_warnings | ||
warnings <- stan_fit$warnings | ||
warnings_not_allowed <- warnings[!warnings %in% ignorable_warnings] | ||
for (i in warnings_not_allowed) warning(warnings_not_allowed) | ||
|
||
fit <- stan_fit$results | ||
check_mcmc(fit, n_imputations) | ||
|
||
draws <- extract_draws(fit) | ||
|
||
ret_obj <- list( | ||
"samples" = draws, | ||
"fit" = fit | ||
) | ||
|
||
return(ret_obj) | ||
} | ||
``` | ||
|
||
Finally, the user can call the `draws()` method now including the new `control` argument. This then also comprises the `quiet` argument. The method signature changes to: | ||
|
||
```{r} | ||
draws.bayes <- function(data, | ||
data_ice = NULL, | ||
vars, | ||
method, | ||
ncores = 1, | ||
control = control_bayes()) | ||
``` | ||
|
||
and internally the `control` argument is just forwarded to `fit_mcmc()`. | ||
|
||
Note that it is important for easy usability that we just rely on all default arguments in the `control_bayes()` function. The user can then just override the defaults they need to override by | ||
specifying the corresponding arguments. In addition, we can still pass the Stan `control` argument inside `control_bayes()`, and it is not a problem that this argument has the same name as the `draws()` argument. | ||
|
||
So a call could e.g. just be: | ||
|
||
```r | ||
draws( | ||
data = data, vars = c("beta", "Sigma"), method = method, chains = 4, | ||
control = control_bayes(seed = 123, control = list(adapt_delta = 0.99)) | ||
) | ||
``` |