-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MCLMC sampler #586
Merged
Merged
Add MCLMC sampler #586
Changes from 10 commits
Commits
Show all changes
91 commits
Select commit
Hold shift + click to select a range
fbe7f75
initial draft of mclmc
reubenharry 3a23242
refactor
reubenharry 86b3a90
wip
reubenharry e82550f
wip
reubenharry f0e1bec
wip
reubenharry 4d7dc57
wip
reubenharry 82b8466
wip
reubenharry a4d403b
fix pre-commit
reubenharry a67ecb7
remove dim from class
reubenharry 3dd4f74
add docstrings
reubenharry 5d8061d
add mclmc to init
reubenharry 5428f2c
Merge branch 'main' of https://github.com/blackjax-devs/blackjax
reubenharry 59ecc8a
Merge branch 'main' into refactor
reubenharry 2bf639e
move minimal_norm to integrators
reubenharry 172fee0
move update pos and momentum
reubenharry b710e62
remove params
reubenharry 3cc52fd
Infer the shape from inverse_mass_matrix outside the function step
reubenharry 57d5c3b
use tree_map
reubenharry 7e70d78
integration now aligned with mclmc repo
reubenharry 1343463
dE and logdensity align too (fixed sign error)
reubenharry e53a877
make L and step size arguments to kernel
reubenharry 05517b6
rough draft of tuning: works
reubenharry d84a23d
remove inv mass matrix
reubenharry de1e5cf
almost correct
reubenharry 263ab3a
almost correct
reubenharry 777213d
move tuning to adaptation
reubenharry e75274a
tuning works in this commit
reubenharry 8a89f13
clean up 1
reubenharry 49b3bec
remove sigma from tuning
reubenharry 81999f9
wip
reubenharry 8ab01f2
fix linting
reubenharry 6266bc4
rename T and V
reubenharry ca984e7
uniformity wip
reubenharry 59ffb21
make uniform implementation of integrators
reubenharry 8f9214f
make uniform implementation of integrators
reubenharry b2e3b8e
fix minimal norm integrator
reubenharry 2fb2293
add warning to tune3
reubenharry 59e4424
Refactor integrators.py to make it more general.
junpenglao 6684413
temp: explore
reubenharry 4284092
Refactor to use integrator generation functions
junpenglao 4a514dd
Additional refactoring
junpenglao ef1f62d
Minor clean up.
junpenglao af43521
Use standard JAX ops
junpenglao 0dd419d
new integrator
reubenharry 0c8330e
add references
reubenharry e6fa2bb
merge
reubenharry 40fc61c
flake
reubenharry 6ea5320
temporarily add 'explore'
reubenharry c83dc1a
temporarily add 'explore'
reubenharry c8b43be
Adding a test for energy preservation.
junpenglao 8894248
fix formatting
junpenglao 9865145
wip: tests
reubenharry 68464bc
Merge branch 'integrator_refactor' into refactor
reubenharry 0c61412
use pytrees for partially_refresh_momentum, and add test
reubenharry a66af60
Merge branch 'main' into refactor
junpenglao be07631
update docstring
reubenharry 71d934b
resolve conflict
reubenharry a170d0b
remove 'explore'
reubenharry 8cfb75f
fix pre-commit
reubenharry b42e77e
adding randomized MCHMC
JakobRobnik 2b323ce
wip checkpoint on tuning
reubenharry 9a41cdf
align blackjax and mclmc repos, for tuning
reubenharry cdbb4f6
use effective_sample_size
reubenharry 947d717
patial rename
reubenharry e9ab7b4
rename
reubenharry 72d70c6
clean up tuning
reubenharry c121beb
clean up tuning
reubenharry fe99163
IN THIS COMMIT, BLACKJAX AND ORIGINAL REPO AGREE. SEED IS FIXED.
reubenharry c456efe
RANDOMIZE KEYS
reubenharry d0a008a
ADD TEST
reubenharry d692498
ADD TEST
reubenharry 3e8d8ea
Merge branch 'main' of https://github.com/blackjax-devs/blackjax
reubenharry eda029a
Merge branch 'main' into refactor
reubenharry a45f58f
MERGE MAIN
reubenharry 2a21c56
INCREASE CODE COVERAGE
reubenharry 67f0de9
REMOVE REDUNDANT LINE
reubenharry 3f55f5f
ADD NAME 'mclmc'
reubenharry 666c540
SPLIT KEYS AND FIX DOCSTRING
reubenharry c1615f5
FIX MINOR ERRORS
reubenharry ae1bf30
FIX MINOR ERRORS
reubenharry 3c2dbad
Merge branch 'main' of https://github.com/blackjax-devs/blackjax
reubenharry c396aa1
FIX CONFLICT IN BIB
reubenharry 0902a1c
RANDOMIZE KEYS (reversion)
reubenharry 2e3c80b
PRECOMMIT CLEAN UP
reubenharry 604b5a9
ADD KWARGS FOR DEFAULT HYPERPARAMS
reubenharry 50b1c95
Merge branch 'main' of https://github.com/blackjax-devs/blackjax
reubenharry fecd82b
Merge branch 'main' into refactor
reubenharry 50a8243
UPDATE ESS
reubenharry a20a681
NAME CHANGES
reubenharry 75e71de
NAME CHANGES
reubenharry 70f1dd5
MINOR FIXES
reubenharry File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,233 @@ | ||
# Copyright 2020- The Blackjax Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Public API for the MCLMC Kernel""" | ||
from typing import Callable, NamedTuple | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
|
||
import blackjax.mcmc.integrators as integrators | ||
from blackjax.base import SamplingAlgorithm | ||
from blackjax.types import Array, ArrayLike, PRNGKey | ||
|
||
__all__ = ["MCLMCState", "MCLMCInfo", "init", "build_kernel", "mclmc", "Parameters"] | ||
|
||
MCLMCState = integrators.IntegratorState | ||
|
||
class MCLMCInfo(NamedTuple): | ||
"""Additional information on the MCLMC transition. | ||
|
||
transformed_x | ||
The value of the samples after a transformation (e.g. projection onto lower dim subspace) | ||
logdensity | ||
logdensity at given step | ||
dE | ||
energy difference | ||
|
||
""" | ||
|
||
transformed_x: Array | ||
logdensity: Array | ||
dE: float | ||
reubenharry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
class Parameters(NamedTuple): | ||
"""Tunable parameters""" | ||
|
||
L: float | ||
step_size: float | ||
inverse_mass_matrix: Array | ||
|
||
def init(x_initial: ArrayLike, logdensity_fn, rng_key): | ||
l, g = jax.value_and_grad(logdensity_fn)(x_initial) | ||
return MCLMCState( | ||
position=x_initial, | ||
momentum=random_unit_vector(rng_key, dim=x_initial.shape[0]), | ||
logdensity=l, | ||
logdensity_grad=g, | ||
) | ||
|
||
def build_kernel(grad_logp, integrator, transform, params: Parameters): | ||
|
||
"""Build a HMC kernel. | ||
|
||
Parameters | ||
---------- | ||
integrator | ||
The symplectic integrator to use to integrate the Hamiltonian dynamics. | ||
transform | ||
Value of the difference in energy above which we consider that the transition is divergent. | ||
params | ||
Parameters | ||
|
||
Returns | ||
------- | ||
A kernel that takes a rng_key and a Pytree that contains the current state | ||
of the chain and that returns a new state of the chain along with | ||
information about the transition. | ||
|
||
""" | ||
step = integrator(T=update_position(grad_logp), V=update_momentum) | ||
|
||
def kernel(rng_key: PRNGKey, state: MCLMCState) -> tuple[MCLMCState, MCLMCInfo]: | ||
xx, uu, ll, gg, kinetic_change = step(state, params) | ||
dim = xx.shape[0] | ||
# Langevin-like noise | ||
nu = jnp.sqrt((jnp.exp(2 * params.step_size / params.L) - 1.0) / dim) | ||
uu = partially_refresh_momentum(u=uu, rng_key=rng_key, nu=nu) | ||
|
||
return MCLMCState(xx, uu, ll, gg), MCLMCInfo( | ||
transformed_x=transform(xx), | ||
logdensity=ll, | ||
dE=kinetic_change + ll - state.logdensity, | ||
) | ||
|
||
return kernel | ||
|
||
def minimal_norm(T, V): | ||
lambda_c = 0.1931833275037836 # critical value of the lambda parameter for the minimal norm integrator | ||
|
||
def step(state: MCLMCState, params: Parameters): | ||
"""Integrator from https://arxiv.org/pdf/hep-lat/0505020.pdf, see Equation 20.""" | ||
|
||
# V T V T V | ||
dt = params.step_size | ||
sigma = jnp.sqrt(params.inverse_mass_matrix) | ||
uu, r1 = V(dt * lambda_c, state.momentum, state.logdensity_grad * sigma) | ||
xx, ll, gg = T(dt, state.position, 0.5 * uu * sigma) | ||
uu, r2 = V(dt * (1 - 2 * lambda_c), uu, gg * sigma) | ||
xx, ll, gg = T(dt, xx, 0.5 * uu * sigma) | ||
uu, r3 = V(dt * lambda_c, uu, gg * sigma) | ||
|
||
# kinetic energy change | ||
dim = xx.shape[0] | ||
kinetic_change = (r1 + r2 + r3) * (dim - 1) | ||
|
||
return xx, uu, ll, gg, kinetic_change | ||
|
||
return step | ||
junpenglao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
class mclmc: | ||
"""The general mclmc kernel builder (:meth:`blackjax.mcmc.mclmc.build_kernel`, alias `blackjax.mclmc.build_kernel`) can be | ||
cumbersome to manipulate. Since most users only need to specify the kernel | ||
parameters at initialization time, we provide a helper function that | ||
specializes the general kernel. | ||
|
||
We also add the general kernel and state generator as an attribute to this class so | ||
users only need to pass `blackjax.mclmc` to SMC, adaptation, etc. algorithms. | ||
|
||
Examples | ||
-------- | ||
|
||
A new mclmc kernel can be initialized and used with the following code: | ||
|
||
.. code:: | ||
|
||
mclmc = blackjax.mcmc.mclmc.mclmc( | ||
logdensity_fn=logdensity_fn, | ||
transform=lambda x: x, | ||
params=params | ||
) | ||
state = mclmc.init(position) | ||
new_state, info = mclmc.step(rng_key, state) | ||
|
||
Kernels are not jit-compiled by default so you will need to do it manually: | ||
|
||
.. code:: | ||
|
||
step = jax.jit(mclmc.step) | ||
new_state, info = step(rng_key, state) | ||
|
||
Parameters | ||
---------- | ||
logdensity_fn | ||
The log-density function we wish to draw samples from. | ||
transform | ||
The value to use for the inverse mass matrix when drawing a value for | ||
the momentum and computing the kinetic energy. | ||
params | ||
Paramters | ||
integrator | ||
an integrator. We recommend using the default here. | ||
|
||
Returns | ||
------- | ||
A ``SamplingAlgorithm``. | ||
""" | ||
|
||
init = staticmethod(init) | ||
build_kernel = staticmethod(build_kernel) | ||
|
||
def __new__( # type: ignore[misc] | ||
cls, | ||
logdensity_fn: Callable, | ||
transform: Callable, | ||
params: Parameters, | ||
integrator=minimal_norm, | ||
) -> SamplingAlgorithm: | ||
grad_logp = jax.value_and_grad(logdensity_fn) | ||
|
||
kernel = cls.build_kernel(grad_logp, integrator, transform, params) | ||
|
||
def init_fn(position: ArrayLike): | ||
return cls.init(position, logdensity_fn, jax.random.PRNGKey(0)) | ||
|
||
return SamplingAlgorithm(init_fn, kernel) | ||
|
||
|
||
### | ||
# helper funcs | ||
### | ||
|
||
|
||
def random_unit_vector(rng_key, dim): | ||
u = jax.random.normal(rng_key, shape=(dim,)) | ||
u /= jnp.sqrt(jnp.sum(jnp.square(u))) | ||
return u | ||
|
||
|
||
def update_position(grad_logp): | ||
def update(step_size, x, u): | ||
xx = x + step_size * u | ||
ll, gg = grad_logp(xx) | ||
return xx, ll, gg | ||
|
||
return update | ||
|
||
|
||
def partially_refresh_momentum(u, rng_key, nu): | ||
"""Adds a small noise to u and normalizes.""" | ||
z = nu * jax.random.normal(rng_key, shape=(u.shape[0],)) | ||
return (u + z) / jnp.sqrt(jnp.sum(jnp.square(u + z))) | ||
|
||
|
||
### | ||
# integrator | ||
### | ||
|
||
|
||
def update_momentum(step_size, u, g): | ||
"""The momentum updating map of the esh dynamics (see https://arxiv.org/pdf/2111.02434.pdf) | ||
similar to the implementation: https://github.com/gregversteeg/esh_dynamics | ||
There are no exponentials e^delta, which prevents overflows when the gradient norm is large. | ||
""" | ||
g_norm = jnp.sqrt(jnp.sum(jnp.square(g))) | ||
e = g / g_norm | ||
ue = jnp.dot(u, e) | ||
dim = u.shape[0] | ||
delta = step_size * g_norm / (dim - 1) | ||
zeta = jnp.exp(-delta) | ||
uu = e * (1 - zeta) * (1 + zeta + ue * (1 - zeta)) + 2 * zeta * u | ||
delta_r = delta - jnp.log(2) + jnp.log(1 + ue + (1 - ue) * zeta**2) | ||
return uu / jnp.sqrt(jnp.sum(jnp.square(uu))), delta_r |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use IntegratorState directly.