Skip to content

Commit

Permalink
Merge pull request #1120 from cvnad1:issue-1118
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 714262047
  • Loading branch information
OptaxDev committed Jan 11, 2025
2 parents bff9977 + d1fb2c7 commit 5a3b829
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 3 deletions.
6 changes: 5 additions & 1 deletion docs/api/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Losses
softmax_cross_entropy_with_integer_labels
sparsemax_loss
squared_error

triplet_margin_loss

Convex Kullback Leibler divergence
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -116,3 +116,7 @@ Sparsemax
~~~~~~~~~
.. autofunction:: sparsemax_loss
.. autofunction:: multiclass_sparsemax_loss

Triplet margin loss
~~~~~~~~~~~~~~~~~~~
.. autofunction:: triplet_margin_loss
1 change: 1 addition & 0 deletions optax/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,5 @@
from optax.losses._regression import log_cosh
from optax.losses._regression import squared_error
from optax.losses._self_supervised import ntxent
from optax.losses._self_supervised import triplet_loss
from optax.losses._smoothing import smooth_labels
54 changes: 53 additions & 1 deletion optax/losses/_self_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,58 @@ def ntxent(
denom = jnp.sum(jnp.exp(xcs_shift_diffs), axis=1, keepdims=True)
denom += numer_exp
log_softm = numer - jnp.log(denom)
loss = -jnp.where(matches == 1, log_softm, 0.0).sum() / matches.sum()
loss = -jnp.where(matches == 1, log_softm, 0.0).sum()/matches.sum()

return loss


def triplet_loss(
anchors: chex.Array,
positives: chex.Array,
negatives: chex.Array,
axis: int = -1,
norm_degree: chex.Numeric = 2,
margin: chex.Numeric = 1.0,
eps: chex.Numeric = 1e-6,
) -> chex.Array:
"""Returns the triplet loss for a batch of embeddings.
Examples:
>>> import jax.numpy as jnp, optax, chex
>>> anchors = jnp.array([[0.0, 0.0], [1.0, 1.0]])
>>> positives = jnp.array([[0.1, 0.1], [1.1, 1.1]])
>>> negatives = jnp.array([[1.0, 0.0], [0.0, 1.0]])
>>> output = optax.triplet_loss(anchors, positives, negatives, margin=1.0)
>>> print(output)
>>> Array([0.14142442, 0.14142442], dtype=float32)
Args:
anchors: An array of anchor embeddings, with shape [batch, feature_dim].
positives: An array of positive embeddings (similar to anchors), with
shape [batch, feature_dim].
negatives: An array of negative embeddings (dissimilar to anchors), with
shape [batch, feature_dim].
axis: The axis along which to compute the distances (default is -1).
norm_degree: The norm degree for distance calculation (default is 2 for
Euclidean distance).
margin: The minimum margin by which the positive distance should be
smaller than the negative distance.
eps: A small epsilon value to ensure numerical stability in the distance
calculation.
Returns:
Returns the computed triplet loss as an array.
References:
V. Balntas et al,
`Learning shallow convolutional feature descriptors with triplet losses
<https://bmva-archive.org.uk/bmvc/2016/papers/paper119/abstract119.pdf>`
_, 2016
"""
chex.assert_type([anchors, positives, negatives], float)
positive_distance = jnp.power(jnp.power(anchors - positives, norm_degree)
.sum(axis) + eps, 1/norm_degree)
negative_distance = jnp.power(jnp.power(anchors - negatives, norm_degree)
.sum(axis) + eps, 1/norm_degree)
loss = jnp.maximum(positive_distance - negative_distance + margin, 0)
return loss
65 changes: 64 additions & 1 deletion optax/losses/_self_supervised_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
"""Tests for self-supervised losses in `optax.losses._self_supervised.py`."""

from absl.testing import absltest
from absl.testing import parameterized
import chex
import jax
import jax.numpy as jnp
import numpy as np

from optax.losses import _self_supervised


Expand Down Expand Up @@ -46,7 +49,6 @@ def setUp(self):

@chex.all_variants
def test_batched(self):
"""Tests for a full batch."""
np.testing.assert_allclose(
self.variant(_self_supervised.ntxent)(self.ys, self.ts_1),
self.exp_1,
Expand All @@ -66,5 +68,66 @@ def test_batched(self):
)


class TripletMarginLossTest(chex.TestCase, parameterized.TestCase):

def setUp(self):
super().setUp()
self.a1 = jnp.ones((2, 2))
self.p1 = jnp.zeros((2, 2))
self.n1 = jnp.ones((2, 2)) * 2
self.a2 = jnp.zeros((2, 2))
self.p2 = jnp.ones((2, 2))
self.n2 = jnp.ones((2, 2)) * 2

@chex.all_variants
@parameterized.parameters([
{
'anchor': np.ones((2, 2)),
'positive': np.zeros((2, 2)),
'negative': np.ones((2, 2)) * 2,
'margin': 1.0,
},
{
'anchor': np.zeros((2, 2)),
'positive': np.ones((2, 2)),
'negative': np.ones((2, 2)) * 2,
'margin': 1.0,
}
])
def test_batched(self, anchor, positive, negative, margin):
def testing_triplet_loss(a, p, n, margin=1.0, p_norm=2, eps=1e-6):
ap_distance = jnp.sqrt(jnp.sum(jnp.power(a - p, p_norm)) + eps)
an_distance = jnp.sqrt(jnp.sum(jnp.power(a - n, p_norm)) + eps)
return jnp.maximum(ap_distance - an_distance + margin, 0)

handmade_result = testing_triplet_loss(
a=anchor, p=positive, n=negative, margin=margin
)
result = self.variant(_self_supervised.triplet_loss)(
anchor, positive, negative
)
np.testing.assert_allclose(result, handmade_result, atol=1e-4)

@chex.all_variants
@parameterized.parameters([
{
'anchor': np.ones((2, 2)),
'positive': np.zeros((2, 2)),
'negative': np.ones((2, 2)) * 2,
},
])
def test_vmap(self, anchor, positive, negative):
original_loss = _self_supervised.triplet_loss(anchor, positive,
negative)
anchor_batched = anchor.reshape(1, *anchor.shape)
positive_batched = positive.reshape(1, *positive.shape)
negative_batched = negative.reshape(1, *negative.shape)
vmap_loss = self.variant(
jax.vmap(_self_supervised.triplet_loss, in_axes=(0, 0, 0)))(
anchor_batched, positive_batched, negative_batched)
np.testing.assert_allclose(vmap_loss.flatten(), original_loss.flatten()
, atol=1e-4)


if __name__ == '__main__':
absltest.main()

0 comments on commit 5a3b829

Please sign in to comment.