Skip to content

Commit

Permalink
[doc] refine comments for mask.py (#142)
Browse files Browse the repository at this point in the history
* refine comments for mask.py

* update
  • Loading branch information
placebokkk authored Feb 4, 2021
1 parent 5dfc0a9 commit b8dad2b
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 5 deletions.
3 changes: 1 addition & 2 deletions wenet/transformer/encoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def forward(
Args:
x (torch.Tensor): (#batch, time, size)
mask (torch.Tensor): Mask tensor for the input (#batch, time).
mask (torch.Tensor): Mask tensor for the input (#batch, time,time).
pos_emb (torch.Tensor): positional encoding, must not be None
for ConformerEncoderLayer.
output_cache (torch.Tensor): Cache tensor of the output
Expand Down Expand Up @@ -223,7 +223,6 @@ def forward(
mask = mask[:, -chunk:, :]

x_att = self.self_attn(x_q, x, x, pos_emb, mask)

if self.concat_after:
x_concat = torch.cat((x, x_att), dim=-1)
x = residual + self.concat_linear(x_concat)
Expand Down
2 changes: 1 addition & 1 deletion wenet/transformer/subsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def forward(
torch.Tensor: positional encoding
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = x.unsqueeze(1) # (b, c=1, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
Expand Down
21 changes: 19 additions & 2 deletions wenet/utils/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@ def subsequent_mask(
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size).
This mask is used only in decoder which works in an auto-regressive mode.
This means the current step could only do attention with its left steps.
In encoder, fully attention is used when streaming is not necessary and
the sequence is not long. In this case, no attention mask is needed.
When streaming is need, chunk-based attention is used in encoder. See
subsequent_chunk_mask for the chunk-based attention mask.
Args:
size (int): size of mask
str device (str): "cpu" or "cuda" or torch.Tensor.device
Expand Down Expand Up @@ -47,7 +56,7 @@ def subsequent_chunk_mask(
torch.Tensor: mask
Examples:
>>> subsequent_mask(4, 2)
>>> subsequent_chunk_mask(4, 2)
[[1, 1, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 1],
Expand Down Expand Up @@ -113,7 +122,7 @@ def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor,
def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
"""Make mask tensor containing indices of padded part.
1 for padded part and 0 for non-padded part.
See description of make_non_pad_mask.
Args:
lengths (torch.Tensor): Batch of lengths (B,).
Expand Down Expand Up @@ -142,6 +151,14 @@ def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
def make_non_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
"""Make mask tensor containing indices of non-padded part.
The sequences in a batch may have different lengths. To enable
batch computing, padding is need to make all sequence in same
size. To avoid the padding part pass value to context dependent
block such as attention or convolution , this padding part is
masked.
This pad_mask is used in both encoder and decoder.
1 for non-padded part and 0 for padded part.
Args:
Expand Down

0 comments on commit b8dad2b

Please sign in to comment.