Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace seed with key #1145

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions optax/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from optax._src import linesearch as _linesearch
from optax._src import transform
from optax._src import wrappers
import chex


MaskOrFn = Optional[Union[Any, Callable[[base.Params], Any]]]
Expand Down Expand Up @@ -1254,9 +1255,9 @@ def lamb(

def noisy_sgd(
learning_rate: base.ScalarOrSchedule,
key: Optional[chex.PRNGKey] = None,
eta: float = 0.01,
gamma: float = 0.55,
seed: int = 0,
) -> base.GradientTransformation:
r"""A variant of SGD with added noise.

Expand Down Expand Up @@ -1284,10 +1285,10 @@ def noisy_sgd(
Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
key: a PRNG key used as the random key.
eta: Initial variance for the Gaussian noise added to gradients.
gamma: A parameter controlling the annealing of noise over time ``t``, the
variance decays according to ``(1+t)**(-gamma)``.
seed: Seed for the pseudo-random generation process.
variance decays according to ``(1+t)**(-gamma)``.

Returns:
The corresponding :class:`optax.GradientTransformation`.
Expand All @@ -1297,7 +1298,8 @@ def noisy_sgd(
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2) # simple quadratic function
>>> solver = optax.noisy_sgd(learning_rate=0.003)
>>> key = jax.random.key(0)
>>> solver = optax.noisy_sgd(learning_rate=0.003, key)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function: 14.0
Expand All @@ -1317,8 +1319,13 @@ def noisy_sgd(
Neelakantan et al, `Adding Gradient Noise Improves Learning for Very Deep
Networks <https://arxiv.org/abs/1511.06807>`_, 2015
"""
if key is None:
raise ValueError(
"noisy_sgd optimizer requires specifying key: "
"noisy_sgd(..., key=jax.random.key(0))"
)
return combine.chain(
transform.add_noise(eta, gamma, seed),
transform.add_noise(key, eta, gamma),
transform.scale_by_learning_rate(learning_rate),
)

Expand Down Expand Up @@ -2394,7 +2401,7 @@ def lbfgs(
linesearch: Optional[
base.GradientTransformationExtraArgs
] = _linesearch.scale_by_zoom_linesearch(
max_linesearch_steps=20, initial_guess_strategy='one'
max_linesearch_steps=20, initial_guess_strategy="one"
),
) -> base.GradientTransformationExtraArgs:
r"""L-BFGS optimizer.
Expand Down
13 changes: 8 additions & 5 deletions optax/_src/alias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@
),
dict(opt_name='nadam', opt_kwargs=dict(learning_rate=1e-2)),
dict(opt_name='nadamw', opt_kwargs=dict(learning_rate=1e-2)),
dict(opt_name='noisy_sgd', opt_kwargs=dict(learning_rate=1e-3, eta=1e-4)),
dict(
opt_name='noisy_sgd',
opt_kwargs=dict(learning_rate=1e-3, key=jrd.key(0), eta=1e-4)
),
dict(opt_name='novograd', opt_kwargs=dict(learning_rate=1e-3)),
dict(
opt_name='optimistic_gradient_descent',
Expand Down Expand Up @@ -566,7 +569,7 @@ def zakharov(x, xnp):
class LBFGSTest(chex.TestCase):

def test_plain_preconditioning(self):
key = jrd.PRNGKey(0)
key = jrd.key(0)
key_ws, key_us, key_vec = jrd.split(key, 3)
m = 4
d = 3
Expand All @@ -585,7 +588,7 @@ def test_plain_preconditioning(self):

@parameterized.product(idx=[0, 1, 2, 3])
def test_preconditioning_by_lbfgs_on_vectors(self, idx: int):
key = jrd.PRNGKey(0)
key = jrd.key(0)
key_ws, key_us, key_vec = jrd.split(key, 3)
m = 4
d = 3
Expand All @@ -612,7 +615,7 @@ def test_preconditioning_by_lbfgs_on_vectors(self, idx: int):

@parameterized.product(idx=[0, 1, 2, 3])
def test_preconditioning_by_lbfgs_on_trees(self, idx: int):
key = jrd.PRNGKey(0)
key = jrd.key(0)
key_ws, key_us, key_vec = jrd.split(key, 3)
m = 4
shapes = ((3, 2), (5,))
Expand Down Expand Up @@ -716,7 +719,7 @@ def fun_(x):
def fun(x):
return otu.tree_sum(jax.tree.map(fun_, x))

key = jrd.PRNGKey(0)
key = jrd.key(0)
init_array = jrd.normal(key, (2, 4))
init_tree = (init_array[0], init_array[1])

Expand Down
4 changes: 2 additions & 2 deletions optax/_src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ def __init__(self, loc: chex.Array, log_scale: chex.Array):
self._mean.shape, self._scale.shape
)

def sample(self, shape: Sequence[int], seed: chex.PRNGKey) -> chex.Array:
def sample(self, shape: Sequence[int], key: chex.PRNGKey) -> chex.Array:
sample_shape = tuple(shape) + self._param_shape
return (
jax.random.normal(seed, shape=sample_shape) * self._scale + self._mean
jax.random.normal(key, shape=sample_shape) * self._scale + self._mean
)

def log_prob(self, x: chex.Array) -> chex.Array:
Expand Down
25 changes: 19 additions & 6 deletions optax/contrib/_privacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from typing import Any, NamedTuple, Optional

import chex
import jax
from optax._src import base
from optax._src import clipping
Expand All @@ -33,14 +34,16 @@ class DifferentiallyPrivateAggregateState(NamedTuple):


def differentially_private_aggregate(
l2_norm_clip: float, noise_multiplier: float, seed: int
l2_norm_clip: float,
noise_multiplier: float,
key: Optional[chex.PRNGKey] = None
) -> base.GradientTransformation:
"""Aggregates gradients based on the DPSGD algorithm.

Args:
l2_norm_clip: maximum L2 norm of the per-example gradients.
noise_multiplier: ratio of standard deviation to the clipping norm.
seed: initial seed used for the jax.random.PRNGKey
key: a PRNG key used as the random key.

Returns:
A :class:`optax.GradientTransformation`.
Expand All @@ -56,11 +59,16 @@ def differentially_private_aggregate(
JAX using `jax.vmap`). It can still be composed with other transformations
as long as it is the first in the chain.
"""
if key is None:
raise ValueError(
"differentially_private_aggregate optimizer requires specifying key: "
"differentially_private_aggregate(..., key=jax.random.key(0))"
)
noise_std = l2_norm_clip * noise_multiplier

def init_fn(params):
del params
return DifferentiallyPrivateAggregateState(rng_key=jax.random.PRNGKey(seed))
return DifferentiallyPrivateAggregateState(rng_key=key)

def update_fn(updates, state, params=None):
del params
Expand All @@ -85,7 +93,7 @@ def dpsgd(
learning_rate: base.ScalarOrSchedule,
l2_norm_clip: float,
noise_multiplier: float,
seed: int,
key: Optional[chex.PRNGKey] = None,
momentum: Optional[float] = None,
nesterov: bool = False,
) -> base.GradientTransformation:
Expand All @@ -100,7 +108,7 @@ def dpsgd(
learning_rate: A fixed global scaling factor.
l2_norm_clip: Maximum L2 norm of the per-example gradients.
noise_multiplier: Ratio of standard deviation to the clipping norm.
seed: Initial seed used for the jax.random.PRNGKey
key: a PRNG key used as the random key.
momentum: Decay rate used by the momentum term, when it is set to `None`,
then momentum is not used at all.
nesterov: Whether Nesterov momentum is used.
Expand All @@ -117,11 +125,16 @@ def dpsgd(
batch dimension on the 0th axis. That is, this function expects per-example
gradients as input (which are easy to obtain in JAX using `jax.vmap`).
"""
if key is None:
raise ValueError(
"dpsgd optimizer requires specifying key: "
"dpsgd(..., key=jax.random.key(0))"
)
return combine.chain(
differentially_private_aggregate(
l2_norm_clip=l2_norm_clip,
noise_multiplier=noise_multiplier,
seed=seed,
key=key,
),
(
transform.trace(decay=momentum, nesterov=nesterov)
Expand Down
13 changes: 9 additions & 4 deletions optax/contrib/_privacy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from absl.testing import parameterized
import chex
import jax
import jax.random as jrd
import jax.numpy as jnp
from optax.contrib import _privacy

Expand Down Expand Up @@ -45,7 +46,9 @@ def setUp(self):
def test_no_privacy(self):
"""l2_norm_clip=MAX_FLOAT32 and noise_multiplier=0 should recover SGD."""
dp_agg = _privacy.differentially_private_aggregate(
l2_norm_clip=jnp.finfo(jnp.float32).max, noise_multiplier=0.0, seed=0
l2_norm_clip=jnp.finfo(jnp.float32).max,
noise_multiplier=0.0,
key=jrd.key(0)
)
state = dp_agg.init(self.params)
update_fn = self.variant(dp_agg.update)
Expand All @@ -59,7 +62,7 @@ def test_no_privacy(self):
@parameterized.parameters(0.5, 10.0, 20.0, 40.0, 80.0)
def test_clipping_norm(self, l2_norm_clip):
dp_agg = _privacy.differentially_private_aggregate(
l2_norm_clip=l2_norm_clip, noise_multiplier=0.0, seed=42
l2_norm_clip=l2_norm_clip, noise_multiplier=0.0, key=jrd.key(42)
)
state = dp_agg.init(self.params)
update_fn = self.variant(dp_agg.update)
Expand Down Expand Up @@ -87,7 +90,9 @@ def test_clipping_norm(self, l2_norm_clip):
def test_noise_multiplier(self, l2_norm_clip, noise_multiplier):
"""Standard dev. of noise should be l2_norm_clip * noise_multiplier."""
dp_agg = _privacy.differentially_private_aggregate(
l2_norm_clip=l2_norm_clip, noise_multiplier=noise_multiplier, seed=1337
l2_norm_clip=l2_norm_clip,
noise_multiplier=noise_multiplier,
key=jrd.key(1337)
)
state = dp_agg.init(self.params)
update_fn = self.variant(dp_agg.update)
Expand All @@ -103,7 +108,7 @@ def test_noise_multiplier(self, l2_norm_clip, noise_multiplier):
def test_aggregated_updates_as_input_fails(self):
"""Expect per-example gradients as input to this transform."""
dp_agg = _privacy.differentially_private_aggregate(
l2_norm_clip=0.1, noise_multiplier=1.1, seed=2021
l2_norm_clip=0.1, noise_multiplier=1.1, key=jrd.key(2021)
)
state = dp_agg.init(self.params)
mean_grads = jax.tree.map(lambda g: g.mean(0), self.per_eg_grads)
Expand Down
8 changes: 4 additions & 4 deletions optax/monte_carlo/stochastic_gradient_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def score_function_jacobians(
def surrogate(params):
dist = dist_builder(*params)
one_sample_surrogate_fn = lambda x: function(x) * dist.log_prob(x)
samples = jax.lax.stop_gradient(dist.sample((num_samples,), seed=rng))
samples = jax.lax.stop_gradient(dist.sample((num_samples,), key=rng))
# We vmap the function application over samples - this ensures that the
# function we use does not have to be vectorized itself.
return jax.vmap(one_sample_surrogate_fn)(samples)
Expand Down Expand Up @@ -141,7 +141,7 @@ def surrogate(params):
# We vmap the function application over samples - this ensures that the
# function we use does not have to be vectorized itself.
dist = dist_builder(*params)
return jax.vmap(function)(dist.sample((num_samples,), seed=rng))
return jax.vmap(function)(dist.sample((num_samples,), key=rng))

return jax.jacfwd(surrogate)(params)

Expand Down Expand Up @@ -239,7 +239,7 @@ def measure_valued_estimation_mean(
mean, log_std = dist.params
std = jnp.exp(log_std)

dist_samples = dist.sample((num_samples,), seed=rng)
dist_samples = dist.sample((num_samples,), key=rng)

pos_rng, neg_rng = jax.random.split(rng)
pos_sample = jax.random.weibull_min(
Expand Down Expand Up @@ -312,7 +312,7 @@ def measure_valued_estimation_std(
mean, log_std = dist.params
std = jnp.exp(log_std)

dist_samples = dist.sample((num_samples,), seed=rng)
dist_samples = dist.sample((num_samples,), key=rng)

pos_rng, neg_rng = jax.random.split(rng)

Expand Down
17 changes: 9 additions & 8 deletions optax/monte_carlo/stochastic_gradient_estimators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from absl.testing import parameterized
import chex
import jax
import jax.random as jrd
import jax.numpy as jnp
import numpy as np
from optax._src import utils
Expand Down Expand Up @@ -99,7 +100,7 @@ def testConstantFunction(self, estimator, constant):

effective_log_scale = 0.0
log_scale = effective_log_scale * _ones(data_dims)
rng = jax.random.PRNGKey(1)
rng = jrd.key(1)

jacobians = _estimator_variant(self.variant, estimator)(
lambda x: jnp.array(constant),
Expand Down Expand Up @@ -142,7 +143,7 @@ def testConstantFunction(self, estimator, constant):
def testLinearFunction(self, estimator, effective_mean, effective_log_scale):
data_dims = 3
num_samples = _estimator_to_num_samples[estimator]
rng = jax.random.PRNGKey(1)
rng = jrd.key(1)

mean = effective_mean * _ones(data_dims)
log_scale = effective_log_scale * _ones(data_dims)
Expand Down Expand Up @@ -183,7 +184,7 @@ def testQuadraticFunction(
):
data_dims = 3
num_samples = _estimator_to_num_samples[estimator]
rng = jax.random.PRNGKey(1)
rng = jrd.key(1)

mean = effective_mean * _ones(data_dims)
log_scale = effective_log_scale * _ones(data_dims)
Expand Down Expand Up @@ -231,7 +232,7 @@ def testWeightedLinear(
self, estimator, effective_mean, effective_log_scale, weights
):
num_samples = _weighted_estimator_to_num_samples[estimator]
rng = jax.random.PRNGKey(1)
rng = jrd.key(1)

mean = jnp.array(effective_mean)
log_scale = jnp.array(effective_log_scale)
Expand Down Expand Up @@ -278,7 +279,7 @@ def testWeightedQuadratic(
self, estimator, effective_mean, effective_log_scale, weights
):
num_samples = _weighted_estimator_to_num_samples[estimator]
rng = jax.random.PRNGKey(1)
rng = jrd.key(1)

mean = jnp.array(effective_mean, dtype=jnp.float32)
log_scale = jnp.array(effective_log_scale, dtype=jnp.float32)
Expand Down Expand Up @@ -340,8 +341,8 @@ def testNonPolynomialFunctionConsistencyWithPathwise(
self, effective_mean, effective_log_scale, function, coupling
):
num_samples = 10**5
rng = jax.random.PRNGKey(1)
measure_rng, pathwise_rng = jax.random.split(rng)
rng = jrd.key(1)
measure_rng, pathwise_rng = jrd.split(rng)

mean = jnp.array(effective_mean, dtype=jnp.float32)
log_scale = jnp.array(effective_log_scale, dtype=jnp.float32)
Expand Down Expand Up @@ -403,7 +404,7 @@ class MeasuredValuedEstimatorsTest(chex.TestCase):
@parameterized.parameters([True, False])
def testRaisesErrorForNonGaussian(self, coupling):
num_samples = 10**5
rng = jax.random.PRNGKey(1)
rng = jrd.key(1)

function = lambda x: jnp.sum(x) ** 2

Expand Down
8 changes: 4 additions & 4 deletions optax/perturbations/_make_pert.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ class Normal:

def sample(
self,
seed: chex.PRNGKey,
key: chex.PRNGKey,
sample_shape: Shape,
dtype: chex.ArrayDType = float,
) -> jax.Array:
return jax.random.normal(seed, sample_shape, dtype)
return jax.random.normal(key, sample_shape, dtype)

def log_prob(self, inputs: jax.Array) -> jax.Array:
return -0.5 * inputs**2
Expand All @@ -50,11 +50,11 @@ class Gumbel:

def sample(
self,
seed: chex.PRNGKey,
key: chex.PRNGKey,
sample_shape: Shape,
dtype: chex.ArrayDType = float,
) -> jax.Array:
return jax.random.gumbel(seed, sample_shape, dtype)
return jax.random.gumbel(key, sample_shape, dtype)

def log_prob(self, inputs: jax.Array) -> jax.Array:
return -inputs - jnp.exp(-inputs)
Expand Down
Loading
Loading