-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[firred] bring fired aed to wenet (#2680)
* add att for fired * merge main * fix conflict * merge main * fix conflict * convert work * add missing files * precision match! * fix typo
- Loading branch information
Showing
10 changed files
with
874 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
# Copyright (c) 2025 Wenet Community. authors: Mddct(Dinghao Zhou) | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import math | ||
from typing import Optional, Tuple, Union | ||
|
||
import torch | ||
from wenet.transformer.attention import (T_CACHE, | ||
RelPositionMultiHeadedAttention) | ||
from wenet.transformer.embedding import PositionalEncoding | ||
|
||
|
||
class FireRedRelPositionalEncoding(PositionalEncoding): | ||
|
||
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): | ||
|
||
super().__init__(d_model, dropout_rate, max_len) | ||
pe_positive = torch.zeros(max_len, d_model, requires_grad=False) | ||
pe_negative = torch.zeros(max_len, d_model, requires_grad=False) | ||
position = torch.arange(0, max_len).unsqueeze(1).float() | ||
div_term = torch.exp( | ||
torch.arange(0, d_model, 2).float() * | ||
-(torch.log(torch.tensor(10000.0)).item() / d_model)) | ||
pe_positive[:, 0::2] = torch.sin(position * div_term) | ||
pe_positive[:, 1::2] = torch.cos(position * div_term) | ||
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) | ||
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) | ||
|
||
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) | ||
pe_negative = pe_negative[1:].unsqueeze(0) | ||
pe = torch.cat([pe_positive, pe_negative], dim=1) | ||
self.register_buffer('pe', pe) | ||
|
||
def position_encoding(self, | ||
offset: Union[int, torch.Tensor], | ||
size: int, | ||
apply_dropout: bool = True) -> torch.Tensor: | ||
|
||
raise NotImplementedError('firedasr not support streaming pos encding') | ||
|
||
def forward(self, x, offset=None): | ||
Tmax, T = self.pe.size(1), x.size(1) | ||
pos_emb = self.pe[:, Tmax // 2 - T + 1:Tmax // 2 + T].clone().detach() | ||
return self.dropout(x), self.dropout(pos_emb) | ||
|
||
|
||
class FiredRelPositionMultiHeadedAttention(RelPositionMultiHeadedAttention): | ||
"""Multi-Head Attention layer with relative position encoding. | ||
Paper: https://arxiv.org/abs/1901.02860 | ||
Args: | ||
n_head (int): The number of heads. | ||
n_feat (int): The number of features. | ||
dropout_rate (float): Dropout rate. | ||
""" | ||
|
||
def __init__(self, | ||
n_head: int, | ||
n_feat: int, | ||
dropout_rate: float, | ||
query_bias: bool = True, | ||
key_bias: bool = True, | ||
value_bias: bool = True, | ||
use_sdpa: bool = False, | ||
n_kv_head: Optional[int] = None, | ||
head_dim: Optional[int] = None): | ||
"""Construct an RelPositionMultiHeadedAttention object.""" | ||
super().__init__(n_head, n_feat, dropout_rate, query_bias, key_bias, | ||
value_bias, use_sdpa, n_kv_head, head_dim) | ||
|
||
self.layer_norm_q = torch.nn.LayerNorm(n_feat) | ||
self.layer_norm_k = torch.nn.LayerNorm(n_feat) | ||
self.layer_norm_v = torch.nn.LayerNorm(n_feat) | ||
|
||
def rel_shift(self, x): | ||
"""Compute relative positinal encoding. | ||
Args: | ||
x (torch.Tensor): Input tensor (batch, time, size). | ||
Returns: | ||
torch.Tensor: Output tensor. | ||
""" | ||
|
||
zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1), | ||
device=x.device, | ||
dtype=x.dtype) | ||
x_padded = torch.cat([zero_pad, x], dim=-1) | ||
|
||
x_padded = x_padded.view(x.size()[0], | ||
x.size()[1], | ||
x.size(3) + 1, x.size(2)) | ||
x = x_padded[:, :, 1:].view_as(x) | ||
x = x[:, :, :, :x.size(-1) // 2 + 1] | ||
|
||
return x | ||
|
||
def forward_qkv( | ||
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor | ||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
query = self.layer_norm_q(query) | ||
key = self.layer_norm_k(key) | ||
value = self.layer_norm_v(value) | ||
return super().forward_qkv(query, key, value) | ||
|
||
def forward( | ||
self, | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
value: torch.Tensor, | ||
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), | ||
pos_emb: torch.Tensor = torch.empty(0), | ||
cache: T_CACHE = (torch.zeros((0, 0, 0, 0)), torch.zeros((0, 0, 0, 0))) | ||
) -> Tuple[torch.Tensor, T_CACHE]: | ||
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding. | ||
Args: | ||
query (torch.Tensor): Query tensor (#batch, time1, size). | ||
key (torch.Tensor): Key tensor (#batch, time2, size). | ||
value (torch.Tensor): Value tensor (#batch, time2, size). | ||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or | ||
(#batch, time1, time2), (0, 0, 0) means fake mask. | ||
pos_emb (torch.Tensor): Positional embedding tensor | ||
(#batch, time2, size). | ||
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), | ||
where `cache_t == chunk_size * num_decoding_left_chunks` | ||
and `head * d_k == size` | ||
Returns: | ||
torch.Tensor: Output tensor (#batch, time1, d_model). | ||
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) | ||
where `cache_t == chunk_size * num_decoding_left_chunks` | ||
and `head * d_k == size` | ||
""" | ||
q, k, v = self.forward_qkv(query, key, value) | ||
q = q.transpose(1, 2) # (batch, time1, head, d_k) | ||
k, v, new_cache = self._update_kv_and_cache(k, v, cache) | ||
|
||
n_batch_pos = pos_emb.size(0) | ||
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) | ||
p = p.transpose(1, 2) # (batch, head, time1, d_k) | ||
|
||
# (batch, head, time1, d_k) | ||
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) | ||
# (batch, head, time1, d_k) | ||
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) | ||
|
||
# compute matrix b and matrix d | ||
# (batch, head, time1, time2) | ||
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) | ||
matrix_bd = self.rel_shift(matrix_bd) | ||
if not self.use_sdpa: | ||
# compute attention score | ||
# first compute matrix a and matrix c | ||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3 | ||
# (batch, head, time1, time2) | ||
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) | ||
|
||
scores = (matrix_ac + matrix_bd) / math.sqrt( | ||
self.d_k) # (batch, head, time1, time2) | ||
|
||
return self.forward_attention(v, scores, mask), new_cache | ||
else: | ||
# NOTE(Mddct): we need mask bias, not boolean mask | ||
assert mask.dtype != torch.bool | ||
mask = mask.unsqueeze(1) | ||
# matrix_bd as a mask bias | ||
mask = (matrix_bd + mask) / math.sqrt(self.d_k) | ||
output = torch.nn.functional.scaled_dot_product_attention( | ||
q_with_bias_u, | ||
k, | ||
v, | ||
attn_mask=mask, | ||
dropout_p=self.dropout_rate, | ||
scale=1 / math.sqrt(self.d_k), | ||
) | ||
output = (output.transpose(1, 2).contiguous().view( | ||
query.size(0), -1, | ||
self.h * self.d_k)) # (batch, time1, d_model) | ||
return self.linear_out(output), new_cache |
Oops, something went wrong.