diff --git a/egs/librispeech/ASR/zipformer/test_cope.py b/egs/librispeech/ASR/zipformer/test_cope.py index 5eb6ccfd98..00acbcb1de 100755 --- a/egs/librispeech/ASR/zipformer/test_cope.py +++ b/egs/librispeech/ASR/zipformer/test_cope.py @@ -1,14 +1,17 @@ #!/usr/bin/env python3 +import torch from zipformer import ContextualPositionalEncoding def test(): embed_dim = 5 npos_max = 10 + cope = ContextualPositionalEncoding(embed_dim=embed_dim, npos_max=npos_max) - q = torch.rand(2, 3, 4, embed_dim) - qk = torch.rand(2, 3, 4, 6) + q = torch.rand(2, 3, npos_max, embed_dim) + + qk = torch.rand(2, 3, npos_max, npos_max) p = cope(q=q, qk=qk) print(p.shape) @@ -19,4 +22,5 @@ def main(): if __name__ == "__main__": + torch.manual_seed(20240703) main() diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 6a94e3ab00..a62cd54f15 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1402,26 +1402,59 @@ def forward(self, q: torch.Tensor, qk: torch.Tensor) -> torch.Tensor: qk (torch.Tensor): A tensor of shape (head, batch, time1, time2) Returns: Return a tensor of shape (head, batch, time1, npos_max) + + Note the implementation assumes time1 == time2 and npos_max <= time2. + The implementation is reasonable for the streaming ASR encoder where + only self attention is used. """ + # The implementation on page 13 Listing 1 from the paper does not use + # a mask to ensure that only gates[:, :, i, j] where j < i is computed. + # + # Here we fix that by introducing a mask + mask = torch.triu( + torch.full((qk.size(3), qk.size(3)), True, dtype=torch.bool), + diagonal=0, + ) + # + # if qk.size(3) is 4, mask is + # + # tensor([[ True, True, True, True], + # [False, True, True, True], + # [False, False, True, True], + # [False, False, False, True]]) + # + # mask[i, j] is True if i >= j gates = torch.sigmoid(qk) - pos = gates.sum(dim=-1, keepdim=True) # (head, batch, dim1, 1) - # Note: We don't use cumulative sum here for non-streaming - # speech recognition + + # We don't use an in-place operation here for the sake of autograd + gates = gates.masked_fill(mask, 0) + + # cumsum() is an inclusive sum in PyTorch + pos = gates.flip(-1).cumsum(dim=-1).flip(-1) # (head, batch, time1, time2) + # pos[:, :, i, j] should be 0 for j >= i + # pos[:, :, i, j] contains the position between i and j. If gates + # is a 0-1 matrix, then pos[:, :, i, j] equals to i - j (for j < i) + # Note: The paper says on page 4 it equals to i - j + 1 instead of i - j. pos = pos.clamp(max=self.npos_max - 1) pos_ceil = pos.ceil().long() pos_floor = pos.floor().long() + + # We assume query_head_dim equals to embed_dim + logits_int = torch.matmul( q, self.embedding.weight.t() ) # (head, batch, time1, npos_max) - logits_cell = logits_int.gather(-1, pos_ceil.expand(*logits_int.shape)) - logits_floor = logits_int.gather(-1, pos_floor.expand(*logits_int.shape)) + + # We assume that npos_max <= time2 + logits_cell = logits_int.gather(-1, pos_ceil) + logits_floor = logits_int.gather(-1, pos_floor) w = pos - pos_floor - return logits_cell * w + logits_floor * (1 - w) - def streaming_forward(self): - raise RuntimeError("To be implemented") + # Note: The code in the paper on page 13 is correct + # while the description on page 4 equation (5) is wrong + return logits_cell * w + logits_floor * (1 - w) class CompactRelPositionalEncoding(torch.nn.Module):