From 0e092741ef233015b6533ba25fb1d7a93d512dd4 Mon Sep 17 00:00:00 2001 From: Dinghao Zhou Date: Sat, 8 Feb 2025 15:17:21 +0800 Subject: [PATCH] jit works (#2687) --- wenet/firered/attention.py | 14 +++++--------- wenet/firered/model.py | 2 -- wenet/firered/subsampling.py | 8 ++++++-- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/wenet/firered/attention.py b/wenet/firered/attention.py index e53363ba0..3e4907b94 100644 --- a/wenet/firered/attention.py +++ b/wenet/firered/attention.py @@ -49,7 +49,7 @@ def position_encoding(self, raise NotImplementedError('firedasr not support streaming pos encding') - def forward(self, x, offset=None): + def forward(self, x, offset: Optional[Union[int, torch.Tensor]] = None): Tmax, T = self.pe.size(1), x.size(1) pos_emb = self.pe[:, Tmax // 2 - T + 1:Tmax // 2 + T].clone().detach() return self.dropout(x), self.dropout(pos_emb) @@ -103,14 +103,6 @@ def rel_shift(self, x): return x - def forward_qkv( - self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - query = self.layer_norm_q(query) - key = self.layer_norm_k(key) - value = self.layer_norm_v(value) - return super().forward_qkv(query, key, value) - def forward( self, query: torch.Tensor, @@ -138,6 +130,10 @@ def forward( where `cache_t == chunk_size * num_decoding_left_chunks` and `head * d_k == size` """ + query = self.layer_norm_q(query) + key = self.layer_norm_k(key) + value = self.layer_norm_v(value) + q, k, v = self.forward_qkv(query, key, value) q = q.transpose(1, 2) # (batch, time1, head, d_k) k, v, new_cache = self._update_kv_and_cache(k, v, cache) diff --git a/wenet/firered/model.py b/wenet/firered/model.py index e6d3ad46d..d3044142d 100644 --- a/wenet/firered/model.py +++ b/wenet/firered/model.py @@ -46,8 +46,6 @@ def __init__( self.eos = special_tokens["eos"] self.decode_maxlen = self.decoder.embed[1].max_len - del self.encoder.after_norm - @torch.jit.unused def forward_encoder_chunk( self, diff --git a/wenet/firered/subsampling.py b/wenet/firered/subsampling.py index 724b13f9a..b04e693ff 100644 --- a/wenet/firered/subsampling.py +++ b/wenet/firered/subsampling.py @@ -65,6 +65,10 @@ def forward( x_mask = make_non_pad_mask(x_lens).unsqueeze(1) x = torch.nn.functional.pad(x, (0, 0, 0, self.right_context), 'constant', 0.0) - x, pos, _ = super().forward(x, x_mask, offset) + x = x.unsqueeze(1) # (b, c=1, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + x, pos_emb = self.pos_enc(x, offset) mask = x_mask[:, :, :-2:2][:, :, :-2:2] - return x, pos, mask + return x, pos_emb, mask