diff --git a/blackjax/util.py b/blackjax/util.py index 43f9977ec..9fb461c8d 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -146,6 +146,7 @@ def run_inference_algorithm( inference_algorithm, num_steps, progress_bar: bool = False, + transform=lambda x: x, ) -> tuple[State, State, Info]: """Wrapper to run an inference algorithm. @@ -160,6 +161,8 @@ def run_inference_algorithm( One of blackjax's sampling algorithms or variational inference algorithms. num_steps : int Number of learning steps. + transform: + a transformation of the sequence of states to be returned. By default, the states are returned as is. Returns ------- @@ -180,7 +183,7 @@ def run_inference_algorithm( def _one_step(state, xs): _, rng_key = xs state, info = inference_algorithm.step(rng_key, state) - return state, (state, info) + return state, (transform(state), info) if progress_bar: one_step = progress_bar_scan(num_steps)(_one_step) diff --git a/tests/test_util.py b/tests/test_util.py index 5bc35e50a..d3eed1193 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -30,6 +30,7 @@ def check_compatible(self, initial_state_or_position, progress_bar): self.algorithm, self.num_steps, progress_bar, + transform=lambda x: x.position, ) @parameterized.parameters([True, False])