From 5eeb3e11e7492aeebb80480b3091286df9c5994e Mon Sep 17 00:00:00 2001 From: = Date: Mon, 13 May 2024 15:30:07 -0400 Subject: [PATCH 1/9] UPDATE DOCSTRING --- blackjax/util.py | 3 +-- explore.py | 53 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 2 deletions(-) create mode 100644 explore.py diff --git a/blackjax/util.py b/blackjax/util.py index df527ed01..e2654481c 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -159,8 +159,7 @@ def run_inference_algorithm( rng_key The random state used by JAX's random numbers generator. initial_state_or_position - The initial state OR the initial position of the inference algorithm. If an initial position - is passed in, the function will automatically convert it into an initial state. + The initial state of the inference algorithm. inference_algorithm One of blackjax's sampling algorithms or variational inference algorithms. num_steps diff --git a/explore.py b/explore.py new file mode 100644 index 000000000..514029420 --- /dev/null +++ b/explore.py @@ -0,0 +1,53 @@ +import jax +import jax.numpy as jnp +from benchmarks.mcmc.sampling_algorithms import samplers +import blackjax +from blackjax.mcmc.mhmclmc import mhmclmc, rescale +from blackjax.mcmc.hmc import hmc +from blackjax.mcmc.dynamic_hmc import dynamic_hmc +from blackjax.mcmc.integrators import isokinetic_mclachlan +from blackjax.util import run_inference_algorithm + + + + + +init_key, tune_key, run_key = jax.random.split(jax.random.PRNGKey(0), 3) + +def logdensity_fn(x): + return -0.5 * jnp.sum(jnp.square(x)) + +initial_position = jnp.ones(10,) + + +def run_mclmc(logdensity_fn, num_steps, initial_position): + key = jax.random.PRNGKey(0) + init_key, tune_key, run_key = jax.random.split(key, 3) + + + initial_state = blackjax.mcmc.mclmc.init( + position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key + ) + + kernel = blackjax.mcmc.mclmc.build_kernel( + logdensity_fn=logdensity_fn, + integrator=blackjax.mcmc.integrators.isokinetic_mclachlan, + ) + + ( + blackjax_state_after_tuning, + blackjax_mclmc_sampler_params, + ) = blackjax.mclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=num_steps, + state=initial_state, + rng_key=tune_key, + ) + + print(blackjax_mclmc_sampler_params) + +# out = run_hmc(initial_position) +out = samplers["mhmclmc"](logdensity_fn=logdensity_fn, num_steps=5000, initial_position=initial_position, key=jax.random.PRNGKey(0)) +print(out.mean(axis=0) ) + + From 4a0915673663302ceffbd33400478840576dee4b Mon Sep 17 00:00:00 2001 From: = Date: Mon, 13 May 2024 16:02:30 -0400 Subject: [PATCH 2/9] ADD STREAMING VERSION --- blackjax/util.py | 55 +++++++++++++++++++++++++++++++++------------- tests/test_util.py | 40 +++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 15 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index e2654481c..55b8b3e47 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -8,7 +8,7 @@ from jax.random import normal, split from jax.tree_util import tree_leaves -from blackjax.base import Info, SamplingAlgorithm, State, VIAlgorithm +from blackjax.base import SamplingAlgorithm, VIAlgorithm from blackjax.progress_bar import progress_bar_scan from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -142,12 +142,13 @@ def index_pytree(input_pytree: ArrayLikeTree) -> ArrayTree: def run_inference_algorithm( rng_key: PRNGKey, - initial_state_or_position: ArrayLikeTree, + initial_state: ArrayLikeTree, inference_algorithm: Union[SamplingAlgorithm, VIAlgorithm], num_steps: int, progress_bar: bool = False, transform: Callable = lambda x: x, -) -> tuple[State, State, Info]: + streaming=False, +) -> tuple: """Wrapper to run an inference algorithm. Note that this utility function does not work for Stochastic Gradient MCMC samplers @@ -158,8 +159,8 @@ def run_inference_algorithm( ---------- rng_key The random state used by JAX's random numbers generator. - initial_state_or_position - The initial state of the inference algorithm. + initial_state + The initial state of the inference algorithm. inference_algorithm One of blackjax's sampling algorithms or variational inference algorithms. num_steps @@ -170,6 +171,8 @@ def run_inference_algorithm( A transformation of the trace of states to be returned. This is useful for computing determinstic variables, or returning a subset of the states. By default, the states are returned as is. + streaming + if True, `run_inference_algorithm` will take a streaming average of the value of transform, and return that average instead of the full set of samples. This is useful when memory is a bottleneck. Returns ------- @@ -178,14 +181,8 @@ def run_inference_algorithm( 2. The trace of states of the inference algorithm (contains the MCMC samples). 3. The trace of the info of the inference algorithm for diagnostics. """ - init_key, sample_key = split(rng_key, 2) - try: - initial_state = inference_algorithm.init(initial_state_or_position, init_key) - except (TypeError, ValueError, AttributeError): - # We assume initial_state is already in the right format. - initial_state = initial_state_or_position - keys = split(sample_key, num_steps) + keys = split(rng_key, num_steps) @jit def _one_step(state, xs): @@ -193,11 +190,39 @@ def _one_step(state, xs): state, info = inference_algorithm.step(rng_key, state) return state, (transform(state), info) + def _online_one_step(average_and_state, xs): + _, rng_key = xs + average, state = average_and_state + state, _ = inference_algorithm.step(rng_key, state) + average = streaming_average(transform, state, average) + return (average, state), None + if progress_bar: one_step = progress_bar_scan(num_steps)(_one_step) + online_one_step = progress_bar_scan(num_steps)(_online_one_step) else: one_step = _one_step + online_one_step = _online_one_step - xs = (jnp.arange(num_steps), keys) - final_state, (state_history, info_history) = lax.scan(one_step, initial_state, xs) - return final_state, state_history, info_history + if streaming: + xs = (jnp.arange(num_steps), keys) + (average, final_state), _ = lax.scan( + online_one_step, ((0, transform(initial_state)), initial_state), xs + ) + return average, transform(final_state) + + else: + xs = (jnp.arange(num_steps), keys) + final_state, (state_history, info_history) = lax.scan( + one_step, initial_state, xs + ) + return final_state, state_history, info_history + + +def streaming_average(O, x, streaming_avg, weight=1.0, zero_prevention=0.0): + """streaming average of f(x)""" + total, average = streaming_avg + average = (total * average + weight * O(x)) / (total + weight + zero_prevention) + total += weight + streaming_avg = (total, average) + return streaming_avg diff --git a/tests/test_util.py b/tests/test_util.py index a6e023074..ed5cb12f0 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -33,6 +33,46 @@ def check_compatible(self, initial_state_or_position, progress_bar): transform=lambda x: x.position, ) + def test_streamning(self): + def logdensity_fn(x): + return -0.5 * jnp.sum(jnp.square(x)) + + initial_position = jnp.ones( + 10, + ) + + init_key, run_key = jax.random.split(self.key, 2) + + initial_state = blackjax.mcmc.mclmc.init( + position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key + ) + + alg = blackjax.mclmc(logdensity_fn=logdensity_fn, L=0.5, step_size=0.1) + + average, states = run_inference_algorithm( + rng_key=run_key, + initial_state=initial_state, + inference_algorithm=alg, + num_steps=50, + progress_bar=True, + transform=lambda x: x.position, + streaming=True, + ) + + print(average) + + _, states, _ = run_inference_algorithm( + rng_key=run_key, + initial_state=initial_state, + inference_algorithm=alg, + num_steps=50, + progress_bar=False, + transform=lambda x: x.position, + streaming=False, + ) + + assert jnp.array_equal(states.mean(axis=0), average) + @parameterized.parameters([True, False]) def test_compatible_with_initial_pos(self, progress_bar): self.check_compatible(jnp.array([1.0, 1.0]), progress_bar) From dbab9a3077165234c87beea50018e0d3f33befe7 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 13 May 2024 17:27:26 -0400 Subject: [PATCH 3/9] UPDATE TESTS --- explore.py | 53 ------------------------------------- tests/mcmc/test_sampling.py | 2 +- tests/test_util.py | 6 ++--- 3 files changed, 4 insertions(+), 57 deletions(-) delete mode 100644 explore.py diff --git a/explore.py b/explore.py deleted file mode 100644 index 514029420..000000000 --- a/explore.py +++ /dev/null @@ -1,53 +0,0 @@ -import jax -import jax.numpy as jnp -from benchmarks.mcmc.sampling_algorithms import samplers -import blackjax -from blackjax.mcmc.mhmclmc import mhmclmc, rescale -from blackjax.mcmc.hmc import hmc -from blackjax.mcmc.dynamic_hmc import dynamic_hmc -from blackjax.mcmc.integrators import isokinetic_mclachlan -from blackjax.util import run_inference_algorithm - - - - - -init_key, tune_key, run_key = jax.random.split(jax.random.PRNGKey(0), 3) - -def logdensity_fn(x): - return -0.5 * jnp.sum(jnp.square(x)) - -initial_position = jnp.ones(10,) - - -def run_mclmc(logdensity_fn, num_steps, initial_position): - key = jax.random.PRNGKey(0) - init_key, tune_key, run_key = jax.random.split(key, 3) - - - initial_state = blackjax.mcmc.mclmc.init( - position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key - ) - - kernel = blackjax.mcmc.mclmc.build_kernel( - logdensity_fn=logdensity_fn, - integrator=blackjax.mcmc.integrators.isokinetic_mclachlan, - ) - - ( - blackjax_state_after_tuning, - blackjax_mclmc_sampler_params, - ) = blackjax.mclmc_find_L_and_step_size( - mclmc_kernel=kernel, - num_steps=num_steps, - state=initial_state, - rng_key=tune_key, - ) - - print(blackjax_mclmc_sampler_params) - -# out = run_hmc(initial_position) -out = samplers["mhmclmc"](logdensity_fn=logdensity_fn, num_steps=5000, initial_position=initial_position, key=jax.random.PRNGKey(0)) -print(out.mean(axis=0) ) - - diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 39c1b811b..19f72a7c2 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -104,7 +104,7 @@ def run_mclmc(self, logdensity_fn, num_steps, initial_position, key): _, samples, _ = run_inference_algorithm( rng_key=run_key, - initial_state_or_position=blackjax_state_after_tuning, + initial_state=blackjax_state_after_tuning, inference_algorithm=sampling_alg, num_steps=num_steps, transform=lambda x: x.position, diff --git a/tests/test_util.py b/tests/test_util.py index ed5cb12f0..97aba5205 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -19,14 +19,14 @@ def setUp(self): ) self.num_steps = 10 - def check_compatible(self, initial_state_or_position, progress_bar): + def check_compatible(self, initial_state, progress_bar): """ Runs 10 steps with `run_inference_algorithm` starting with - `initial_state_or_position` and potentially a progress bar. + `initial_state` and potentially a progress bar. """ _ = run_inference_algorithm( self.key, - initial_state_or_position, + initial_state, self.algorithm, self.num_steps, progress_bar, From 5bd2a3f4c12aab6ba333ac2baa936e7d3df64ee0 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 13 May 2024 17:41:26 -0400 Subject: [PATCH 4/9] ADD DOCSTRING --- blackjax/util.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index 55b8b3e47..a9ed821f4 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -220,9 +220,25 @@ def _online_one_step(average_and_state, xs): def streaming_average(O, x, streaming_avg, weight=1.0, zero_prevention=0.0): - """streaming average of f(x)""" + """Compute the streaming average of a function O(x) using a weight. + Parameters: + ---------- + O + function to be averaged + x + current state + streaming_avg + tuple of (total, average) where total is the sum of weights and average is the current average + weight + weight of the current state + zero_prevention + small value to prevent division by zero + Returns: + ---------- + new streaming average + """ total, average = streaming_avg average = (total * average + weight * O(x)) / (total + weight + zero_prevention) total += weight streaming_avg = (total, average) - return streaming_avg + return streaming_avg \ No newline at end of file From 4fc1453b4b4430b5350bafc6b5fbb9b6fc7721e7 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 13 May 2024 17:56:18 -0400 Subject: [PATCH 5/9] ADD TEST --- blackjax/util.py | 4 ++-- tests/test_util.py | 8 +++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index a9ed821f4..2efb93f12 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -206,7 +206,7 @@ def _online_one_step(average_and_state, xs): if streaming: xs = (jnp.arange(num_steps), keys) - (average, final_state), _ = lax.scan( + ((_, average), final_state), _ = lax.scan( online_one_step, ((0, transform(initial_state)), initial_state), xs ) return average, transform(final_state) @@ -241,4 +241,4 @@ def streaming_average(O, x, streaming_avg, weight=1.0, zero_prevention=0.0): average = (total * average + weight * O(x)) / (total + weight + zero_prevention) total += weight streaming_avg = (total, average) - return streaming_avg \ No newline at end of file + return streaming_avg diff --git a/tests/test_util.py b/tests/test_util.py index 97aba5205..1291b09e7 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -33,7 +33,7 @@ def check_compatible(self, initial_state, progress_bar): transform=lambda x: x.position, ) - def test_streamning(self): + def test_streaming(self): def logdensity_fn(x): return -0.5 * jnp.sum(jnp.square(x)) @@ -54,13 +54,11 @@ def logdensity_fn(x): initial_state=initial_state, inference_algorithm=alg, num_steps=50, - progress_bar=True, + progress_bar=False, transform=lambda x: x.position, streaming=True, ) - print(average) - _, states, _ = run_inference_algorithm( rng_key=run_key, initial_state=initial_state, @@ -71,7 +69,7 @@ def logdensity_fn(x): streaming=False, ) - assert jnp.array_equal(states.mean(axis=0), average) + assert jnp.allclose(states.mean(axis=0), average) @parameterized.parameters([True, False]) def test_compatible_with_initial_pos(self, progress_bar): From 49410f9037a99f87677e1e89b32d48a350526b7f Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 May 2024 20:04:37 +0200 Subject: [PATCH 6/9] REFACTOR RUN_INFERENCE_ALGORITHM --- blackjax/util.py | 67 +++++++++++++++++++++++----------------------- tests/test_util.py | 21 ++++++++++++--- 2 files changed, 51 insertions(+), 37 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index 2efb93f12..4c58ad597 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -2,6 +2,7 @@ from functools import partial from typing import Callable, Union +import jax import jax.numpy as jnp from jax import jit, lax from jax.flatten_util import ravel_pytree @@ -147,7 +148,8 @@ def run_inference_algorithm( num_steps: int, progress_bar: bool = False, transform: Callable = lambda x: x, - streaming=False, + return_state_history=True, + expectation: Callable = lambda x: x, ) -> tuple: """Wrapper to run an inference algorithm. @@ -171,52 +173,44 @@ def run_inference_algorithm( A transformation of the trace of states to be returned. This is useful for computing determinstic variables, or returning a subset of the states. By default, the states are returned as is. - streaming + return_expectation if True, `run_inference_algorithm` will take a streaming average of the value of transform, and return that average instead of the full set of samples. This is useful when memory is a bottleneck. Returns ------- Tuple[State, State, Info] - 1. The final state of the inference algorithm. - 2. The trace of states of the inference algorithm (contains the MCMC samples). - 3. The trace of the info of the inference algorithm for diagnostics. + 1. The expectation of transform(state) over the chain. + 2. The final state of the inference algorithm. + 3. The trace of the state and info of the inference algorithm for diagnostics. """ keys = split(rng_key, num_steps) - @jit - def _one_step(state, xs): + def one_step(average_and_state, xs, return_state): _, rng_key = xs + average, state = average_and_state state, info = inference_algorithm.step(rng_key, state) - return state, (transform(state), info) + average = streaming_average(expectation, state, average) + if return_state: + return (average, state), (transform(state), info) + else: + return (average, state), None - def _online_one_step(average_and_state, xs): - _, rng_key = xs - average, state = average_and_state - state, _ = inference_algorithm.step(rng_key, state) - average = streaming_average(transform, state, average) - return (average, state), None + one_step = jax.jit(partial(one_step, return_state=return_state_history)) if progress_bar: - one_step = progress_bar_scan(num_steps)(_one_step) - online_one_step = progress_bar_scan(num_steps)(_online_one_step) - else: - one_step = _one_step - online_one_step = _online_one_step - - if streaming: - xs = (jnp.arange(num_steps), keys) - ((_, average), final_state), _ = lax.scan( - online_one_step, ((0, transform(initial_state)), initial_state), xs - ) - return average, transform(final_state) + one_step = progress_bar_scan(num_steps)(one_step) + xs = (jnp.arange(num_steps), keys) + ((_, average), final_state), history = lax.scan( + one_step, ((0, expectation(initial_state)), initial_state), xs + ) + + if not return_state_history: + return average, transform(final_state) else: - xs = (jnp.arange(num_steps), keys) - final_state, (state_history, info_history) = lax.scan( - one_step, initial_state, xs - ) - return final_state, state_history, info_history + state_history, info_history = history + return transform(final_state), state_history, info_history def streaming_average(O, x, streaming_avg, weight=1.0, zero_prevention=0.0): @@ -237,8 +231,15 @@ def streaming_average(O, x, streaming_avg, weight=1.0, zero_prevention=0.0): ---------- new streaming average """ + + expectation = O(x) + flat_expectation, unravel_fn = ravel_pytree(expectation) total, average = streaming_avg - average = (total * average + weight * O(x)) / (total + weight + zero_prevention) + flat_average, _ = ravel_pytree(average) + average = (total * flat_average + weight * flat_expectation) / ( + total + weight + zero_prevention + ) total += weight - streaming_avg = (total, average) + streaming_avg = (total, unravel_fn(average)) return streaming_avg + diff --git a/tests/test_util.py b/tests/test_util.py index 1291b09e7..1665bc2c3 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -49,26 +49,39 @@ def logdensity_fn(x): alg = blackjax.mclmc(logdensity_fn=logdensity_fn, L=0.5, step_size=0.1) - average, states = run_inference_algorithm( + alg = blackjax.mclmc(logdensity_fn=logdensity_fn, L=0.5, step_size=0.1) + + _, states, info = run_inference_algorithm( rng_key=run_key, initial_state=initial_state, inference_algorithm=alg, num_steps=50, progress_bar=False, + expectation=lambda x: x.position, transform=lambda x: x.position, - streaming=True, + return_state_history=True, ) - _, states, _ = run_inference_algorithm( + print(states.mean(axis=0)) + + + average, _ = run_inference_algorithm( rng_key=run_key, initial_state=initial_state, inference_algorithm=alg, num_steps=50, progress_bar=False, + expectation=lambda x: x.position, transform=lambda x: x.position, - streaming=False, + return_state_history=False, ) + print(average) + print(states.mean(axis=0)[1]==average[1]) + + print(jnp.allclose(states.mean(axis=0), average)) + + assert jnp.allclose(states.mean(axis=0), average) @parameterized.parameters([True, False]) From ffdca93147726882fb4dc13fe0778fcd8f435d65 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 May 2024 20:13:59 +0200 Subject: [PATCH 7/9] UPDATE DOCSTRING --- blackjax/util.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index 4c58ad597..77a90ba56 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -173,15 +173,20 @@ def run_inference_algorithm( A transformation of the trace of states to be returned. This is useful for computing determinstic variables, or returning a subset of the states. By default, the states are returned as is. - return_expectation - if True, `run_inference_algorithm` will take a streaming average of the value of transform, and return that average instead of the full set of samples. This is useful when memory is a bottleneck. + expectation + A function that computes the expectation of the state. This is done incrementally, so doesn't require storing all the states. + return_state_history + if False, `run_inference_algorithm` will only return an expectation of the value of transform, and return that average instead of the full set of samples. This is useful when memory is a bottleneck. Returns ------- - Tuple[State, State, Info] - 1. The expectation of transform(state) over the chain. + If return_state_history is True: + 1. The final state. + 2. The trace of the state. + 3. The trace of the info of the inference algorithm for diagnostics. + If return_state_history is False: + 1. This is the expectation of state over the chain. Otherwise the final state. 2. The final state of the inference algorithm. - 3. The trace of the state and info of the inference algorithm for diagnostics. """ keys = split(rng_key, num_steps) From b7b7084f92ea59847ea301c44b8dbc7f64d22e3b Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 May 2024 20:14:28 +0200 Subject: [PATCH 8/9] Precommit --- blackjax/util.py | 1 - tests/test_util.py | 6 ++---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index 77a90ba56..e579c126d 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -247,4 +247,3 @@ def streaming_average(O, x, streaming_avg, weight=1.0, zero_prevention=0.0): total += weight streaming_avg = (total, unravel_fn(average)) return streaming_avg - diff --git a/tests/test_util.py b/tests/test_util.py index 1665bc2c3..6a7efd6b5 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -50,7 +50,7 @@ def logdensity_fn(x): alg = blackjax.mclmc(logdensity_fn=logdensity_fn, L=0.5, step_size=0.1) alg = blackjax.mclmc(logdensity_fn=logdensity_fn, L=0.5, step_size=0.1) - + _, states, info = run_inference_algorithm( rng_key=run_key, initial_state=initial_state, @@ -64,7 +64,6 @@ def logdensity_fn(x): print(states.mean(axis=0)) - average, _ = run_inference_algorithm( rng_key=run_key, initial_state=initial_state, @@ -77,11 +76,10 @@ def logdensity_fn(x): ) print(average) - print(states.mean(axis=0)[1]==average[1]) + print(states.mean(axis=0)[1] == average[1]) print(jnp.allclose(states.mean(axis=0), average)) - assert jnp.allclose(states.mean(axis=0), average) @parameterized.parameters([True, False]) From 97cfc9eccd92ae5a2616e4bf379d6a75102abf54 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 May 2024 20:18:54 +0200 Subject: [PATCH 9/9] CLEAN TESTS --- tests/test_util.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tests/test_util.py b/tests/test_util.py index 6a7efd6b5..3bafca894 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -49,8 +49,6 @@ def logdensity_fn(x): alg = blackjax.mclmc(logdensity_fn=logdensity_fn, L=0.5, step_size=0.1) - alg = blackjax.mclmc(logdensity_fn=logdensity_fn, L=0.5, step_size=0.1) - _, states, info = run_inference_algorithm( rng_key=run_key, initial_state=initial_state, @@ -62,8 +60,6 @@ def logdensity_fn(x): return_state_history=True, ) - print(states.mean(axis=0)) - average, _ = run_inference_algorithm( rng_key=run_key, initial_state=initial_state, @@ -75,11 +71,6 @@ def logdensity_fn(x): return_state_history=False, ) - print(average) - print(states.mean(axis=0)[1] == average[1]) - - print(jnp.allclose(states.mean(axis=0), average)) - assert jnp.allclose(states.mean(axis=0), average) @parameterized.parameters([True, False])