Skip to content

Commit

Permalink
[firered] simplified code (#2686)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct authored Feb 8, 2025
1 parent de41dd7 commit b67088a
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 14 deletions.
11 changes: 0 additions & 11 deletions wenet/firered/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,6 @@ def __init__(
self.eos = special_tokens["eos"]
self.decode_maxlen = self.decoder.embed[1].max_len

# fix subsampling
odim = 32
idim = 80
self.encoder.embed.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2), torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 2), torch.nn.ReLU())
self.encoder.embed.out = torch.nn.Sequential(
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2),
self.encoder.output_size()))

# fix final norm in conformer
del self.encoder.after_norm

@torch.jit.unused
Expand Down
3 changes: 1 addition & 2 deletions wenet/firered/subsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def __init__(self,
pos_enc_class: torch.nn.Module,
odim: int = 32):
"""Construct an Conv2dSubsampling4 object."""

super().__init__(idim, d_model, dropout_rate, pos_enc_class)
del self.conv, self.out
self.conv = torch.nn.Sequential(
Expand All @@ -46,7 +45,7 @@ def __init__(self,
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), d_model))
self.pos_enc = pos_enc_class
# The right context for every conv layer is computed by:
# (kernel_size - 1) * frame_rate_of_this_layer
Expand Down
2 changes: 1 addition & 1 deletion wenet/utils/class_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
"crossattn": MultiHeadedCrossAttention,
'shaw_rel_selfattn': ShawRelPositionMultiHeadedAttention,
'rope_abs_selfattn': RopeMultiHeadedAttention,
'firered_rel_selfattn': FiredRelPositionMultiHeadedAttention,
'firered_rel_selfattn': FiredRelPositionMultiHeadedAttention
}

WENET_MLP_CLASSES = {
Expand Down

0 comments on commit b67088a

Please sign in to comment.