Skip to content

Commit

Permalink
add conv factor (#2681)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct authored Feb 6, 2025
1 parent 6284076 commit a344f37
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 20 deletions.
12 changes: 6 additions & 6 deletions wenet/transformer/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import torch
from torch import nn

from wenet.utils.class_utils import WENET_NORM_CLASSES


Expand All @@ -34,6 +33,7 @@ def __init__(
causal: bool = False,
bias: bool = True,
norm_eps: float = 1e-5,
conv_inner_factor: int = 2,
):
"""Construct an ConvolutionModule object.
Args:
Expand All @@ -45,7 +45,7 @@ def __init__(

self.pointwise_conv1 = nn.Conv1d(
channels,
2 * channels,
conv_inner_factor * channels,
kernel_size=1,
stride=1,
padding=0,
Expand All @@ -64,12 +64,12 @@ def __init__(
padding = (kernel_size - 1) // 2
self.lorder = 0
self.depthwise_conv = nn.Conv1d(
channels,
channels,
conv_inner_factor * channels // 2,
conv_inner_factor * channels // 2,
kernel_size,
stride=1,
padding=padding,
groups=channels,
groups=conv_inner_factor * channels // 2,
bias=bias,
)

Expand All @@ -83,7 +83,7 @@ def __init__(
self.norm = WENET_NORM_CLASSES[norm](channels, eps=norm_eps)

self.pointwise_conv2 = nn.Conv1d(
channels,
conv_inner_factor * channels // 2,
channels,
kernel_size=1,
stride=1,
Expand Down
26 changes: 12 additions & 14 deletions wenet/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,16 @@

import torch
import torch.utils.checkpoint as ckpt

from wenet.transformer.convolution import ConvolutionModule
from wenet.transformer.encoder_layer import TransformerEncoderLayer
from wenet.transformer.encoder_layer import ConformerEncoderLayer
from wenet.utils.class_utils import (
WENET_EMB_CLASSES,
WENET_MLP_CLASSES,
WENET_NORM_CLASSES,
WENET_SUBSAMPLE_CLASSES,
WENET_ATTENTION_CLASSES,
WENET_ACTIVATION_CLASSES,
)
from wenet.utils.mask import make_pad_mask
from wenet.utils.mask import add_optional_chunk_mask
from wenet.transformer.encoder_layer import (ConformerEncoderLayer,
TransformerEncoderLayer)
from wenet.utils.class_utils import (WENET_ACTIVATION_CLASSES,
WENET_ATTENTION_CLASSES,
WENET_EMB_CLASSES, WENET_MLP_CLASSES,
WENET_NORM_CLASSES,
WENET_SUBSAMPLE_CLASSES)
from wenet.utils.common import mask_to_bias
from wenet.utils.mask import add_optional_chunk_mask, make_pad_mask


class BaseEncoder(torch.nn.Module):
Expand Down Expand Up @@ -479,6 +474,8 @@ def __init__(
mlp_bias: bool = True,
n_expert: int = 8,
n_expert_activated: int = 2,
conv_norm_eps: float = 1e-5,
conv_inner_factor: int = 2,
):
"""Construct ConformerEncoder
Expand Down Expand Up @@ -530,7 +527,8 @@ def __init__(
)
# convolution module definition
convolution_layer_args = (output_size, cnn_module_kernel, activation,
cnn_module_norm, causal, conv_bias)
cnn_module_norm, causal, conv_bias,
conv_norm_eps, conv_inner_factor)

mlp_class = WENET_MLP_CLASSES[mlp_type]
self.encoders = torch.nn.ModuleList([
Expand Down

0 comments on commit a344f37

Please sign in to comment.