Skip to content

Commit

Permalink
Merge pull request #1170 from leloykun:fc--fix-muon
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 714411374
  • Loading branch information
OptaxDev committed Jan 11, 2025
2 parents 6ed9095 + e9f838f commit 831ddbf
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion optax/contrib/_muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def update_fn(updates, state, params=None):
# Apply Newton-schulz orthogonalization.
updates = jax.tree.map(
lambda x: orthogonalize_via_newton_schulz(x, ns_coeffs_, ns_steps, eps),
updates,
mu_hat,
)
if adaptive:
# Scale the orthogonalized updates by the dual norm of the original
Expand Down

0 comments on commit 831ddbf

Please sign in to comment.