Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MCLMC sampler #586

Merged
merged 91 commits into from
Dec 5, 2023
Merged
Changes from 1 commit
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
fbe7f75
initial draft of mclmc
reubenharry Nov 10, 2023
3a23242
refactor
reubenharry Nov 11, 2023
86b3a90
wip
reubenharry Nov 11, 2023
e82550f
wip
reubenharry Nov 11, 2023
f0e1bec
wip
reubenharry Nov 11, 2023
4d7dc57
wip
reubenharry Nov 11, 2023
82b8466
wip
reubenharry Nov 11, 2023
a4d403b
fix pre-commit
reubenharry Nov 11, 2023
a67ecb7
remove dim from class
reubenharry Nov 11, 2023
3dd4f74
add docstrings
reubenharry Nov 11, 2023
5d8061d
add mclmc to init
reubenharry Nov 13, 2023
5428f2c
Merge branch 'main' of https://github.com/blackjax-devs/blackjax
reubenharry Nov 13, 2023
59ecc8a
Merge branch 'main' into refactor
reubenharry Nov 13, 2023
2bf639e
move minimal_norm to integrators
reubenharry Nov 13, 2023
172fee0
move update pos and momentum
reubenharry Nov 13, 2023
b710e62
remove params
reubenharry Nov 13, 2023
3cc52fd
Infer the shape from inverse_mass_matrix outside the function step
reubenharry Nov 14, 2023
57d5c3b
use tree_map
reubenharry Nov 14, 2023
7e70d78
integration now aligned with mclmc repo
reubenharry Nov 15, 2023
1343463
dE and logdensity align too (fixed sign error)
reubenharry Nov 15, 2023
e53a877
make L and step size arguments to kernel
reubenharry Nov 15, 2023
05517b6
rough draft of tuning: works
reubenharry Nov 15, 2023
d84a23d
remove inv mass matrix
reubenharry Nov 15, 2023
de1e5cf
almost correct
reubenharry Nov 15, 2023
263ab3a
almost correct
reubenharry Nov 16, 2023
777213d
move tuning to adaptation
reubenharry Nov 16, 2023
e75274a
tuning works in this commit
reubenharry Nov 16, 2023
8a89f13
clean up 1
reubenharry Nov 16, 2023
49b3bec
remove sigma from tuning
reubenharry Nov 16, 2023
81999f9
wip
reubenharry Nov 16, 2023
8ab01f2
fix linting
reubenharry Nov 17, 2023
6266bc4
rename T and V
reubenharry Nov 17, 2023
ca984e7
uniformity wip
reubenharry Nov 17, 2023
59ffb21
make uniform implementation of integrators
reubenharry Nov 17, 2023
8f9214f
make uniform implementation of integrators
reubenharry Nov 18, 2023
b2e3b8e
fix minimal norm integrator
reubenharry Nov 18, 2023
2fb2293
add warning to tune3
reubenharry Nov 18, 2023
59e4424
Refactor integrators.py to make it more general.
junpenglao Nov 19, 2023
6684413
temp: explore
reubenharry Nov 19, 2023
4284092
Refactor to use integrator generation functions
junpenglao Nov 20, 2023
4a514dd
Additional refactoring
junpenglao Nov 20, 2023
ef1f62d
Minor clean up.
junpenglao Nov 21, 2023
af43521
Use standard JAX ops
junpenglao Nov 21, 2023
0dd419d
new integrator
reubenharry Nov 23, 2023
0c8330e
add references
reubenharry Nov 23, 2023
e6fa2bb
merge
reubenharry Nov 23, 2023
40fc61c
flake
reubenharry Nov 24, 2023
6ea5320
temporarily add 'explore'
reubenharry Nov 24, 2023
c83dc1a
temporarily add 'explore'
reubenharry Nov 25, 2023
c8b43be
Adding a test for energy preservation.
junpenglao Nov 26, 2023
8894248
fix formatting
junpenglao Nov 26, 2023
9865145
wip: tests
reubenharry Nov 26, 2023
68464bc
Merge branch 'integrator_refactor' into refactor
reubenharry Nov 26, 2023
0c61412
use pytrees for partially_refresh_momentum, and add test
reubenharry Nov 26, 2023
a66af60
Merge branch 'main' into refactor
junpenglao Nov 27, 2023
be07631
update docstring
reubenharry Nov 27, 2023
71d934b
resolve conflict
reubenharry Nov 27, 2023
a170d0b
remove 'explore'
reubenharry Nov 27, 2023
8cfb75f
fix pre-commit
reubenharry Nov 27, 2023
b42e77e
adding randomized MCHMC
JakobRobnik Nov 29, 2023
2b323ce
wip checkpoint on tuning
reubenharry Dec 1, 2023
9a41cdf
align blackjax and mclmc repos, for tuning
reubenharry Dec 1, 2023
cdbb4f6
use effective_sample_size
reubenharry Dec 1, 2023
947d717
patial rename
reubenharry Dec 1, 2023
e9ab7b4
rename
reubenharry Dec 1, 2023
72d70c6
clean up tuning
reubenharry Dec 1, 2023
c121beb
clean up tuning
reubenharry Dec 1, 2023
fe99163
IN THIS COMMIT, BLACKJAX AND ORIGINAL REPO AGREE. SEED IS FIXED.
reubenharry Dec 2, 2023
c456efe
RANDOMIZE KEYS
reubenharry Dec 2, 2023
d0a008a
ADD TEST
reubenharry Dec 2, 2023
d692498
ADD TEST
reubenharry Dec 2, 2023
3e8d8ea
Merge branch 'main' of https://github.com/blackjax-devs/blackjax
reubenharry Dec 2, 2023
eda029a
Merge branch 'main' into refactor
reubenharry Dec 2, 2023
a45f58f
MERGE MAIN
reubenharry Dec 2, 2023
2a21c56
INCREASE CODE COVERAGE
reubenharry Dec 2, 2023
67f0de9
REMOVE REDUNDANT LINE
reubenharry Dec 2, 2023
3f55f5f
ADD NAME 'mclmc'
reubenharry Dec 2, 2023
666c540
SPLIT KEYS AND FIX DOCSTRING
reubenharry Dec 2, 2023
c1615f5
FIX MINOR ERRORS
reubenharry Dec 2, 2023
ae1bf30
FIX MINOR ERRORS
reubenharry Dec 2, 2023
3c2dbad
Merge branch 'main' of https://github.com/blackjax-devs/blackjax
reubenharry Dec 2, 2023
c396aa1
FIX CONFLICT IN BIB
reubenharry Dec 2, 2023
0902a1c
RANDOMIZE KEYS (reversion)
reubenharry Dec 2, 2023
2e3c80b
PRECOMMIT CLEAN UP
reubenharry Dec 2, 2023
604b5a9
ADD KWARGS FOR DEFAULT HYPERPARAMS
reubenharry Dec 3, 2023
50b1c95
Merge branch 'main' of https://github.com/blackjax-devs/blackjax
reubenharry Dec 4, 2023
fecd82b
Merge branch 'main' into refactor
reubenharry Dec 4, 2023
50a8243
UPDATE ESS
reubenharry Dec 5, 2023
a20a681
NAME CHANGES
reubenharry Dec 5, 2023
75e71de
NAME CHANGES
reubenharry Dec 5, 2023
70f1dd5
MINOR FIXES
reubenharry Dec 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Additional refactoring
Also add test for esh momentum update.

Co-authored-by: Reuben Cohn-Gordon <[email protected]>
  • Loading branch information
junpenglao and reubenharry committed Nov 20, 2023
commit 4a514ddd6217eb1e2c22bc1fb34db29e9dbc57b7
83 changes: 55 additions & 28 deletions blackjax/mcmc/integrators.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Splitting the changes here into #589

Original file line number Diff line number Diff line change
@@ -21,7 +21,14 @@
from blackjax.mcmc.metrics import EuclideanKineticEnergy
from blackjax.types import ArrayTree

__all__ = ["mclachlan", "velocity_verlet", "yoshida"]
__all__ = [
"mclachlan",
"velocity_verlet",
"yoshida",
"noneuclidean_leapfrog",
"noneuclidean_mclachlan",
"noneuclidean_yoshida",
]


class IntegratorState(NamedTuple):
@@ -40,27 +47,37 @@
Integrator = Callable[[IntegratorState, float], IntegratorState]


def generalized_symplectic_integrator(
momentum_update_fn: Callable,
position_update_fn: Callable,
def generalized_two_stage_integrator(
operator1: Callable,
operator2: Callable,
coefficients: list[float],
format_output_fn: Callable = lambda x: x,
):
"""Generalized symplectic integrator.
"""Generalized numerical integrator for solving ODEs.

The generalized symplectic integrator performs numerical integration
of a Hamiltonian system by alernating between momentum and position updates.
The update scheme is decided by the coefficients and palindromic, i.e.
the coefficients of the update scheme should be symmetric with respect to the
The generalized integrator performs numerical integration of a ODE system by
alernating between stage 1 and stage 2 updates.
The update scheme is decided by the coefficients, The scheme should be palindromic,
i.e. the coefficients of the update scheme should be symmetric with respect to the
middle of the scheme.
[TODO]: expand this with information in https://github.com/blackjax-devs/blackjax/issues/587

For instance, for *any* differential equation of the form:

.. math:: \\frac{d}{dt}f = (O_1+O_2)f

The leapfrog operator can be seen as approximating :math:`e^{\\epsilon(O_1 + O_2)}`
by :math:`e^{\\epsilon O_1/2}e^{\\epsilon O_2}e^{\\epsilon O_1/2}`.

In a standard Hamiltonian, the forms of :math:`e^{\\epsilon O_2}` and
:math:`e^{\\epsilon O_1}` are simple, but for other differential equations,
they may be more complex.

Parameters
----------
momentum_update_fn
Function that updates the momentum.
position_update_fn
Function that updates the position.
operator1
Stage 1 operator, a function that updates the momentum.
operator2
Stage 2 operator, a function that updates the position.
coefficients
Coefficients of the integrator.
format_output_fn
@@ -75,12 +92,12 @@
def one_step(state: IntegratorState, step_size: float):
position, momentum, _, logdensity_grad = state
# auxiliary infomation generated during integration for diagnostics. It is
# updated by the momentum_update_fn and position_update_fn at each call.
# updated by the operator1 and operator2 at each call.
momentum_update_info = None
position_update_info = None
for i, coef in enumerate(coefficients[:-1]):
if i % 2 == 0:
momentum, kinetic_grad, momentum_update_info = momentum_update_fn(
momentum, kinetic_grad, momentum_update_info = operator1(
momentum,
logdensity_grad,
step_size,
@@ -94,15 +111,15 @@
logdensity,
logdensity_grad,
position_update_info,
) = position_update_fn(
) = operator2(
position,
kinetic_grad,
step_size,
coef,
position_update_info,
)
# Separate the last steps to short circuit the computation of the kinetic_grad.
momentum, kinetic_grad, momentum_update_info = momentum_update_fn(
momentum, kinetic_grad, momentum_update_info = operator1(
momentum,
logdensity_grad,
step_size,
@@ -189,12 +206,18 @@


def generate_euclidean_integrator(cofficients):
"""Generate symplectic integrator for solving a Hamiltonian system.

The resulting integrator is volume-preserve and preserves the symplectic structure
of phase space.
"""

def euclidean_integrator(
logdensity_fn: Callable, kinetic_energy_fn: EuclideanKineticEnergy
) -> Integrator:
position_update_fn = euclidean_position_update_fn(logdensity_fn)
momentum_update_fn = euclidean_momentum_update_fn(kinetic_energy_fn)
one_step = generalized_symplectic_integrator(
one_step = generalized_two_stage_integrator(
momentum_update_fn,
position_update_fn,
cofficients,
@@ -234,6 +257,8 @@
method with respect to the value of `step_size`. The values used here are
the ones derived in :cite:p:`mclachlan1995numerical`; note that :cite:p:`blanes2014numerical`
is more focused on stability and derives different values.

Also known as the minimal norm integrator.
"""
b1 = 0.1931833275037836
a1 = 0.5
@@ -259,6 +284,11 @@


# Intergrators with non Euclidean updates
def normalized_flatten_array(x, tol=1e-13):
norm = jnp.sqrt(jnp.sum(jnp.square(x)))
return jnp.where(norm > tol, x / norm, x), norm


def esh_dynamics_momentum_update_one_step(
momentum: ArrayTree,
logdensity_grad: ArrayTree,
@@ -269,36 +299,33 @@
):
"""Momentum update based on Esh dynamics.

[TODO]: update this docstring with proper references and citations.
The momentum updating map of the esh dynamics (see https://arxiv.org/pdf/2111.02434.pdf)
similar to the implementation: https://github.com/gregversteeg/esh_dynamics
The momentum updating map of the esh dynamics as derived in :cite:p:`steeg2021hamiltonian`
There are no exponentials e^delta, which prevents overflows when the gradient norm
is large.
"""

flatten_grads, unravel_fn = ravel_pytree(logdensity_grad)
flatten_momentum, _ = ravel_pytree(momentum)
dims = flatten_momentum.shape[0]
gradient_norm = jnp.sqrt(jnp.sum(jnp.square(flatten_grads)))
normalized_gradient = -flatten_grads / gradient_norm
normalized_gradient, gradient_norm = normalized_flatten_array(flatten_grads)
momentum_proj = jnp.dot(flatten_momentum, normalized_gradient)
delta = step_size * coef * gradient_norm / (dims - 1)
zeta = jnp.exp(-delta)
new_momentum = (
new_momentum_raw = (
normalized_gradient * (1 - zeta) * (1 + zeta + momentum_proj * (1 - zeta))
+ 2 * zeta * flatten_momentum
)
new_momentum_norm = new_momentum / jnp.sqrt(jnp.sum(jnp.square(new_momentum)))
new_momentum_normalized, _ = normalized_flatten_array(new_momentum_raw)
next_momentum = unravel_fn(new_momentum_normalized)
kinetic_energy_change = (
delta
- jnp.log(2)
+ jnp.log(1 + momentum_proj + (1 - momentum_proj) * zeta**2)
)
next_momentum = unravel_fn(new_momentum_norm)
if previous_kinetic_energy_change is not None:
kinetic_energy_change += previous_kinetic_energy_change

Check warning on line 326 in blackjax/mcmc/integrators.py

Codecov / codecov/patch

blackjax/mcmc/integrators.py#L326

Added line #L326 was not covered by tests
if is_last_call:
kinetic_energy_change *= dims - 1

Check warning on line 328 in blackjax/mcmc/integrators.py

Codecov / codecov/patch

blackjax/mcmc/integrators.py#L328

Added line #L328 was not covered by tests
return next_momentum, next_momentum, kinetic_energy_change


@@ -311,8 +338,8 @@
position_update_info,
momentum_update_info,
):
del kinetic_grad, position_update_info
return (

Check warning on line 342 in blackjax/mcmc/integrators.py

Codecov / codecov/patch

blackjax/mcmc/integrators.py#L341-L342

Added lines #L341 - L342 were not covered by tests
IntegratorState(position, momentum, logdensity, logdensity_grad),
momentum_update_info,
)
@@ -320,14 +347,14 @@

def generate_noneuclidean_integrator(cofficients):
def noneuclidean_integrator(logdensity_fn: Callable, *args, **kwargs) -> Callable:
position_update_fn = euclidean_position_update_fn(logdensity_fn)
one_step = generalized_symplectic_integrator(
one_step = generalized_two_stage_integrator(

Check warning on line 351 in blackjax/mcmc/integrators.py

Codecov / codecov/patch

blackjax/mcmc/integrators.py#L350-L351

Added lines #L350 - L351 were not covered by tests
esh_dynamics_momentum_update_one_step,
position_update_fn,
cofficients,
format_output_fn=format_noneuclidean_state_output,
)
return one_step

Check warning on line 357 in blackjax/mcmc/integrators.py

Codecov / codecov/patch

blackjax/mcmc/integrators.py#L357

Added line #L357 was not covered by tests

return noneuclidean_integrator

18 changes: 18 additions & 0 deletions docs/refs.bib
Original file line number Diff line number Diff line change
@@ -360,3 +360,21 @@ @inproceedings{hoffman2021adaptive
year={2021},
organization={PMLR}
}

@misc{steeg2021hamiltonian,
title={Hamiltonian Dynamics with Non-Newtonian Momentum for Rapid Sampling},
author={Greg Ver Steeg and Aram Galstyan},
year={2021},
eprint={2111.02434},
archivePrefix={arXiv},
primaryClass={cs.LG}
}

@misc{robnik2023microcanonical,
title={Microcanonical Hamiltonian Monte Carlo},
author={Jakob Robnik and G. Bruno De Luca and Eva Silverstein and Uroš Seljak},
year={2023},
eprint={2212.08549},
archivePrefix={arXiv},
primaryClass={stat.CO}
}
112 changes: 90 additions & 22 deletions tests/mcmc/test_integrators.py
Original file line number Diff line number Diff line change
@@ -3,9 +3,13 @@
import chex
import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
import numpy as np
from absl.testing import absltest, parameterized
from jax.flatten_util import ravel_pytree

import blackjax.mcmc.integrators as integrators
from blackjax.mcmc.integrators import esh_dynamics_momentum_update_one_step


def HarmonicOscillator(inv_mass_matrix, k=1.0, m=1.0):
@@ -47,20 +51,37 @@ def kinetic_energy(p):
return neg_potential_energy, kinetic_energy


algorithms = {
"velocity_verlet": {"algorithm": integrators.velocity_verlet, "precision": 1e-4},
"mclachlan": {"algorithm": integrators.mclachlan, "precision": 1e-5},
"yoshida": {"algorithm": integrators.yoshida, "precision": 1e-6},
"non_euclidean_leapfrog": {
"algorithm": integrators.noneuclidean_leapfrog,
"precision": 1e-4,
},
"non_euclidean_mclachlan": {
"algorithm": integrators.noneuclidean_mclachlan,
"precision": 1e-5,
},
}
def MultivariateNormal(inv_mass_matrix):
"""Potential and kinetic energy for a multivariate normal distribution."""

def log_density(q):
q, _ = ravel_pytree(q)
return stats.multivariate_normal.logpdf(q, jnp.zeros_like(q), inv_mass_matrix)

def kinetic_energy(p):
p, _ = ravel_pytree(p)
return 0.5 * p.T @ inv_mass_matrix @ p

return log_density, kinetic_energy


mvnormal_position_init = {
"a": 0.0,
"b": jnp.asarray([1.0, 2.0, 3.0]),
"c": jnp.ones((2, 1)),
}
_, unravel_fn = ravel_pytree(mvnormal_position_init)
key0, key1 = jax.random.split(jax.random.key(52))
mvnormal_momentum_init = unravel_fn(jax.random.normal(key0, (6,)))
a = jax.random.normal(key1, (6, 6))
cov = jnp.matmul(a.T, a)
# Validated numerically
mvnormal_position_end = unravel_fn(
jnp.asarray([0.38887993, 0.85231394, 2.7879136, 3.0339851, 0.5856687, 1.9291426])
)
mvnormal_momentum_end = unravel_fn(
jnp.asarray([0.46576163, 0.23854092, 1.2518811, -0.35647452, -0.742138, 1.2552949])
)

examples = {
"free_fall": {
@@ -93,6 +114,22 @@ def kinetic_energy(p):
"p_final": {"x": 0.0, "y": 1.0},
"inv_mass_matrix": jnp.array([1.0, 1.0]),
},
"multivariate_normal": {
"model": MultivariateNormal,
"num_steps": 16,
"step_size": 0.005,
"q_init": mvnormal_position_init,
"p_init": mvnormal_momentum_init,
"q_final": mvnormal_position_end,
"p_final": mvnormal_momentum_end,
"inv_mass_matrix": cov,
},
}

algorithms = {
"velocity_verlet": {"algorithm": integrators.velocity_verlet, "precision": 1e-4},
"mclachlan": {"algorithm": integrators.mclachlan, "precision": 1e-5},
"yoshida": {"algorithm": integrators.yoshida, "precision": 1e-6},
}


@@ -108,17 +145,20 @@ class IntegratorTest(chex.TestCase):
@chex.all_variants(with_pmap=False)
@parameterized.parameters(
itertools.product(
["free_fall", "harmonic_oscillator", "planetary_motion"],
[
"free_fall",
"harmonic_oscillator",
"planetary_motion",
"multivariate_normal",
],
[
"velocity_verlet",
"mclachlan",
"yoshida",
# "noneuclidean_leapfrog",
# "noneuclidean_mclachlan",
],
)
)
def test_integrator(self, example_name, integrator_name):
def test_euclidean_integrator(self, example_name, integrator_name):
integrator = algorithms[integrator_name]
example = examples[example_name]

@@ -134,14 +174,11 @@ def test_integrator(self, example_name, integrator_name):
initial_state = integrators.IntegratorState(
q, p, neg_potential(q), jax.grad(neg_potential)(q)
)
if integrator_name in ["non_euclidean_leapfrog", "minimal_norm"]:
one_step = lambda _, state: step(state, step_size)[0]
else:
one_step = lambda _, state: step(state, step_size)

final_state = jax.lax.fori_loop(
0,
example["num_steps"],
one_step,
lambda _, state: step(state, step_size),
initial_state,
)

@@ -155,6 +192,37 @@ def test_integrator(self, example_name, integrator_name):
)
self.assertAlmostEqual(energy, new_energy, delta=integrator["precision"])

@chex.all_variants(with_pmap=False)
@parameterized.parameters([3, 5])
def test_esh_momentum_update(self, dims):
"""
Test the numerically efficient version of the momentum update currently
implemented match the naive implementation according to the equation in
:cite:p:`robnik2023microcanonical`
"""
step_size = 1e-3
momentum = jax.random.uniform(key=jax.random.PRNGKey(0), shape=(dims,))
momentum /= jnp.linalg.norm(momentum)
gradient = jax.random.uniform(key=jax.random.PRNGKey(1), shape=(dims,))

# Navie implementation
gradient_norm = jnp.linalg.norm(gradient)
gradient_normalized = gradient / gradient_norm
delta = step_size * gradient_norm / (dims - 1)
next_momentum = (
momentum
+ gradient_normalized
* (
jnp.sinh(delta)
+ jnp.dot(gradient_normalized, momentum * (jnp.cosh(delta) - 1))
)
) / (jnp.cosh(delta) + jnp.dot(gradient_normalized, momentum * jnp.sinh(delta)))

# Efficient implementation
update_stable = self.variant(esh_dynamics_momentum_update_one_step)
next_momentum1, *_ = update_stable(momentum, gradient, step_size, 1.0)
np.testing.assert_array_almost_equal(next_momentum, next_momentum1)


if __name__ == "__main__":
absltest.main()