From 994bb0fbea6e0a23132dbabe51349e72e1d89530 Mon Sep 17 00:00:00 2001 From: Craig Gower-Page Date: Mon, 13 Jan 2025 10:49:06 +0000 Subject: [PATCH] Improve Reproducibility(#470) --- R/utilities.R | 23 ++++++++- tests/testthat/test-reproducibility.R | 74 +++++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 1 deletion(-) diff --git a/R/utilities.R b/R/utilities.R index ca55175ac..81c3e6b72 100644 --- a/R/utilities.R +++ b/R/utilities.R @@ -556,11 +556,30 @@ clear_model_cache <- function(cache_dir = getOption("rbmi.cache_dir")) { unlink(files) } + #' Get Compiled Stan Object #' #' Gets a compiled Stan object that can be used with `rstan::sampling()` #' @keywords internal get_stan_model <- function() { + + # Compiling Stan models updates the current seed state. This can lead to + # non-reproducibility as compiling is conditional on wether there is a cached + # model available or not. Thus we save the current seed state and restore it + # at the end of this function so that it is in the same state regardless of + # whether the model was compiled or not. + # See https://github.com/insightsengineering/rbmi/issues/469 + # Note that .Random.seed is only set if the seed has been set or if a random number + # has been generated. + current_seed_state <- globalenv()$.Random.seed + on.exit({ + if (is.null(current_seed_state) && exists(".Random.seed", envir = globalenv())) { + rm(".Random.seed", envir = globalenv(), inherits = FALSE) + } else { + assign(".Random.seed", value = current_seed_state, envir = globalenv(), inherits = FALSE) + } + }) + ensure_rstan() local_file <- file.path("inst", "stan", "MMRM.stan") system_file <- system.file(file.path("stan", "MMRM.stan"), package = "rbmi") @@ -580,11 +599,13 @@ get_stan_model <- function() { file.copy(file_loc, model_file, overwrite = TRUE) } - rstan::stan_model( + model <- rstan::stan_model( file = model_file, auto_write = TRUE, model_name = "rbmi_mmrm" ) + + return(model) } diff --git a/tests/testthat/test-reproducibility.R b/tests/testthat/test-reproducibility.R index 9c10249ce..4532f6cae 100644 --- a/tests/testthat/test-reproducibility.R +++ b/tests/testthat/test-reproducibility.R @@ -126,3 +126,77 @@ test_that("bayes - set.seed produces identical results", { }) expect_equal(x$samples, y$samples) }) + + +test_that("Results are if model is recompiled", { + + skip_if_not(is_full_test()) + + run_test <- function() { + set.seed(4642) + sigma <- as_vcov(c(2, 1, 0.7), c(0.5, 0.3, 0.2)) + dat <- get_sim_data(40, sigma, trt = 8) %>% + mutate(outcome = if_else(rbinom(n(), 1, 0.3) == 1, NA_real_, outcome)) + + dat_ice <- dat %>% + group_by(id) %>% + arrange(id, visit) %>% + filter(is.na(outcome)) %>% + slice(1) %>% + ungroup() %>% + select(id, visit) %>% + mutate(strategy = "JR") + + vars <- set_vars( + outcome = "outcome", + group = "group", + strategy = "strategy", + subjid = "id", + visit = "visit", + covariates = c("age", "sex", "visit * group") + ) + + vars2 <- vars + vars2$covariates <- c("age", "sex") + + set.seed(984) + drawobj <- suppressWarnings({ + draws( + data = dat, + data_ice = dat_ice, + vars = vars, + method = method_bayes(n_samples = 20), + quiet = TRUE + ) + }) + imputeobj <- impute(draws = drawobj, references = c("A" = "B", "B" = "B")) + anaobj <- analyse(imputeobj, fun = rbmi::ancova, vars = vars2) + poolobj <- pool(results = anaobj) + + ## Tidy up things that will never be the same: + drawobj$formula <- NULL # Formulas contain environments specific to their build + drawobj$fit <- NULL # Bayes object has "fit" which contains a timestamp + anaobj$call <- NULL # Argument names are different (imputeobj2) + + return(list( + draws = drawobj, + impute = imputeobj, + analyse = anaobj, + pool = poolobj + )) + } + + old_cache <- options("rbmi.cache_dir") + tmp_dir <- tempfile(tmpdir = tempdir(check = TRUE)) + dir.create(tmp_dir) + options("rbmi.cache_dir" = tmp_dir) + results_no_cache <- run_test() + results_cache <- run_test() # Now rerun but using the same cache + options("rbmi.cache_dir" = old_cache) + + expect_equal(results_no_cache$draws, results_cache$draws) + expect_equal(results_no_cache$impute, results_cache$impute) + expect_equal(results_no_cache$analyse, results_cache$analyse) + expect_equal(results_no_cache$pool, results_cache$pool) + +})