Skip to content

Commit

Permalink
Extend function signature for InitFn
Browse files Browse the repository at this point in the history
  • Loading branch information
junpenglao committed Dec 11, 2023
1 parent fac1d5e commit 9b29b00
Show file tree
Hide file tree
Showing 24 changed files with 63 additions and 45 deletions.
2 changes: 1 addition & 1 deletion blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def kernel(x):
initial_state = MCMCState(position=0, momentum=1)
# Generate a random number generator key
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
# Find the optimal parameters for the MCLMC algorithm
final_state, final_params = mclmc_find_L_and_step_size(
Expand Down
4 changes: 2 additions & 2 deletions blackjax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, NamedTuple
from typing import Callable, NamedTuple, Optional

from typing_extensions import Protocol

Expand All @@ -34,7 +34,7 @@ class InitFn(Protocol):
"""

def __call__(self, position: Position) -> State:
def __call__(self, position: Position, rng_key: Optional[PRNGKey]) -> State:
"""Initialize the algorithm's state.
Parameters
Expand Down
3 changes: 2 additions & 1 deletion blackjax/mcmc/barker.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ def __new__( # type: ignore[misc]
) -> SamplingAlgorithm:
kernel = cls.build_kernel()

def init_fn(position: ArrayLikeTree):
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return cls.init(position, logdensity_fn)

def step_fn(rng_key: PRNGKey, state):
Expand Down
9 changes: 6 additions & 3 deletions blackjax/mcmc/dynamic_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,11 @@ def __new__( # type: ignore[misc]
integrator, divergence_threshold, next_random_arg_fn, integration_steps_fn
)

def init_fn(position: ArrayLikeTree, random_generator_arg: Array):
return cls.init(position, logdensity_fn, random_generator_arg)
def init_fn(position: ArrayLikeTree, rng_key: Array):
# Note that rng_key here is not necessarily a PRNGKey, could be a Array that
# for generates a sequence of pseudo or quasi-random numbers (previously
# named as `random_generator_arg`)
return cls.init(position, logdensity_fn, rng_key)

def step_fn(rng_key: PRNGKey, state):
return kernel(
Expand All @@ -175,7 +178,7 @@ def step_fn(rng_key: PRNGKey, state):
inverse_mass_matrix,
)

return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type]
return SamplingAlgorithm(init_fn, step_fn)


def halton_sequence(i: Array, max_bits: int = 10) -> float:
Expand Down
3 changes: 2 additions & 1 deletion blackjax/mcmc/elliptical_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def __new__( # type: ignore[misc]
) -> SamplingAlgorithm:
kernel = cls.build_kernel(cov, mean)

def init_fn(position: ArrayLikeTree):
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return cls.init(position, loglikelihood_fn)

def step_fn(rng_key: PRNGKey, state):
Expand Down
2 changes: 1 addition & 1 deletion blackjax/mcmc/ghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,4 +287,4 @@ def step_fn(rng_key: PRNGKey, state):
delta,
)

return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type]
return SamplingAlgorithm(init_fn, step_fn)
3 changes: 2 additions & 1 deletion blackjax/mcmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ def __new__( # type: ignore[misc]
) -> SamplingAlgorithm:
kernel = cls.build_kernel(integrator, divergence_threshold)

def init_fn(position: ArrayLikeTree):
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return cls.init(position, logdensity_fn)

def step_fn(rng_key: PRNGKey, state):
Expand Down
3 changes: 2 additions & 1 deletion blackjax/mcmc/mala.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ def __new__( # type: ignore[misc]
) -> SamplingAlgorithm:
kernel = cls.build_kernel()

def init_fn(position: ArrayLikeTree):
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return cls.init(position, logdensity_fn)

def step_fn(rng_key: PRNGKey, state):
Expand Down
8 changes: 5 additions & 3 deletions blackjax/mcmc/marginal_latent_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,18 +184,20 @@ def __new__( # type: ignore[misc]
cls,
logdensity_fn: Callable,
covariance: Array,
delta: float,
mean: Optional[Array] = None,
) -> SamplingAlgorithm:
init, kernel = init_and_kernel(logdensity_fn, covariance, mean)

def init_fn(position: Array):
def init_fn(position: Array, rng_key=None):
del rng_key
return init(position)

def step_fn(rng_key: PRNGKey, state, delta: float):
def step_fn(rng_key: PRNGKey, state):
return kernel(
rng_key,
state,
delta,
)

return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type]
return SamplingAlgorithm(init_fn, step_fn)
7 changes: 3 additions & 4 deletions blackjax/mcmc/mclmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,16 +155,15 @@ def __new__( # type: ignore[misc]
L,
step_size,
integrator=noneuclidean_mclachlan,
seed=1,
) -> SamplingAlgorithm:
kernel = cls.build_kernel(logdensity_fn, integrator)

def init_fn(position: ArrayLike, rng_key: PRNGKey):
return cls.init(position, logdensity_fn, rng_key)

def update_fn(rng_key, state):
return kernel(rng_key, state, L, step_size)

def init_fn(position: ArrayLike):
return cls.init(position, logdensity_fn, jax.random.PRNGKey(seed))

return SamplingAlgorithm(init_fn, update_fn)


Expand Down
3 changes: 2 additions & 1 deletion blackjax/mcmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ def __new__( # type: ignore[misc]
) -> SamplingAlgorithm:
kernel = cls.build_kernel(integrator, divergence_threshold)

def init_fn(position: ArrayLikeTree):
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return cls.init(position, logdensity_fn)

def step_fn(rng_key: PRNGKey, state):
Expand Down
3 changes: 2 additions & 1 deletion blackjax/mcmc/periodic_orbital.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,8 @@ def __new__( # type: ignore[misc]
) -> SamplingAlgorithm:
kernel = cls.build_kernel(bijection)

def init_fn(position: ArrayLikeTree):
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return cls.init(position, logdensity_fn, period)

def step_fn(rng_key: PRNGKey, state):
Expand Down
9 changes: 6 additions & 3 deletions blackjax/mcmc/random_walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,8 @@ def __new__( # type: ignore[misc]
) -> SamplingAlgorithm:
kernel = cls.build_kernel()

def init_fn(position: ArrayLikeTree):
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return cls.init(position, logdensity_fn)

def step_fn(rng_key: PRNGKey, state):
Expand Down Expand Up @@ -345,7 +346,8 @@ def __new__( # type: ignore[misc]
) -> SamplingAlgorithm:
kernel = cls.build_kernel()

def init_fn(position: ArrayLikeTree):
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return cls.init(position, logdensity_fn)

def step_fn(rng_key: PRNGKey, state):
Expand Down Expand Up @@ -466,7 +468,8 @@ def __new__( # type: ignore[misc]
) -> SamplingAlgorithm:
kernel = cls.build_kernel()

def init_fn(position: ArrayLikeTree):
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return cls.init(position, logdensity_fn)

def step_fn(rng_key: PRNGKey, state):
Expand Down
3 changes: 2 additions & 1 deletion blackjax/sgmcmc/csgld.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,8 @@ def __new__( # type: ignore[misc]
) -> SamplingAlgorithm:
kernel = cls.build_kernel(num_partitions, energy_gap, min_energy)

def init_fn(position: ArrayLikeTree):
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return cls.init(position, num_partitions)

def step_fn(
Expand Down
3 changes: 2 additions & 1 deletion blackjax/sgmcmc/sghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ def __new__( # type: ignore[misc]
) -> SamplingAlgorithm:
kernel = cls.build_kernel(alpha, beta)

def init_fn(position: ArrayLikeTree):
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return cls.init(position)

def step_fn(
Expand Down
3 changes: 2 additions & 1 deletion blackjax/sgmcmc/sgld.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def __new__( # type: ignore[misc]
) -> SamplingAlgorithm:
kernel = cls.build_kernel()

def init_fn(position: ArrayLikeTree):
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return cls.init(position)

def step_fn(
Expand Down
3 changes: 2 additions & 1 deletion blackjax/smc/adaptive_tempered.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ def __new__( # type: ignore[misc]
root_solver,
)

def init_fn(position: ArrayLikeTree):
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return cls.init(position)

def step_fn(rng_key: PRNGKey, state):
Expand Down
3 changes: 2 additions & 1 deletion blackjax/smc/inner_kernel_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def __new__( # type: ignore[misc]
**extra_parameters,
)

def init_fn(position):
def init_fn(position, rng_key=None):
del rng_key
return cls.init(smc_algorithm.init, position, initial_parameter_value)

def step_fn(
Expand Down
3 changes: 2 additions & 1 deletion blackjax/smc/tempered.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ def __new__( # type: ignore[misc]
resampling_fn,
)

def init_fn(position: ArrayLikeTree):
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return cls.init(position)

def step_fn(rng_key: PRNGKey, state, lmbda):
Expand Down
9 changes: 7 additions & 2 deletions blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ def run_inference_algorithm(
) -> tuple[State, State, Info]:
"""Wrapper to run an inference algorithm.
Note that this utility function does not work for Stochastic Gradient MCMC samplers
like sghmc, as SG-MCMC samplers require additional control flow for batches of data
to be passed in during each sample.
Parameters
----------
rng_key
Expand All @@ -175,13 +179,14 @@ def run_inference_algorithm(
2. The trace of states of the inference algorithm (contains the MCMC samples).
3. The trace of the info of the inference algorithm for diagnostics.
"""
init_key, sample_key = split(rng_key, 2)
try:
initial_state = inference_algorithm.init(initial_state_or_position)
initial_state = inference_algorithm.init(initial_state_or_position, init_key)
except TypeError:
# We assume initial_state is already in the right format.
initial_state = initial_state_or_position

keys = split(rng_key, num_steps)
keys = split(sample_key, num_steps)

@jit
def _one_step(state, xs):
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/howto_use_oryx.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import jax
import jax.numpy as jnp
from datetime import date
rng_key = jax.random.PRNGKey(int(date.today().strftime("%Y%m%d")))
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))
```

Oryx's approach, like Aesara's, is to implement probabilistic models as generative models and then apply transformations to get the log-probability density function. We begin with implementing a dense layer with normal prior probability on the weights and use the function `random_variable` to define random variables:
Expand Down
2 changes: 1 addition & 1 deletion tests/adaptation/test_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_chees_adaptation():
num_chains = 16
step_size = 0.1

init_key, warmup_key, inference_key = jax.random.split(jax.random.PRNGKey(0), 3)
init_key, warmup_key, inference_key = jax.random.split(jax.random.key(0), 3)

warmup = blackjax.chees_adaptation(
logprob_fn, num_chains=num_chains, target_acceptance_rate=0.75
Expand Down
12 changes: 3 additions & 9 deletions tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,13 +513,7 @@ def test_latent_gaussian(self):
from blackjax import mgrad_gaussian

inference_algorithm = mgrad_gaussian(
lambda x: -0.5 * jnp.sum((x - 1.0) ** 2), self.C
)
inference_algorithm = inference_algorithm._replace(
step=functools.partial(
inference_algorithm.step,
delta=self.delta,
)
lambda x: -0.5 * jnp.sum((x - 1.0) ** 2), self.C, self.delta
)

initial_state = inference_algorithm.init(jnp.zeros((1,)))
Expand Down Expand Up @@ -588,8 +582,8 @@ def test_latent_gaussian(self):
"algorithm": blackjax.mala,
"initial_position": 1.0,
"parameters": {"step_size": 1e-1},
"num_sampling_steps": 20_000,
"burnin": 2_000,
"num_sampling_steps": 45_000,
"burnin": 5_000,
},
{
"algorithm": blackjax.elliptical_slice,
Expand Down
6 changes: 3 additions & 3 deletions tests/smc/test_inner_kernel_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def log_weights_fn(x, y):
class SMCParameterTuningTest(chex.TestCase):
def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(42)
self.key = jax.random.key(42)

def logdensity_fn(self, log_scale, coefs, preds, x):
"""Linear regression"""
Expand Down Expand Up @@ -129,7 +129,7 @@ def wrapped_kernel(rng_key, state, logdensity):
class MeanAndStdFromParticlesTest(chex.TestCase):
def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(42)
self.key = jax.random.key(42)

def test_mean_and_std(self):
particles = np.array(
Expand Down Expand Up @@ -182,7 +182,7 @@ def test_mean_and_std_multivariable_particles(self):
class InverseMassMatrixFromParticles(chex.TestCase):
def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(42)
self.key = jax.random.key(42)

def test_inverse_mass_matrix_from_particles(self):
inverse_mass_matrix = mass_matrix_from_particles(
Expand Down

0 comments on commit 9b29b00

Please sign in to comment.