-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtest_lstm2d_inference.py
60 lines (47 loc) · 2.35 KB
/
test_lstm2d_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from unittest import TestCase, skip
import torch
from model.lstm2d import LSTM2d
class LSTM2dInferenceTest(TestCase):
"""
Unit tests for the 2D-LSTM in inference mode.
"""
embed_dim = 50
encoder_state_dim = 20
cell_state_dim = 25
batch_size = 42
input_seq_len = 4
input_vocab_size = 3
output_vocab_size = 5
def setUp(self):
torch.manual_seed(42)
device = torch.device('cpu')
self.lstm = LSTM2d(embed_dim=self.embed_dim, state_dim_2d=self.cell_state_dim,
encoder_state_dim=self.encoder_state_dim, input_vocab_size=self.input_vocab_size,
output_vocab_size=self.output_vocab_size, device=device)
def test_dimensions(self):
"""
Tests if the input and output dimensions of the 2D-LSTM are as expected.
"""
# random token indices of shape (input_seq_len x batch_size)
sample_x = torch.randint(0, self.input_vocab_size, (self.input_seq_len, self.batch_size), dtype=torch.long)
x_lengths = torch.tensor([self.input_seq_len] * self.batch_size, dtype=torch.long)
# toy inference
self.lstm.eval()
pred = self.lstm.predict(x=sample_x, x_lengths=x_lengths)
pred_shape = list(pred.shape)
output_seq_len = pred_shape[0] # this depends on the model parameters (when it predicts '<eos>')
self.assertEqual(pred_shape, [output_seq_len, self.batch_size, self.output_vocab_size],
'The predictions have an unexpected shape.')
def test_same_over_batch(self):
"""
Tests if the outputs of the 2D-LSTM are the same over the batch if the same input is fed in multiple times.
"""
repeated_x = torch.tensor([0, 1, 2, 0], dtype=torch.long)
batch_x = repeated_x.expand(self.batch_size, self.input_seq_len).t()
x_lengths = torch.tensor([4] * self.batch_size, dtype=torch.long)
self.lstm.eval()
pred = self.lstm.predict(x=batch_x, x_lengths=x_lengths) # shape (output_seq_len x batch_size x vocab_size)
pred_first = pred[:, 0, :]
output_seq_len = list(pred_first.shape)[0]
pred_expected = pred_first.expand(self.batch_size, output_seq_len, self.output_vocab_size).permute(1, 0, 2)
self.assertTrue(torch.allclose(pred, pred_expected), 'Predictions vary across same-input batch.')