Skip to content

Commit

Permalink
Add transform TO run_inference_loop (#621)
Browse files Browse the repository at this point in the history
* ADD TRANSFORM

* ADD TRANSFORM

* ADD DOCSTRING AND TEST
  • Loading branch information
reubenharry authored and junpenglao committed Mar 12, 2024
1 parent 1351993 commit 8ad627d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
5 changes: 4 additions & 1 deletion blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def run_inference_algorithm(
inference_algorithm,
num_steps,
progress_bar: bool = False,
transform=lambda x: x,
) -> tuple[State, State, Info]:
"""Wrapper to run an inference algorithm.
Expand All @@ -160,6 +161,8 @@ def run_inference_algorithm(
One of blackjax's sampling algorithms or variational inference algorithms.
num_steps : int
Number of learning steps.
transform:
a transformation of the sequence of states to be returned. By default, the states are returned as is.
Returns
-------
Expand All @@ -180,7 +183,7 @@ def run_inference_algorithm(
def _one_step(state, xs):
_, rng_key = xs
state, info = inference_algorithm.step(rng_key, state)
return state, (state, info)
return state, (transform(state), info)

if progress_bar:
one_step = progress_bar_scan(num_steps)(_one_step)
Expand Down
1 change: 1 addition & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def check_compatible(self, initial_state_or_position, progress_bar):
self.algorithm,
self.num_steps,
progress_bar,
transform=lambda x: x.position,
)

@parameterized.parameters([True, False])
Expand Down

0 comments on commit 8ad627d

Please sign in to comment.