Skip to content

Commit

Permalink
Add l2-norm utility
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Mar 15, 2024
1 parent 4429745 commit 729c1ca
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
8 changes: 8 additions & 0 deletions jaxley/optimize/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import jax.numpy as jnp
from jax import tree_util


def l2_norm(x: "PyTree") -> jnp.array:
"""Return the L2-norm of a pytree. Taken from GH jax/issues/3124."""
leaves, _ = tree_util.tree_flatten(x)
return jnp.sqrt(sum([jnp.sum(leaf**2) for leaf in leaves]))
12 changes: 12 additions & 0 deletions tests/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import jaxley as jx
from jaxley.channels import HH
from jaxley.optimize import TypeOptimizer
from jaxley.optimize.utils import l2_norm


def test_type_optimizer():
Expand Down Expand Up @@ -75,3 +76,14 @@ def loss_fn(params):
opt_params = optax.apply_updates(opt_params, updates)

assert l > 30.0, "Loss should be high if a uniformly high lr is used."


def test_l2_norm_utility():
true_norm = np.sqrt(np.sum(np.asarray([0.01, 0.003, 0.05, 0.006, 0.07, 0.04]) ** 2))
pytree = [
{"a": 0.01},
{"b": jnp.asarray([[0.003, 0.05]])},
{"c": jnp.asarray([[0.006, 0.07]])},
0.04,
]
assert l2_norm(pytree).item() == true_norm

0 comments on commit 729c1ca

Please sign in to comment.