Skip to content

Commit

Permalink
[firred] bring fired aed to wenet (#2680)
Browse files Browse the repository at this point in the history
* add att for fired

* merge main

* fix conflict

* merge main

* fix conflict

* convert work

* add missing files

* precision match!

* fix typo
  • Loading branch information
Mddct authored Feb 8, 2025
1 parent 59dc505 commit de41dd7
Show file tree
Hide file tree
Showing 10 changed files with 874 additions and 46 deletions.
186 changes: 186 additions & 0 deletions wenet/firered/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Copyright (c) 2025 Wenet Community. authors: Mddct(Dinghao Zhou)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Optional, Tuple, Union

import torch
from wenet.transformer.attention import (T_CACHE,
RelPositionMultiHeadedAttention)
from wenet.transformer.embedding import PositionalEncoding


class FireRedRelPositionalEncoding(PositionalEncoding):

def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):

super().__init__(d_model, dropout_rate, max_len)
pe_positive = torch.zeros(max_len, d_model, requires_grad=False)
pe_negative = torch.zeros(max_len, d_model, requires_grad=False)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(
torch.arange(0, d_model, 2).float() *
-(torch.log(torch.tensor(10000.0)).item() / d_model))
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)

pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.register_buffer('pe', pe)

def position_encoding(self,
offset: Union[int, torch.Tensor],
size: int,
apply_dropout: bool = True) -> torch.Tensor:

raise NotImplementedError('firedasr not support streaming pos encding')

def forward(self, x, offset=None):
Tmax, T = self.pe.size(1), x.size(1)
pos_emb = self.pe[:, Tmax // 2 - T + 1:Tmax // 2 + T].clone().detach()
return self.dropout(x), self.dropout(pos_emb)


class FiredRelPositionMultiHeadedAttention(RelPositionMultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""

def __init__(self,
n_head: int,
n_feat: int,
dropout_rate: float,
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
use_sdpa: bool = False,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate, query_bias, key_bias,
value_bias, use_sdpa, n_kv_head, head_dim)

self.layer_norm_q = torch.nn.LayerNorm(n_feat)
self.layer_norm_k = torch.nn.LayerNorm(n_feat)
self.layer_norm_v = torch.nn.LayerNorm(n_feat)

def rel_shift(self, x):
"""Compute relative positinal encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, size).
Returns:
torch.Tensor: Output tensor.
"""

zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
device=x.device,
dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)

x_padded = x_padded.view(x.size()[0],
x.size()[1],
x.size(3) + 1, x.size(2))
x = x_padded[:, :, 1:].view_as(x)
x = x[:, :, :, :x.size(-1) // 2 + 1]

return x

def forward_qkv(
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
query = self.layer_norm_q(query)
key = self.layer_norm_k(key)
value = self.layer_norm_v(value)
return super().forward_qkv(query, key, value)

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
pos_emb: torch.Tensor = torch.empty(0),
cache: T_CACHE = (torch.zeros((0, 0, 0, 0)), torch.zeros((0, 0, 0, 0)))
) -> Tuple[torch.Tensor, T_CACHE]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
pos_emb (torch.Tensor): Positional embedding tensor
(#batch, time2, size).
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)
k, v, new_cache = self._update_kv_and_cache(k, v, cache)

n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, time1, d_k)

# (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)

# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self.rel_shift(matrix_bd)
if not self.use_sdpa:
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))

scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k) # (batch, head, time1, time2)

return self.forward_attention(v, scores, mask), new_cache
else:
# NOTE(Mddct): we need mask bias, not boolean mask
assert mask.dtype != torch.bool
mask = mask.unsqueeze(1)
# matrix_bd as a mask bias
mask = (matrix_bd + mask) / math.sqrt(self.d_k)
output = torch.nn.functional.scaled_dot_product_attention(
q_with_bias_u,
k,
v,
attn_mask=mask,
dropout_p=self.dropout_rate,
scale=1 / math.sqrt(self.d_k),
)
output = (output.transpose(1, 2).contiguous().view(
query.size(0), -1,
self.h * self.d_k)) # (batch, time1, d_model)
return self.linear_out(output), new_cache
Loading

0 comments on commit de41dd7

Please sign in to comment.