From a93af33afb29585d7398265f1076d6dec1f11ab8 Mon Sep 17 00:00:00 2001 From: Dinghao Zhou Date: Fri, 8 Mar 2024 20:57:44 +0800 Subject: [PATCH] [transformer] add norm eps (#2397) --- wenet/transformer/convolution.py | 22 ++++++++++------- wenet/transformer/decoder.py | 17 ++++++++++--- wenet/transformer/decoder_layer.py | 7 +++--- wenet/transformer/encoder.py | 38 ++++++++++++++++++------------ wenet/transformer/encoder_layer.py | 16 +++++++------ 5 files changed, 63 insertions(+), 37 deletions(-) diff --git a/wenet/transformer/convolution.py b/wenet/transformer/convolution.py index ad5f7f15e..2ce3b0699 100644 --- a/wenet/transformer/convolution.py +++ b/wenet/transformer/convolution.py @@ -25,13 +25,16 @@ class ConvolutionModule(nn.Module): """ConvolutionModule in Conformer model.""" - def __init__(self, - channels: int, - kernel_size: int = 15, - activation: nn.Module = nn.ReLU(), - norm: str = "batch_norm", - causal: bool = False, - bias: bool = True): + def __init__( + self, + channels: int, + kernel_size: int = 15, + activation: nn.Module = nn.ReLU(), + norm: str = "batch_norm", + causal: bool = False, + bias: bool = True, + norm_eps: float = 1e-5, + ): """Construct an ConvolutionModule object. Args: channels (int): The number of channels of conv layers. @@ -73,10 +76,11 @@ def __init__(self, assert norm in ['batch_norm', 'layer_norm', 'rms_norm'] if norm == "batch_norm": self.use_layer_norm = False - self.norm = WENET_NORM_CLASSES['batch_norm'](channels) + self.norm = WENET_NORM_CLASSES['batch_norm'](channels, + eps=norm_eps) else: self.use_layer_norm = True - self.norm = WENET_NORM_CLASSES[norm](channels) + self.norm = WENET_NORM_CLASSES[norm](channels, eps=norm_eps) self.pointwise_conv2 = nn.Conv1d( channels, diff --git a/wenet/transformer/decoder.py b/wenet/transformer/decoder.py index c315eb39f..4efd764f3 100644 --- a/wenet/transformer/decoder.py +++ b/wenet/transformer/decoder.py @@ -83,6 +83,7 @@ def __init__( use_sdpa: bool = False, mlp_type: str = 'position_wise_feed_forward', layer_norm_type: str = 'layer_norm', + norm_eps: float = 1e-5, ): super().__init__() attention_dim = encoder_output_size @@ -98,7 +99,7 @@ def __init__( assert layer_norm_type in ['layer_norm', 'rms_norm'] self.normalize_before = normalize_before self.after_norm = WENET_NORM_CLASSES[layer_norm_type](attention_dim, - eps=1e-5) + eps=norm_eps) self.use_output_layer = use_output_layer if use_output_layer: self.output_layer = torch.nn.Linear(attention_dim, vocab_size) @@ -122,6 +123,8 @@ def __init__( activation, mlp_bias), dropout_rate, normalize_before, + layer_norm_type, + norm_eps, ) for _ in range(self.num_blocks) ]) @@ -329,6 +332,8 @@ def __init__( gradient_checkpointing: bool = False, tie_word_embedding: bool = False, use_sdpa: bool = False, + layer_norm_type: str = 'layer_norm', + norm_eps: float = 1e-5, ): super().__init__() @@ -352,7 +357,10 @@ def __init__( value_bias=value_bias, gradient_checkpointing=gradient_checkpointing, tie_word_embedding=tie_word_embedding, - use_sdpa=use_sdpa) + use_sdpa=use_sdpa, + layer_norm_type=layer_norm_type, + norm_eps=norm_eps, + ) self.right_decoder = TransformerDecoder( vocab_size, @@ -373,7 +381,10 @@ def __init__( mlp_bias=mlp_bias, gradient_checkpointing=gradient_checkpointing, tie_word_embedding=tie_word_embedding, - use_sdpa=use_sdpa) + use_sdpa=use_sdpa, + layer_norm_type=layer_norm_type, + norm_eps=norm_eps, + ) def forward( self, diff --git a/wenet/transformer/decoder_layer.py b/wenet/transformer/decoder_layer.py index d28da1dc2..3c0626535 100644 --- a/wenet/transformer/decoder_layer.py +++ b/wenet/transformer/decoder_layer.py @@ -49,6 +49,7 @@ def __init__( dropout_rate: float, normalize_before: bool = True, layer_norm_type: str = 'layer_norm', + norm_eps: float = 1e-5, ): """Construct an DecoderLayer object.""" super().__init__() @@ -57,9 +58,9 @@ def __init__( self.src_attn = src_attn self.feed_forward = feed_forward assert layer_norm_type in ['layer_norm', 'rms_norm'] - self.norm1 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5) - self.norm2 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5) - self.norm3 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5) + self.norm1 = WENET_NORM_CLASSES[layer_norm_type](size, eps=norm_eps) + self.norm2 = WENET_NORM_CLASSES[layer_norm_type](size, eps=norm_eps) + self.norm3 = WENET_NORM_CLASSES[layer_norm_type](size, eps=norm_eps) self.dropout = nn.Dropout(dropout_rate) self.normalize_before = normalize_before diff --git a/wenet/transformer/encoder.py b/wenet/transformer/encoder.py index 4d1dddc2a..be319a89a 100644 --- a/wenet/transformer/encoder.py +++ b/wenet/transformer/encoder.py @@ -57,6 +57,7 @@ def __init__( gradient_checkpointing: bool = False, use_sdpa: bool = False, layer_norm_type: str = 'layer_norm', + norm_eps: float = 1e-5, ): """ Args: @@ -107,7 +108,7 @@ def __init__( assert layer_norm_type in ['layer_norm', 'rms_norm'] self.normalize_before = normalize_before self.after_norm = WENET_NORM_CLASSES[layer_norm_type](output_size, - eps=1e-5) + eps=norm_eps) self.static_chunk_size = static_chunk_size self.use_dynamic_chunk = use_dynamic_chunk self.use_dynamic_left_chunk = use_dynamic_left_chunk @@ -373,6 +374,7 @@ def __init__( use_sdpa: bool = False, mlp_type: str = 'position_wise_feed_forward', layer_norm_type: str = 'layer_norm', + norm_eps: float = 1e-5, ): """ Construct TransformerEncoder @@ -384,22 +386,24 @@ def __init__( input_layer, pos_enc_layer_type, normalize_before, static_chunk_size, use_dynamic_chunk, global_cmvn, use_dynamic_left_chunk, gradient_checkpointing, - use_sdpa, layer_norm_type) + use_sdpa, layer_norm_type, norm_eps) activation = WENET_ACTIVATION_CLASSES[activation_type]() mlp_class = WENET_MLP_CLASSES[mlp_type] self.encoders = torch.nn.ModuleList([ - TransformerEncoderLayer(output_size, - WENET_ATTENTION_CLASSES["selfattn"]( - attention_heads, output_size, - attention_dropout_rate, query_bias, - key_bias, value_bias, use_sdpa), - mlp_class(output_size, linear_units, - dropout_rate, activation, - mlp_bias), - dropout_rate, - normalize_before, - layer_norm_type=layer_norm_type) - for _ in range(num_blocks) + TransformerEncoderLayer( + output_size, + WENET_ATTENTION_CLASSES["selfattn"](attention_heads, + output_size, + attention_dropout_rate, + query_bias, key_bias, + value_bias, use_sdpa), + mlp_class(output_size, linear_units, dropout_rate, activation, + mlp_bias), + dropout_rate, + normalize_before, + layer_norm_type=layer_norm_type, + norm_eps=norm_eps, + ) for _ in range(num_blocks) ]) @@ -439,6 +443,8 @@ def __init__( gradient_checkpointing: bool = False, use_sdpa: bool = False, mlp_type: str = 'position_wise_feed_forward', + layer_norm_type: str = 'layer_norm', + norm_eps: float = 1e-5, ): """Construct ConformerEncoder @@ -463,7 +469,7 @@ def __init__( input_layer, pos_enc_layer_type, normalize_before, static_chunk_size, use_dynamic_chunk, global_cmvn, use_dynamic_left_chunk, gradient_checkpointing, - use_sdpa) + use_sdpa, layer_norm_type, norm_eps) activation = WENET_ACTIVATION_CLASSES[activation_type]() # self-attention module definition @@ -500,5 +506,7 @@ def __init__( *convolution_layer_args) if use_cnn_module else None, dropout_rate, normalize_before, + layer_norm_type=layer_norm_type, + norm_eps=norm_eps, ) for _ in range(num_blocks) ]) diff --git a/wenet/transformer/encoder_layer.py b/wenet/transformer/encoder_layer.py index 31cf59291..68825a52e 100644 --- a/wenet/transformer/encoder_layer.py +++ b/wenet/transformer/encoder_layer.py @@ -47,14 +47,15 @@ def __init__( dropout_rate: float, normalize_before: bool = True, layer_norm_type: str = 'layer_norm', + norm_eps: float = 1e-5, ): """Construct an EncoderLayer object.""" super().__init__() self.self_attn = self_attn self.feed_forward = feed_forward assert layer_norm_type in ['layer_norm', 'rms_norm'] - self.norm1 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5) - self.norm2 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5) + self.norm1 = WENET_NORM_CLASSES[layer_norm_type](size, eps=norm_eps) + self.norm2 = WENET_NORM_CLASSES[layer_norm_type](size, eps=norm_eps) self.dropout = nn.Dropout(dropout_rate) self.size = size self.normalize_before = normalize_before @@ -140,6 +141,7 @@ def __init__( dropout_rate: float = 0.1, normalize_before: bool = True, layer_norm_type: str = 'layer_norm', + norm_eps: float = 1e-5, ): """Construct an EncoderLayer object.""" super().__init__() @@ -149,20 +151,20 @@ def __init__( self.feed_forward_macaron = feed_forward_macaron self.conv_module = conv_module self.norm_ff = WENET_NORM_CLASSES[layer_norm_type]( - size, eps=1e-5) # for the FNN module + size, eps=norm_eps) # for the FNN module self.norm_mha = WENET_NORM_CLASSES[layer_norm_type]( - size, eps=1e-5) # for the MHA module + size, eps=norm_eps) # for the MHA module if feed_forward_macaron is not None: self.norm_ff_macaron = WENET_NORM_CLASSES[layer_norm_type]( - size, eps=1e-5) + size, eps=norm_eps) self.ff_scale = 0.5 else: self.ff_scale = 1.0 if self.conv_module is not None: self.norm_conv = WENET_NORM_CLASSES[layer_norm_type]( - size, eps=1e-5) # for the CNN module + size, eps=norm_eps) # for the CNN module self.norm_final = WENET_NORM_CLASSES[layer_norm_type]( - size, eps=1e-5) # for the final output of the block + size, eps=norm_eps) # for the final output of the block self.dropout = nn.Dropout(dropout_rate) self.size = size self.normalize_before = normalize_before