diff --git a/optax/losses/_self_supervised.py b/optax/losses/_self_supervised.py index f91129e79..29cdd3c2f 100644 --- a/optax/losses/_self_supervised.py +++ b/optax/losses/_self_supervised.py @@ -165,13 +165,15 @@ def triplet_loss( If reduction is 'mean' or 'sum', returns a scalar. References: - V. Balntas et al, `Learning shallow convolutional feature descriptors with triplet losses - `_, 2016 + V. Balntas et al, + `Learning shallow convolutional feature descriptors with triplet losses + ` + _, 2016 """ chex.assert_type([anchors, positives, negatives], float) - positive_distance = jnp.power(jnp.power(anchors - positives, norm_degree).sum(axis) + eps, - 1/norm_degree) - negative_distance = jnp.power(jnp.power(anchors - negatives, norm_degree).sum(axis) + eps, - 1/norm_degree) + positive_distance = jnp.power(jnp.power(anchors - positives, norm_degree) + .sum(axis) + eps, 1/norm_degree) + negative_distance = jnp.power(jnp.power(anchors - negatives, norm_degree) + .sum(axis) + eps, 1/norm_degree) loss = jnp.maximum(positive_distance - negative_distance + margin, 0) return loss