Skip to content

Commit

Permalink
add design doc to align on approach
Browse files Browse the repository at this point in the history
  • Loading branch information
danielinteractive committed Jan 20, 2025
1 parent 994bb0f commit 48bda12
Showing 1 changed file with 205 additions and 0 deletions.
205 changes: 205 additions & 0 deletions misc/design_mcmc_improve.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
# Design doc for improving MCMC sampling

This design doc proposes several updates to the `rbmi` package to improve the options for MCMC sampling. The main goal is to provide more flexibility to the user in terms of the sampling behavior, while still keeping the package easy to use.

## Wish list

1. Stan sampling argument as input parameters (`rstan::sampling` control should be controllable)
1. Keep current option (single chain; MMRM output as initial value as default)
1. Add option to have multiple chains with random initial values
1. Always have default prior and don't make prior the input parameter

## Sampling controls

Currently the `rstan` sampling is controlled by a hard coded list in [`fit_mcmc`](https://github.com/insightsengineering/rbmi/blob/main/R/mcmc.R#L94).

The `sampling()` function is described [here](https://mc-stan.org/rstan/reference/stanmodel-method-sampling.html). In particular, the `control` argument allows to further specify the sampling behavior, which is documented [here](https://mc-stan.org/rstan/reference/stan.html).

Idea:

- Provide a `control_bayes()` function that returns a list of control parameters for `rstan::stan()`
- This function should have a default behavior that is the same as the current behavior
- The user can override the default behavior by providing a list of control parameters
- Takes additional arguments which are then included in the returned list
- In order to simplify the user interface, the previous `method` parts `burn_in`, `burn_between`, `n_samples` are moved to the `control_bayes()` function, and will properly be deprecated from `method_bayes()`.
- The `fit_mcmc` function should take the `control_bayes()` output as an additional argument and passes it internally to `rstan::sampling()`

## Multiple chains

Currently just a single chain is used.

Idea:

- As part of the `control_bayes()` function, we could add an argument to specify the number of chains.
- If the number of chains is greater than 1, then the `init` argument of `rstan::sampling()` should be set to `random` (see [here](https://mc-stan.org/rstan/reference/stanmodel-method-sampling.html)).
- We can make some experiments to ensure that chains converge to the stationary distribution within reasonable time frame if started from random values. If the Stan random initial values are not good enough, we could later have the option to use initial values sampled from the MLE distribution as the initial values for the chains.

We note that due to the definition of the prior distribution for the covariance matrix, we still need to fit the frequentist MMRM model in any case to get the required prior parameter.

Question: What do we need to adapt downstream when using multiple chains?

## Default prior

We would want to keep using a default prior, as it is currently implemented for the unstructured covariance structure already.

That implies that we will need to come up with such default priors for the other covariance structures. This will be handled in the separate design doc.

## Prototype

This is how it could look like:

```r
control_bayes <- function(
burn_in = 200,
burn_between = 50,
n_samples = 20,
chains = 1,
init = ife(chains > 1, "random", "mmrm"),
quiet = FALSE,
refresh = ife(
quiet,
0,
iter / 10
),
seed = sample.int(.Machine$integer.max, 1),
...
) {
list(
warmup = burn_in,
thin = burn_between,
iter = burn_in + burn_between * n_samples,
chains = chains,
warmup = warmup,
thin = thin,
iter = iter,
quiet = quiet,
refresh = refresh,
init = init,
seed = seed,
additional = list(...)
)
}
```

This function's output `control` can then be used in `fit_mcmc()` like this (copied and adapted function code):

```r
fit_mcmc <- function(
designmat,
outcome,
group,
subjid,
visit,
method,
control) {

# Fit MMRM (needed for Sigma prior parameter and possibly initial values).
mmrm_initial <- fit_mmrm(
designmat = designmat,
outcome = outcome,
subjid = subjid,
visit = visit,
group = group,
cov_struct = "us",
REML = TRUE,
same_cov = method$same_cov
)

if (mmrm_initial$failed) {
stop("Fitting MMRM to original dataset failed")
}

stan_data <- prepare_stan_data(
ddat = designmat,
subjid = subjid,
visit = visit,
outcome = outcome,
group = ife(method$same_cov == TRUE, rep(1, length(group)), group)
)

stan_data$Sigma_init <- ife(
same_cov == TRUE,
list(mmrm_initial$sigma[[1]]),
mmrm_initial$sigma
)

sampling_args <- c(
list(
object = get_stan_model(),
data = stan_data,
pars = c("beta", "Sigma"),
chains = control$chains,
warmup = control$warmup,
thin = control$thin,
iter = control$iter,
init = ife(
control$init == "mmrm",
list(list(
theta = as.vector(stan_data$R %*% mmrm_initial$beta),
sigma = mmrm_initial$sigma
)),
control$init
),
refresh = control$refresh,
seed = control$seed
),
control$additional
)

stan_fit <- record({
do.call(rstan::sampling, sampling_args)
})

if (!is.null(stan_fit$errors)) {
stop(stan_fit$errors)
}

ignorable_warnings <- c(
"Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.\nRunning the chains for more iterations may help. See\nhttps://mc-stan.org/misc/warnings.html#bulk-ess",
"Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.\nRunning the chains for more iterations may help. See\nhttps://mc-stan.org/misc/warnings.html#tail-ess"
)

# handle warning: display only warnings if
# 1) the warning is not in ignorable_warnings
warnings <- stan_fit$warnings
warnings_not_allowed <- warnings[!warnings %in% ignorable_warnings]
for (i in warnings_not_allowed) warning(warnings_not_allowed)

fit <- stan_fit$results
check_mcmc(fit, n_imputations)

draws <- extract_draws(fit)

ret_obj <- list(
"samples" = draws,
"fit" = fit
)

return(ret_obj)
}
```

Finally, the user can call the `draws()` method now including the new `control` argument. This then also comprises the `quiet` argument. The method signature changes to:

```{r}
draws.bayes <- function(data,
data_ice = NULL,
vars,
method,
ncores = 1,
control = control_bayes())
```

and internally the `control` argument is just forwarded to `fit_mcmc()`.

Note that it is important for easy usability that we just rely on all default arguments in the `control_bayes()` function. The user can then just override the defaults they need to override by
specifying the corresponding arguments. In addition, we can still pass the Stan `control` argument inside `control_bayes()`, and it is not a problem that this argument has the same name as the `draws()` argument.

So a call could e.g. just be:

```r
draws(
data = data, vars = c("beta", "Sigma"), method = method, chains = 4,
control = control_bayes(seed = 123, control = list(adapt_delta = 0.99))
)
```

0 comments on commit 48bda12

Please sign in to comment.