Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MCLMC sampler #586

Merged
merged 91 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from 78 commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
fbe7f75
initial draft of mclmc
reubenharry Nov 10, 2023
3a23242
refactor
reubenharry Nov 11, 2023
86b3a90
wip
reubenharry Nov 11, 2023
e82550f
wip
reubenharry Nov 11, 2023
f0e1bec
wip
reubenharry Nov 11, 2023
4d7dc57
wip
reubenharry Nov 11, 2023
82b8466
wip
reubenharry Nov 11, 2023
a4d403b
fix pre-commit
reubenharry Nov 11, 2023
a67ecb7
remove dim from class
reubenharry Nov 11, 2023
3dd4f74
add docstrings
reubenharry Nov 11, 2023
5d8061d
add mclmc to init
reubenharry Nov 13, 2023
5428f2c
Merge branch 'main' of https://github.com/blackjax-devs/blackjax
reubenharry Nov 13, 2023
59ecc8a
Merge branch 'main' into refactor
reubenharry Nov 13, 2023
2bf639e
move minimal_norm to integrators
reubenharry Nov 13, 2023
172fee0
move update pos and momentum
reubenharry Nov 13, 2023
b710e62
remove params
reubenharry Nov 13, 2023
3cc52fd
Infer the shape from inverse_mass_matrix outside the function step
reubenharry Nov 14, 2023
57d5c3b
use tree_map
reubenharry Nov 14, 2023
7e70d78
integration now aligned with mclmc repo
reubenharry Nov 15, 2023
1343463
dE and logdensity align too (fixed sign error)
reubenharry Nov 15, 2023
e53a877
make L and step size arguments to kernel
reubenharry Nov 15, 2023
05517b6
rough draft of tuning: works
reubenharry Nov 15, 2023
d84a23d
remove inv mass matrix
reubenharry Nov 15, 2023
de1e5cf
almost correct
reubenharry Nov 15, 2023
263ab3a
almost correct
reubenharry Nov 16, 2023
777213d
move tuning to adaptation
reubenharry Nov 16, 2023
e75274a
tuning works in this commit
reubenharry Nov 16, 2023
8a89f13
clean up 1
reubenharry Nov 16, 2023
49b3bec
remove sigma from tuning
reubenharry Nov 16, 2023
81999f9
wip
reubenharry Nov 16, 2023
8ab01f2
fix linting
reubenharry Nov 17, 2023
6266bc4
rename T and V
reubenharry Nov 17, 2023
ca984e7
uniformity wip
reubenharry Nov 17, 2023
59ffb21
make uniform implementation of integrators
reubenharry Nov 17, 2023
8f9214f
make uniform implementation of integrators
reubenharry Nov 18, 2023
b2e3b8e
fix minimal norm integrator
reubenharry Nov 18, 2023
2fb2293
add warning to tune3
reubenharry Nov 18, 2023
59e4424
Refactor integrators.py to make it more general.
junpenglao Nov 19, 2023
6684413
temp: explore
reubenharry Nov 19, 2023
4284092
Refactor to use integrator generation functions
junpenglao Nov 20, 2023
4a514dd
Additional refactoring
junpenglao Nov 20, 2023
ef1f62d
Minor clean up.
junpenglao Nov 21, 2023
af43521
Use standard JAX ops
junpenglao Nov 21, 2023
0dd419d
new integrator
reubenharry Nov 23, 2023
0c8330e
add references
reubenharry Nov 23, 2023
e6fa2bb
merge
reubenharry Nov 23, 2023
40fc61c
flake
reubenharry Nov 24, 2023
6ea5320
temporarily add 'explore'
reubenharry Nov 24, 2023
c83dc1a
temporarily add 'explore'
reubenharry Nov 25, 2023
c8b43be
Adding a test for energy preservation.
junpenglao Nov 26, 2023
8894248
fix formatting
junpenglao Nov 26, 2023
9865145
wip: tests
reubenharry Nov 26, 2023
68464bc
Merge branch 'integrator_refactor' into refactor
reubenharry Nov 26, 2023
0c61412
use pytrees for partially_refresh_momentum, and add test
reubenharry Nov 26, 2023
a66af60
Merge branch 'main' into refactor
junpenglao Nov 27, 2023
be07631
update docstring
reubenharry Nov 27, 2023
71d934b
resolve conflict
reubenharry Nov 27, 2023
a170d0b
remove 'explore'
reubenharry Nov 27, 2023
8cfb75f
fix pre-commit
reubenharry Nov 27, 2023
b42e77e
adding randomized MCHMC
JakobRobnik Nov 29, 2023
2b323ce
wip checkpoint on tuning
reubenharry Dec 1, 2023
9a41cdf
align blackjax and mclmc repos, for tuning
reubenharry Dec 1, 2023
cdbb4f6
use effective_sample_size
reubenharry Dec 1, 2023
947d717
patial rename
reubenharry Dec 1, 2023
e9ab7b4
rename
reubenharry Dec 1, 2023
72d70c6
clean up tuning
reubenharry Dec 1, 2023
c121beb
clean up tuning
reubenharry Dec 1, 2023
fe99163
IN THIS COMMIT, BLACKJAX AND ORIGINAL REPO AGREE. SEED IS FIXED.
reubenharry Dec 2, 2023
c456efe
RANDOMIZE KEYS
reubenharry Dec 2, 2023
d0a008a
ADD TEST
reubenharry Dec 2, 2023
d692498
ADD TEST
reubenharry Dec 2, 2023
3e8d8ea
Merge branch 'main' of https://github.com/blackjax-devs/blackjax
reubenharry Dec 2, 2023
eda029a
Merge branch 'main' into refactor
reubenharry Dec 2, 2023
a45f58f
MERGE MAIN
reubenharry Dec 2, 2023
2a21c56
INCREASE CODE COVERAGE
reubenharry Dec 2, 2023
67f0de9
REMOVE REDUNDANT LINE
reubenharry Dec 2, 2023
3f55f5f
ADD NAME 'mclmc'
reubenharry Dec 2, 2023
666c540
SPLIT KEYS AND FIX DOCSTRING
reubenharry Dec 2, 2023
c1615f5
FIX MINOR ERRORS
reubenharry Dec 2, 2023
ae1bf30
FIX MINOR ERRORS
reubenharry Dec 2, 2023
3c2dbad
Merge branch 'main' of https://github.com/blackjax-devs/blackjax
reubenharry Dec 2, 2023
c396aa1
FIX CONFLICT IN BIB
reubenharry Dec 2, 2023
0902a1c
RANDOMIZE KEYS (reversion)
reubenharry Dec 2, 2023
2e3c80b
PRECOMMIT CLEAN UP
reubenharry Dec 2, 2023
604b5a9
ADD KWARGS FOR DEFAULT HYPERPARAMS
reubenharry Dec 3, 2023
50b1c95
Merge branch 'main' of https://github.com/blackjax-devs/blackjax
reubenharry Dec 4, 2023
fecd82b
Merge branch 'main' into refactor
reubenharry Dec 4, 2023
50a8243
UPDATE ESS
reubenharry Dec 5, 2023
a20a681
NAME CHANGES
reubenharry Dec 5, 2023
75e71de
NAME CHANGES
reubenharry Dec 5, 2023
70f1dd5
MINOR FIXES
reubenharry Dec 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .adaptation.meads_adaptation import meads_adaptation
from .adaptation.pathfinder_adaptation import pathfinder_adaptation
from .adaptation.window_adaptation import window_adaptation
from .adaptation.mclmc_adaptation import mclmc_find_L_and_step_size
reubenharry marked this conversation as resolved.
Show resolved Hide resolved
from .diagnostics import effective_sample_size as ess
from .diagnostics import potential_scale_reduction as rhat
from .mcmc.barker import barker_proposal
Expand All @@ -12,7 +13,9 @@
from .mcmc.hmc import dynamic_hmc, hmc
from .mcmc.mala import mala
from .mcmc.marginal_latent_gaussian import mgrad_gaussian
from .mcmc.mclmc import mclmc
reubenharry marked this conversation as resolved.
Show resolved Hide resolved
from .mcmc.nuts import nuts
from .mcmc.mclmc import mclmc
from .mcmc.periodic_orbital import orbital_hmc
from .mcmc.random_walk import additive_step_random_walk, irmh, rmh
from .optimizers import dual_averaging, lbfgs
Expand All @@ -39,6 +42,7 @@
"additive_step_random_walk",
"rmh",
"irmh",
"mclmc",
"elliptical_slice",
"ghmc",
"barker_proposal",
Expand Down
2 changes: 2 additions & 0 deletions blackjax/adaptation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from . import (
chees_adaptation,
mclmc_adaptation,
meads_adaptation,
pathfinder_adaptation,
window_adaptation,
Expand All @@ -10,4 +11,5 @@
"meads_adaptation",
"window_adaptation",
"pathfinder_adaptation",
"mclmc_adaptation",
]
243 changes: 243 additions & 0 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Algorithms to adapt the MCLMC kernel parameters, namely step size and L.

"""

from typing import NamedTuple

import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree

from blackjax.diagnostics import effective_sample_size # type: ignore
from blackjax.util import pytree_size


class MCLMCAdaptationState(NamedTuple):
"""Represents the tunable parameters for MCLMC adaptation.

Attributes:
L (float): The momentum decoherent rate for the MCLMC algorithm.
step_size (float): The step size used for the MCLMC algorithm.
"""

L: float
step_size: float


def mclmc_find_L_and_step_size(
mclmc_kernel,
num_steps,
state,
rng_key,
frac_tune1=0.1,
frac_tune2=0.1,
frac_tune3=0.1,
):
"""
Finds the optimal value of the parameters for the MCLMC algorithm.

Args:
kernel: The kernel function used for the MCMC algorithm.
num_steps: The number of MCMC steps that will subsequently be run, after tuning
state: The initial state of the MCMC algorithm.
frac_tune1: The fraction of tuning for the first step of the adaptation.
frac_tune2: The fraction of tuning for the second step of the adaptation.
frac_tune3: The fraction of tuning for the third step of the adaptation.
reubenharry marked this conversation as resolved.
Show resolved Hide resolved

Returns:
state: The final state of the MCMC algorithm.
params: The final hyperparameters of the MCMC algorithm.
"""
dim = pytree_size(state.position)
params = MCLMCAdaptationState(jnp.sqrt(dim), jnp.sqrt(dim) * 0.25)
varEwanted = 5e-4
reubenharry marked this conversation as resolved.
Show resolved Hide resolved
reubenharry marked this conversation as resolved.
Show resolved Hide resolved
part1_key, part2_key = jax.random.split(rng_key, 2)

state, params = make_L_step_size_adaptation(
kernel=mclmc_kernel,
dim=dim,
frac_tune1=frac_tune1,
frac_tune2=frac_tune2,
varEwanted=varEwanted,
sigma_xi=1.5,
num_effective_samples=150,
)(state, params, num_steps, part1_key)

if frac_tune3 != 0:
state, params = make_adaptation_L(mclmc_kernel, frac=frac_tune3, Lfactor=0.4)(
state, params, num_steps, part2_key
)

return state, params


def make_L_step_size_adaptation(
kernel,
dim,
frac_tune1,
frac_tune2,
varEwanted=1e-3,
reubenharry marked this conversation as resolved.
Show resolved Hide resolved
sigma_xi=1.5,
num_effective_samples=150,
):
"""Adapts the stepsize and L of the MCLMC kernel. Designed for the unadjusted MCLMC"""

gamma_forget = (num_effective_samples - 1.0) / (num_effective_samples + 1.0)

def predictor(state_old, params, adaptive_state, rng_key):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

state_old -> previous_state
state_new -> next_state

Same for below and other files.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reminder to change the naming.

"""does one step with the dynamics and updates the prediction for the optimal stepsize
Designed for the unadjusted MCHMC"""

W, F, step_size_max = adaptive_state

# dynamics
state_new, info = kernel(
rng_key=rng_key, state=state_old, L=params.L, step_size=params.step_size
)
energy_change = info.dE
# step updating
success, state, step_size_max, energy_change = handle_nans(
state_old, state_new, params.step_size, step_size_max, energy_change
)

# Warning: var = 0 if there were nans, but we will give it a very small weight
xi = (
jnp.square(energy_change) / (dim * varEwanted)
) + 1e-8 # 1e-8 is added to avoid divergences in log xi
w = jnp.exp(
-0.5 * jnp.square(jnp.log(xi) / (6.0 * sigma_xi))
) # the weight reduces the impact of stepsizes which are much larger on much smaller than the desired one.

F = gamma_forget * F + w * (xi / jnp.power(params.step_size, 6.0))
W = gamma_forget * W + w
step_size = jnp.power(
F / W, -1.0 / 6.0
) # We use the Var[E] = O(eps^6) relation here.
step_size = (step_size < step_size_max) * step_size + (
step_size > step_size_max
) * step_size_max # if the proposed stepsize is above the stepsize where we have seen divergences
params_new = params._replace(step_size=step_size)

return state, params_new, params_new, (W, F, step_size_max), success

def update_kalman(x, state, outer_weight, success, step_size):
"""kalman filter to estimate the size of the posterior"""
W, F1, F2 = state
w = outer_weight * step_size * success
zero_prevention = 1 - outer_weight
F1 = (W * F1 + w * x) / (
W + w + zero_prevention
) # Update <f(x)> with a Kalman filter
F2 = (W * F2 + w * jnp.square(x)) / (
W + w + zero_prevention
) # Update <f(x)> with a Kalman filter
W += w
return (W, F1, F2)

adap0 = (0.0, 0.0, jnp.inf)

def step(iteration_state, weight_and_key):
outer_weight, rng_key = weight_and_key
"""does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize"""
junpenglao marked this conversation as resolved.
Show resolved Hide resolved
state, params, adaptive_state, kalman_state = iteration_state
state, params, params_final, adaptive_state, success = predictor(
state, params, adaptive_state, rng_key
)
position, _ = ravel_pytree(state.position)
kalman_state = update_kalman(
position, kalman_state, outer_weight, success, params.step_size
)

return (state, params_final, adaptive_state, kalman_state), None

def L_step_size_adaptation(state, params, num_steps, rng_key):
num_steps1, num_steps2 = int(num_steps * frac_tune1), int(
num_steps * frac_tune2
)
# L_step_size_adaptation_keys = jax.random.split(rng_key, num_steps1 + num_steps2)
L_step_size_adaptation_keys = jnp.array([rng_key] * (num_steps1 + num_steps2))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why using the same key? I see you commented out the previous line


# we use the last num_steps2 to compute the diagonal preconditioner
outer_weights = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2)))

# initial state of the kalman filter
kalman_state = (0.0, jnp.zeros(dim), jnp.zeros(dim))

# run the steps
kalman_state = jax.lax.scan(
step,
init=(state, params, adap0, kalman_state),
xs=(outer_weights, L_step_size_adaptation_keys),
length=num_steps1 + num_steps2,
)[0]
state, params, _, kalman_state_output = kalman_state

L = params.L
# determine L
if num_steps2 != 0.0:
_, F1, F2 = kalman_state_output
variances = F2 - jnp.square(F1)
L = jnp.sqrt(jnp.sum(variances))

return state, MCLMCAdaptationState(L, params.step_size)

return L_step_size_adaptation


def make_adaptation_L(kernel, frac, Lfactor):
"""determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)"""

def adaptation_L(state, params, num_steps, key):
num_steps = int(num_steps * frac)
# adaptation_L_keys = jax.random.split(key, num_steps)
adaptation_L_keys = jnp.array([key] * (num_steps))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here that i assume different key should be used


# run kernel in the normal way
state, info = jax.lax.scan(
f=lambda s, k: (
kernel(rng_key=k, state=s, L=params.L, step_size=params.step_size)
),
init=state,
xs=adaptation_L_keys,
)
samples = info.transformed_position # tranform is the identity here
flat_samples, _ = ravel_pytree(samples)
dim = pytree_size(state.position)
flat_samples = flat_samples.reshape(-1, dim)
ESS = 0.5 * effective_sample_size(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah because ESS in Stan is designed to work on 1+ chains. In this case, dont repeat the sample, but reshape it into 2 chains (basically split ESS, which is how single chain ESS usually computed)

jnp.array([flat_samples, flat_samples])
) # TODO: should only use a single chain here
reubenharry marked this conversation as resolved.
Show resolved Hide resolved

return state, params._replace(
L=Lfactor * params.step_size * jnp.average(num_steps / ESS)
reubenharry marked this conversation as resolved.
Show resolved Hide resolved
)

return adaptation_L


def handle_nans(state_old, state_new, step_size, step_size_max, kinetic_change):
"""if there are nans, let's reduce the stepsize, and not update the state. The function returns the old state in this case."""

reduced_step_size = 0.8
p, unravel_fn = ravel_pytree(state_new.position)
nonans = jnp.all(jnp.isfinite(p))
state, step_size, kinetic_change = jax.tree_util.tree_map(
lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old),
(state_new, step_size_max, kinetic_change),
(state_old, step_size * reduced_step_size, 0.0),
)

return nonans, state, step_size, kinetic_change
2 changes: 2 additions & 0 deletions blackjax/mcmc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
hmc,
mala,
marginal_latent_gaussian,
mclmc,
nuts,
periodic_orbital,
random_walk,
Expand All @@ -20,4 +21,5 @@
"periodic_orbital",
"marginal_latent_gaussian",
"random_walk",
"mclmc",
]
2 changes: 1 addition & 1 deletion blackjax/mcmc/integrators.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Splitting the changes here into #589

Original file line number Diff line number Diff line change
Expand Up @@ -365,5 +365,5 @@ def noneuclidean_integrator(


noneuclidean_leapfrog = generate_noneuclidean_integrator(velocity_verlet_cofficients)
noneuclidean_mclachlan = generate_noneuclidean_integrator(mclachlan_cofficients)
noneuclidean_yoshida = generate_noneuclidean_integrator(yoshida_cofficients)
noneuclidean_mclachlan = generate_noneuclidean_integrator(mclachlan_cofficients)
Loading