Skip to content

Commit

Permalink
Adjust according to review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
daskol committed Jan 8, 2025
1 parent d2d14cf commit 80a1270
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 8 deletions.
27 changes: 20 additions & 7 deletions optax/losses/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,31 @@
"""Classification losses."""

import functools
import operator
from typing import Optional, Union

import chex
import jax
import jax.numpy as jnp
import numpy as np
from optax import projections

if np.__version__.startswith('1.'):
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
else:
from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple

def canonicalize_axis(axis, ndim):
"""Vendored version of :func:`numpy.lib.array_utils.normalize_axis_index`.
"""
if 0 <= (axis := operator.index(axis)) < ndim:
return axis
elif -ndim <= axis < 0:
return axis + ndim
else:
raise ValueError(f'axis {axis} is out of bounds for array of '
f'dimension {ndim}')


def canonicalize_axes(axes, ndim) -> tuple[int, ...]:
"""Vendored version of :func:`numpy.lib.array_utils.normalize_axis_tuple`.
"""
return tuple(canonicalize_axis(x, ndim) for x in axes)


def sigmoid_binary_cross_entropy(
Expand Down Expand Up @@ -354,11 +367,11 @@ def softmax_cross_entropy_with_integer_labels(
chex.assert_type([logits], float)
chex.assert_type([labels], int)
if isinstance(axis, int):
axis = normalize_axis_index(axis, logits.ndim)
axis = canonicalize_axis(axis, logits.ndim)
elif isinstance(axis, tuple):
# Move all "feature" dimensions to the end preserving axis ordering and
# subsequent flattening "feature" dimensions to a single one.
logit_axis = normalize_axis_tuple(axis, logits.ndim, argname='logits')
logit_axis = canonicalize_axes(axis, logits.ndim)
batch_axis = tuple(x for x in range(logits.ndim) if x not in logit_axis)
axis = len(batch_axis)
logits = logits.transpose(batch_axis + logit_axis)
Expand Down
6 changes: 5 additions & 1 deletion optax/losses/_classification_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ def test_axis(self, shape, axis):
{'axis': (-3, -1), 'shape': (2, 3, 4, 5)},
{'axis': (-1, -2), 'shape': (2, 3, 4, 5)},
{'axis': (-2, -1), 'shape': (2, 3, 4, 5)},
{'axis': (0, 1, 3), 'shape': (2, 3, 4, 5)},
{'axis': (-4, -3, -1), 'shape': (2, 3, 4, 5)},
)
def test_axes(self, shape: tuple[int, ...], axis: tuple[int, ...]):
# Canonicalize axis and calculate shapes.
Expand All @@ -274,9 +276,11 @@ def test_axes(self, shape: tuple[int, ...], axis: tuple[int, ...]):
desired = fn(logits, labels)

# Apply inverse axes permutation to obtain an array of `shape` shape.
perm = labels_axis + logits_axis
perm_inv = tuple(i for i, p in sorted(enumerate(perm), key=lambda x: x[1]))
logits = logits \
.reshape(labels_shape + logits_shape) \
.transpose(labels_axis + logits_axis)
.transpose(perm_inv)
assert logits.shape == shape
actual = fn(logits, labels, axis)
np.testing.assert_allclose(actual, desired)
Expand Down

0 comments on commit 80a1270

Please sign in to comment.