Skip to content

Commit

Permalink
jit works (#2687)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct authored Feb 8, 2025
1 parent b67088a commit 0e09274
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 13 deletions.
14 changes: 5 additions & 9 deletions wenet/firered/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def position_encoding(self,

raise NotImplementedError('firedasr not support streaming pos encding')

def forward(self, x, offset=None):
def forward(self, x, offset: Optional[Union[int, torch.Tensor]] = None):
Tmax, T = self.pe.size(1), x.size(1)
pos_emb = self.pe[:, Tmax // 2 - T + 1:Tmax // 2 + T].clone().detach()
return self.dropout(x), self.dropout(pos_emb)
Expand Down Expand Up @@ -103,14 +103,6 @@ def rel_shift(self, x):

return x

def forward_qkv(
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
query = self.layer_norm_q(query)
key = self.layer_norm_k(key)
value = self.layer_norm_v(value)
return super().forward_qkv(query, key, value)

def forward(
self,
query: torch.Tensor,
Expand Down Expand Up @@ -138,6 +130,10 @@ def forward(
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
query = self.layer_norm_q(query)
key = self.layer_norm_k(key)
value = self.layer_norm_v(value)

q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)
k, v, new_cache = self._update_kv_and_cache(k, v, cache)
Expand Down
2 changes: 0 additions & 2 deletions wenet/firered/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ def __init__(
self.eos = special_tokens["eos"]
self.decode_maxlen = self.decoder.embed[1].max_len

del self.encoder.after_norm

@torch.jit.unused
def forward_encoder_chunk(
self,
Expand Down
8 changes: 6 additions & 2 deletions wenet/firered/subsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ def forward(
x_mask = make_non_pad_mask(x_lens).unsqueeze(1)
x = torch.nn.functional.pad(x, (0, 0, 0, self.right_context),
'constant', 0.0)
x, pos, _ = super().forward(x, x_mask, offset)
x = x.unsqueeze(1) # (b, c=1, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
x, pos_emb = self.pos_enc(x, offset)
mask = x_mask[:, :, :-2:2][:, :, :-2:2]
return x, pos, mask
return x, pos_emb, mask

0 comments on commit 0e09274

Please sign in to comment.