diff --git a/blackjax/mcmc/hmc.py b/blackjax/mcmc/hmc.py index 228fd0b51..56e1c1790 100644 --- a/blackjax/mcmc/hmc.py +++ b/blackjax/mcmc/hmc.py @@ -276,7 +276,7 @@ def hmc_proposal( """ build_trajectory = trajectory.static_integration(integrator) init_proposal, generate_proposal = proposal.proposal_generator( - hmc_energy(kinetic_energy), divergence_threshold + hmc_energy(kinetic_energy) ) def generate( @@ -286,7 +286,8 @@ def generate( end_state = build_trajectory(state, step_size, num_integration_steps) end_state = flip_momentum(end_state) proposal = init_proposal(state) - new_proposal, is_diverging = generate_proposal(proposal.energy, end_state) + new_proposal = generate_proposal(proposal.energy, end_state) + is_diverging = -new_proposal.weight > divergence_threshold sampled_proposal, *info = sample_proposal(rng_key, proposal, new_proposal) do_accept, p_accept = info diff --git a/blackjax/mcmc/mala.py b/blackjax/mcmc/mala.py index 0f4295a0e..f4aad2fb6 100644 --- a/blackjax/mcmc/mala.py +++ b/blackjax/mcmc/mala.py @@ -90,7 +90,7 @@ def transition_energy(state, new_state, step_size): return -state.logdensity + 0.25 * (1.0 / step_size) * theta_dot init_proposal, generate_proposal = proposal.asymmetric_proposal_generator( - transition_energy, divergence_threshold=jnp.inf + transition_energy ) sample_proposal = proposal.static_binomial_sampling @@ -107,7 +107,7 @@ def kernel( new_state = MALAState(*new_state) proposal = init_proposal(state) - new_proposal, _ = generate_proposal(state, new_state, step_size=step_size) + new_proposal = generate_proposal(state, new_state, step_size=step_size) sampled_proposal, do_accept, p_accept = sample_proposal( key_rmh, proposal, new_proposal ) diff --git a/blackjax/mcmc/proposal.py b/blackjax/mcmc/proposal.py index 2ed3eca87..9415438b0 100644 --- a/blackjax/mcmc/proposal.py +++ b/blackjax/mcmc/proposal.py @@ -40,17 +40,13 @@ class Proposal(NamedTuple): sum_log_p_accept: float -def proposal_generator( - energy: Callable, divergence_threshold: float -) -> tuple[Callable, Callable]: +def proposal_generator(energy: Callable) -> tuple[Callable, Callable]: """ Parameters ---------- energy A function that computes the energy associated to a given state - divergence_threshold - max value allowed for the difference in energies not to be considered a divergence Returns ------- @@ -61,7 +57,7 @@ def proposal_generator( def new(state: TrajectoryState) -> Proposal: return Proposal(state, energy(state), 0.0, -jnp.inf) - def update(initial_energy: float, state: TrajectoryState) -> tuple[Proposal, bool]: + def update(initial_energy: float, state: TrajectoryState) -> Proposal: """Generate a new proposal from a trajectory state. The trajectory state records information about the position in the state @@ -83,9 +79,7 @@ def update(initial_energy: float, state: TrajectoryState) -> tuple[Proposal, boo """ new_energy = energy(state) - return proposal_from_energy_diff( - initial_energy, new_energy, divergence_threshold, state - ) + return proposal_from_energy_diff(initial_energy, new_energy, state) return new, update @@ -93,22 +87,16 @@ def update(initial_energy: float, state: TrajectoryState) -> tuple[Proposal, boo def proposal_from_energy_diff( initial_energy: float, new_energy: float, - divergence_threshold: float, state: TrajectoryState, -) -> tuple[Proposal, bool]: +) -> Proposal: """Computes a new proposal from the energy difference between two states. - It also verifies whether this difference is a divergence, if the - energy diff is above divergence_threshold. - Parameters ---------- initial_energy the energy from the initial state new_energy the energy at the proposed state - divergence_threshold - max value allowed for an increase in energies not to be considered a divergence state the proposed state @@ -118,7 +106,6 @@ def proposal_from_energy_diff( """ delta_energy = initial_energy - new_energy delta_energy = jnp.where(jnp.isnan(delta_energy), -jnp.inf, delta_energy) - is_transition_divergent = -delta_energy > divergence_threshold # The weight of the new proposal is equal to H0 - H(z_new) weight = delta_energy @@ -126,20 +113,16 @@ def proposal_from_energy_diff( # Acceptance statistic min(e^{H0 - H(z_new)}, 1) sum_log_p_accept = jnp.minimum(delta_energy, 0.0) - return ( - Proposal( - state, - new_energy, - weight, - sum_log_p_accept, - ), - is_transition_divergent, + return Proposal( + state, + new_energy, + weight, + sum_log_p_accept, ) def asymmetric_proposal_generator( transition_energy_fn: Callable, - divergence_threshold: float, proposal_factory: Callable = proposal_from_energy_diff, ) -> tuple[Callable, Callable]: """A proposal generator that takes into account the transition between @@ -153,8 +136,6 @@ def asymmetric_proposal_generator( transition_energy_fn A function that computes the energy of a transition from an initial state to a new state, given some optional keyword arguments. - divergence_threshold - The maximum value allowed for the difference in energies not to be considered a divergence. proposal_factory A function that builds a proposal from the transition energies. @@ -174,7 +155,7 @@ def update( ) -> tuple[Proposal, bool]: new_energy = transition_energy_fn(initial_state, state, **energy_params) prev_energy = transition_energy_fn(state, initial_state, **energy_params) - return proposal_factory(prev_energy, new_energy, divergence_threshold, state) + return proposal_factory(prev_energy, new_energy, state) return new, update diff --git a/blackjax/mcmc/random_walk.py b/blackjax/mcmc/random_walk.py index 6d97c7c08..9d7a0abee 100644 --- a/blackjax/mcmc/random_walk.py +++ b/blackjax/mcmc/random_walk.py @@ -63,7 +63,6 @@ from typing import Callable, NamedTuple, Optional import jax -import numpy as np from jax import numpy as jnp from blackjax.base import SamplingAlgorithm @@ -391,7 +390,7 @@ def kernel( transition_energy = build_rmh_transition_energy(proposal_logdensity_fn) init_proposal, generate_proposal = proposal.asymmetric_proposal_generator( - transition_energy, np.inf + transition_energy ) proposal_generator = rmh_proposal( @@ -496,7 +495,7 @@ def build_trajectory(rng_key, initial_state: RWState) -> RWState: def generate(rng_key, state: RWState) -> tuple[RWState, bool, float]: key_proposal, key_accept = jax.random.split(rng_key, 2) end_state = build_trajectory(key_proposal, state) - new_proposal, _ = generate_proposal(state, end_state) + new_proposal = generate_proposal(state, end_state) previous_proposal = init_proposal(state) sampled_proposal, do_accept, p_accept = sample_proposal( key_accept, previous_proposal, new_proposal diff --git a/blackjax/mcmc/trajectory.py b/blackjax/mcmc/trajectory.py index 81d369c0b..00f25989d 100644 --- a/blackjax/mcmc/trajectory.py +++ b/blackjax/mcmc/trajectory.py @@ -159,9 +159,7 @@ def dynamic_progressive_integration( which we say a transition is divergent. """ - _, generate_proposal = proposal_generator( - hmc_energy(kinetic_energy), divergence_threshold - ) + _, generate_proposal = proposal_generator(hmc_energy(kinetic_energy)) sample_proposal = progressive_uniform_sampling def integrate( @@ -215,7 +213,8 @@ def add_one_state(loop_state): rng_key, proposal_key = jax.random.split(rng_key) new_state = integrator(trajectory.rightmost_state, direction * step_size) - new_proposal, is_diverging = generate_proposal(initial_energy, new_state) + new_proposal = generate_proposal(initial_energy, new_state) + is_diverging = -new_proposal.weight > divergence_threshold # At step 0, we always accept the proposal, since we # take one step to get the leftmost state of the tree. @@ -248,7 +247,7 @@ def add_one_state(loop_state): return (rng_key, new_integration_state, (is_diverging, has_terminated)) - proposal_placeholder, _ = generate_proposal(initial_energy, initial_state) + proposal_placeholder = generate_proposal(initial_energy, initial_state) trajectory_placeholder = Trajectory( initial_state, initial_state, initial_state.momentum, 0 ) @@ -319,9 +318,7 @@ def dynamic_recursive_integration( Bool to indicate whether to perform additional U turn check between two trajectory. """ - _, generate_proposal = proposal_generator( - hmc_energy(kinetic_energy), divergence_threshold - ) + _, generate_proposal = proposal_generator(hmc_energy(kinetic_energy)) sample_proposal = progressive_uniform_sampling def buildtree_integrate( @@ -357,7 +354,8 @@ def buildtree_integrate( if tree_depth == 0: # Base case - take one leapfrog step in the direction v. next_state = integrator(initial_state, direction * step_size) - new_proposal, is_diverging = generate_proposal(initial_energy, next_state) + new_proposal = generate_proposal(initial_energy, next_state) + is_diverging = -new_proposal.weight > divergence_threshold trajectory = Trajectory(next_state, next_state, next_state.momentum, 1) return ( rng_key, diff --git a/tests/mcmc/test_proposal_without_chex.py b/tests/mcmc/test_proposal_without_chex.py index b34d758ef..82097f53f 100644 --- a/tests/mcmc/test_proposal_without_chex.py +++ b/tests/mcmc/test_proposal_without_chex.py @@ -14,7 +14,7 @@ class TestAsymmetricProposalGenerator(unittest.TestCase): def test_new(self): state = MagicMock() - new, _ = asymmetric_proposal_generator(None, None, None) + new, _ = asymmetric_proposal_generator(None, None) assert new(state) == Proposal(state, 0.0, 0.0, -np.inf) def test_update(self): @@ -23,16 +23,13 @@ def transition_energy(prev, next): new_proposal = MagicMock() - def proposal_factory(prev_energy, new_energy, divergence_threshold, new_state): + def proposal_factory(prev_energy, new_energy, new_state): assert prev_energy == -20 assert new_energy == 20 - assert divergence_threshold == 50 assert new_state == 50 return new_proposal - _, update = asymmetric_proposal_generator( - transition_energy, 50, proposal_factory - ) + _, update = asymmetric_proposal_generator(transition_energy, proposal_factory) proposed = update(30, 50) assert proposed == new_proposal @@ -52,25 +49,26 @@ class TestProposalFromEnergyDiff(parameterized.TestCase): ) def test_divergence_threshold(self, before, after, threshold, is_divergent): state = MagicMock() - proposal, divergence = proposal_from_energy_diff(5, 10, threshold, state) + proposal = proposal_from_energy_diff(5, 10, state) + divergence = -proposal.weight > threshold assert divergence == is_divergent def test_sum_log_paccept(self): state = MagicMock() - proposal, _ = proposal_from_energy_diff(5, 10, 0, state) + proposal = proposal_from_energy_diff(5, 10, state) np.testing.assert_allclose(proposal.sum_log_p_accept, -5.0) - proposal, _ = proposal_from_energy_diff(10, 5, 0, state) + proposal = proposal_from_energy_diff(10, 5, state) np.testing.assert_allclose(proposal.sum_log_p_accept, 0.0) def test_delta_energy_is_nan(self): state = MagicMock() - proposal, _ = proposal_from_energy_diff(np.nan, np.nan, 0, state) + proposal = proposal_from_energy_diff(np.nan, np.nan, state) assert np.isneginf(proposal.weight) def test_weight(self): state = MagicMock() - proposal, _ = proposal_from_energy_diff(5, 10, 0, state) + proposal = proposal_from_energy_diff(5, 10, state) assert proposal.state == state np.testing.assert_allclose(proposal.weight, -5) diff --git a/tests/mcmc/test_random_walk_without_chex.py b/tests/mcmc/test_random_walk_without_chex.py index 6e4e7afe1..8bbcd578e 100644 --- a/tests/mcmc/test_random_walk_without_chex.py +++ b/tests/mcmc/test_random_walk_without_chex.py @@ -90,7 +90,7 @@ def init_proposal(self, state): return Proposal(state, 0, 0, 0) def generate_proposal(self, prev, new): - return Proposal(new, 0, 0, 0), False + return Proposal(new, 0, 0, 0) def test_generate_reject(self): """