Skip to content

Commit

Permalink
returned jax.random instead of jrd v2
Browse files Browse the repository at this point in the history
  • Loading branch information
Tomas542 committed Jan 6, 2025
1 parent 8a3a909 commit 5499f47
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions optax/monte_carlo/stochastic_gradient_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@

import chex
import jax
import jax.radom as jrd
import jax.numpy as jnp
import numpy as np
from optax._src import base
Expand Down Expand Up @@ -242,15 +241,15 @@ def measure_valued_estimation_mean(

dist_samples = dist.sample((num_samples,), seed=rng)

pos_rng, neg_rng = jrd.split(rng)
pos_sample = jrd.weibull_min(
pos_rng, neg_rng = jax.random.split(rng)
pos_sample = jax.random.weibull_min(
pos_rng, scale=math.sqrt(2.0), concentration=2.0, shape=dist_samples.shape
)

if coupling:
neg_sample = pos_sample
else:
neg_sample = jrd.weibull_min(
neg_sample = jax.random.weibull_min(
neg_rng,
scale=math.sqrt(2.0),
concentration=2.0,
Expand Down Expand Up @@ -315,17 +314,17 @@ def measure_valued_estimation_std(

dist_samples = dist.sample((num_samples,), seed=rng)

pos_rng, neg_rng = jrd.split(rng)
pos_rng, neg_rng = jax.random.split(rng)

# The only difference between mean and std gradients is what we sample.
pos_sample = jrd.double_sided_maxwell(
pos_sample = jax.random.double_sided_maxwell(
pos_rng, loc=0.0, scale=1.0, shape=dist_samples.shape
)
if coupling:
unif_rvs = jrd.uniform(neg_rng, dist_samples.shape)
unif_rvs = jax.random.uniform(neg_rng, dist_samples.shape)
neg_sample = unif_rvs * pos_sample
else:
neg_sample = jrd.normal(neg_rng, dist_samples.shape)
neg_sample = jax.random.normal(neg_rng, dist_samples.shape)

# Both need to be positive in the case of the scale.
# N x D
Expand Down

0 comments on commit 5499f47

Please sign in to comment.