diff --git a/blackjax/util.py b/blackjax/util.py index f667d147a..43f9977ec 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -9,6 +9,7 @@ from jax.tree_util import tree_leaves from blackjax.base import Info, State +from blackjax.progress_bar import progress_bar_scan from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -144,6 +145,7 @@ def run_inference_algorithm( initial_state_or_position, inference_algorithm, num_steps, + progress_bar: bool = False, ) -> tuple[State, State, Info]: """Wrapper to run an inference algorithm. @@ -171,14 +173,20 @@ def run_inference_algorithm( except TypeError: # We assume initial_state is already in the right format. initial_state = initial_state_or_position - initial_state = initial_state_or_position keys = split(rng_key, num_steps) @jit - def one_step(state, rng_key): + def _one_step(state, xs): + _, rng_key = xs state, info = inference_algorithm.step(rng_key, state) return state, (state, info) - final_state, (state_history, info_history) = lax.scan(one_step, initial_state, keys) + if progress_bar: + one_step = progress_bar_scan(num_steps)(_one_step) + else: + one_step = _one_step + + xs = (jnp.arange(num_steps), keys) + final_state, (state_history, info_history) = lax.scan(one_step, initial_state, xs) return final_state, state_history, info_history diff --git a/tests/test_util.py b/tests/test_util.py new file mode 100644 index 000000000..5bc35e50a --- /dev/null +++ b/tests/test_util.py @@ -0,0 +1,50 @@ +import chex +import jax +import jax.numpy as jnp +from absl.testing import absltest, parameterized + +from blackjax.mcmc.hmc import hmc +from blackjax.util import run_inference_algorithm + + +class RunInferenceAlgorithmTest(chex.TestCase): + def setUp(self): + super().setUp() + self.key = jax.random.key(42) + self.algorithm = hmc( + logdensity_fn=self.logdensity_fn, + inverse_mass_matrix=jnp.eye(2), + step_size=1.0, + num_integration_steps=1000, + ) + self.num_steps = 10 + + def check_compatible(self, initial_state_or_position, progress_bar): + """ + Runs 10 steps with `run_inference_algorithm` starting with + `initial_state_or_position` and potentially a progress bar. + """ + _ = run_inference_algorithm( + self.key, + initial_state_or_position, + self.algorithm, + self.num_steps, + progress_bar, + ) + + @parameterized.parameters([True, False]) + def test_compatible_with_initial_pos(self, progress_bar): + self.check_compatible(jnp.array([1.0, 1.0]), progress_bar) + + @parameterized.parameters([True, False]) + def test_compatible_with_initial_state(self, progress_bar): + state = self.algorithm.init(jnp.array([1.0, 1.0])) + self.check_compatible(state, progress_bar) + + @staticmethod + def logdensity_fn(x): + return -0.5 * jnp.sum(jnp.square(x)) + + +if __name__ == "__main__": + absltest.main()