Skip to content

Commit

Permalink
Pylint corrections
Browse files Browse the repository at this point in the history
  • Loading branch information
Saanidhyavats committed Dec 13, 2024
1 parent 3b9a71a commit d1fb2c7
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions optax/losses/_self_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,15 @@ def triplet_loss(
If reduction is 'mean' or 'sum', returns a scalar.
References:
V. Balntas et al, `Learning shallow convolutional feature descriptors with triplet losses
<https://bmva-archive.org.uk/bmvc/2016/papers/paper119/abstract119.pdf>`_, 2016
V. Balntas et al,
`Learning shallow convolutional feature descriptors with triplet losses
<https://bmva-archive.org.uk/bmvc/2016/papers/paper119/abstract119.pdf>`
_, 2016
"""
chex.assert_type([anchors, positives, negatives], float)
positive_distance = jnp.power(jnp.power(anchors - positives, norm_degree).sum(axis) + eps,
1/norm_degree)
negative_distance = jnp.power(jnp.power(anchors - negatives, norm_degree).sum(axis) + eps,
1/norm_degree)
positive_distance = jnp.power(jnp.power(anchors - positives, norm_degree)
.sum(axis) + eps, 1/norm_degree)
negative_distance = jnp.power(jnp.power(anchors - negatives, norm_degree)
.sum(axis) + eps, 1/norm_degree)
loss = jnp.maximum(positive_distance - negative_distance + margin, 0)
return loss

0 comments on commit d1fb2c7

Please sign in to comment.