Skip to content

Commit

Permalink
Refactor divergence check to each sampler (#579)
Browse files Browse the repository at this point in the history
* partially fix #391

* fix formatting
  • Loading branch information
junpenglao authored Oct 29, 2023
1 parent df935cf commit 29dc2eb
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 57 deletions.
5 changes: 3 additions & 2 deletions blackjax/mcmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def hmc_proposal(
"""
build_trajectory = trajectory.static_integration(integrator)
init_proposal, generate_proposal = proposal.proposal_generator(
hmc_energy(kinetic_energy), divergence_threshold
hmc_energy(kinetic_energy)
)

def generate(
Expand All @@ -286,7 +286,8 @@ def generate(
end_state = build_trajectory(state, step_size, num_integration_steps)
end_state = flip_momentum(end_state)
proposal = init_proposal(state)
new_proposal, is_diverging = generate_proposal(proposal.energy, end_state)
new_proposal = generate_proposal(proposal.energy, end_state)
is_diverging = -new_proposal.weight > divergence_threshold
sampled_proposal, *info = sample_proposal(rng_key, proposal, new_proposal)
do_accept, p_accept = info

Expand Down
4 changes: 2 additions & 2 deletions blackjax/mcmc/mala.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def transition_energy(state, new_state, step_size):
return -state.logdensity + 0.25 * (1.0 / step_size) * theta_dot

init_proposal, generate_proposal = proposal.asymmetric_proposal_generator(
transition_energy, divergence_threshold=jnp.inf
transition_energy
)
sample_proposal = proposal.static_binomial_sampling

Expand All @@ -107,7 +107,7 @@ def kernel(
new_state = MALAState(*new_state)

proposal = init_proposal(state)
new_proposal, _ = generate_proposal(state, new_state, step_size=step_size)
new_proposal = generate_proposal(state, new_state, step_size=step_size)
sampled_proposal, do_accept, p_accept = sample_proposal(
key_rmh, proposal, new_proposal
)
Expand Down
39 changes: 10 additions & 29 deletions blackjax/mcmc/proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,13 @@ class Proposal(NamedTuple):
sum_log_p_accept: float


def proposal_generator(
energy: Callable, divergence_threshold: float
) -> tuple[Callable, Callable]:
def proposal_generator(energy: Callable) -> tuple[Callable, Callable]:
"""
Parameters
----------
energy
A function that computes the energy associated to a given state
divergence_threshold
max value allowed for the difference in energies not to be considered a divergence
Returns
-------
Expand All @@ -61,7 +57,7 @@ def proposal_generator(
def new(state: TrajectoryState) -> Proposal:
return Proposal(state, energy(state), 0.0, -jnp.inf)

def update(initial_energy: float, state: TrajectoryState) -> tuple[Proposal, bool]:
def update(initial_energy: float, state: TrajectoryState) -> Proposal:
"""Generate a new proposal from a trajectory state.
The trajectory state records information about the position in the state
Expand All @@ -83,32 +79,24 @@ def update(initial_energy: float, state: TrajectoryState) -> tuple[Proposal, boo
"""
new_energy = energy(state)
return proposal_from_energy_diff(
initial_energy, new_energy, divergence_threshold, state
)
return proposal_from_energy_diff(initial_energy, new_energy, state)

return new, update


def proposal_from_energy_diff(
initial_energy: float,
new_energy: float,
divergence_threshold: float,
state: TrajectoryState,
) -> tuple[Proposal, bool]:
) -> Proposal:
"""Computes a new proposal from the energy difference between two states.
It also verifies whether this difference is a divergence, if the
energy diff is above divergence_threshold.
Parameters
----------
initial_energy
the energy from the initial state
new_energy
the energy at the proposed state
divergence_threshold
max value allowed for an increase in energies not to be considered a divergence
state
the proposed state
Expand All @@ -118,28 +106,23 @@ def proposal_from_energy_diff(
"""
delta_energy = initial_energy - new_energy
delta_energy = jnp.where(jnp.isnan(delta_energy), -jnp.inf, delta_energy)
is_transition_divergent = -delta_energy > divergence_threshold

# The weight of the new proposal is equal to H0 - H(z_new)
weight = delta_energy

# Acceptance statistic min(e^{H0 - H(z_new)}, 1)
sum_log_p_accept = jnp.minimum(delta_energy, 0.0)

return (
Proposal(
state,
new_energy,
weight,
sum_log_p_accept,
),
is_transition_divergent,
return Proposal(
state,
new_energy,
weight,
sum_log_p_accept,
)


def asymmetric_proposal_generator(
transition_energy_fn: Callable,
divergence_threshold: float,
proposal_factory: Callable = proposal_from_energy_diff,
) -> tuple[Callable, Callable]:
"""A proposal generator that takes into account the transition between
Expand All @@ -153,8 +136,6 @@ def asymmetric_proposal_generator(
transition_energy_fn
A function that computes the energy of a transition from an initial state
to a new state, given some optional keyword arguments.
divergence_threshold
The maximum value allowed for the difference in energies not to be considered a divergence.
proposal_factory
A function that builds a proposal from the transition energies.
Expand All @@ -174,7 +155,7 @@ def update(
) -> tuple[Proposal, bool]:
new_energy = transition_energy_fn(initial_state, state, **energy_params)
prev_energy = transition_energy_fn(state, initial_state, **energy_params)
return proposal_factory(prev_energy, new_energy, divergence_threshold, state)
return proposal_factory(prev_energy, new_energy, state)

return new, update

Expand Down
5 changes: 2 additions & 3 deletions blackjax/mcmc/random_walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
from typing import Callable, NamedTuple, Optional

import jax
import numpy as np
from jax import numpy as jnp

from blackjax.base import SamplingAlgorithm
Expand Down Expand Up @@ -391,7 +390,7 @@ def kernel(
transition_energy = build_rmh_transition_energy(proposal_logdensity_fn)

init_proposal, generate_proposal = proposal.asymmetric_proposal_generator(
transition_energy, np.inf
transition_energy
)

proposal_generator = rmh_proposal(
Expand Down Expand Up @@ -496,7 +495,7 @@ def build_trajectory(rng_key, initial_state: RWState) -> RWState:
def generate(rng_key, state: RWState) -> tuple[RWState, bool, float]:
key_proposal, key_accept = jax.random.split(rng_key, 2)
end_state = build_trajectory(key_proposal, state)
new_proposal, _ = generate_proposal(state, end_state)
new_proposal = generate_proposal(state, end_state)
previous_proposal = init_proposal(state)
sampled_proposal, do_accept, p_accept = sample_proposal(
key_accept, previous_proposal, new_proposal
Expand Down
16 changes: 7 additions & 9 deletions blackjax/mcmc/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,7 @@ def dynamic_progressive_integration(
which we say a transition is divergent.
"""
_, generate_proposal = proposal_generator(
hmc_energy(kinetic_energy), divergence_threshold
)
_, generate_proposal = proposal_generator(hmc_energy(kinetic_energy))
sample_proposal = progressive_uniform_sampling

def integrate(
Expand Down Expand Up @@ -215,7 +213,8 @@ def add_one_state(loop_state):
rng_key, proposal_key = jax.random.split(rng_key)

new_state = integrator(trajectory.rightmost_state, direction * step_size)
new_proposal, is_diverging = generate_proposal(initial_energy, new_state)
new_proposal = generate_proposal(initial_energy, new_state)
is_diverging = -new_proposal.weight > divergence_threshold

# At step 0, we always accept the proposal, since we
# take one step to get the leftmost state of the tree.
Expand Down Expand Up @@ -248,7 +247,7 @@ def add_one_state(loop_state):

return (rng_key, new_integration_state, (is_diverging, has_terminated))

proposal_placeholder, _ = generate_proposal(initial_energy, initial_state)
proposal_placeholder = generate_proposal(initial_energy, initial_state)
trajectory_placeholder = Trajectory(
initial_state, initial_state, initial_state.momentum, 0
)
Expand Down Expand Up @@ -319,9 +318,7 @@ def dynamic_recursive_integration(
Bool to indicate whether to perform additional U turn check between two trajectory.
"""
_, generate_proposal = proposal_generator(
hmc_energy(kinetic_energy), divergence_threshold
)
_, generate_proposal = proposal_generator(hmc_energy(kinetic_energy))
sample_proposal = progressive_uniform_sampling

def buildtree_integrate(
Expand Down Expand Up @@ -357,7 +354,8 @@ def buildtree_integrate(
if tree_depth == 0:
# Base case - take one leapfrog step in the direction v.
next_state = integrator(initial_state, direction * step_size)
new_proposal, is_diverging = generate_proposal(initial_energy, next_state)
new_proposal = generate_proposal(initial_energy, next_state)
is_diverging = -new_proposal.weight > divergence_threshold
trajectory = Trajectory(next_state, next_state, next_state.momentum, 1)
return (
rng_key,
Expand Down
20 changes: 9 additions & 11 deletions tests/mcmc/test_proposal_without_chex.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class TestAsymmetricProposalGenerator(unittest.TestCase):
def test_new(self):
state = MagicMock()
new, _ = asymmetric_proposal_generator(None, None, None)
new, _ = asymmetric_proposal_generator(None, None)
assert new(state) == Proposal(state, 0.0, 0.0, -np.inf)

def test_update(self):
Expand All @@ -23,16 +23,13 @@ def transition_energy(prev, next):

new_proposal = MagicMock()

def proposal_factory(prev_energy, new_energy, divergence_threshold, new_state):
def proposal_factory(prev_energy, new_energy, new_state):
assert prev_energy == -20
assert new_energy == 20
assert divergence_threshold == 50
assert new_state == 50
return new_proposal

_, update = asymmetric_proposal_generator(
transition_energy, 50, proposal_factory
)
_, update = asymmetric_proposal_generator(transition_energy, proposal_factory)
proposed = update(30, 50)
assert proposed == new_proposal

Expand All @@ -52,25 +49,26 @@ class TestProposalFromEnergyDiff(parameterized.TestCase):
)
def test_divergence_threshold(self, before, after, threshold, is_divergent):
state = MagicMock()
proposal, divergence = proposal_from_energy_diff(5, 10, threshold, state)
proposal = proposal_from_energy_diff(5, 10, state)
divergence = -proposal.weight > threshold
assert divergence == is_divergent

def test_sum_log_paccept(self):
state = MagicMock()
proposal, _ = proposal_from_energy_diff(5, 10, 0, state)
proposal = proposal_from_energy_diff(5, 10, state)
np.testing.assert_allclose(proposal.sum_log_p_accept, -5.0)

proposal, _ = proposal_from_energy_diff(10, 5, 0, state)
proposal = proposal_from_energy_diff(10, 5, state)
np.testing.assert_allclose(proposal.sum_log_p_accept, 0.0)

def test_delta_energy_is_nan(self):
state = MagicMock()
proposal, _ = proposal_from_energy_diff(np.nan, np.nan, 0, state)
proposal = proposal_from_energy_diff(np.nan, np.nan, state)
assert np.isneginf(proposal.weight)

def test_weight(self):
state = MagicMock()
proposal, _ = proposal_from_energy_diff(5, 10, 0, state)
proposal = proposal_from_energy_diff(5, 10, state)

assert proposal.state == state
np.testing.assert_allclose(proposal.weight, -5)
Expand Down
2 changes: 1 addition & 1 deletion tests/mcmc/test_random_walk_without_chex.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def init_proposal(self, state):
return Proposal(state, 0, 0, 0)

def generate_proposal(self, prev, new):
return Proposal(new, 0, 0, 0), False
return Proposal(new, 0, 0, 0)

def test_generate_reject(self):
"""
Expand Down

0 comments on commit 29dc2eb

Please sign in to comment.