diff --git a/optax/monte_carlo/stochastic_gradient_estimators.py b/optax/monte_carlo/stochastic_gradient_estimators.py index b766bd97..541b697a 100644 --- a/optax/monte_carlo/stochastic_gradient_estimators.py +++ b/optax/monte_carlo/stochastic_gradient_estimators.py @@ -34,7 +34,6 @@ import chex import jax -import jax.radom as jrd import jax.numpy as jnp import numpy as np from optax._src import base @@ -242,15 +241,15 @@ def measure_valued_estimation_mean( dist_samples = dist.sample((num_samples,), seed=rng) - pos_rng, neg_rng = jrd.split(rng) - pos_sample = jrd.weibull_min( + pos_rng, neg_rng = jax.random.split(rng) + pos_sample = jax.random.weibull_min( pos_rng, scale=math.sqrt(2.0), concentration=2.0, shape=dist_samples.shape ) if coupling: neg_sample = pos_sample else: - neg_sample = jrd.weibull_min( + neg_sample = jax.random.weibull_min( neg_rng, scale=math.sqrt(2.0), concentration=2.0, @@ -315,17 +314,17 @@ def measure_valued_estimation_std( dist_samples = dist.sample((num_samples,), seed=rng) - pos_rng, neg_rng = jrd.split(rng) + pos_rng, neg_rng = jax.random.split(rng) # The only difference between mean and std gradients is what we sample. - pos_sample = jrd.double_sided_maxwell( + pos_sample = jax.random.double_sided_maxwell( pos_rng, loc=0.0, scale=1.0, shape=dist_samples.shape ) if coupling: - unif_rvs = jrd.uniform(neg_rng, dist_samples.shape) + unif_rvs = jax.random.uniform(neg_rng, dist_samples.shape) neg_sample = unif_rvs * pos_sample else: - neg_sample = jrd.normal(neg_rng, dist_samples.shape) + neg_sample = jax.random.normal(neg_rng, dist_samples.shape) # Both need to be positive in the case of the scale. # N x D