From d2d14cfe287c46fda9087bec68b7c8a7b36991a1 Mon Sep 17 00:00:00 2001 From: Daniel Bershatsky Date: Fri, 3 Jan 2025 00:11:03 +0300 Subject: [PATCH] Support tuple of axis in `softmax_cross_entropy_with_integer_labels` --- optax/losses/_classification.py | 47 +++++++++++++++++++++++++--- optax/losses/_classification_test.py | 34 ++++++++++++++++++++ 2 files changed, 76 insertions(+), 5 deletions(-) diff --git a/optax/losses/_classification.py b/optax/losses/_classification.py index 694a84f40..6813e2841 100644 --- a/optax/losses/_classification.py +++ b/optax/losses/_classification.py @@ -20,8 +20,14 @@ import chex import jax import jax.numpy as jnp +import numpy as np from optax import projections +if np.__version__.startswith('1.'): + from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple +else: + from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple + def sigmoid_binary_cross_entropy( logits, @@ -273,7 +279,7 @@ def softmax_cross_entropy( def softmax_cross_entropy_with_integer_labels( logits: chex.Array, labels: chex.Array, - axis: Union[int, None] = -1, + axis: Union[int, tuple[int, ...]] = -1, where: Union[chex.Array, None] = None, ) -> chex.Array: r"""Computes softmax cross entropy between the logits and integer labels. @@ -297,7 +303,10 @@ def softmax_cross_entropy_with_integer_labels( labels: Integers specifying the correct class for each input, with shape ``[batch_size]``. Class labels are assumed to be between 0 and ``num_classes - 1`` inclusive. - axis: Axis along which to compute. + axis: Axis or axes along which to compute. If a tuple of axes is passed + then ``num_classes`` must match the total number of elements in ``axis`` + dimensions and a label is interpreted as a flat index in a ``logits`` + slice of shape ``logits[axis]``. where: Elements to include in the computation. Returns: @@ -313,6 +322,21 @@ def softmax_cross_entropy_with_integer_labels( >>> print(optax.softmax_cross_entropy_with_integer_labels(logits, labels)) [0.2761297 2.951799 ] + >>> import jax.numpy as jnp + >>> import numpy as np + >>> import optax + >>> # example: batch_size = (1, 2), num_classes = 12 (i.e. 3 * 4) + >>> shape = (1, 2, 3, 4) + >>> logits = jnp.arange(np.prod(shape), dtype=jnp.float32).reshape(shape) + >>> # elements indices in slice of shape (3, 4) + >>> ix = jnp.array([[1, 2]]) + >>> jx = jnp.array([[1, 3]]) + >>> labels = jnp.ravel_multi_index((ix, jx), shape[2:]) + >>> cross_entropy = optax.softmax_cross_entropy_with_integer_labels( + ... logits, labels, axis=(2, 3)) + >>> print(cross_entropy) + [[6.458669 0.45866907]] + References: `Cross-entropy Loss `_, Wikipedia @@ -329,9 +353,22 @@ def softmax_cross_entropy_with_integer_labels( """ chex.assert_type([logits], float) chex.assert_type([labels], int) - if axis is not None and not isinstance(axis, int): - raise ValueError(f'axis = {axis} is unsupported. Provide an int or None.') - + if isinstance(axis, int): + axis = normalize_axis_index(axis, logits.ndim) + elif isinstance(axis, tuple): + # Move all "feature" dimensions to the end preserving axis ordering and + # subsequent flattening "feature" dimensions to a single one. + logit_axis = normalize_axis_tuple(axis, logits.ndim, argname='logits') + batch_axis = tuple(x for x in range(logits.ndim) if x not in logit_axis) + axis = len(batch_axis) + logits = logits.transpose(batch_axis + logit_axis) + logits = logits.reshape(logits.shape[:len(batch_axis)] + (-1, )) + if where is not None: + where = where.transpose(batch_axis + logit_axis) + where = where.reshape(where.shape[:len(batch_axis)] + (-1, )) + else: + raise ValueError('Keyword argument \'axis\' must be of type \'int\' or ' + f'\'tuple[int, ...]\' but actual type is {type(axis)}.') # This is like jnp.take_along_axis(jax.nn.log_softmax(...), ...) except that # we avoid subtracting the normalizer from all values, just from the values # for the correct labels. diff --git a/optax/losses/_classification_test.py b/optax/losses/_classification_test.py index 7b6321618..2fc132712 100644 --- a/optax/losses/_classification_test.py +++ b/optax/losses/_classification_test.py @@ -247,6 +247,40 @@ def test_axis(self, shape, axis): ) np.testing.assert_allclose(x, y, atol=1e-4) + @parameterized.parameters( + {'axis': (1, 3), 'shape': (2, 3, 4, 5)}, + {'axis': (3, 2), 'shape': (2, 3, 4, 5)}, + {'axis': (2, 3), 'shape': (2, 3, 4, 5)}, + {'axis': (-3, -1), 'shape': (2, 3, 4, 5)}, + {'axis': (-1, -2), 'shape': (2, 3, 4, 5)}, + {'axis': (-2, -1), 'shape': (2, 3, 4, 5)}, + ) + def test_axes(self, shape: tuple[int, ...], axis: tuple[int, ...]): + # Canonicalize axis and calculate shapes. + ndim = len(shape) + logits_axis = tuple((x + ndim) % ndim for x in axis) + labels_axis = tuple(x for x in range(ndim) if x not in logits_axis) + # Obtain shapes of batch and logits subspaces. + logits_shape = tuple(shape[x] for x in logits_axis) + labels_shape = tuple(shape[x] for x in labels_axis) + num_classes: float = np.prod(logits_shape).item() + + key = jax.random.key(42) + keys = jax.random.split(key, 2) + logits = jax.random.uniform(keys[0], labels_shape + (num_classes, )) + labels = jax.random.randint(keys[1], labels_shape, 0, num_classes - 1) + + fn = _classification.softmax_cross_entropy_with_integer_labels + desired = fn(logits, labels) + + # Apply inverse axes permutation to obtain an array of `shape` shape. + logits = logits \ + .reshape(labels_shape + logits_shape) \ + .transpose(labels_axis + logits_axis) + assert logits.shape == shape + actual = fn(logits, labels, axis) + np.testing.assert_allclose(actual, desired) + class SigmoidCrossEntropyTest(parameterized.TestCase):