From 80a127080cc57752eb341507a6a1c3a31f126f6b Mon Sep 17 00:00:00 2001 From: Daniel Bershatsky Date: Fri, 3 Jan 2025 23:27:46 +0300 Subject: [PATCH] Adjust according to review comments --- optax/losses/_classification.py | 27 ++++++++++++++++++++------- optax/losses/_classification_test.py | 6 +++++- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/optax/losses/_classification.py b/optax/losses/_classification.py index 6813e284..c5818ef5 100644 --- a/optax/losses/_classification.py +++ b/optax/losses/_classification.py @@ -15,18 +15,31 @@ """Classification losses.""" import functools +import operator from typing import Optional, Union import chex import jax import jax.numpy as jnp -import numpy as np from optax import projections -if np.__version__.startswith('1.'): - from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple -else: - from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple + +def canonicalize_axis(axis, ndim): + """Vendored version of :func:`numpy.lib.array_utils.normalize_axis_index`. + """ + if 0 <= (axis := operator.index(axis)) < ndim: + return axis + elif -ndim <= axis < 0: + return axis + ndim + else: + raise ValueError(f'axis {axis} is out of bounds for array of ' + f'dimension {ndim}') + + +def canonicalize_axes(axes, ndim) -> tuple[int, ...]: + """Vendored version of :func:`numpy.lib.array_utils.normalize_axis_tuple`. + """ + return tuple(canonicalize_axis(x, ndim) for x in axes) def sigmoid_binary_cross_entropy( @@ -354,11 +367,11 @@ def softmax_cross_entropy_with_integer_labels( chex.assert_type([logits], float) chex.assert_type([labels], int) if isinstance(axis, int): - axis = normalize_axis_index(axis, logits.ndim) + axis = canonicalize_axis(axis, logits.ndim) elif isinstance(axis, tuple): # Move all "feature" dimensions to the end preserving axis ordering and # subsequent flattening "feature" dimensions to a single one. - logit_axis = normalize_axis_tuple(axis, logits.ndim, argname='logits') + logit_axis = canonicalize_axes(axis, logits.ndim) batch_axis = tuple(x for x in range(logits.ndim) if x not in logit_axis) axis = len(batch_axis) logits = logits.transpose(batch_axis + logit_axis) diff --git a/optax/losses/_classification_test.py b/optax/losses/_classification_test.py index 2fc13271..9dbe9c5f 100644 --- a/optax/losses/_classification_test.py +++ b/optax/losses/_classification_test.py @@ -254,6 +254,8 @@ def test_axis(self, shape, axis): {'axis': (-3, -1), 'shape': (2, 3, 4, 5)}, {'axis': (-1, -2), 'shape': (2, 3, 4, 5)}, {'axis': (-2, -1), 'shape': (2, 3, 4, 5)}, + {'axis': (0, 1, 3), 'shape': (2, 3, 4, 5)}, + {'axis': (-4, -3, -1), 'shape': (2, 3, 4, 5)}, ) def test_axes(self, shape: tuple[int, ...], axis: tuple[int, ...]): # Canonicalize axis and calculate shapes. @@ -274,9 +276,11 @@ def test_axes(self, shape: tuple[int, ...], axis: tuple[int, ...]): desired = fn(logits, labels) # Apply inverse axes permutation to obtain an array of `shape` shape. + perm = labels_axis + logits_axis + perm_inv = tuple(i for i, p in sorted(enumerate(perm), key=lambda x: x[1])) logits = logits \ .reshape(labels_shape + logits_shape) \ - .transpose(labels_axis + logits_axis) + .transpose(perm_inv) assert logits.shape == shape actual = fn(logits, labels, axis) np.testing.assert_allclose(actual, desired)