Skip to content

Commit

Permalink
conformer support final norm configuable (#2682)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct authored Feb 6, 2025
1 parent a344f37 commit 0228877
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions wenet/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
use_sdpa: bool = False,
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
final_norm: bool = True,
):
"""
Args:
Expand Down Expand Up @@ -105,6 +106,7 @@ def __init__(

assert layer_norm_type in ['layer_norm', 'rms_norm']
self.normalize_before = normalize_before
self.final_norm = final_norm
self.after_norm = WENET_NORM_CLASSES[layer_norm_type](output_size,
eps=norm_eps)
self.static_chunk_size = static_chunk_size
Expand Down Expand Up @@ -170,7 +172,7 @@ def forward(
mask_pad)
else:
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
if self.normalize_before:
if self.normalize_before and self.final_norm:
xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just
# return the masks before encoder layers, and the masks will be used
Expand Down Expand Up @@ -285,7 +287,7 @@ def forward_chunk(
# shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
r_cnn_cache.append(new_cnn_cache.unsqueeze(0))
if self.normalize_before:
if self.normalize_before and self.final_norm:
xs = self.after_norm(xs)

# NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
Expand Down Expand Up @@ -476,6 +478,7 @@ def __init__(
n_expert_activated: int = 2,
conv_norm_eps: float = 1e-5,
conv_inner_factor: int = 2,
final_norm: bool = True,
):
"""Construct ConformerEncoder
Expand All @@ -500,7 +503,7 @@ def __init__(
input_layer, pos_enc_layer_type, normalize_before,
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk, gradient_checkpointing,
use_sdpa, layer_norm_type, norm_eps)
use_sdpa, layer_norm_type, norm_eps, final_norm)
activation = WENET_ACTIVATION_CLASSES[activation_type]()

# self-attention module definition
Expand Down

0 comments on commit 0228877

Please sign in to comment.