Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MCLMC sampler #586

Merged
merged 91 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from 59 commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
fbe7f75
initial draft of mclmc
reubenharry Nov 10, 2023
3a23242
refactor
reubenharry Nov 11, 2023
86b3a90
wip
reubenharry Nov 11, 2023
e82550f
wip
reubenharry Nov 11, 2023
f0e1bec
wip
reubenharry Nov 11, 2023
4d7dc57
wip
reubenharry Nov 11, 2023
82b8466
wip
reubenharry Nov 11, 2023
a4d403b
fix pre-commit
reubenharry Nov 11, 2023
a67ecb7
remove dim from class
reubenharry Nov 11, 2023
3dd4f74
add docstrings
reubenharry Nov 11, 2023
5d8061d
add mclmc to init
reubenharry Nov 13, 2023
5428f2c
Merge branch 'main' of https://github.com/blackjax-devs/blackjax
reubenharry Nov 13, 2023
59ecc8a
Merge branch 'main' into refactor
reubenharry Nov 13, 2023
2bf639e
move minimal_norm to integrators
reubenharry Nov 13, 2023
172fee0
move update pos and momentum
reubenharry Nov 13, 2023
b710e62
remove params
reubenharry Nov 13, 2023
3cc52fd
Infer the shape from inverse_mass_matrix outside the function step
reubenharry Nov 14, 2023
57d5c3b
use tree_map
reubenharry Nov 14, 2023
7e70d78
integration now aligned with mclmc repo
reubenharry Nov 15, 2023
1343463
dE and logdensity align too (fixed sign error)
reubenharry Nov 15, 2023
e53a877
make L and step size arguments to kernel
reubenharry Nov 15, 2023
05517b6
rough draft of tuning: works
reubenharry Nov 15, 2023
d84a23d
remove inv mass matrix
reubenharry Nov 15, 2023
de1e5cf
almost correct
reubenharry Nov 15, 2023
263ab3a
almost correct
reubenharry Nov 16, 2023
777213d
move tuning to adaptation
reubenharry Nov 16, 2023
e75274a
tuning works in this commit
reubenharry Nov 16, 2023
8a89f13
clean up 1
reubenharry Nov 16, 2023
49b3bec
remove sigma from tuning
reubenharry Nov 16, 2023
81999f9
wip
reubenharry Nov 16, 2023
8ab01f2
fix linting
reubenharry Nov 17, 2023
6266bc4
rename T and V
reubenharry Nov 17, 2023
ca984e7
uniformity wip
reubenharry Nov 17, 2023
59ffb21
make uniform implementation of integrators
reubenharry Nov 17, 2023
8f9214f
make uniform implementation of integrators
reubenharry Nov 18, 2023
b2e3b8e
fix minimal norm integrator
reubenharry Nov 18, 2023
2fb2293
add warning to tune3
reubenharry Nov 18, 2023
59e4424
Refactor integrators.py to make it more general.
junpenglao Nov 19, 2023
6684413
temp: explore
reubenharry Nov 19, 2023
4284092
Refactor to use integrator generation functions
junpenglao Nov 20, 2023
4a514dd
Additional refactoring
junpenglao Nov 20, 2023
ef1f62d
Minor clean up.
junpenglao Nov 21, 2023
af43521
Use standard JAX ops
junpenglao Nov 21, 2023
0dd419d
new integrator
reubenharry Nov 23, 2023
0c8330e
add references
reubenharry Nov 23, 2023
e6fa2bb
merge
reubenharry Nov 23, 2023
40fc61c
flake
reubenharry Nov 24, 2023
6ea5320
temporarily add 'explore'
reubenharry Nov 24, 2023
c83dc1a
temporarily add 'explore'
reubenharry Nov 25, 2023
c8b43be
Adding a test for energy preservation.
junpenglao Nov 26, 2023
8894248
fix formatting
junpenglao Nov 26, 2023
9865145
wip: tests
reubenharry Nov 26, 2023
68464bc
Merge branch 'integrator_refactor' into refactor
reubenharry Nov 26, 2023
0c61412
use pytrees for partially_refresh_momentum, and add test
reubenharry Nov 26, 2023
a66af60
Merge branch 'main' into refactor
junpenglao Nov 27, 2023
be07631
update docstring
reubenharry Nov 27, 2023
71d934b
resolve conflict
reubenharry Nov 27, 2023
a170d0b
remove 'explore'
reubenharry Nov 27, 2023
8cfb75f
fix pre-commit
reubenharry Nov 27, 2023
b42e77e
adding randomized MCHMC
JakobRobnik Nov 29, 2023
2b323ce
wip checkpoint on tuning
reubenharry Dec 1, 2023
9a41cdf
align blackjax and mclmc repos, for tuning
reubenharry Dec 1, 2023
cdbb4f6
use effective_sample_size
reubenharry Dec 1, 2023
947d717
patial rename
reubenharry Dec 1, 2023
e9ab7b4
rename
reubenharry Dec 1, 2023
72d70c6
clean up tuning
reubenharry Dec 1, 2023
c121beb
clean up tuning
reubenharry Dec 1, 2023
fe99163
IN THIS COMMIT, BLACKJAX AND ORIGINAL REPO AGREE. SEED IS FIXED.
reubenharry Dec 2, 2023
c456efe
RANDOMIZE KEYS
reubenharry Dec 2, 2023
d0a008a
ADD TEST
reubenharry Dec 2, 2023
d692498
ADD TEST
reubenharry Dec 2, 2023
3e8d8ea
Merge branch 'main' of https://github.com/blackjax-devs/blackjax
reubenharry Dec 2, 2023
eda029a
Merge branch 'main' into refactor
reubenharry Dec 2, 2023
a45f58f
MERGE MAIN
reubenharry Dec 2, 2023
2a21c56
INCREASE CODE COVERAGE
reubenharry Dec 2, 2023
67f0de9
REMOVE REDUNDANT LINE
reubenharry Dec 2, 2023
3f55f5f
ADD NAME 'mclmc'
reubenharry Dec 2, 2023
666c540
SPLIT KEYS AND FIX DOCSTRING
reubenharry Dec 2, 2023
c1615f5
FIX MINOR ERRORS
reubenharry Dec 2, 2023
ae1bf30
FIX MINOR ERRORS
reubenharry Dec 2, 2023
3c2dbad
Merge branch 'main' of https://github.com/blackjax-devs/blackjax
reubenharry Dec 2, 2023
c396aa1
FIX CONFLICT IN BIB
reubenharry Dec 2, 2023
0902a1c
RANDOMIZE KEYS (reversion)
reubenharry Dec 2, 2023
2e3c80b
PRECOMMIT CLEAN UP
reubenharry Dec 2, 2023
604b5a9
ADD KWARGS FOR DEFAULT HYPERPARAMS
reubenharry Dec 3, 2023
50b1c95
Merge branch 'main' of https://github.com/blackjax-devs/blackjax
reubenharry Dec 4, 2023
fecd82b
Merge branch 'main' into refactor
reubenharry Dec 4, 2023
50a8243
UPDATE ESS
reubenharry Dec 5, 2023
a20a681
NAME CHANGES
reubenharry Dec 5, 2023
75e71de
NAME CHANGES
reubenharry Dec 5, 2023
70f1dd5
MINOR FIXES
reubenharry Dec 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .mcmc.hmc import dynamic_hmc, hmc
from .mcmc.mala import mala
from .mcmc.marginal_latent_gaussian import mgrad_gaussian
from .mcmc.mclmc import mclmc
reubenharry marked this conversation as resolved.
Show resolved Hide resolved
from .mcmc.nuts import nuts
from .mcmc.periodic_orbital import orbital_hmc
from .mcmc.random_walk import additive_step_random_walk, irmh, rmh
Expand Down Expand Up @@ -38,6 +39,7 @@
"additive_step_random_walk",
"rmh",
"irmh",
"mclmc",
"elliptical_slice",
"ghmc",
"sgld", # stochastic gradient mcmc
Expand Down
286 changes: 286 additions & 0 deletions blackjax/adaptation/step_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Step size adaptation"""
import warnings
from typing import Callable, NamedTuple

import jax
import jax.numpy as jnp
from scipy.fft import next_fast_len

from blackjax.mcmc.hmc import HMCState
from blackjax.mcmc.integrators import noneuclidean_mclachlan
from blackjax.mcmc.mclmc import IntegratorState, build_kernel, init
from blackjax.optimizers.dual_averaging import dual_averaging
from blackjax.types import PRNGKey

Expand Down Expand Up @@ -257,3 +261,285 @@
rss_state = jax.lax.while_loop(do_continue, update, rss_state)

return rss_state.step_size


class MCLMCAdaptationState(NamedTuple):
"""Tunable parameters for MCLMC"""

L: float
step_size: float


def ess_corr(x):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of copying large proportion of code, you should refactor the origin function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, will definitely do that.

"""Taken from: https://blackjax-devs.github.io/blackjax/diagnostics.html
shape(x) = (num_samples, d)"""

input_array = jnp.array(

Check warning on line 277 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L277

Added line #L277 was not covered by tests
[
x,
]
)

num_chains = 1 # input_array.shape[0]
num_samples = input_array.shape[1]

Check warning on line 284 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L283-L284

Added lines #L283 - L284 were not covered by tests

mean_across_chain = input_array.mean(axis=1, keepdims=True)

Check warning on line 286 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L286

Added line #L286 was not covered by tests
# Compute autocovariance estimates for every lag for the input array using FFT.
centered_array = input_array - mean_across_chain
m = next_fast_len(2 * num_samples)
ifft_ary = jnp.fft.rfft(centered_array, n=m, axis=1)
ifft_ary *= jnp.conjugate(ifft_ary)
autocov_value = jnp.fft.irfft(ifft_ary, n=m, axis=1)
autocov_value = (

Check warning on line 293 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L288-L293

Added lines #L288 - L293 were not covered by tests
jnp.take(autocov_value, jnp.arange(num_samples), axis=1) / num_samples
)
mean_autocov_var = autocov_value.mean(0, keepdims=True)
mean_var0 = (

Check warning on line 297 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L296-L297

Added lines #L296 - L297 were not covered by tests
jnp.take(mean_autocov_var, jnp.array([0]), axis=1)
* num_samples
/ (num_samples - 1.0)
)
weighted_var = mean_var0 * (num_samples - 1.0) / num_samples

Check warning on line 302 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L302

Added line #L302 was not covered by tests

weighted_var = jax.lax.cond(

Check warning on line 304 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L304

Added line #L304 was not covered by tests
num_chains > 1,
lambda _: weighted_var + mean_across_chain.var(axis=0, ddof=1, keepdims=True),
lambda _: weighted_var,
operand=None,
)

# Geyer's initial positive sequence
num_samples_even = num_samples - num_samples % 2
mean_autocov_var_tp1 = jnp.take(

Check warning on line 313 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L312-L313

Added lines #L312 - L313 were not covered by tests
mean_autocov_var, jnp.arange(1, num_samples_even), axis=1
)
rho_hat = jnp.concatenate(

Check warning on line 316 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L316

Added line #L316 was not covered by tests
[
jnp.ones_like(mean_var0),
1.0 - (mean_var0 - mean_autocov_var_tp1) / weighted_var,
],
axis=1,
)

rho_hat = jnp.moveaxis(rho_hat, 1, 0)
rho_hat_even = rho_hat[0::2]
rho_hat_odd = rho_hat[1::2]

Check warning on line 326 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L324-L326

Added lines #L324 - L326 were not covered by tests

mask0 = (rho_hat_even + rho_hat_odd) > 0.0
carry_cond = jnp.ones_like(mask0[0])
max_t = jnp.zeros_like(mask0[0], dtype=int)

Check warning on line 330 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L328-L330

Added lines #L328 - L330 were not covered by tests

def positive_sequence_body_fn(state, mask_t):
t, carry_cond, max_t = state
next_mask = carry_cond & mask_t
next_max_t = jnp.where(next_mask, jnp.ones_like(max_t) * t, max_t)
return (t + 1, next_mask, next_max_t), next_mask

Check warning on line 336 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L332-L336

Added lines #L332 - L336 were not covered by tests

(*_, max_t_next), mask = jax.lax.scan(

Check warning on line 338 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L338

Added line #L338 was not covered by tests
positive_sequence_body_fn, (0, carry_cond, max_t), mask0
)
indices = jnp.indices(max_t_next.shape)
indices = tuple([max_t_next + 1] + [indices[i] for i in range(max_t_next.ndim)])
rho_hat_odd = jnp.where(mask, rho_hat_odd, jnp.zeros_like(rho_hat_odd))

Check warning on line 343 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L341-L343

Added lines #L341 - L343 were not covered by tests
# improve estimation
mask_even = mask.at[indices].set(rho_hat_even[indices] > 0)
rho_hat_even = jnp.where(mask_even, rho_hat_even, jnp.zeros_like(rho_hat_even))

Check warning on line 346 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L345-L346

Added lines #L345 - L346 were not covered by tests

# Geyer's initial monotone sequence
def monotone_sequence_body_fn(rho_hat_sum_tm1, rho_hat_sum_t):
update_mask = rho_hat_sum_t > rho_hat_sum_tm1
next_rho_hat_sum_t = jnp.where(update_mask, rho_hat_sum_tm1, rho_hat_sum_t)
return next_rho_hat_sum_t, (update_mask, next_rho_hat_sum_t)

Check warning on line 352 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L349-L352

Added lines #L349 - L352 were not covered by tests

rho_hat_sum = rho_hat_even + rho_hat_odd
_, (update_mask, update_value) = jax.lax.scan(

Check warning on line 355 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L354-L355

Added lines #L354 - L355 were not covered by tests
monotone_sequence_body_fn, rho_hat_sum[0], rho_hat_sum
)

rho_hat_even_final = jnp.where(update_mask, update_value / 2.0, rho_hat_even)
rho_hat_odd_final = jnp.where(update_mask, update_value / 2.0, rho_hat_odd)

Check warning on line 360 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L359-L360

Added lines #L359 - L360 were not covered by tests

# compute effective sample size
ess_raw = num_chains * num_samples
tau_hat = (

Check warning on line 364 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L363-L364

Added lines #L363 - L364 were not covered by tests
-1.0
+ 2.0 * jnp.sum(rho_hat_even_final + rho_hat_odd_final, axis=0)
- rho_hat_even_final[indices]
)

tau_hat = jnp.maximum(tau_hat, 1 / jnp.log10(ess_raw))
ess = ess_raw / tau_hat

Check warning on line 371 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L370-L371

Added lines #L370 - L371 were not covered by tests

neff = ess.squeeze() / num_samples
return 1.0 / jnp.average(1 / neff)

Check warning on line 374 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L373-L374

Added lines #L373 - L374 were not covered by tests


def nan_reject(x, u, l, g, xx, uu, ll, gg, eps, eps_max, dK):
"""if there are nans, let's reduce the stepsize, and not update the state. The function returns the old state in this case."""

nonans = jnp.all(jnp.isfinite(xx))
return nonans, *jax.tree_util.tree_map(

Check warning on line 381 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L380-L381

Added lines #L380 - L381 were not covered by tests
lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old),
(xx, uu, ll, gg, eps_max, dK),
(x, u, l, g, eps * 0.8, 0.0),
)


def dynamics_adaptive(dynamics, state, L):
"""One step of the dynamics with the adaptive stepsize"""

x, u, l, g, E, Feps, Weps, eps_max, key = state

Check warning on line 391 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L391

Added line #L391 was not covered by tests

eps = jnp.power(

Check warning on line 393 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L393

Added line #L393 was not covered by tests
Feps / Weps, -1.0 / 6.0
) # We use the Var[E] = O(eps^6) relation here.
eps = (eps < eps_max) * eps + (

Check warning on line 396 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L396

Added line #L396 was not covered by tests
eps > eps_max
) * eps_max # if the proposed stepsize is above the stepsize where we have seen divergences

state, info = dynamics(

Check warning on line 400 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L400

Added line #L400 was not covered by tests
jax.random.PRNGKey(0), IntegratorState(x, u, l, g), L=L, step_size=eps
)

xx, uu, ll, gg = state

Check warning on line 404 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L404

Added line #L404 was not covered by tests
# ll, gg = -ll, -gg
kinetic_change = info.kinetic_change

Check warning on line 406 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L406

Added line #L406 was not covered by tests

varEwanted = 5e-4
sigma_xi = 1.5
neff = 150 # effective number of steps used to determine the stepsize in the adaptive step
gamma = (neff - 1.0) / (neff + 1.0) # forgeting factor in the adaptive step

Check warning on line 411 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L408-L411

Added lines #L408 - L411 were not covered by tests

# step updating
success, xx, uu, ll, gg, eps_max, kinetic_change = nan_reject(

Check warning on line 414 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L414

Added line #L414 was not covered by tests
x, u, l, g, xx, uu, ll, gg, eps, eps_max, kinetic_change
)

DE = info.dE # energy difference
EE = E + DE # energy

Check warning on line 419 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L418-L419

Added lines #L418 - L419 were not covered by tests
# Warning: var = 0 if there were nans, but we will give it a very small weight
xi = (

Check warning on line 421 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L421

Added line #L421 was not covered by tests
(DE**2) / (xx.shape[0] * varEwanted)
) + 1e-8 # 1e-8 is added to avoid divergences in log xi
w = jnp.exp(

Check warning on line 424 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L424

Added line #L424 was not covered by tests
-0.5 * jnp.square(jnp.log(xi) / (6.0 * sigma_xi))
) # the weight which reduces the impact of stepsizes which are much larger on much smaller than the desired one.
Feps = gamma * Feps + w * (

Check warning on line 427 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L427

Added line #L427 was not covered by tests
xi / jnp.power(eps, 6.0)
) # Kalman update the linear combinations
Weps = gamma * Weps + w

Check warning on line 430 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L430

Added line #L430 was not covered by tests

return xx, uu, ll, gg, EE, Feps, Weps, eps_max, key, eps * success

Check warning on line 432 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L432

Added line #L432 was not covered by tests


def tune12(kernel, x, u, l, g, random_key, L, eps, num_steps1, num_steps2):
"""cheap hyperparameter tuning"""

def step(state, outer_weight):

Check warning on line 438 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L438

Added line #L438 was not covered by tests
"""one adaptive step of the dynamics"""
x, u, l, g, E, Feps, Weps, eps_max, key, eps = dynamics_adaptive(

Check warning on line 440 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L440

Added line #L440 was not covered by tests
kernel, state[0], L
)
W, F1, F2 = state[1]
w = outer_weight * eps
zero_prevention = 1 - outer_weight
F1 = (W * F1 + w * x) / (

Check warning on line 446 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L443-L446

Added lines #L443 - L446 were not covered by tests
W + w + zero_prevention
) # Update <f(x)> with a Kalman filter
F2 = (W * F2 + w * jnp.square(x)) / (

Check warning on line 449 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L449

Added line #L449 was not covered by tests
W + w + zero_prevention
) # Update <f(x)> with a Kalman filter
W += w

Check warning on line 452 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L452

Added line #L452 was not covered by tests

return ((x, u, l, g, E, Feps, Weps, eps_max, key), (W, F1, F2)), eps

Check warning on line 454 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L454

Added line #L454 was not covered by tests

# we use the last num_steps2 to compute the diagonal preconditioner
outer_weights = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2)))

Check warning on line 457 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L457

Added line #L457 was not covered by tests

# initial state
state = (

Check warning on line 460 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L460

Added line #L460 was not covered by tests
(x, u, l, g, 0.0, jnp.power(eps, -6.0) * 1e-5, 1e-5, jnp.inf, random_key),
(0.0, jnp.zeros(len(x)), jnp.zeros(len(x))),
)
# run the steps
state, eps = jax.lax.scan(

Check warning on line 465 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L465

Added line #L465 was not covered by tests
step, init=state, xs=outer_weights, length=num_steps1 + num_steps2
)
# determine L
if num_steps2 != 0.0:
F1, F2 = state[1][1], state[1][2]
variances = F2 - jnp.square(F1)
sigma2 = jnp.average(variances)

Check warning on line 472 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L469-L472

Added lines #L469 - L472 were not covered by tests

L = jnp.sqrt(sigma2 * x.shape[0])

Check warning on line 474 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L474

Added line #L474 was not covered by tests

xx, uu, ll, gg, _, _, _, _, _ = state[0] # the final state
return (

Check warning on line 477 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L476-L477

Added lines #L476 - L477 were not covered by tests
L,
eps[-1],
IntegratorState(xx, uu, ll, gg),
) # return the tuned hyperparameters and the final state


def tune3(kernel, state, rng_key, L, eps, num_steps):
"""determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)"""

state, info = jax.lax.scan(

Check warning on line 487 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L487

Added line #L487 was not covered by tests
lambda s, k: (kernel(k, s, L, eps)), state, jax.random.split(rng_key, num_steps)
)

Lfactor = 0.4

Check warning on line 491 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L491

Added line #L491 was not covered by tests
# ESS2 = effective_sample_size(info.transformed_x)
# neff = ESS2.squeeze() / info.transformed_x.shape[0]
# ESS_alt = 1.0 / jnp.average(1 / neff)
ESS = ess_corr(info.transformed_x)
if ESS * num_steps <= 10:
warnings.warn(

Check warning on line 497 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L495-L497

Added lines #L495 - L497 were not covered by tests
"tune3 cannot be expected to work with 10 or fewer effective samples"
)

Lnew = Lfactor * eps / ESS
return Lnew, state

Check warning on line 502 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L501-L502

Added lines #L501 - L502 were not covered by tests


def tune(
position,
logdensity_fn,
num_steps: int,
rng_key: PRNGKey,
params: MCLMCAdaptationState,
) -> tuple[MCLMCAdaptationState, IntegratorState]:
num_tune_step_ratio_1 = 0.1
num_tune_step_ratio_2 = 0.1

Check warning on line 513 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L512-L513

Added lines #L512 - L513 were not covered by tests

kernel = build_kernel(

Check warning on line 515 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L515

Added line #L515 was not covered by tests
logdensity_fn, integrator=noneuclidean_mclachlan, transform=lambda x: x
)

init_key, tune1_key, tune2_key = jax.random.split(rng_key, 3)

Check warning on line 519 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L519

Added line #L519 was not covered by tests

x, u, l, g = init(position, logdensity_fn=logdensity_fn, rng_key=init_key)

Check warning on line 521 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L521

Added line #L521 was not covered by tests
# x, u, l, g = (
# jnp.array([0.1, 0.1]),
# jnp.array([-0.6755803, 0.73728645]),
# -0.010000001,
# -jnp.array([0.1, 0.1]),
# )

L, eps, state = tune12(

Check warning on line 529 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L529

Added line #L529 was not covered by tests
kernel,
x,
u,
l,
g,
tune1_key,
params.L,
params.step_size,
int(num_steps * num_tune_step_ratio_1),
int(num_steps * num_tune_step_ratio_1),
)

L, state = tune3(

Check warning on line 542 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L542

Added line #L542 was not covered by tests
kernel, state, tune2_key, L, eps, int(num_steps * num_tune_step_ratio_2)
)
return MCLMCAdaptationState(L, eps), state

Check warning on line 545 in blackjax/adaptation/step_size.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/step_size.py#L545

Added line #L545 was not covered by tests
2 changes: 2 additions & 0 deletions blackjax/mcmc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
hmc,
mala,
marginal_latent_gaussian,
mclmc,
nuts,
periodic_orbital,
random_walk,
Expand All @@ -18,4 +19,5 @@
"periodic_orbital",
"marginal_latent_gaussian",
"random_walk",
"mclmc",
]
1 change: 1 addition & 0 deletions blackjax/mcmc/integrators.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Splitting the changes here into #589

Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# @title `integrators.py` from https://github.com/blackjax-devs/blackjax/pull/589
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean remove this change (since it was added because you were copying the changes from the integrator PR)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reminder to remove this line.

# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
Loading