Skip to content

Commit

Permalink
Clean up of run_inference_algorithm docstring. (#624)
Browse files Browse the repository at this point in the history
* Clean up of `run_inference_algorithm` docstring.

Also fix a bug in MCLMC

* revert fix for further investigation.
  • Loading branch information
junpenglao authored Dec 11, 2023
1 parent 24a328f commit d8fd15a
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 24 deletions.
13 changes: 6 additions & 7 deletions blackjax/mcmc/mclmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from blackjax.base import SamplingAlgorithm
from blackjax.mcmc.integrators import IntegratorState, noneuclidean_mclachlan
from blackjax.types import ArrayLike, PRNGKey
from blackjax.util import generate_unit_vector, pytree_size
from blackjax.util import generate_unit_vector

__all__ = ["MCLMCInfo", "init", "build_kernel", "mclmc"]

Expand All @@ -44,12 +44,12 @@ class MCLMCInfo(NamedTuple):
energy_change: float


def init(x_initial: ArrayLike, logdensity_fn, rng_key):
l, g = jax.value_and_grad(logdensity_fn)(x_initial)
def init(position: ArrayLike, logdensity_fn, rng_key):
l, g = jax.value_and_grad(logdensity_fn)(position)

return IntegratorState(
position=x_initial,
momentum=generate_unit_vector(rng_key, x_initial),
position=position,
momentum=generate_unit_vector(rng_key, position),
logdensity=l,
logdensity_grad=g,
)
Expand Down Expand Up @@ -83,8 +83,6 @@ def kernel(
state, step_size
)

dim = pytree_size(position)

# Langevin-like noise
momentum, dim = partially_refresh_momentum(
momentum=momentum, rng_key=rng_key, L=L, step_size=step_size
Expand All @@ -95,6 +93,7 @@ def kernel(
), MCLMCInfo(
logdensity=logdensity,
energy_change=kinetic_change - logdensity + state.logdensity,
# TODO: Potential bug here, see #625
kinetic_change=kinetic_change * (dim - 1),
)

Expand Down
36 changes: 20 additions & 16 deletions blackjax/util.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Utility functions for BlackJax."""
from functools import partial
from typing import Union
from typing import Callable, Union

import jax.numpy as jnp
from jax import jit, lax
from jax.flatten_util import ravel_pytree
from jax.random import normal, split
from jax.tree_util import tree_leaves

from blackjax.base import Info, State
from blackjax.base import Info, SamplingAlgorithm, State, VIAlgorithm
from blackjax.progress_bar import progress_bar_scan
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey

Expand Down Expand Up @@ -141,35 +141,39 @@ def index_pytree(input_pytree: ArrayLikeTree) -> ArrayTree:


def run_inference_algorithm(
rng_key,
initial_state_or_position,
inference_algorithm,
num_steps,
rng_key: PRNGKey,
initial_state_or_position: ArrayLikeTree,
inference_algorithm: Union[SamplingAlgorithm, VIAlgorithm],
num_steps: int,
progress_bar: bool = False,
transform=lambda x: x,
transform: Callable = lambda x: x,
) -> tuple[State, State, Info]:
"""Wrapper to run an inference algorithm.
Parameters
----------
rng_key : PRNGKey
rng_key
The random state used by JAX's random numbers generator.
initial_state_or_position: ArrayLikeTree
initial_state_or_position
The initial state OR the initial position of the inference algorithm. If an initial position
is passed in, the function will automatically convert it into an initial state.
inference_algorithm : Union[SamplingAlgorithm, VIAlgorithm]
inference_algorithm
One of blackjax's sampling algorithms or variational inference algorithms.
num_steps : int
Number of learning steps.
transform:
a transformation of the sequence of states to be returned. By default, the states are returned as is.
num_steps
Number of MCMC steps.
progress_bar
Whether to display a progress bar.
transform
A transformation of the trace of states to be returned. This is useful for
computing determinstic variables, or returning a subset of the states.
By default, the states are returned as is.
Returns
-------
Tuple[State, State, Info]
1. The final state of the inference algorithm.
2. The history of states of the inference algorithm.
3. The history of the info of the inference algorithm.
2. The trace of states of the inference algorithm (contains the MCMC samples).
3. The trace of the info of the inference algorithm for diagnostics.
"""
try:
initial_state = inference_algorithm.init(initial_state_or_position)
Expand Down
2 changes: 1 addition & 1 deletion tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def run_mclmc(self, logdensity_fn, num_steps, initial_position, key):
init_key, tune_key, run_key = jax.random.split(key, 3)

initial_state = blackjax.mcmc.mclmc.init(
x_initial=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key
position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key
)

kernel = blackjax.mcmc.mclmc.build_kernel(
Expand Down

0 comments on commit d8fd15a

Please sign in to comment.