diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 4e1f4ee75..ba7d0f399 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -92,7 +92,7 @@ def kernel(x): initial_state = MCMCState(position=0, momentum=1) # Generate a random number generator key - rng_key = jax.random.PRNGKey(0) + rng_key = jax.random.key(0) # Find the optimal parameters for the MCLMC algorithm final_state, final_params = mclmc_find_L_and_step_size( diff --git a/blackjax/base.py b/blackjax/base.py index 7f709b895..f766e98b5 100644 --- a/blackjax/base.py +++ b/blackjax/base.py @@ -10,7 +10,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, NamedTuple +from typing import Callable, NamedTuple, Optional from typing_extensions import Protocol @@ -34,7 +34,7 @@ class InitFn(Protocol): """ - def __call__(self, position: Position) -> State: + def __call__(self, position: Position, rng_key: Optional[PRNGKey]) -> State: """Initialize the algorithm's state. Parameters diff --git a/blackjax/mcmc/barker.py b/blackjax/mcmc/barker.py index 2d9e9d648..b91721a71 100644 --- a/blackjax/mcmc/barker.py +++ b/blackjax/mcmc/barker.py @@ -189,7 +189,8 @@ def __new__( # type: ignore[misc] ) -> SamplingAlgorithm: kernel = cls.build_kernel() - def init_fn(position: ArrayLikeTree): + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key return cls.init(position, logdensity_fn) def step_fn(rng_key: PRNGKey, state): diff --git a/blackjax/mcmc/dynamic_hmc.py b/blackjax/mcmc/dynamic_hmc.py index adc269c40..0fe4ec992 100644 --- a/blackjax/mcmc/dynamic_hmc.py +++ b/blackjax/mcmc/dynamic_hmc.py @@ -163,8 +163,11 @@ def __new__( # type: ignore[misc] integrator, divergence_threshold, next_random_arg_fn, integration_steps_fn ) - def init_fn(position: ArrayLikeTree, random_generator_arg: Array): - return cls.init(position, logdensity_fn, random_generator_arg) + def init_fn(position: ArrayLikeTree, rng_key: Array): + # Note that rng_key here is not necessarily a PRNGKey, could be a Array that + # for generates a sequence of pseudo or quasi-random numbers (previously + # named as `random_generator_arg`) + return cls.init(position, logdensity_fn, rng_key) def step_fn(rng_key: PRNGKey, state): return kernel( @@ -175,7 +178,7 @@ def step_fn(rng_key: PRNGKey, state): inverse_mass_matrix, ) - return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] + return SamplingAlgorithm(init_fn, step_fn) def halton_sequence(i: Array, max_bits: int = 10) -> float: diff --git a/blackjax/mcmc/elliptical_slice.py b/blackjax/mcmc/elliptical_slice.py index 4ff310445..c0d1c5998 100644 --- a/blackjax/mcmc/elliptical_slice.py +++ b/blackjax/mcmc/elliptical_slice.py @@ -164,7 +164,8 @@ def __new__( # type: ignore[misc] ) -> SamplingAlgorithm: kernel = cls.build_kernel(cov, mean) - def init_fn(position: ArrayLikeTree): + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key return cls.init(position, loglikelihood_fn) def step_fn(rng_key: PRNGKey, state): diff --git a/blackjax/mcmc/ghmc.py b/blackjax/mcmc/ghmc.py index 345b3fe5f..ada6bea9c 100644 --- a/blackjax/mcmc/ghmc.py +++ b/blackjax/mcmc/ghmc.py @@ -287,4 +287,4 @@ def step_fn(rng_key: PRNGKey, state): delta, ) - return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] + return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/mcmc/hmc.py b/blackjax/mcmc/hmc.py index 499a23e6a..90bdbc60c 100644 --- a/blackjax/mcmc/hmc.py +++ b/blackjax/mcmc/hmc.py @@ -229,7 +229,8 @@ def __new__( # type: ignore[misc] ) -> SamplingAlgorithm: kernel = cls.build_kernel(integrator, divergence_threshold) - def init_fn(position: ArrayLikeTree): + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key return cls.init(position, logdensity_fn) def step_fn(rng_key: PRNGKey, state): diff --git a/blackjax/mcmc/mala.py b/blackjax/mcmc/mala.py index 369433341..9690bc7f5 100644 --- a/blackjax/mcmc/mala.py +++ b/blackjax/mcmc/mala.py @@ -177,7 +177,8 @@ def __new__( # type: ignore[misc] ) -> SamplingAlgorithm: kernel = cls.build_kernel() - def init_fn(position: ArrayLikeTree): + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key return cls.init(position, logdensity_fn) def step_fn(rng_key: PRNGKey, state): diff --git a/blackjax/mcmc/marginal_latent_gaussian.py b/blackjax/mcmc/marginal_latent_gaussian.py index 8c910769b..8d4d76f6a 100644 --- a/blackjax/mcmc/marginal_latent_gaussian.py +++ b/blackjax/mcmc/marginal_latent_gaussian.py @@ -272,7 +272,8 @@ def __new__( # type: ignore[misc] kernel = cls.build_kernel(cov_svd) - def init_fn(position: Array): + def init_fn(position: Array, rng_key=None): + del rng_key return init(position, logdensity_fn, U_t) def step_fn(rng_key: PRNGKey, state): diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 26bbda2b8..4da35677d 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -155,16 +155,15 @@ def __new__( # type: ignore[misc] L, step_size, integrator=noneuclidean_mclachlan, - seed=1, ) -> SamplingAlgorithm: kernel = cls.build_kernel(logdensity_fn, integrator) + def init_fn(position: ArrayLike, rng_key: PRNGKey): + return cls.init(position, logdensity_fn, rng_key) + def update_fn(rng_key, state): return kernel(rng_key, state, L, step_size) - def init_fn(position: ArrayLike): - return cls.init(position, logdensity_fn, jax.random.PRNGKey(seed)) - return SamplingAlgorithm(init_fn, update_fn) diff --git a/blackjax/mcmc/nuts.py b/blackjax/mcmc/nuts.py index b185cff27..883121514 100644 --- a/blackjax/mcmc/nuts.py +++ b/blackjax/mcmc/nuts.py @@ -222,7 +222,8 @@ def __new__( # type: ignore[misc] ) -> SamplingAlgorithm: kernel = cls.build_kernel(integrator, divergence_threshold) - def init_fn(position: ArrayLikeTree): + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key return cls.init(position, logdensity_fn) def step_fn(rng_key: PRNGKey, state): diff --git a/blackjax/mcmc/periodic_orbital.py b/blackjax/mcmc/periodic_orbital.py index a8bb54787..6e4a2ca5e 100644 --- a/blackjax/mcmc/periodic_orbital.py +++ b/blackjax/mcmc/periodic_orbital.py @@ -276,7 +276,8 @@ def __new__( # type: ignore[misc] ) -> SamplingAlgorithm: kernel = cls.build_kernel(bijection) - def init_fn(position: ArrayLikeTree): + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key return cls.init(position, logdensity_fn, period) def step_fn(rng_key: PRNGKey, state): diff --git a/blackjax/mcmc/random_walk.py b/blackjax/mcmc/random_walk.py index 05a43e37b..e454c057d 100644 --- a/blackjax/mcmc/random_walk.py +++ b/blackjax/mcmc/random_walk.py @@ -243,7 +243,8 @@ def __new__( # type: ignore[misc] ) -> SamplingAlgorithm: kernel = cls.build_kernel() - def init_fn(position: ArrayLikeTree): + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key return cls.init(position, logdensity_fn) def step_fn(rng_key: PRNGKey, state): @@ -345,7 +346,8 @@ def __new__( # type: ignore[misc] ) -> SamplingAlgorithm: kernel = cls.build_kernel() - def init_fn(position: ArrayLikeTree): + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key return cls.init(position, logdensity_fn) def step_fn(rng_key: PRNGKey, state): @@ -466,7 +468,8 @@ def __new__( # type: ignore[misc] ) -> SamplingAlgorithm: kernel = cls.build_kernel() - def init_fn(position: ArrayLikeTree): + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key return cls.init(position, logdensity_fn) def step_fn(rng_key: PRNGKey, state): diff --git a/blackjax/sgmcmc/csgld.py b/blackjax/sgmcmc/csgld.py index 93a8ea1de..e0e008a33 100644 --- a/blackjax/sgmcmc/csgld.py +++ b/blackjax/sgmcmc/csgld.py @@ -223,7 +223,8 @@ def __new__( # type: ignore[misc] ) -> SamplingAlgorithm: kernel = cls.build_kernel(num_partitions, energy_gap, min_energy) - def init_fn(position: ArrayLikeTree): + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key return cls.init(position, num_partitions) def step_fn( diff --git a/blackjax/sgmcmc/sghmc.py b/blackjax/sgmcmc/sghmc.py index 0b1cbfd14..806bbc14e 100644 --- a/blackjax/sgmcmc/sghmc.py +++ b/blackjax/sgmcmc/sghmc.py @@ -123,7 +123,8 @@ def __new__( # type: ignore[misc] ) -> SamplingAlgorithm: kernel = cls.build_kernel(alpha, beta) - def init_fn(position: ArrayLikeTree): + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key return cls.init(position) def step_fn( diff --git a/blackjax/sgmcmc/sgld.py b/blackjax/sgmcmc/sgld.py index b43f3de89..e2055c511 100644 --- a/blackjax/sgmcmc/sgld.py +++ b/blackjax/sgmcmc/sgld.py @@ -109,7 +109,8 @@ def __new__( # type: ignore[misc] ) -> SamplingAlgorithm: kernel = cls.build_kernel() - def init_fn(position: ArrayLikeTree): + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key return cls.init(position) def step_fn( diff --git a/blackjax/smc/adaptive_tempered.py b/blackjax/smc/adaptive_tempered.py index bbb71761a..5b02c783b 100644 --- a/blackjax/smc/adaptive_tempered.py +++ b/blackjax/smc/adaptive_tempered.py @@ -159,7 +159,8 @@ def __new__( # type: ignore[misc] root_solver, ) - def init_fn(position: ArrayLikeTree): + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key return cls.init(position) def step_fn(rng_key: PRNGKey, state): diff --git a/blackjax/smc/inner_kernel_tuning.py b/blackjax/smc/inner_kernel_tuning.py index 83bfdb50d..6aaf3a5d3 100644 --- a/blackjax/smc/inner_kernel_tuning.py +++ b/blackjax/smc/inner_kernel_tuning.py @@ -139,7 +139,8 @@ def __new__( # type: ignore[misc] **extra_parameters, ) - def init_fn(position): + def init_fn(position, rng_key=None): + del rng_key return cls.init(smc_algorithm.init, position, initial_parameter_value) def step_fn( diff --git a/blackjax/smc/tempered.py b/blackjax/smc/tempered.py index 40e95b665..49fa21277 100644 --- a/blackjax/smc/tempered.py +++ b/blackjax/smc/tempered.py @@ -203,7 +203,8 @@ def __new__( # type: ignore[misc] resampling_fn, ) - def init_fn(position: ArrayLikeTree): + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key return cls.init(position) def step_fn(rng_key: PRNGKey, state, lmbda): diff --git a/blackjax/util.py b/blackjax/util.py index 6f1d49072..ee226a5ac 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -150,6 +150,10 @@ def run_inference_algorithm( ) -> tuple[State, State, Info]: """Wrapper to run an inference algorithm. + Note that this utility function does not work for Stochastic Gradient MCMC samplers + like sghmc, as SG-MCMC samplers require additional control flow for batches of data + to be passed in during each sample. + Parameters ---------- rng_key @@ -175,13 +179,14 @@ def run_inference_algorithm( 2. The trace of states of the inference algorithm (contains the MCMC samples). 3. The trace of the info of the inference algorithm for diagnostics. """ + init_key, sample_key = split(rng_key, 2) try: - initial_state = inference_algorithm.init(initial_state_or_position) + initial_state = inference_algorithm.init(initial_state_or_position, init_key) except TypeError: # We assume initial_state is already in the right format. initial_state = initial_state_or_position - keys = split(rng_key, num_steps) + keys = split(sample_key, num_steps) @jit def _one_step(state, xs): diff --git a/docs/examples/howto_use_oryx.md b/docs/examples/howto_use_oryx.md index c145ed1ee..dd34a47e3 100644 --- a/docs/examples/howto_use_oryx.md +++ b/docs/examples/howto_use_oryx.md @@ -43,7 +43,7 @@ import jax import jax.numpy as jnp from datetime import date -rng_key = jax.random.PRNGKey(int(date.today().strftime("%Y%m%d"))) +rng_key = jax.random.key(int(date.today().strftime("%Y%m%d"))) ``` Oryx's approach, like Aesara's, is to implement probabilistic models as generative models and then apply transformations to get the log-probability density function. We begin with implementing a dense layer with normal prior probability on the weights and use the function `random_variable` to define random variables: diff --git a/tests/adaptation/test_adaptation.py b/tests/adaptation/test_adaptation.py index 93bf418d2..286bf30aa 100644 --- a/tests/adaptation/test_adaptation.py +++ b/tests/adaptation/test_adaptation.py @@ -44,7 +44,7 @@ def test_chees_adaptation(): num_chains = 16 step_size = 0.1 - init_key, warmup_key, inference_key = jax.random.split(jax.random.PRNGKey(0), 3) + init_key, warmup_key, inference_key = jax.random.split(jax.random.key(0), 3) warmup = blackjax.chees_adaptation( logprob_fn, num_chains=num_chains, target_acceptance_rate=0.75 diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index b87c14d7c..c719286a4 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -584,8 +584,8 @@ def test_latent_gaussian(self): "algorithm": blackjax.mala, "initial_position": 1.0, "parameters": {"step_size": 1e-1}, - "num_sampling_steps": 20_000, - "burnin": 2_000, + "num_sampling_steps": 45_000, + "burnin": 5_000, }, { "algorithm": blackjax.elliptical_slice, diff --git a/tests/smc/test_inner_kernel_tuning.py b/tests/smc/test_inner_kernel_tuning.py index 038f27e0a..1bbc68970 100644 --- a/tests/smc/test_inner_kernel_tuning.py +++ b/tests/smc/test_inner_kernel_tuning.py @@ -60,7 +60,7 @@ def log_weights_fn(x, y): class SMCParameterTuningTest(chex.TestCase): def setUp(self): super().setUp() - self.key = jax.random.PRNGKey(42) + self.key = jax.random.key(42) def logdensity_fn(self, log_scale, coefs, preds, x): """Linear regression""" @@ -129,7 +129,7 @@ def wrapped_kernel(rng_key, state, logdensity): class MeanAndStdFromParticlesTest(chex.TestCase): def setUp(self): super().setUp() - self.key = jax.random.PRNGKey(42) + self.key = jax.random.key(42) def test_mean_and_std(self): particles = np.array( @@ -182,7 +182,7 @@ def test_mean_and_std_multivariable_particles(self): class InverseMassMatrixFromParticles(chex.TestCase): def setUp(self): super().setUp() - self.key = jax.random.PRNGKey(42) + self.key = jax.random.key(42) def test_inverse_mass_matrix_from_particles(self): inverse_mass_matrix = mass_matrix_from_particles(