-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpvp.py
399 lines (322 loc) · 15 KB
/
pvp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
"""
This file contains the pattern (prompt template) verbalizer pairs for different tasks.
"""
import random
import string
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Tuple, Union, List
import torch
from utils import InputExample, get_verbalization_ids
import log
logger = log.get_logger('root')
# used for designing the prompt template for data example
FilledPattern = Tuple[List[Union[str, Tuple[str, bool]]], List[Union[str, Tuple[str, bool]]]]
class PVP(ABC):
"""
This class contains functions to apply patterns and verbalizers as required by prompt learning.
Each task requires its own custom implementation (processor) of pvp.
"""
def __init__(self, wrapper, pattern_id: int = 0, verbalizer_file: str = None, seed: int = 42):
"""
Create a new PVP.
:param wrapper: the wrapper for the underlying language model
:param pattern_id: the pattern id to use
:param verbalizer_file: an optional file that contains the verbalizer to be used
:param seed: a seed to be used for generating random numbers if necessary
"""
self.wrapper = wrapper
self.pattern_id = pattern_id
self.rng = random.Random(seed) # random number generator
if verbalizer_file:
self.verbalize = PVP._load_verbalizer_from_file(verbalizer_file, self.pattern_id)
self.mlm_logits_to_cls_logits_tensor = self._build_mlm_logits_to_cls_logits_tensor()
def _build_mlm_logits_to_cls_logits_tensor(self):
label_list = self.wrapper.config.label_list
m2c_tensor = torch.ones([len(label_list), self.max_num_verbalizers], dtype=torch.long) * -1
for label_idx, label in enumerate(label_list):
verbalizers = self.verbalize(label)
for verbalizer_idx, verbalizer in enumerate(verbalizers):
verbalizer_id = get_verbalization_ids(verbalizer, self.wrapper.tokenizer, force_single_token=True)
assert verbalizer_id != self.wrapper.tokenizer.unk_token_id, "verbalization was tokenized as <UNK>"
m2c_tensor[label_idx, verbalizer_idx] = verbalizer_id
return m2c_tensor
@property
def mask(self) -> str:
"""Return the underlying LM's special mask token."""
return self.wrapper.tokenizer.mask_token
@property
def mask_id(self) -> int:
"""Return the underlying LM's mask token id."""
return self.wrapper.tokenizer.mask_token_id
@property
def max_num_verbalizers(self) -> int:
"""Return the maximum number of the verbalizers across all labels."""
return max(len(self.verbalize(label)) for label in self.wrapper.config.label_list)
def encode(self, example: InputExample, priming: bool = False, labeled: bool = False,
max_length = None) -> Tuple[List[int], List[int]]:
"""
Encode an input example using this pattern verbalizer pair
:param example: an input example to encode
:param priming: wheather to use this example for priming
:param labeled: if "priming=True", wheather the label should be appended to this example
:return: A tuple, consisting of a list of input ids and a list of token type ids
"""
if not priming:
assert not labeled, "'labeled' can only be set to true if 'priming' is also set to true."
tokenizer = self.wrapper.tokenizer # type: PreTrainedTokenizer
parts_a, parts_b = self.get_parts(example)
parts_a = [x if isinstance(x, tuple) else (x, False) for x in parts_a]
parts_a = [(tokenizer.encode(x, add_special_tokens=False), s) for x, s in parts_a if x]
if parts_b:
parts_b = [x if isinstance(x, tuple) else (x, False) for x in parts_b]
parts_b = [(tokenizer.encode(x, add_special_tokens=False), s) for x, s in parts_b if x]
if max_length:
self.truncate(parts_a, parts_b, max_length=max_length)
else:
self.truncate(parts_a, parts_b, max_length=self.wrapper.config.max_seq_length)
tokens_a = [token_id for part, _ in parts_a for token_id in part]
tokens_b = [token_id for part, _ in parts_b for token_id in part]
if priming:
input_ids = tokens_a
if tokens_b:
input_ids += tokens_b
if labeled:
assert self.mask_id in input_ids, 'sequence of input_ids must contain a mask token'
mask_idx = input_ids.index(self.mask_id)
# assert len(self.verbalize(example.label)) == 1, 'priming only supports one verbalization per label'
verbalizer = self.verbalize(example.label)[0]
verbalizer_id = get_verbalization_ids(verbalizer, self.wrapper.tokenizer, force_single_token=True)
input_ids[mask_idx] = verbalizer_id
return input_ids, []
input_ids = tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
# input_ids.append(102)
token_type_ids = tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b)
return input_ids, token_type_ids
@staticmethod
def shortenable(s: str) -> Tuple[str, bool]:
"""
Return an instance of this string that is marked as shortenable
:param s: the given string to be marked
:return: a tuple
"""
return s, True
@staticmethod
def remove_final_punc(s: Union[str, Tuple[str, bool]]):
"""Remove the final punctuation mark."""
if isinstance(s, tuple):
return PVP.remove_final_punc(s[0]), s[1]
return s.rsplit(string.punctuation)
# TODO: data type of the first element in the tuple: List[int] or str???
def truncate(self, parts_a: List[Tuple[List[int], bool]], parts_b: List[Tuple[List[int], bool]], max_length: int):
"""
Truncate two sequences of text to a predefined total maximum of length.
:param parts_a: the first text
:param parts_b: the second text
:param max_length: predefined total maximum length
:return: truncated parts_a and parts_b
"""
total_len = self._seq_length(parts_a) + self._seq_length(parts_b)
# total_len += self.wrapper.tokenizer.num_special_tokens_to_add(bool(parts_b))
total_len += self.wrapper.tokenizer.num_special_tokens_to_add(bool(parts_a))
num_tokens_to_remove = total_len - max_length
if num_tokens_to_remove <= 0:
return parts_a, parts_b
for _ in range(num_tokens_to_remove):
if self._seq_length(parts_a, only_shortenable=True) > self._seq_length(parts_b, only_shortenable=True):
self._remove_last(parts_a)
else:
self._remove_last(parts_b)
@staticmethod
def _seq_length(parts: List[Tuple[str, bool]], only_shortenable: bool = False):
return sum([len(x) for x, shortenable in parts if not only_shortenable or shortenable]) if parts else 0
@staticmethod
def _remove_last(parts: List[Tuple[str, bool]]):
last_idx = max(idx for idx, (seq, shortenable) in enumerate(parts) if shortenable and seq)
parts[last_idx] = (parts[last_idx][0][:-1], parts[last_idx][1])
def convert_mlm_logits_to_cls_logits(self, mlm_labels: torch.Tensor, logits: torch.Tensor) -> torch.Tensor:
masked_logits = logits[mlm_labels >= 0]
cls_logits = torch.stack([self._convert_single_mlm_logits_to_cls_logits(ml) for ml in masked_logits])
return cls_logits
def _convert_single_mlm_logits_to_cls_logits(self, logits: torch.Tensor) -> torch.Tensor:
m2c = self.mlm_logits_to_cls_logits_tensor
m2c = m2c.to(logits.device)
# filler_len.shape() == max_fillers
filler_len = torch.tensor([len(self.verbalize(label)) for label in self.wrapper.config.label_list],
dtype=torch.float)
filler_len = filler_len.to(logits.device)
# cls_logits.shape() == num_labels * max_fillers
cls_logits = logits[torch.max(torch.zeros_like(m2c), m2c)]
cls_logits = cls_logits * (m2c > 0).float()
# cls_logits.shape() == num_labels
cls_logits = cls_logits.sum(axis=1) / filler_len
return cls_logits
def get_mask_positions(self, input_ids: List[int]) -> List[int]:
label_idx = input_ids.index(self.mask_id)
labels = [-1] * len(input_ids)
labels[label_idx] = 1
return labels
@abstractmethod
def get_parts(self, example: InputExample) -> FilledPattern:
"""
Given an input example, apply a pattern to obtain two text sequences text_a and text_b, containing exactly one
mask token for a single task. If a task requires only a single sequence of text, then the second sequence
should be an empty list.
:param example: the input example to be processed
:return: Two sequences of texts. All text segments can optionally be marked as being shortenable.
"""
pass
@abstractmethod
def verbalize(self, label) -> List[str]:
"""
Return all verbalizations for a given label
:param label: the label
:return: the list of all verbalizations to the label
"""
pass
@staticmethod
def _load_verbalizer_from_file(path: str, pattern_id: int):
verbalizers = defaultdict(dict) # type: Dict[int, Dict[str, List[str]]]
current_pattern_ids = None
with open(path, 'r') as fh:
for line in fh.read().splitlines():
if line.isdigit():
current_pattern_ids = int(line)
elif line:
label, *realizations = line.split()
verbalizers[current_pattern_ids][label] = realizations
logger.info('Automatically loaded from the following verbalizer: \n {}'.format(verbalizers[pattern_id]))
def verbalize(label) -> List[str]:
return verbalizers[pattern_id][label]
return verbalize
class ProductPVP(PVP):
# 1: terrible 2: great
VERBALIZER = {
'1': ['bad'],
'2': ['great']
}
# verbalizer pool: good - 'positive', 'great', 'super'
# bad - 'terrible', 'negative'
def get_parts(self, example: InputExample) -> FilledPattern:
text = self.shortenable(example.text_a)
# e.g. text = 'The best laptop I have ever used!'
# pattern 0: It was [MASK]. The best laptop I have ever used!
if self.pattern_id == 0:
return [text, self.mask], []
elif self.pattern_id == 1:
return ['It was', self.mask, '.', text], []
# pattern 1: The best laptop I have ever used! All in all, it was [MASK].
elif self.pattern_id == 2:
return [text, 'All in all, it was', self.mask, '.'], []
# pattern 2: Just [MASK]! The best laptop I have ever used!
elif self.pattern_id == 3:
return ['Just', self.mask, '!'], [text]
# pattern 3: The best laptop I have ever used! In summary, the product is [MASK].
elif self.pattern_id == 4:
return [text], ['In summary, the product is', self.mask, '.']
else:
raise ValueError("No pattern implemented for id {}".format(self.pattern_id))
def verbalize(self, label) -> List[str]:
return ProductPVP.VERBALIZER[label]
class XnliPVP(PVP):
VERBALIZER_A = {
'0': ['Yes'],
'1': ['Maybe'],
'2': ['No']
}
VERBALIZER_B = {
'0': ['Right'],
'1': ['Maybe'],
'2': ['Wrong']
}
def get_parts(self, example: InputExample) -> FilledPattern:
text_a = self.shortenable(self.remove_final_punc(example.text_a))
text_a = (' '.join(text_a[0]), text_a[1])
text_b = self.shortenable(example.text_b)
if self.pattern_id == 0:
return [text_a, '.', self.mask, text_b], []
elif self.pattern_id == 1 or self.pattern_id == 2:
return [text_a, '?'], [self.mask, ',', text_b]
else:
raise ValueError("No pattern implemented for id {}".format(self.pattern_id))
def verbalize(self, label) -> List[str]:
if self.pattern_id == 1 or self.pattern_id == 0:
return XnliPVP.VERBALIZER_A[label]
return XnliPVP.VERBALIZER_B[label]
class AgNewsPVP(PVP):
VERBALIZER = {
'0': ['World'],
'1': ['Sports'],
'2': ['Business'],
'3': ['Tech']
}
def get_parts(self, example: InputExample) -> FilledPattern:
text_a = self.shortenable(example.text_a)
# example text: Germany won the 2014 Wolrd Cup.
# pattern 0: Germany won the 2014 Wolrd Cup. [MASK]
if self.pattern_id == 0:
return [text_a, self.mask], []
# pattern 1: [MASK]: Germany won the 2014 Wolrd Cup.
elif self.pattern_id == 1:
return [self.mask, ':', text_a], []
# pattern 2: [MASK] News: Germany won the 2014 Wolrd Cup.
elif self.pattern_id == 2:
return [self.mask, 'News', ':', text_a], []
# pattern 3: Germany won the 2014 Wolrd Cup. Category: [MASK]
elif self.pattern_id == 3:
return [text_a, 'Category', ':', self.mask], []
else:
raise ValueError("No pattern implemented for id {}".format(self.pattern_id))
def verbalize(self, label) -> List[str]:
return AgNewsPVP.VERBALIZER[label]
class XtcPVP(PVP):
VERBALIZER = {
'1': ['Politics'],
'2': ['Military'],
'3': ['Law'],
'4': ['Economics'],
'5': ['Education'],
'6': ['Medicine'],
'7': ['Religion'],
'8': ['Literature'],
'9': ['Culture'],
'10': ['Transportation'],
'11': ['Sport'],
'12': ['History'],
'13': ['Landscape'],
'14': ['Science'],
'15': ['Daily'],
'16': ['Media'],
'17': ['Entertainment'],
'18': ['Food'],
'19': ['Philosophy'],
'20': ['News'],
'21': ['Person'],
'22': ['Popular'],
'23': ['Organization']
}
def get_parts(self, example: InputExample) -> FilledPattern:
text_a = self.shortenable(example.text_a)
# example text: Germany won the 2014 Wolrd Cup.
# pattern 0: Germany won the 2014 Wolrd Cup. [MASK]
if self.pattern_id == 0:
return [text_a, self.mask], []
# pattern 1: [MASK]: Germany won the 2014 Wolrd Cup.
elif self.pattern_id == 1:
return [self.mask, ':', text_a], []
# pattern 2: [MASK] News: Germany won the 2014 Wolrd Cup.
elif self.pattern_id == 2:
return [self.mask, 'News', ':', text_a], []
# pattern 3: Germany won the 2014 Wolrd Cup. Category: [MASK]
elif self.pattern_id == 3:
return [text_a, 'Category', ':', self.mask], []
else:
raise ValueError("No pattern implemented for id {}".format(self.pattern_id))
def verbalize(self, label) -> List[str]:
return XtcPVP.VERBALIZER[label]
PVPS = {
'product-review-polarity': ProductPVP,
'xnli': XnliPVP,
'ag_news': AgNewsPVP,
'xtc':XtcPVP
}