-
Notifications
You must be signed in to change notification settings - Fork 41
/
Copy pathgram_matrix_losses.py
157 lines (122 loc) · 7.39 KB
/
gram_matrix_losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""
Base class for all losses to be applied on a Gram matrix like output, ie when the output y_pred of the network is the the pair-wise
distance / similarity of all items of the batch (see GramMatrix layer for instance). y_true should be one-hot encoded.
For unsupervised metric learning, it is standard to use each image instance as a distinct class of its own. In this settings all
the losses are directly available by setting label = image_id and y_true stands indeed for all the patches/glimpses/etc. extracted from
the same image. It is usually supposed that the risk of collision is low. For more information on unsupervised learning of visual
representation, see for instance
[Momentum Contrast for Unsupervised Visual Representation Learning](https://arxiv.org/abs/1911.05722)
[Dimensionality Reduction by Learning an Invariant Mapping](http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf)
[Unsupervised feature learning via non-parametric instance discrimination](https://arxiv.org/abs/1805.01978v1)
"""
import tensorflow as tf
import tensorflow.keras.backend as K
import tensorflow_probability as tfp
from tensorflow.keras.losses import Loss
class MeanScoreClassificationLoss(Loss):
"""
Use the mean score of an image against all the samples from the same class to get a score per class for each image.
"""
def call(self, y_true, y_pred):
y_pred = tf.linalg.normalize(y_pred @ tf.math.divide_no_nan(y_true, tf.reduce_sum(y_true, axis=0)), ord=1, axis=1)[0]
return tf.reduce_sum(K.binary_crossentropy(y_true, y_pred) * y_true, axis=1)
def class_consistency_loss(y_true, y_pred):
"""
Use the mean score of an image against all the samples from the same class to get a score per class for each image.
Then average again over all the samples to get a class_wise confusion matrix
"""
y_true = tf.math.divide_no_nan(y_true, tf.reduce_sum(y_true, axis=0))
class_mask = tf.reduce_sum(y_true, axis=0) > 0
confusion_matrix = tf.boolean_mask(
tf.matmul(y_true, tf.matmul(y_pred, y_true), transpose_a=True)[class_mask], class_mask, axis=1
)
identity_matrix = tf.eye(tf.shape(confusion_matrix)[0])
return K.binary_crossentropy(identity_matrix, confusion_matrix)
class ClassConsistencyLoss(Loss):
def call(self, y_true, y_pred):
return class_consistency_loss(y_true, y_pred)
class BinaryCrossentropy(Loss):
def call(self, y_true, y_pred):
adjacency_matrix = tf.matmul(y_true, y_true, transpose_b=True)
return K.binary_crossentropy(adjacency_matrix, y_pred)
class ClippedBinaryCrossentropy(BinaryCrossentropy):
"""
Compute the binary crossentropy loss of each possible pair in the batch.
The margins lets define a threshold against which the difference is not taken into account,
ie. only values with lower < |y_true - y_pred| < upper will be non-zero
Args:
lower (float): ignore loss values below this threshold. Useful to make the network focus on more significant errors
upper (float): ignore loss values above this threshold. Useful to prevent the network from focusing on errors due to
wrongs labels (or collision in unsupervised learning)
"""
def __init__(self, lower=0.0, upper=1.0, **kwargs):
super().__init__(**kwargs)
self.lower = lower
self.upper = upper
def call(self, y_true, y_pred):
loss = super().call(y_true, y_pred)
clip_mask = tf.math.logical_and(
-tf.math.log(1 - tf.cast(self.lower, dtype=loss.dtype)) < loss,
loss < -tf.math.log(1 - tf.cast(self.upper, dtype=loss.dtype)),
)
return tf.cast(clip_mask, dtype=loss.dtype) * loss
# TODO: use reduction kwarg of loss when it becomes possible to give custom reduction to includes all other reductions below in
# TODO: base BinaryCrossentropy
class MaxBinaryCrossentropy(BinaryCrossentropy):
def call(self, y_true, y_pred):
return tf.reduce_max(super().call(y_true, y_pred))
class StdBinaryCrossentropy(BinaryCrossentropy):
def call(self, y_true, y_pred):
return tf.math.reduce_std(super().call(y_true, y_pred))
class PercentileBinaryCrossentropy(BinaryCrossentropy):
"""
Wrap tf probability percentile method to be used on the loss tensor
"""
def __init__(self, percentile=50, **kwargs):
super().__init__(**kwargs)
self.percentile = percentile
def call(self, y_true, y_pred):
return tfp.stats.percentile(super().call(y_true, y_pred), self.percentile, interpolation="midpoint")
class TripletLoss(Loss):
"""
Implement triplet loss with semi-hard negative mining as in
[FaceNet: A Unified Embedding for Face Recognition and Clustering](https://arxiv.org/pdf/1503.03832.pdf):
a triplet (A, N, P) consists of three images with two labels such that A and P are of the same class (positive pair) and A and N are
from two different classes (negative pair).
Then the loss tries to enforce the following relation: d(A, P) + margin < d(A,N) with d a given _distance_ function. In the original
implementation as well as in the standard tf.addons one this is the euclidean distance. Here this can be any _kernel_ (see
SupportLayer).
Semi-hard mining means that given a batch of embeddings, all item in the batch are used as anchor and for all anchor, all positive pairs
are used in a triplet and that the negative pair is chosen such that:
1) it is the closest negative sample farther than the positive one
2) or it is the farthest negative sample
It means that we somehow select the negative sample the closest to the margin, but give a preference to sample beyond the margin.
Args:
margin (float): margin for separating positive to negative pairs
"""
def __init__(self, margin=1.0, **kwargs):
super().__init__(**kwargs)
self.margin = margin
def call(self, y_true, y_pred):
# 0) build triplets tensor such that triplet[a, p, n] = d(a, p) - d(a, n)
adjacency_matrix = tf.matmul(y_true, y_true, transpose_b=True)
anp_mask = tf.cast(tf.expand_dims(adjacency_matrix, -1) + tf.expand_dims(adjacency_matrix, 1) == 1, tf.float32)
triplets = tf.expand_dims(y_pred, -1) - tf.expand_dims(y_pred, 1)
triplets_max = tf.reduce_max(triplets, axis=-1, keepdims=True)
triplets_min = tf.reduce_min(triplets, axis=-1, keepdims=True)
farther_negative_mask = tf.cast(triplets < 0, tf.float32)
# 1) negatives_outside: smallest negative distance greater than positive one
negatives_outside = tf.reduce_max((triplets - triplets_min + K.epsilon()) * farther_negative_mask * anp_mask, axis=-1)
negatives_outside_mask = negatives_outside > 0
loss_negatives_outside = tf.maximum(negatives_outside + tf.squeeze(triplets_min) - K.epsilon() + self.margin, 0)
# 2) negatives_inside: greatest negative distance smaller than positive one
loss_negatives_inside = tf.maximum(
tf.reduce_min((triplets - triplets_max + K.epsilon()) * (1 - farther_negative_mask) * anp_mask, axis=-1)
+ tf.squeeze(triplets_max)
- K.epsilon()
+ self.margin,
0,
)
all_losses = tf.where(negatives_outside_mask, loss_negatives_outside, loss_negatives_inside)
true_triplets_mask = adjacency_matrix - tf.eye(tf.shape(y_true)[0])
return tf.reduce_sum(all_losses * true_triplets_mask) / tf.reduce_sum(true_triplets_mask)