-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrnnt_decoder.py
382 lines (323 loc) · 14.3 KB
/
rnnt_decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
import torch
import numpy as np
from torch.nn import Module
from typing import Any, Dict, List, Optional, Tuple
"""
Classes and methods from the nemo-toolkit for using the RNNTDecoder module
"""
class LSTMDropout(torch.nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int,
num_layers: int,
dropout: Optional[float],
forget_gate_bias: Optional[float],
t_max: Optional[int] = None,
weights_init_scale: float = 1.0,
hidden_hidden_bias_scale: float = 0.0,
proj_size: int = 0,
):
"""Returns an LSTM with forget gate bias init to `forget_gate_bias`.
Args:
input_size: See `torch.nn.LSTM`.
hidden_size: See `torch.nn.LSTM`.
num_layers: See `torch.nn.LSTM`.
dropout: See `torch.nn.LSTM`.
forget_gate_bias: float, set by default to 1.0, which constructs a forget gate
initialized to 1.0.
Reference:
[An Empirical Exploration of Recurrent Network Architectures](http://proceedings.mlr.press/v37/jozefowicz15.pdf)
t_max: int value, set to None by default. If an int is specified, performs Chrono Initialization
of the LSTM network, based on the maximum number of timesteps `t_max` expected during the course
of training.
Reference:
[Can recurrent neural networks warp time?](https://openreview.net/forum?id=SJcKhk-Ab)
weights_init_scale: Float scale of the weights after initialization. Setting to lower than one
sometimes helps reduce variance between runs.
hidden_hidden_bias_scale: Float scale for the hidden-to-hidden bias scale. Set to 0.0 for
the default behaviour.
Returns:
A `torch.nn.LSTM`.
"""
super(LSTMDropout, self).__init__()
self.lstm = torch.nn.LSTM(
input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, proj_size=proj_size
)
if t_max is not None:
# apply chrono init
for name, v in self.lstm.named_parameters():
if 'bias' in name:
p = getattr(self.lstm, name)
n = p.nelement()
hidden_size = n // 4
p.data.fill_(0)
p.data[hidden_size : 2 * hidden_size] = torch.log(
torch.nn.init.uniform_(p.data[0:hidden_size], 1, t_max - 1)
)
# forget gate biases = log(uniform(1, Tmax-1))
p.data[0:hidden_size] = -p.data[hidden_size : 2 * hidden_size]
# input gate biases = -(forget gate biases)
elif forget_gate_bias is not None:
for name, v in self.lstm.named_parameters():
if "bias_ih" in name:
bias = getattr(self.lstm, name)
bias.data[hidden_size : 2 * hidden_size].fill_(forget_gate_bias)
if "bias_hh" in name:
bias = getattr(self.lstm, name)
bias.data[hidden_size : 2 * hidden_size] *= float(hidden_hidden_bias_scale)
self.dropout = torch.nn.Dropout(dropout) if dropout else None
for name, v in self.named_parameters():
if 'weight' in name or 'bias' in name:
v.data *= float(weights_init_scale)
def forward(
self, x: torch.Tensor, h: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
x, h = self.lstm(x, h)
if self.dropout:
x = self.dropout(x)
return x, h
class RNNTDecoder(Module):
def __init__(
self,
prednet: Dict[str, Any],
vocab_size: int,
normalization_mode: Optional[str] = None,
random_state_sampling: bool = False,
blank_as_pad: bool = True,
):
# Required arguments
self.pred_hidden = prednet['pred_hidden']
self.pred_rnn_layers = prednet["pred_rnn_layers"]
self.blank_idx = vocab_size
super().__init__()
self.vocab_size = vocab_size
self.blank_as_pad = blank_as_pad
# Optional arguments
forget_gate_bias = prednet.get('forget_gate_bias', 1.0)
t_max = prednet.get('t_max', None)
weights_init_scale = prednet.get('weights_init_scale', 1.0)
hidden_hidden_bias_scale = prednet.get('hidden_hidden_bias_scale', 0.0)
dropout = prednet.get('dropout', 0.0)
self.random_state_sampling = random_state_sampling
self.prediction = self._predict_modules(
vocab_size=vocab_size, # add 1 for blank symbol
pred_n_hidden=self.pred_hidden,
pred_rnn_layers=self.pred_rnn_layers,
forget_gate_bias=forget_gate_bias,
t_max=t_max,
norm=normalization_mode,
weights_init_scale=weights_init_scale,
hidden_hidden_bias_scale=hidden_hidden_bias_scale,
dropout=dropout,
rnn_hidden_size=prednet.get("rnn_hidden_size", -1),
)
self._rnnt_export = False
def forward(self, targets, target_length, states=None):
# y: (B, U)
y = label_collate(targets)
# state maintenance is unnecessary during training forward call
# to get state, use .predict() method.
if self._rnnt_export:
add_sos = False
else:
add_sos = True
g, states = self.predict(y, state=states, add_sos=add_sos) # (B, U, D)
g = g.transpose(1, 2) # (B, D, U)
return g, target_length, states
def predict(
self,
y: Optional[torch.Tensor] = None,
state: Optional[List[torch.Tensor]] = None,
add_sos: bool = True,
batch_size: Optional[int] = None,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Stateful prediction of scores and state for a (possibly null) tokenset.
This method takes various cases into consideration :
- No token, no state - used for priming the RNN
- No token, state provided - used for blank token scoring
- Given token, states - used for scores + new states
Here:
B - batch size
U - label length
H - Hidden dimension size of RNN
L - Number of RNN layers
Args:
y: Optional torch tensor of shape [B, U] of dtype long which will be passed to the Embedding.
If None, creates a zero tensor of shape [B, 1, H] which mimics output of pad-token on EmbeddiNg.
state: An optional list of states for the RNN. Eg: For LSTM, it is the state list length is 2.
Each state must be a tensor of shape [L, B, H].
If None, and during training mode and `random_state_sampling` is set, will sample a
normal distribution tensor of the above shape. Otherwise, None will be passed to the RNN.
add_sos: bool flag, whether a zero vector describing a "start of signal" token should be
prepended to the above "y" tensor. When set, output size is (B, U + 1, H).
batch_size: An optional int, specifying the batch size of the `y` tensor.
Can be infered if `y` and `state` is None. But if both are None, then batch_size cannot be None.
Returns:
A tuple (g, hid) such that -
If add_sos is False:
g: (B, U, H)
hid: (h, c) where h is the final sequence hidden state and c is the final cell state:
h (tensor), shape (L, B, H)
c (tensor), shape (L, B, H)
If add_sos is True:
g: (B, U + 1, H)
hid: (h, c) where h is the final sequence hidden state and c is the final cell state:
h (tensor), shape (L, B, H)
c (tensor), shape (L, B, H)
"""
# Get device and dtype of current module
_p = next(self.parameters())
device = _p.device
dtype = _p.dtype
# If y is not None, it is of shape [B, U] with dtype long.
if y is not None:
if y.device != device:
y = y.to(device)
# (B, U) -> (B, U, H)
y = self.prediction["embed"](y)
else:
# Y is not provided, assume zero tensor with shape [B, 1, H] is required
# Emulates output of embedding of pad token.
if batch_size is None:
B = 1 if state is None else state[0].size(1)
else:
B = batch_size
y = torch.zeros((B, 1, self.pred_hidden), device=device, dtype=dtype)
# Prepend blank "start of sequence" symbol (zero tensor)
if add_sos:
B, U, H = y.shape
start = torch.zeros((B, 1, H), device=y.device, dtype=y.dtype)
y = torch.cat([start, y], dim=1).contiguous() # (B, U + 1, H)
else:
start = None # makes del call later easier
# If in training mode, and random_state_sampling is set,
# initialize state to random normal distribution tensor.
if state is None:
if self.random_state_sampling and self.training:
state = self.initialize_state(y)
# Forward step through RNN
y = y.transpose(0, 1) # (U + 1, B, H)
g, hid = self.prediction["dec_rnn"](y, state)
g = g.transpose(0, 1) # (B, U + 1, H)
del y, start, state
return g, hid
def _predict_modules(
self,
vocab_size,
pred_n_hidden,
pred_rnn_layers,
forget_gate_bias,
t_max,
norm,
weights_init_scale,
hidden_hidden_bias_scale,
dropout,
rnn_hidden_size,
):
"""
Prepare the trainable parameters of the Prediction Network.
Args:
vocab_size: Vocab size (excluding the blank token).
pred_n_hidden: Hidden size of the RNNs.
pred_rnn_layers: Number of RNN layers.
forget_gate_bias: Whether to perform unit forget gate bias.
t_max: Whether to perform Chrono LSTM init.
norm: Type of normalization to perform in RNN.
weights_init_scale: Float scale of the weights after initialization. Setting to lower than one
sometimes helps reduce variance between runs.
hidden_hidden_bias_scale: Float scale for the hidden-to-hidden bias scale. Set to 0.0 for
the default behaviour.
dropout: Whether to apply dropout to RNN.
rnn_hidden_size: the hidden size of the RNN, if not specified, pred_n_hidden would be used
"""
if self.blank_as_pad:
embed = torch.nn.Embedding(vocab_size + 1, pred_n_hidden, padding_idx=self.blank_idx)
else:
embed = torch.nn.Embedding(vocab_size, pred_n_hidden)
layers = torch.nn.ModuleDict(
{
"embed": embed,
"dec_rnn": rnn(
input_size=pred_n_hidden,
hidden_size=rnn_hidden_size if rnn_hidden_size > 0 else pred_n_hidden,
num_layers=pred_rnn_layers,
norm=norm,
forget_gate_bias=forget_gate_bias,
t_max=t_max,
dropout=dropout,
weights_init_scale=weights_init_scale,
hidden_hidden_bias_scale=hidden_hidden_bias_scale,
proj_size=pred_n_hidden if pred_n_hidden < rnn_hidden_size else 0,
),
}
)
return layers
def initialize_state(self, y: torch.Tensor) -> List[torch.Tensor]:
"""
Initialize the state of the RNN layers, with same dtype and device as input `y`.
Args:
y: A torch.Tensor whose device the generated states will be placed on.
Returns:
List of torch.Tensor, each of shape [L, B, H], where
L = Number of RNN layers
B = Batch size
H = Hidden size of RNN.
"""
batch = y.size(0)
if self.random_state_sampling and self.training:
state = [
torch.randn(self.pred_rnn_layers, batch, self.pred_hidden, dtype=y.dtype, device=y.device),
torch.randn(self.pred_rnn_layers, batch, self.pred_hidden, dtype=y.dtype, device=y.device),
]
else:
state = [
torch.zeros(self.pred_rnn_layers, batch, self.pred_hidden, dtype=y.dtype, device=y.device),
torch.zeros(self.pred_rnn_layers, batch, self.pred_hidden, dtype=y.dtype, device=y.device),
]
return state
def label_collate(labels, device=None):
"""Collates the label inputs for the rnn-t prediction network.
If `labels` is already in torch.Tensor form this is a no-op.
Args:
labels: A torch.Tensor List of label indexes or a torch.Tensor.
device: Optional torch device to place the label on.
Returns:
A padded torch.Tensor of shape (batch, max_seq_len).
"""
if isinstance(labels, torch.Tensor):
return labels.type(torch.int64)
if not isinstance(labels, (list, tuple)):
raise ValueError(f"`labels` should be a list or tensor not {type(labels)}")
batch_size = len(labels)
max_len = max(len(label) for label in labels)
cat_labels = np.full((batch_size, max_len), fill_value=0.0, dtype=np.int32)
for e, l in enumerate(labels):
cat_labels[e, : len(l)] = l
labels = torch.tensor(cat_labels, dtype=torch.int64, device=device)
return labels
def rnn(
input_size: int,
hidden_size: int,
num_layers: int,
norm: Optional[str] = None,
forget_gate_bias: Optional[float] = 1.0,
dropout: Optional[float] = 0.0,
norm_first_rnn: Optional[bool] = None,
t_max: Optional[int] = None,
weights_init_scale: float = 1.0,
hidden_hidden_bias_scale: float = 0.0,
proj_size: int = 0,
) -> torch.nn.Module:
return LSTMDropout(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
forget_gate_bias=forget_gate_bias,
t_max=t_max,
weights_init_scale=weights_init_scale,
hidden_hidden_bias_scale=hidden_hidden_bias_scale,
proj_size=proj_size,
)