Skip to content

Commit

Permalink
fix: correct the positional encoding of Transformer in pytorch examples
Browse files Browse the repository at this point in the history
  • Loading branch information
Galaxy-Husky authored and lantiga committed Nov 12, 2024
1 parent b0aa504 commit 7038b8d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/lightning/pytorch/demos/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def forward(self, x: Tensor) -> Tensor:
# TODO: Could make this a `nn.Parameter` with `requires_grad=False`
self.pe = self._init_pos_encoding(device=x.device)

x = x + self.pe[: x.size(0), :]
x = x + self.pe[:, x.size(1)]
return self.dropout(x)

def _init_pos_encoding(self, device: torch.device) -> Tensor:
Expand All @@ -97,7 +97,7 @@ def _init_pos_encoding(self, device: torch.device) -> Tensor:
div_term = torch.exp(torch.arange(0, self.dim, 2, device=device).float() * (-math.log(10000.0) / self.dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
pe = pe.unsqueeze(0)
return pe


Expand Down

0 comments on commit 7038b8d

Please sign in to comment.