Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend function signature for InitFn #627

Merged
merged 3 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def kernel(x):
initial_state = MCMCState(position=0, momentum=1)
# Generate a random number generator key
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
# Find the optimal parameters for the MCLMC algorithm
final_state, final_params = mclmc_find_L_and_step_size(
Expand Down
4 changes: 2 additions & 2 deletions blackjax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, NamedTuple
from typing import Callable, NamedTuple, Optional

from typing_extensions import Protocol

Expand All @@ -34,7 +34,7 @@ class InitFn(Protocol):
"""

def __call__(self, position: Position) -> State:
def __call__(self, position: Position, rng_key: Optional[PRNGKey]) -> State:
"""Initialize the algorithm's state.
Parameters
Expand Down
3 changes: 2 additions & 1 deletion blackjax/mcmc/barker.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ def __new__( # type: ignore[misc]
) -> SamplingAlgorithm:
kernel = cls.build_kernel()

def init_fn(position: ArrayLikeTree):
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return cls.init(position, logdensity_fn)

def step_fn(rng_key: PRNGKey, state):
Expand Down
9 changes: 6 additions & 3 deletions blackjax/mcmc/dynamic_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,11 @@ def __new__( # type: ignore[misc]
integrator, divergence_threshold, next_random_arg_fn, integration_steps_fn
)

def init_fn(position: ArrayLikeTree, random_generator_arg: Array):
return cls.init(position, logdensity_fn, random_generator_arg)
def init_fn(position: ArrayLikeTree, rng_key: Array):
# Note that rng_key here is not necessarily a PRNGKey, could be a Array that
# for generates a sequence of pseudo or quasi-random numbers (previously
# named as `random_generator_arg`)
return cls.init(position, logdensity_fn, rng_key)

def step_fn(rng_key: PRNGKey, state):
return kernel(
Expand All @@ -175,7 +178,7 @@ def step_fn(rng_key: PRNGKey, state):
inverse_mass_matrix,
)

return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type]
return SamplingAlgorithm(init_fn, step_fn)


def halton_sequence(i: Array, max_bits: int = 10) -> float:
Expand Down
3 changes: 2 additions & 1 deletion blackjax/mcmc/elliptical_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def __new__( # type: ignore[misc]
) -> SamplingAlgorithm:
kernel = cls.build_kernel(cov, mean)

def init_fn(position: ArrayLikeTree):
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return cls.init(position, loglikelihood_fn)

def step_fn(rng_key: PRNGKey, state):
Expand Down
2 changes: 1 addition & 1 deletion blackjax/mcmc/ghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,4 +287,4 @@ def step_fn(rng_key: PRNGKey, state):
delta,
)

return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type]
return SamplingAlgorithm(init_fn, step_fn)
3 changes: 2 additions & 1 deletion blackjax/mcmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ def __new__( # type: ignore[misc]
) -> SamplingAlgorithm:
kernel = cls.build_kernel(integrator, divergence_threshold)

def init_fn(position: ArrayLikeTree):
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return cls.init(position, logdensity_fn)

def step_fn(rng_key: PRNGKey, state):
Expand Down
3 changes: 2 additions & 1 deletion blackjax/mcmc/mala.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ def __new__( # type: ignore[misc]
) -> SamplingAlgorithm:
kernel = cls.build_kernel()

def init_fn(position: ArrayLikeTree):
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return cls.init(position, logdensity_fn)

def step_fn(rng_key: PRNGKey, state):
Expand Down
3 changes: 2 additions & 1 deletion blackjax/mcmc/marginal_latent_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ def __new__( # type: ignore[misc]

kernel = cls.build_kernel(cov_svd)

def init_fn(position: Array):
def init_fn(position: Array, rng_key=None):
del rng_key
return init(position, logdensity_fn, U_t)

def step_fn(rng_key: PRNGKey, state):
Expand Down
7 changes: 3 additions & 4 deletions blackjax/mcmc/mclmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,16 +155,15 @@ def __new__( # type: ignore[misc]
L,
step_size,
integrator=noneuclidean_mclachlan,
seed=1,
) -> SamplingAlgorithm:
kernel = cls.build_kernel(logdensity_fn, integrator)

def init_fn(position: ArrayLike, rng_key: PRNGKey):
return cls.init(position, logdensity_fn, rng_key)

def update_fn(rng_key, state):
return kernel(rng_key, state, L, step_size)

def init_fn(position: ArrayLike):
return cls.init(position, logdensity_fn, jax.random.PRNGKey(seed))

return SamplingAlgorithm(init_fn, update_fn)


Expand Down
3 changes: 2 additions & 1 deletion blackjax/mcmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ def __new__( # type: ignore[misc]
) -> SamplingAlgorithm:
kernel = cls.build_kernel(integrator, divergence_threshold)

def init_fn(position: ArrayLikeTree):
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return cls.init(position, logdensity_fn)

def step_fn(rng_key: PRNGKey, state):
Expand Down
3 changes: 2 additions & 1 deletion blackjax/mcmc/periodic_orbital.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,8 @@ def __new__( # type: ignore[misc]
) -> SamplingAlgorithm:
kernel = cls.build_kernel(bijection)

def init_fn(position: ArrayLikeTree):
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return cls.init(position, logdensity_fn, period)

def step_fn(rng_key: PRNGKey, state):
Expand Down
9 changes: 6 additions & 3 deletions blackjax/mcmc/random_walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,8 @@ def __new__( # type: ignore[misc]
) -> SamplingAlgorithm:
kernel = cls.build_kernel()

def init_fn(position: ArrayLikeTree):
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return cls.init(position, logdensity_fn)

def step_fn(rng_key: PRNGKey, state):
Expand Down Expand Up @@ -345,7 +346,8 @@ def __new__( # type: ignore[misc]
) -> SamplingAlgorithm:
kernel = cls.build_kernel()

def init_fn(position: ArrayLikeTree):
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return cls.init(position, logdensity_fn)

def step_fn(rng_key: PRNGKey, state):
Expand Down Expand Up @@ -466,7 +468,8 @@ def __new__( # type: ignore[misc]
) -> SamplingAlgorithm:
kernel = cls.build_kernel()

def init_fn(position: ArrayLikeTree):
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return cls.init(position, logdensity_fn)

def step_fn(rng_key: PRNGKey, state):
Expand Down
3 changes: 2 additions & 1 deletion blackjax/sgmcmc/csgld.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,8 @@ def __new__( # type: ignore[misc]
) -> SamplingAlgorithm:
kernel = cls.build_kernel(num_partitions, energy_gap, min_energy)

def init_fn(position: ArrayLikeTree):
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return cls.init(position, num_partitions)

def step_fn(
Expand Down
3 changes: 2 additions & 1 deletion blackjax/sgmcmc/sghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ def __new__( # type: ignore[misc]
) -> SamplingAlgorithm:
kernel = cls.build_kernel(alpha, beta)

def init_fn(position: ArrayLikeTree):
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return cls.init(position)

def step_fn(
Expand Down
3 changes: 2 additions & 1 deletion blackjax/sgmcmc/sgld.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def __new__( # type: ignore[misc]
) -> SamplingAlgorithm:
kernel = cls.build_kernel()

def init_fn(position: ArrayLikeTree):
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return cls.init(position)

def step_fn(
Expand Down
3 changes: 2 additions & 1 deletion blackjax/smc/adaptive_tempered.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ def __new__( # type: ignore[misc]
root_solver,
)

def init_fn(position: ArrayLikeTree):
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return cls.init(position)

def step_fn(rng_key: PRNGKey, state):
Expand Down
3 changes: 2 additions & 1 deletion blackjax/smc/inner_kernel_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def __new__( # type: ignore[misc]
**extra_parameters,
)

def init_fn(position):
def init_fn(position, rng_key=None):
del rng_key
return cls.init(smc_algorithm.init, position, initial_parameter_value)

def step_fn(
Expand Down
3 changes: 2 additions & 1 deletion blackjax/smc/tempered.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ def __new__( # type: ignore[misc]
resampling_fn,
)

def init_fn(position: ArrayLikeTree):
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return cls.init(position)

def step_fn(rng_key: PRNGKey, state, lmbda):
Expand Down
9 changes: 7 additions & 2 deletions blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ def run_inference_algorithm(
) -> tuple[State, State, Info]:
"""Wrapper to run an inference algorithm.
Note that this utility function does not work for Stochastic Gradient MCMC samplers
like sghmc, as SG-MCMC samplers require additional control flow for batches of data
to be passed in during each sample.
Parameters
----------
rng_key
Expand All @@ -175,13 +179,14 @@ def run_inference_algorithm(
2. The trace of states of the inference algorithm (contains the MCMC samples).
3. The trace of the info of the inference algorithm for diagnostics.
"""
init_key, sample_key = split(rng_key, 2)
try:
initial_state = inference_algorithm.init(initial_state_or_position)
initial_state = inference_algorithm.init(initial_state_or_position, init_key)
except TypeError:
# We assume initial_state is already in the right format.
initial_state = initial_state_or_position

keys = split(rng_key, num_steps)
keys = split(sample_key, num_steps)

@jit
def _one_step(state, xs):
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/howto_use_oryx.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import jax
import jax.numpy as jnp
from datetime import date
rng_key = jax.random.PRNGKey(int(date.today().strftime("%Y%m%d")))
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))
```

Oryx's approach, like Aesara's, is to implement probabilistic models as generative models and then apply transformations to get the log-probability density function. We begin with implementing a dense layer with normal prior probability on the weights and use the function `random_variable` to define random variables:
Expand Down
2 changes: 1 addition & 1 deletion tests/adaptation/test_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_chees_adaptation():
num_chains = 16
step_size = 0.1

init_key, warmup_key, inference_key = jax.random.split(jax.random.PRNGKey(0), 3)
init_key, warmup_key, inference_key = jax.random.split(jax.random.key(0), 3)

warmup = blackjax.chees_adaptation(
logprob_fn, num_chains=num_chains, target_acceptance_rate=0.75
Expand Down
4 changes: 2 additions & 2 deletions tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,8 +584,8 @@ def test_latent_gaussian(self):
"algorithm": blackjax.mala,
"initial_position": 1.0,
"parameters": {"step_size": 1e-1},
"num_sampling_steps": 20_000,
"burnin": 2_000,
"num_sampling_steps": 45_000,
"burnin": 5_000,
},
{
"algorithm": blackjax.elliptical_slice,
Expand Down
6 changes: 3 additions & 3 deletions tests/smc/test_inner_kernel_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def log_weights_fn(x, y):
class SMCParameterTuningTest(chex.TestCase):
def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(42)
self.key = jax.random.key(42)

def logdensity_fn(self, log_scale, coefs, preds, x):
"""Linear regression"""
Expand Down Expand Up @@ -129,7 +129,7 @@ def wrapped_kernel(rng_key, state, logdensity):
class MeanAndStdFromParticlesTest(chex.TestCase):
def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(42)
self.key = jax.random.key(42)

def test_mean_and_std(self):
particles = np.array(
Expand Down Expand Up @@ -182,7 +182,7 @@ def test_mean_and_std_multivariable_particles(self):
class InverseMassMatrixFromParticles(chex.TestCase):
def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(42)
self.key = jax.random.key(42)

def test_inverse_mass_matrix_from_particles(self):
inverse_mass_matrix = mass_matrix_from_particles(
Expand Down