Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add design doc to align on MCMC options enhancements #472

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 218 additions & 0 deletions misc/design_mcmc_improve.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# Design doc for improving MCMC sampling

This design doc proposes several updates to the `rbmi` package to improve the options for MCMC sampling. The main goal is to provide more flexibility to the user in terms of the sampling behavior, while still keeping the package easy to use.

## Wish list

1. Stan sampling argument as input parameters (`rstan::sampling` control should be controllable)
1. Keep current option (single chain; MMRM output as initial value as default)
1. Add option to have multiple chains with random initial values
1. Always have default prior and don't make prior the input parameter.

## Sampling controls

Currently the `rstan` sampling is controlled by a hard coded list in [`fit_mcmc`](https://github.com/insightsengineering/rbmi/blob/main/R/mcmc.R#L94).

The `sampling()` function is described [here](https://mc-stan.org/rstan/reference/stanmodel-method-sampling.html). In particular, the `control` argument allows to further specify the sampling behavior, which is documented [here](https://mc-stan.org/rstan/reference/stan.html).

Idea:

- Provide a `control_bayes()` function that returns a list of control parameters for `rstan::stan()`
- This function should have a default behavior that is the same as the current behavior
- The user can override the default behavior by providing a list of control parameters
- Takes additional arguments which are then included in the returned list
- In order to simplify the user interface, the previous `method` parts `burn_in`, `burn_between`, `n_samples` are moved to the `control_bayes()` function, and will properly be deprecated from `method_bayes()`.
danielinteractive marked this conversation as resolved.
Show resolved Hide resolved
- Arguments of the new control function will be briefly documented (not just copied from the `rstan` documentation, but rather explained in the context of the `rbmi` package). The user will be referred to the `rstan` documentation for more details for additionally passed arguments.
- The `fit_mcmc` function should take the `control_bayes()` output as an additional argument and passes it internally to `rstan::sampling()`

License considerations:

- The `rstan` package is licensed under GPL-3.0, while `rbmi` is licensed under the MIT license.
- Only using the `rstan::sampling()` function and specifying reasonable default values for some of its control arguments should not require to change the license of the `rbmi` package. This is because this does not constitute a derivative work. Instead, the `rbmi` package would continue to be "linking" to the GPL package.
- Importantly, the `rbmi` package only has the `rstan` package in the `Suggests` field of the `DESCRIPTION` file, not in the `Imports` field. This means that the `rstan` package is not required to be installed to use the `rbmi` package.

## Multiple chains

Currently just a single chain is used.

Idea:

- As part of the `control_bayes()` function, we could add an argument to specify the number of chains.
- If the number of chains is greater than 1, then the `init` argument of `rstan::sampling()` should be set to `random` (see [here](https://mc-stan.org/rstan/reference/stanmodel-method-sampling.html)).
- We can make some experiments to ensure that chains converge to the stationary distribution within reasonable time frame if started from random values. If the Stan random initial values are not good enough, we could later have the option to use initial values sampled from the MLE distribution as the initial values for the chains.
danielinteractive marked this conversation as resolved.
Show resolved Hide resolved

We note that due to the definition of the prior distribution for the covariance matrix, we still need to fit the frequentist MMRM model in any case to get the required prior parameter.

Question: What do we need to adapt downstream when using multiple chains?
Answer: We will be able to test this in detail and address it, once we have implemented the possibility to use multiple chains in the development version. To avoid possible confusion, we can add a warning in the development version, which informs the user that multiple chains are not yet fully supported, while we are still working on it.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍


## Default prior

We would want to keep using a default prior, as it is currently implemented for the unstructured covariance structure already. The user should never have to specify the priors actively, because there should always be a reasonable default prior implemented by `rbmi`.

That implies that we will need to come up with such default priors for the other covariance structures. This will be handled in a separate, forthcoming design doc.

## Prototype

This is how it could look like:

```r
control_bayes <- function(
burn_in = 200,
burn_between = 50,
n_samples = 20,
chains = 1,
init = ife(chains > 1, "random", "mmrm"),
quiet = FALSE,
refresh = ife(
quiet,
0,
iter / 10
),
seed = sample.int(.Machine$integer.max, 1),
...
) {
list(
warmup = burn_in,
thin = burn_between,
iter = burn_in + burn_between * n_samples,
chains = chains,
refresh = refresh,
init = init,
seed = seed,
additional = list(...)
)
}
```

The rationale for using this control function approach is:

- We provide sensible defaults to the user, however they can now override them if they want to.
- The user can now also pass additional arguments to the `rstan::sampling()` function, but they are not required to do so.
- If the `rstan` API changes, we can adapt the `control_bayes()` function accordingly, and the user does not have to change their code. We would be notified by CRAN reverse dependencies checks if the `rstan` API changes because we will include tests for the actual use of the `rstan` package in the `rbmi` package.
- We have checked the stability of the `rstan` API over the past, by having a look at the GitHub repository ([link](https://github.com/stan-dev/rstan/blob/develop/rstan/rstan/R/stanmodel-class.R#L507)). It seems stable over the last 6 years:
- Via the [blame view](https://github.com/stan-dev/rstan/blame/develop/rstan/rstan/R/stanmodel-class.R#L507) we can see that the signature of the `sampling` method has not changed during the last 6 years. The signature explicitly includes all of our arguments, except for the `refresh` argument.
- The `refresh` argument is passed via `...`. We can see that the arguments provided via `...` are checked by `is_arg_deprecated()`, which only lists an argument `enable_random_init` as deprecated, and this has not changed in the last 9 years. Further we can see that in the `stan()` function's [blame view](https://github.com/stan-dev/rstan/blame/develop/rstan/rstan/R/rstan.R#L280) the `refresh` argument has been recognized for 11 years.

This function's output `control` can then be used in `fit_mcmc()` like this (copied and adapted function code):

```r
fit_mcmc <- function(
designmat,
outcome,
group,
subjid,
visit,
method,
control) {

# Fit MMRM (needed for Sigma prior parameter and possibly initial values).
mmrm_initial <- fit_mmrm(
designmat = designmat,
outcome = outcome,
subjid = subjid,
visit = visit,
group = group,
cov_struct = "us",
REML = TRUE,
same_cov = method$same_cov
)

if (mmrm_initial$failed) {
stop("Fitting MMRM to original dataset failed")
}

stan_data <- prepare_stan_data(
ddat = designmat,
subjid = subjid,
visit = visit,
outcome = outcome,
group = ife(method$same_cov == TRUE, rep(1, length(group)), group)
)

stan_data$Sigma_init <- ife(
same_cov == TRUE,
list(mmrm_initial$sigma[[1]]),
mmrm_initial$sigma
)

sampling_args <- c(
list(
object = get_stan_model(),
data = stan_data,
pars = c("beta", "Sigma"),
chains = control$chains,
warmup = control$warmup,
thin = control$thin,
iter = control$iter,
init = ife(
control$init == "mmrm",
list(list(
theta = as.vector(stan_data$R %*% mmrm_initial$beta),
sigma = mmrm_initial$sigma
)),
control$init
),
refresh = control$refresh,
seed = control$seed
),
control$additional
)

stan_fit <- record({
do.call(rstan::sampling, sampling_args)
})

if (!is.null(stan_fit$errors)) {
stop(stan_fit$errors)
}

ignorable_warnings <- c(
"Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.\nRunning the chains for more iterations may help. See\nhttps://mc-stan.org/misc/warnings.html#bulk-ess",
"Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.\nRunning the chains for more iterations may help. See\nhttps://mc-stan.org/misc/warnings.html#tail-ess"
)

# handle warning: display only warnings if
# 1) the warning is not in ignorable_warnings
warnings <- stan_fit$warnings
warnings_not_allowed <- warnings[!warnings %in% ignorable_warnings]
for (i in warnings_not_allowed) warning(warnings_not_allowed)

fit <- stan_fit$results
check_mcmc(fit, n_imputations)

draws <- extract_draws(fit)

ret_obj <- list(
"samples" = draws,
"fit" = fit
)

return(ret_obj)
}
```

Finally, the user can call the `draws()` method now including the new `control` argument. This then also comprises the `quiet` argument. The method signature changes to:

```{r}
draws.bayes <- function(data,
data_ice = NULL,
vars,
method,
ncores = 1,
control = control_bayes())
```

and internally the `control` argument is just forwarded to `fit_mcmc()`.

Note that it is important for easy usability that we just rely on all default arguments in the `control_bayes()` function. The user can then just override the defaults they need to override by
specifying the corresponding arguments. In addition, we can still pass the Stan `control` argument inside `control_bayes()`, and it is not a problem that this argument has the same name as the `draws()` argument.

So a call could e.g. just be:

```r
draws(
data = data, vars = c("beta", "Sigma"), method = method, chains = 4,
control = control_bayes(seed = 123, control = list(adapt_delta = 0.99))
)
```
Loading