diff --git a/wenet/transformer/encoder_layer.py b/wenet/transformer/encoder_layer.py index 1a909b714..ab4139212 100644 --- a/wenet/transformer/encoder_layer.py +++ b/wenet/transformer/encoder_layer.py @@ -185,7 +185,7 @@ def forward( Args: x (torch.Tensor): (#batch, time, size) - mask (torch.Tensor): Mask tensor for the input (#batch, time). + mask (torch.Tensor): Mask tensor for the input (#batch, timeļ¼Œtime). pos_emb (torch.Tensor): positional encoding, must not be None for ConformerEncoderLayer. output_cache (torch.Tensor): Cache tensor of the output @@ -223,7 +223,6 @@ def forward( mask = mask[:, -chunk:, :] x_att = self.self_attn(x_q, x, x, pos_emb, mask) - if self.concat_after: x_concat = torch.cat((x, x_att), dim=-1) x = residual + self.concat_linear(x_concat) diff --git a/wenet/transformer/subsampling.py b/wenet/transformer/subsampling.py index c13eb7d23..8933d07c3 100644 --- a/wenet/transformer/subsampling.py +++ b/wenet/transformer/subsampling.py @@ -116,7 +116,7 @@ def forward( torch.Tensor: positional encoding """ - x = x.unsqueeze(1) # (b, c, t, f) + x = x.unsqueeze(1) # (b, c=1, t, f) x = self.conv(x) b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) diff --git a/wenet/utils/mask.py b/wenet/utils/mask.py index 36d14b654..a98588835 100644 --- a/wenet/utils/mask.py +++ b/wenet/utils/mask.py @@ -12,6 +12,15 @@ def subsequent_mask( ) -> torch.Tensor: """Create mask for subsequent steps (size, size). + This mask is used only in decoder which works in an auto-regressive mode. + This means the current step could only do attention with its left steps. + + In encoder, fully attention is used when streaming is not necessary and + the sequence is not long. In this case, no attention mask is needed. + + When streaming is need, chunk-based attention is used in encoder. See + subsequent_chunk_mask for the chunk-based attention mask. + Args: size (int): size of mask str device (str): "cpu" or "cuda" or torch.Tensor.device @@ -47,7 +56,7 @@ def subsequent_chunk_mask( torch.Tensor: mask Examples: - >>> subsequent_mask(4, 2) + >>> subsequent_chunk_mask(4, 2) [[1, 1, 0, 0], [1, 1, 0, 0], [1, 1, 1, 1], @@ -113,7 +122,7 @@ def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor, def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor: """Make mask tensor containing indices of padded part. - 1 for padded part and 0 for non-padded part. + See description of make_non_pad_mask. Args: lengths (torch.Tensor): Batch of lengths (B,). @@ -142,6 +151,14 @@ def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor: def make_non_pad_mask(lengths: torch.Tensor) -> torch.Tensor: """Make mask tensor containing indices of non-padded part. + The sequences in a batch may have different lengths. To enable + batch computing, padding is need to make all sequence in same + size. To avoid the padding part pass value to context dependent + block such as attention or convolution , this padding part is + masked. + + This pad_mask is used in both encoder and decoder. + 1 for non-padded part and 0 for padded part. Args: