From 3b14762419e41f97187791fc5e5cb9df702e0e73 Mon Sep 17 00:00:00 2001 From: Matteo Hessel Date: Tue, 3 Dec 2024 03:51:43 -0800 Subject: [PATCH] Fix docs for `optax.partition`. PiperOrigin-RevId: 702279061 --- optax/transforms/_combining.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/optax/transforms/_combining.py b/optax/transforms/_combining.py index 4a1fd5002..364031b9a 100644 --- a/optax/transforms/_combining.py +++ b/optax/transforms/_combining.py @@ -199,7 +199,7 @@ def partition( >>> gradients = jax.tree.map(jnp.ones_like, params) # dummy gradients >>> label_fn = map_nested_fn(lambda k, _: k) - >>> tx = optax.multi_transform( + >>> tx = optax.partition( ... {'w': optax.adam(1.0), 'b': optax.sgd(1.0)}, label_fn) >>> state = tx.init(params) >>> updates, new_state = tx.update(gradients, state, params) @@ -214,12 +214,12 @@ def partition( >>> all_params = (generator_params, discriminator_params) >>> param_labels = ('generator', 'discriminator') - >>> tx = optax.multi_transform( + >>> tx = optax.partition( >>> {'generator': optax.adam(0.1), 'discriminator': optax.adam(0.5)}, >>> param_labels) If you would like to not optimize some parameters, you may wrap - :func:`optax.multi_transform` with :func:`optax.masked`. + :func:`optax.partition` with :func:`optax.masked`. """ transforms = {