-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathposition_encoding.py
60 lines (49 loc) · 1.72 KB
/
position_encoding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import math
import torch
import torch.nn as nn
class PositionEncoding(nn.Module):
"""
Add positional information to input tensor.
:Examples:
>>> model = PositionEncoding(d_model=6, max_len=10, dropout=0)
>>> test_input1 = torch.zeros(3, 10, 6)
>>> output1 = model(test_input1)
>>> output1.size()
>>> test_input2 = torch.zeros(5, 3, 9, 6)
>>> output2 = model(test_input2)
>>> output2.size()
"""
def __init__(self, n_filters=128, max_len=500):
"""
:param n_filters: same with input hidden size
:param max_len: maximum sequence length
"""
super(PositionEncoding, self).__init__()
# Compute the positional encodings once in log space.
pe = torch.zeros(max_len, n_filters) # (L, D)
position = torch.arange(0, max_len).float().unsqueeze(1)
div_term = torch.exp(torch.arange(0, n_filters, 2).float() * - (math.log(10000.0) / n_filters))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe) # buffer is a tensor, not a variable, (L, D)
def forward(self, x):
"""
:Input: (*, L, D)
:Output: (*, L, D) the same size as input
"""
pe = self.pe.data[:x.size(-2), :] # (#x.size(-2), n_filters)
extra_dim = len(x.size()) - 2
for _ in range(extra_dim):
pe = pe.unsqueeze(0)
x = x + pe
return x
def test_pos_enc():
mdl = PositionEncoding()
batch_size = 8
n_channels = 128
n_items = 60
input = torch.ones(batch_size, n_items, n_channels)
out = mdl(input)
print(out)
if __name__ == '__main__':
test_pos_enc()