From b67088af5f1e6e5b7c8ab9d2bd15c17d5e1785a2 Mon Sep 17 00:00:00 2001 From: Dinghao Zhou Date: Sat, 8 Feb 2025 13:42:48 +0800 Subject: [PATCH] [firered] simplified code (#2686) --- wenet/firered/model.py | 11 ----------- wenet/firered/subsampling.py | 3 +-- wenet/utils/class_utils.py | 2 +- 3 files changed, 2 insertions(+), 14 deletions(-) diff --git a/wenet/firered/model.py b/wenet/firered/model.py index f10a2e00d..e6d3ad46d 100644 --- a/wenet/firered/model.py +++ b/wenet/firered/model.py @@ -46,17 +46,6 @@ def __init__( self.eos = special_tokens["eos"] self.decode_maxlen = self.decoder.embed[1].max_len - # fix subsampling - odim = 32 - idim = 80 - self.encoder.embed.conv = torch.nn.Sequential( - torch.nn.Conv2d(1, odim, 3, 2), torch.nn.ReLU(), - torch.nn.Conv2d(odim, odim, 3, 2), torch.nn.ReLU()) - self.encoder.embed.out = torch.nn.Sequential( - torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), - self.encoder.output_size())) - - # fix final norm in conformer del self.encoder.after_norm @torch.jit.unused diff --git a/wenet/firered/subsampling.py b/wenet/firered/subsampling.py index 192ec4732..724b13f9a 100644 --- a/wenet/firered/subsampling.py +++ b/wenet/firered/subsampling.py @@ -36,7 +36,6 @@ def __init__(self, pos_enc_class: torch.nn.Module, odim: int = 32): """Construct an Conv2dSubsampling4 object.""" - super().__init__(idim, d_model, dropout_rate, pos_enc_class) del self.conv, self.out self.conv = torch.nn.Sequential( @@ -46,7 +45,7 @@ def __init__(self, torch.nn.ReLU(), ) self.out = torch.nn.Sequential( - torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)) + torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), d_model)) self.pos_enc = pos_enc_class # The right context for every conv layer is computed by: # (kernel_size - 1) * frame_rate_of_this_layer diff --git a/wenet/utils/class_utils.py b/wenet/utils/class_utils.py index 6d9dcbe13..59c67f816 100644 --- a/wenet/utils/class_utils.py +++ b/wenet/utils/class_utils.py @@ -76,7 +76,7 @@ "crossattn": MultiHeadedCrossAttention, 'shaw_rel_selfattn': ShawRelPositionMultiHeadedAttention, 'rope_abs_selfattn': RopeMultiHeadedAttention, - 'firered_rel_selfattn': FiredRelPositionMultiHeadedAttention, + 'firered_rel_selfattn': FiredRelPositionMultiHeadedAttention } WENET_MLP_CLASSES = {