Skip to content

Commit

Permalink
Fix docs for optax.partition.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 702279061
  • Loading branch information
mtthss authored and OptaxDev committed Dec 3, 2024
1 parent 02a1bd7 commit 3b14762
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions optax/transforms/_combining.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def partition(
>>> gradients = jax.tree.map(jnp.ones_like, params) # dummy gradients
>>> label_fn = map_nested_fn(lambda k, _: k)
>>> tx = optax.multi_transform(
>>> tx = optax.partition(
... {'w': optax.adam(1.0), 'b': optax.sgd(1.0)}, label_fn)
>>> state = tx.init(params)
>>> updates, new_state = tx.update(gradients, state, params)
Expand All @@ -214,12 +214,12 @@ def partition(
>>> all_params = (generator_params, discriminator_params)
>>> param_labels = ('generator', 'discriminator')
>>> tx = optax.multi_transform(
>>> tx = optax.partition(
>>> {'generator': optax.adam(0.1), 'discriminator': optax.adam(0.5)},
>>> param_labels)
If you would like to not optimize some parameters, you may wrap
:func:`optax.multi_transform` with :func:`optax.masked`.
:func:`optax.partition` with :func:`optax.masked`.
"""

transforms = {
Expand Down

0 comments on commit 3b14762

Please sign in to comment.