From ee1907dd6ec8bc9e67b3c5d5a5f815d524fec596 Mon Sep 17 00:00:00 2001 From: Artyom Iudin Date: Wed, 13 Nov 2024 22:27:46 +0300 Subject: [PATCH 01/18] changed seed to ket in add_noise and noisy_sgd. Added example to add_noise from noisy sgd --- optax/_src/alias.py | 11 ++++++----- optax/transforms/_adding.py | 31 ++++++++++++++++++++++++++++--- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 08346fa13..8be68b7dc 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -26,7 +26,7 @@ from optax._src import linesearch as _linesearch from optax._src import transform from optax._src import wrappers - +import chex MaskOrFn = Optional[Union[Any, Callable[[base.Params], Any]]] @@ -1253,10 +1253,10 @@ def lamb( def noisy_sgd( + key: chex.PRNGKey, learning_rate: base.ScalarOrSchedule, eta: float = 0.01, gamma: float = 0.55, - seed: int = 0, ) -> base.GradientTransformation: r"""A variant of SGD with added noise. @@ -1282,12 +1282,12 @@ def noisy_sgd( represents the initial variance ``eta``. Args: + key: a PRNG key used as the random key. learning_rate: A global scaling factor, either fixed or evolving along iterations with a scheduler, see :func:`optax.scale_by_learning_rate`. eta: Initial variance for the Gaussian noise added to gradients. gamma: A parameter controlling the annealing of noise over time ``t``, the variance decays according to ``(1+t)**(-gamma)``. - seed: Seed for the pseudo-random generation process. Returns: The corresponding :class:`optax.GradientTransformation`. @@ -1297,7 +1297,8 @@ def noisy_sgd( >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function - >>> solver = optax.noisy_sgd(learning_rate=0.003) + >>> key = jax.random.key(42) + >>> solver = optax.noisy_sgd(key, learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 @@ -1318,7 +1319,7 @@ def noisy_sgd( Networks `_, 2015 """ return combine.chain( - transform.add_noise(eta, gamma, seed), + transform.add_noise(eta, gamma, key), transform.scale_by_learning_rate(learning_rate), ) diff --git a/optax/transforms/_adding.py b/optax/transforms/_adding.py index cfc5c93fe..daeecf6d7 100644 --- a/optax/transforms/_adding.py +++ b/optax/transforms/_adding.py @@ -71,18 +71,43 @@ class AddNoiseState(NamedTuple): def add_noise( - eta: float, gamma: float, seed: int + eta: float, gamma: float, key: chex.PRNGKey ) -> base.GradientTransformation: """Add gradient noise. Args: eta: Base variance of the gaussian noise added to the gradient. gamma: Decay exponent for annealing of the variance. - seed: Seed for random number generation. + key: a PRNG key used as the random key. Returns: A :class:`optax.GradientTransformation` object. + Examples: + >>> import optax + >>> import jax + >>> import jax.numpy as jnp + >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function + >>> key = jax.random.key(42) + >>> noise = optax.add_noise(eta=0.01, gamma=0.55, key=key) + >>> sgd = optax.scale_by_learning_rate(learning_rate=0.003) + >>> solver = optax.chain(noise, sgd) + >>> params = jnp.array([1., 2., 3.]) + >>> print('Objective function: ', f(params)) + >>> print('Objective function: ', f(params)) + Objective function: 14.0 + >>> opt_state = solver.init(params) + >>> for _ in range(5): + ... grad = jax.grad(f)(params) + ... updates, opt_state = solver.update(grad, opt_state, params) + ... params = optax.apply_updates(params, updates) + ... print('Objective function: {:.2E}'.format(f(params))) + Objective function: 1.38E+01 + Objective function: 1.37E+01 + Objective function: 1.35E+01 + Objective function: 1.33E+01 + Objective function: 1.32E+01 + References: Neelakantan et al, `Adding Gradient Noise Improves Learning for Very Deep Networks `_, 2015 @@ -91,7 +116,7 @@ def add_noise( def init_fn(params): del params return AddNoiseState( - count=jnp.zeros([], jnp.int32), rng_key=jax.random.PRNGKey(seed) + count=jnp.zeros([], jnp.int32), rng_key=key ) def update_fn(updates, state, params=None): # pylint: disable=missing-docstring From b8bd3cc23ca140718fdb539bbf5755859e6f0523 Mon Sep 17 00:00:00 2001 From: Artyom Iudin Date: Wed, 13 Nov 2024 22:37:19 +0300 Subject: [PATCH 02/18] typo in add_noise example --- optax/transforms/_adding.py | 1 - 1 file changed, 1 deletion(-) diff --git a/optax/transforms/_adding.py b/optax/transforms/_adding.py index daeecf6d7..4578ce537 100644 --- a/optax/transforms/_adding.py +++ b/optax/transforms/_adding.py @@ -94,7 +94,6 @@ def add_noise( >>> solver = optax.chain(noise, sgd) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) - >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): From a40c1f998801e33fec5f1dbd677c5467205f2d66 Mon Sep 17 00:00:00 2001 From: Artyom Iudin Date: Thu, 14 Nov 2024 15:21:03 +0300 Subject: [PATCH 03/18] made the key the first argument in add_noise to achieve style consistency with jax.random --- optax/_src/alias.py | 2 +- optax/transforms/_adding.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 8be68b7dc..9812e92c2 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -1319,7 +1319,7 @@ def noisy_sgd( Networks `_, 2015 """ return combine.chain( - transform.add_noise(eta, gamma, key), + transform.add_noise(key, eta, gamma), transform.scale_by_learning_rate(learning_rate), ) diff --git a/optax/transforms/_adding.py b/optax/transforms/_adding.py index 4578ce537..154c32925 100644 --- a/optax/transforms/_adding.py +++ b/optax/transforms/_adding.py @@ -71,14 +71,14 @@ class AddNoiseState(NamedTuple): def add_noise( - eta: float, gamma: float, key: chex.PRNGKey + key: chex.PRNGKey, eta: float, gamma: float ) -> base.GradientTransformation: """Add gradient noise. Args: + key: a PRNG key used as the random key. eta: Base variance of the gaussian noise added to the gradient. gamma: Decay exponent for annealing of the variance. - key: a PRNG key used as the random key. Returns: A :class:`optax.GradientTransformation` object. @@ -89,7 +89,7 @@ def add_noise( >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> key = jax.random.key(42) - >>> noise = optax.add_noise(eta=0.01, gamma=0.55, key=key) + >>> noise = optax.add_noise(key=key, eta=0.01, gamma=0.55) >>> sgd = optax.scale_by_learning_rate(learning_rate=0.003) >>> solver = optax.chain(noise, sgd) >>> params = jnp.array([1., 2., 3.]) From 698130dd1d035c170ef595a0e48a6839bb10f42d Mon Sep 17 00:00:00 2001 From: Artyom Iudin Date: Thu, 14 Nov 2024 15:22:32 +0300 Subject: [PATCH 04/18] returned second empty line after imports --- optax/_src/alias.py | 1 + 1 file changed, 1 insertion(+) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 9812e92c2..91568c940 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -28,6 +28,7 @@ from optax._src import wrappers import chex + MaskOrFn = Optional[Union[Any, Callable[[base.Params], Any]]] From db8291c07a6aa35bff80e64878974e9298bed955 Mon Sep 17 00:00:00 2001 From: Artyom Iudin Date: Sat, 30 Nov 2024 22:44:50 +0300 Subject: [PATCH 05/18] changed key seed to 0 and replaced old-style random key generation --- optax/_src/alias.py | 12 +++++++----- optax/_src/alias_test.py | 10 +++++----- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 91568c940..e83a65e45 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -1254,8 +1254,8 @@ def lamb( def noisy_sgd( - key: chex.PRNGKey, learning_rate: base.ScalarOrSchedule, + key: Optional[chex.PRNGKey] = None, eta: float = 0.01, gamma: float = 0.55, ) -> base.GradientTransformation: @@ -1283,12 +1283,12 @@ def noisy_sgd( represents the initial variance ``eta``. Args: - key: a PRNG key used as the random key. learning_rate: A global scaling factor, either fixed or evolving along iterations with a scheduler, see :func:`optax.scale_by_learning_rate`. + key: a PRNG key used as the random key. eta: Initial variance for the Gaussian noise added to gradients. gamma: A parameter controlling the annealing of noise over time ``t``, the - variance decays according to ``(1+t)**(-gamma)``. + variance decays according to ``(1+t)**(-gamma)``. Returns: The corresponding :class:`optax.GradientTransformation`. @@ -1298,8 +1298,8 @@ def noisy_sgd( >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function - >>> key = jax.random.key(42) - >>> solver = optax.noisy_sgd(key, learning_rate=0.003) + >>> key = jax.random.key(0) + >>> solver = optax.noisy_sgd(learning_rate=0.003, key) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 @@ -1319,6 +1319,8 @@ def noisy_sgd( Neelakantan et al, `Adding Gradient Noise Improves Learning for Very Deep Networks `_, 2015 """ + if key is None: + raise ValueError("noisy_sgd optimizer requires specifying random key: noisy_sgd(..., key=key=jax.random.key(0))") return combine.chain( transform.add_noise(key, eta, gamma), transform.scale_by_learning_rate(learning_rate), diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index b03471b98..d5e6c4c97 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -63,7 +63,7 @@ ), dict(opt_name='nadam', opt_kwargs=dict(learning_rate=1e-2)), dict(opt_name='nadamw', opt_kwargs=dict(learning_rate=1e-2)), - dict(opt_name='noisy_sgd', opt_kwargs=dict(learning_rate=1e-3, eta=1e-4)), + dict(opt_name='noisy_sgd', opt_kwargs=dict(learning_rate=1e-3, key=jrd.key(0), eta=1e-4)), dict(opt_name='novograd', opt_kwargs=dict(learning_rate=1e-3)), dict( opt_name='optimistic_gradient_descent', @@ -566,7 +566,7 @@ def zakharov(x, xnp): class LBFGSTest(chex.TestCase): def test_plain_preconditioning(self): - key = jrd.PRNGKey(0) + key = jrd.key(0) key_ws, key_us, key_vec = jrd.split(key, 3) m = 4 d = 3 @@ -585,7 +585,7 @@ def test_plain_preconditioning(self): @parameterized.product(idx=[0, 1, 2, 3]) def test_preconditioning_by_lbfgs_on_vectors(self, idx: int): - key = jrd.PRNGKey(0) + key = jrd.key(0) key_ws, key_us, key_vec = jrd.split(key, 3) m = 4 d = 3 @@ -612,7 +612,7 @@ def test_preconditioning_by_lbfgs_on_vectors(self, idx: int): @parameterized.product(idx=[0, 1, 2, 3]) def test_preconditioning_by_lbfgs_on_trees(self, idx: int): - key = jrd.PRNGKey(0) + key = jrd.key(0) key_ws, key_us, key_vec = jrd.split(key, 3) m = 4 shapes = ((3, 2), (5,)) @@ -716,7 +716,7 @@ def fun_(x): def fun(x): return otu.tree_sum(jax.tree.map(fun_, x)) - key = jrd.PRNGKey(0) + key = jrd.key(0) init_array = jrd.normal(key, (2, 4)) init_tree = (init_array[0], init_array[1]) From 0db6b30ae54226edec9acb360b12653ed4f82113 Mon Sep 17 00:00:00 2001 From: Artyom Iudin Date: Sat, 30 Nov 2024 22:45:16 +0300 Subject: [PATCH 06/18] changed key seed to 0 --- optax/transforms/_adding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optax/transforms/_adding.py b/optax/transforms/_adding.py index 154c32925..78e71ab75 100644 --- a/optax/transforms/_adding.py +++ b/optax/transforms/_adding.py @@ -88,7 +88,7 @@ def add_noise( >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function - >>> key = jax.random.key(42) + >>> key = jax.random.key(0) >>> noise = optax.add_noise(key=key, eta=0.01, gamma=0.55) >>> sgd = optax.scale_by_learning_rate(learning_rate=0.003) >>> solver = optax.chain(noise, sgd) From 29642d76f3cac74f8cda6ffb565b544d1fecc65e Mon Sep 17 00:00:00 2001 From: Artyom Iudin Date: Sat, 30 Nov 2024 22:46:04 +0300 Subject: [PATCH 07/18] ranamed seed with key and changed old-style random key generation --- optax/contrib/_privacy.py | 9 ++++++--- optax/contrib/_privacy_test.py | 9 +++++---- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/optax/contrib/_privacy.py b/optax/contrib/_privacy.py index b36901014..eb1c81294 100644 --- a/optax/contrib/_privacy.py +++ b/optax/contrib/_privacy.py @@ -16,6 +16,7 @@ from typing import Any, NamedTuple, Optional +import chex import jax from optax._src import base from optax._src import clipping @@ -33,14 +34,14 @@ class DifferentiallyPrivateAggregateState(NamedTuple): def differentially_private_aggregate( - l2_norm_clip: float, noise_multiplier: float, seed: int + l2_norm_clip: float, noise_multiplier: float, key: Optional[chex.PRNGKey] = None ) -> base.GradientTransformation: """Aggregates gradients based on the DPSGD algorithm. Args: l2_norm_clip: maximum L2 norm of the per-example gradients. noise_multiplier: ratio of standard deviation to the clipping norm. - seed: initial seed used for the jax.random.PRNGKey + key: a PRNG key used as the random key. Returns: A :class:`optax.GradientTransformation`. @@ -56,11 +57,13 @@ def differentially_private_aggregate( JAX using `jax.vmap`). It can still be composed with other transformations as long as it is the first in the chain. """ + if key is None: + raise ValueError("differentially_private_aggregate optimizer requires specifying random key: differentially_private_aggregate(..., key=crandom.key(0))") noise_std = l2_norm_clip * noise_multiplier def init_fn(params): del params - return DifferentiallyPrivateAggregateState(rng_key=jax.random.PRNGKey(seed)) + return DifferentiallyPrivateAggregateState(rng_key=key) def update_fn(updates, state, params=None): del params diff --git a/optax/contrib/_privacy_test.py b/optax/contrib/_privacy_test.py index 47e72f40e..f6016c118 100644 --- a/optax/contrib/_privacy_test.py +++ b/optax/contrib/_privacy_test.py @@ -18,6 +18,7 @@ from absl.testing import parameterized import chex import jax +import jax.random as jrd import jax.numpy as jnp from optax.contrib import _privacy @@ -45,7 +46,7 @@ def setUp(self): def test_no_privacy(self): """l2_norm_clip=MAX_FLOAT32 and noise_multiplier=0 should recover SGD.""" dp_agg = _privacy.differentially_private_aggregate( - l2_norm_clip=jnp.finfo(jnp.float32).max, noise_multiplier=0.0, seed=0 + l2_norm_clip=jnp.finfo(jnp.float32).max, noise_multiplier=0.0, key=jrd.key(0) ) state = dp_agg.init(self.params) update_fn = self.variant(dp_agg.update) @@ -59,7 +60,7 @@ def test_no_privacy(self): @parameterized.parameters(0.5, 10.0, 20.0, 40.0, 80.0) def test_clipping_norm(self, l2_norm_clip): dp_agg = _privacy.differentially_private_aggregate( - l2_norm_clip=l2_norm_clip, noise_multiplier=0.0, seed=42 + l2_norm_clip=l2_norm_clip, noise_multiplier=0.0, key=jrd.key(42) ) state = dp_agg.init(self.params) update_fn = self.variant(dp_agg.update) @@ -87,7 +88,7 @@ def test_clipping_norm(self, l2_norm_clip): def test_noise_multiplier(self, l2_norm_clip, noise_multiplier): """Standard dev. of noise should be l2_norm_clip * noise_multiplier.""" dp_agg = _privacy.differentially_private_aggregate( - l2_norm_clip=l2_norm_clip, noise_multiplier=noise_multiplier, seed=1337 + l2_norm_clip=l2_norm_clip, noise_multiplier=noise_multiplier, key=jrd.key(1337) ) state = dp_agg.init(self.params) update_fn = self.variant(dp_agg.update) @@ -103,7 +104,7 @@ def test_noise_multiplier(self, l2_norm_clip, noise_multiplier): def test_aggregated_updates_as_input_fails(self): """Expect per-example gradients as input to this transform.""" dp_agg = _privacy.differentially_private_aggregate( - l2_norm_clip=0.1, noise_multiplier=1.1, seed=2021 + l2_norm_clip=0.1, noise_multiplier=1.1, key=jrd.key(2021) ) state = dp_agg.init(self.params) mean_grads = jax.tree.map(lambda g: g.mean(0), self.per_eg_grads) From a7640c004d084a2fe22eab6aee480adb2674721e Mon Sep 17 00:00:00 2001 From: Artyom Iudin Date: Sat, 30 Nov 2024 22:46:34 +0300 Subject: [PATCH 08/18] replaced seed with key --- optax/_src/utils.py | 4 ++-- optax/perturbations/_make_pert.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/optax/_src/utils.py b/optax/_src/utils.py index 3c59cad3f..54c988ab1 100644 --- a/optax/_src/utils.py +++ b/optax/_src/utils.py @@ -109,10 +109,10 @@ def __init__(self, loc: chex.Array, log_scale: chex.Array): self._mean.shape, self._scale.shape ) - def sample(self, shape: Sequence[int], seed: chex.PRNGKey) -> chex.Array: + def sample(self, shape: Sequence[int], key: chex.PRNGKey) -> chex.Array: sample_shape = tuple(shape) + self._param_shape return ( - jax.random.normal(seed, shape=sample_shape) * self._scale + self._mean + jax.random.normal(key, shape=sample_shape) * self._scale + self._mean ) def log_prob(self, x: chex.Array) -> chex.Array: diff --git a/optax/perturbations/_make_pert.py b/optax/perturbations/_make_pert.py index ae711f5d2..434fa4d1d 100644 --- a/optax/perturbations/_make_pert.py +++ b/optax/perturbations/_make_pert.py @@ -35,11 +35,11 @@ class Normal: def sample( self, - seed: chex.PRNGKey, + key: chex.PRNGKey, sample_shape: Shape, dtype: chex.ArrayDType = float, ) -> jax.Array: - return jax.random.normal(seed, sample_shape, dtype) + return jax.random.normal(key, sample_shape, dtype) def log_prob(self, inputs: jax.Array) -> jax.Array: return -0.5 * inputs**2 @@ -50,11 +50,11 @@ class Gumbel: def sample( self, - seed: chex.PRNGKey, + key: chex.PRNGKey, sample_shape: Shape, dtype: chex.ArrayDType = float, ) -> jax.Array: - return jax.random.gumbel(seed, sample_shape, dtype) + return jax.random.gumbel(key, sample_shape, dtype) def log_prob(self, inputs: jax.Array) -> jax.Array: return -inputs - jnp.exp(-inputs) From d00bc8d1c8e9231e729d9e9cff1b1ab4b6e5a8c0 Mon Sep 17 00:00:00 2001 From: Artyom Iudin Date: Sat, 30 Nov 2024 23:01:31 +0300 Subject: [PATCH 09/18] replaced missed seed and shorted error string --- optax/contrib/_privacy.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/optax/contrib/_privacy.py b/optax/contrib/_privacy.py index eb1c81294..721a1509a 100644 --- a/optax/contrib/_privacy.py +++ b/optax/contrib/_privacy.py @@ -34,7 +34,9 @@ class DifferentiallyPrivateAggregateState(NamedTuple): def differentially_private_aggregate( - l2_norm_clip: float, noise_multiplier: float, key: Optional[chex.PRNGKey] = None + l2_norm_clip: float, + noise_multiplier: float, + key: Optional[chex.PRNGKey] = None ) -> base.GradientTransformation: """Aggregates gradients based on the DPSGD algorithm. @@ -58,7 +60,10 @@ def differentially_private_aggregate( as long as it is the first in the chain. """ if key is None: - raise ValueError("differentially_private_aggregate optimizer requires specifying random key: differentially_private_aggregate(..., key=crandom.key(0))") + raise ValueError( + "differentially_private_aggregate optimizer requires specifying random key: " + "differentially_private_aggregate(..., key=random.key(0))" + ) noise_std = l2_norm_clip * noise_multiplier def init_fn(params): @@ -88,7 +93,7 @@ def dpsgd( learning_rate: base.ScalarOrSchedule, l2_norm_clip: float, noise_multiplier: float, - seed: int, + key: Optional[chex.PRNGKey] = None, momentum: Optional[float] = None, nesterov: bool = False, ) -> base.GradientTransformation: @@ -103,7 +108,7 @@ def dpsgd( learning_rate: A fixed global scaling factor. l2_norm_clip: Maximum L2 norm of the per-example gradients. noise_multiplier: Ratio of standard deviation to the clipping norm. - seed: Initial seed used for the jax.random.PRNGKey + key: a PRNG key used as the random key. momentum: Decay rate used by the momentum term, when it is set to `None`, then momentum is not used at all. nesterov: Whether Nesterov momentum is used. @@ -120,11 +125,16 @@ def dpsgd( batch dimension on the 0th axis. That is, this function expects per-example gradients as input (which are easy to obtain in JAX using `jax.vmap`). """ + if key is None: + raise ValueError( + "dpsgd optimizer requires specifying random key: " + "dpsgd(..., key=random.key(0))" + ) return combine.chain( differentially_private_aggregate( l2_norm_clip=l2_norm_clip, noise_multiplier=noise_multiplier, - seed=seed, + key=key, ), ( transform.trace(decay=momentum, nesterov=nesterov) From 2d7bb16873bec90271756f08eb3aa305a287386b Mon Sep 17 00:00:00 2001 From: Artyom Iudin Date: Sat, 30 Nov 2024 23:03:00 +0300 Subject: [PATCH 10/18] added jax. to error message about key --- optax/contrib/_privacy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optax/contrib/_privacy.py b/optax/contrib/_privacy.py index 721a1509a..69dff1523 100644 --- a/optax/contrib/_privacy.py +++ b/optax/contrib/_privacy.py @@ -62,7 +62,7 @@ def differentially_private_aggregate( if key is None: raise ValueError( "differentially_private_aggregate optimizer requires specifying random key: " - "differentially_private_aggregate(..., key=random.key(0))" + "differentially_private_aggregate(..., key=jax.random.key(0))" ) noise_std = l2_norm_clip * noise_multiplier @@ -128,7 +128,7 @@ def dpsgd( if key is None: raise ValueError( "dpsgd optimizer requires specifying random key: " - "dpsgd(..., key=random.key(0))" + "dpsgd(..., key=jax.random.key(0))" ) return combine.chain( differentially_private_aggregate( From 2d4b798b560b4c6bdb1c7f79afed91866e1c986c Mon Sep 17 00:00:00 2001 From: Artyom Iudin Date: Sat, 30 Nov 2024 23:04:02 +0300 Subject: [PATCH 11/18] shorted key=None error message --- optax/_src/alias.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index e83a65e45..12479309d 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -1320,7 +1320,10 @@ def noisy_sgd( Networks `_, 2015 """ if key is None: - raise ValueError("noisy_sgd optimizer requires specifying random key: noisy_sgd(..., key=key=jax.random.key(0))") + raise ValueError( + "noisy_sgd optimizer requires specifying random key: " + "noisy_sgd(..., key=jax.random.key(0))" + ) return combine.chain( transform.add_noise(key, eta, gamma), transform.scale_by_learning_rate(learning_rate), From f0010b764bde1eb9f44bd7f74e618d2fea271587 Mon Sep 17 00:00:00 2001 From: Artyom Iudin Date: Sat, 30 Nov 2024 23:08:59 +0300 Subject: [PATCH 12/18] removed key from error message and trailing whitespaces --- optax/contrib/_privacy.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/optax/contrib/_privacy.py b/optax/contrib/_privacy.py index 69dff1523..c5d5414ac 100644 --- a/optax/contrib/_privacy.py +++ b/optax/contrib/_privacy.py @@ -34,8 +34,8 @@ class DifferentiallyPrivateAggregateState(NamedTuple): def differentially_private_aggregate( - l2_norm_clip: float, - noise_multiplier: float, + l2_norm_clip: float, + noise_multiplier: float, key: Optional[chex.PRNGKey] = None ) -> base.GradientTransformation: """Aggregates gradients based on the DPSGD algorithm. @@ -61,7 +61,7 @@ def differentially_private_aggregate( """ if key is None: raise ValueError( - "differentially_private_aggregate optimizer requires specifying random key: " + "differentially_private_aggregate optimizer requires specifying key: " "differentially_private_aggregate(..., key=jax.random.key(0))" ) noise_std = l2_norm_clip * noise_multiplier @@ -127,7 +127,7 @@ def dpsgd( """ if key is None: raise ValueError( - "dpsgd optimizer requires specifying random key: " + "dpsgd optimizer requires specifying key: " "dpsgd(..., key=jax.random.key(0))" ) return combine.chain( From 2673d250a0dbb91c6defbfc2d069c94195850f80 Mon Sep 17 00:00:00 2001 From: Artyom Iudin Date: Sat, 30 Nov 2024 23:09:16 +0300 Subject: [PATCH 13/18] removed key from error message --- optax/_src/alias.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 12479309d..c952ce471 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -1321,7 +1321,7 @@ def noisy_sgd( """ if key is None: raise ValueError( - "noisy_sgd optimizer requires specifying random key: " + "noisy_sgd optimizer requires specifying key: " "noisy_sgd(..., key=jax.random.key(0))" ) return combine.chain( @@ -2401,7 +2401,7 @@ def lbfgs( linesearch: Optional[ base.GradientTransformationExtraArgs ] = _linesearch.scale_by_zoom_linesearch( - max_linesearch_steps=20, initial_guess_strategy='one' + max_linesearch_steps=20, initial_guess_strategy="one" ), ) -> base.GradientTransformationExtraArgs: r"""L-BFGS optimizer. From 4c1d306d7163c456e3906e57f00a2b1057ee7118 Mon Sep 17 00:00:00 2001 From: Artyom Iudin Date: Sat, 30 Nov 2024 23:15:58 +0300 Subject: [PATCH 14/18] changed layout of function arguments --- optax/contrib/_privacy_test.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/optax/contrib/_privacy_test.py b/optax/contrib/_privacy_test.py index f6016c118..ae07d8c41 100644 --- a/optax/contrib/_privacy_test.py +++ b/optax/contrib/_privacy_test.py @@ -46,7 +46,9 @@ def setUp(self): def test_no_privacy(self): """l2_norm_clip=MAX_FLOAT32 and noise_multiplier=0 should recover SGD.""" dp_agg = _privacy.differentially_private_aggregate( - l2_norm_clip=jnp.finfo(jnp.float32).max, noise_multiplier=0.0, key=jrd.key(0) + l2_norm_clip=jnp.finfo(jnp.float32).max, + noise_multiplier=0.0, + key=jrd.key(0) ) state = dp_agg.init(self.params) update_fn = self.variant(dp_agg.update) @@ -88,7 +90,9 @@ def test_clipping_norm(self, l2_norm_clip): def test_noise_multiplier(self, l2_norm_clip, noise_multiplier): """Standard dev. of noise should be l2_norm_clip * noise_multiplier.""" dp_agg = _privacy.differentially_private_aggregate( - l2_norm_clip=l2_norm_clip, noise_multiplier=noise_multiplier, key=jrd.key(1337) + l2_norm_clip=l2_norm_clip, + noise_multiplier=noise_multiplier, + key=jrd.key(1337) ) state = dp_agg.init(self.params) update_fn = self.variant(dp_agg.update) From a162271057c81948b3f323882b767d546fed9186 Mon Sep 17 00:00:00 2001 From: Artyom Iudin Date: Sat, 30 Nov 2024 23:17:29 +0300 Subject: [PATCH 15/18] changed layout of noisy_sgd dict --- optax/_src/alias_test.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index d5e6c4c97..3595b3ecc 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -63,7 +63,10 @@ ), dict(opt_name='nadam', opt_kwargs=dict(learning_rate=1e-2)), dict(opt_name='nadamw', opt_kwargs=dict(learning_rate=1e-2)), - dict(opt_name='noisy_sgd', opt_kwargs=dict(learning_rate=1e-3, key=jrd.key(0), eta=1e-4)), + dict( + opt_name='noisy_sgd', + opt_kwargs=dict(learning_rate=1e-3, key=jrd.key(0), eta=1e-4) + ), dict(opt_name='novograd', opt_kwargs=dict(learning_rate=1e-3)), dict( opt_name='optimistic_gradient_descent', From ec4217c0f7671df4a1cb1f2c191504f79b3e7bd0 Mon Sep 17 00:00:00 2001 From: Artyom Iudin Date: Sat, 30 Nov 2024 23:20:41 +0300 Subject: [PATCH 16/18] trailing whitespace in noisy_sgd dict --- optax/_src/alias_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index 3595b3ecc..314103619 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -64,7 +64,7 @@ dict(opt_name='nadam', opt_kwargs=dict(learning_rate=1e-2)), dict(opt_name='nadamw', opt_kwargs=dict(learning_rate=1e-2)), dict( - opt_name='noisy_sgd', + opt_name='noisy_sgd', opt_kwargs=dict(learning_rate=1e-3, key=jrd.key(0), eta=1e-4) ), dict(opt_name='novograd', opt_kwargs=dict(learning_rate=1e-3)), From 7be3a7563cda001608a3debc0dbdf82e3e477c63 Mon Sep 17 00:00:00 2001 From: Artyom Iudin Date: Sat, 30 Nov 2024 23:39:28 +0300 Subject: [PATCH 17/18] replaced seed with key --- optax/transforms/_adding_test.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/optax/transforms/_adding_test.py b/optax/transforms/_adding_test.py index a1cc3e924..a11172d66 100644 --- a/optax/transforms/_adding_test.py +++ b/optax/transforms/_adding_test.py @@ -17,6 +17,7 @@ from absl.testing import absltest import chex import jax +import jax.random as jrd import jax.numpy as jnp from optax.transforms import _adding @@ -73,9 +74,9 @@ def test_add_noise_has_correct_variance_scaling(self): # Prepare to compare noise with a rescaled unit-variance substitute. eta = 0.3 gamma = 0.55 - seed = 314 - noise = _adding.add_noise(eta, gamma, seed) - noise_unit = _adding.add_noise(1.0, 0.0, seed) + key = jrd.key(314) + noise = _adding.add_noise(key, eta, gamma) + noise_unit = _adding.add_noise(key, 1.0, 0.0) params = self.init_params state = noise.init(params) From 03c618b81b82e63ee032e06bc3490371e4c52124 Mon Sep 17 00:00:00 2001 From: Artyom Iudin Date: Sat, 30 Nov 2024 23:42:57 +0300 Subject: [PATCH 18/18] replaced seed with key and old-style PRGKey generation --- .../stochastic_gradient_estimators.py | 8 ++++---- .../stochastic_gradient_estimators_test.py | 17 +++++++++-------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/optax/monte_carlo/stochastic_gradient_estimators.py b/optax/monte_carlo/stochastic_gradient_estimators.py index 541b697a3..50fd86ee1 100644 --- a/optax/monte_carlo/stochastic_gradient_estimators.py +++ b/optax/monte_carlo/stochastic_gradient_estimators.py @@ -85,7 +85,7 @@ def score_function_jacobians( def surrogate(params): dist = dist_builder(*params) one_sample_surrogate_fn = lambda x: function(x) * dist.log_prob(x) - samples = jax.lax.stop_gradient(dist.sample((num_samples,), seed=rng)) + samples = jax.lax.stop_gradient(dist.sample((num_samples,), key=rng)) # We vmap the function application over samples - this ensures that the # function we use does not have to be vectorized itself. return jax.vmap(one_sample_surrogate_fn)(samples) @@ -141,7 +141,7 @@ def surrogate(params): # We vmap the function application over samples - this ensures that the # function we use does not have to be vectorized itself. dist = dist_builder(*params) - return jax.vmap(function)(dist.sample((num_samples,), seed=rng)) + return jax.vmap(function)(dist.sample((num_samples,), key=rng)) return jax.jacfwd(surrogate)(params) @@ -239,7 +239,7 @@ def measure_valued_estimation_mean( mean, log_std = dist.params std = jnp.exp(log_std) - dist_samples = dist.sample((num_samples,), seed=rng) + dist_samples = dist.sample((num_samples,), key=rng) pos_rng, neg_rng = jax.random.split(rng) pos_sample = jax.random.weibull_min( @@ -312,7 +312,7 @@ def measure_valued_estimation_std( mean, log_std = dist.params std = jnp.exp(log_std) - dist_samples = dist.sample((num_samples,), seed=rng) + dist_samples = dist.sample((num_samples,), key=rng) pos_rng, neg_rng = jax.random.split(rng) diff --git a/optax/monte_carlo/stochastic_gradient_estimators_test.py b/optax/monte_carlo/stochastic_gradient_estimators_test.py index b71dde808..1b085a03d 100644 --- a/optax/monte_carlo/stochastic_gradient_estimators_test.py +++ b/optax/monte_carlo/stochastic_gradient_estimators_test.py @@ -18,6 +18,7 @@ from absl.testing import parameterized import chex import jax +import jax.random as jrd import jax.numpy as jnp import numpy as np from optax._src import utils @@ -99,7 +100,7 @@ def testConstantFunction(self, estimator, constant): effective_log_scale = 0.0 log_scale = effective_log_scale * _ones(data_dims) - rng = jax.random.PRNGKey(1) + rng = jrd.key(1) jacobians = _estimator_variant(self.variant, estimator)( lambda x: jnp.array(constant), @@ -142,7 +143,7 @@ def testConstantFunction(self, estimator, constant): def testLinearFunction(self, estimator, effective_mean, effective_log_scale): data_dims = 3 num_samples = _estimator_to_num_samples[estimator] - rng = jax.random.PRNGKey(1) + rng = jrd.key(1) mean = effective_mean * _ones(data_dims) log_scale = effective_log_scale * _ones(data_dims) @@ -183,7 +184,7 @@ def testQuadraticFunction( ): data_dims = 3 num_samples = _estimator_to_num_samples[estimator] - rng = jax.random.PRNGKey(1) + rng = jrd.key(1) mean = effective_mean * _ones(data_dims) log_scale = effective_log_scale * _ones(data_dims) @@ -231,7 +232,7 @@ def testWeightedLinear( self, estimator, effective_mean, effective_log_scale, weights ): num_samples = _weighted_estimator_to_num_samples[estimator] - rng = jax.random.PRNGKey(1) + rng = jrd.key(1) mean = jnp.array(effective_mean) log_scale = jnp.array(effective_log_scale) @@ -278,7 +279,7 @@ def testWeightedQuadratic( self, estimator, effective_mean, effective_log_scale, weights ): num_samples = _weighted_estimator_to_num_samples[estimator] - rng = jax.random.PRNGKey(1) + rng = jrd.key(1) mean = jnp.array(effective_mean, dtype=jnp.float32) log_scale = jnp.array(effective_log_scale, dtype=jnp.float32) @@ -340,8 +341,8 @@ def testNonPolynomialFunctionConsistencyWithPathwise( self, effective_mean, effective_log_scale, function, coupling ): num_samples = 10**5 - rng = jax.random.PRNGKey(1) - measure_rng, pathwise_rng = jax.random.split(rng) + rng = jrd.key(1) + measure_rng, pathwise_rng = jrd.split(rng) mean = jnp.array(effective_mean, dtype=jnp.float32) log_scale = jnp.array(effective_log_scale, dtype=jnp.float32) @@ -403,7 +404,7 @@ class MeasuredValuedEstimatorsTest(chex.TestCase): @parameterized.parameters([True, False]) def testRaisesErrorForNonGaussian(self, coupling): num_samples = 10**5 - rng = jax.random.PRNGKey(1) + rng = jrd.key(1) function = lambda x: jnp.sum(x) ** 2