diff --git a/test/wenet/dataset/test_datapipes.py b/test/wenet/dataset/test_datapipes.py index f269788c9e..d36afa9c82 100644 --- a/test/wenet/dataset/test_datapipes.py +++ b/test/wenet/dataset/test_datapipes.py @@ -7,9 +7,10 @@ from wenet.dataset.datapipes import (RepeatDatapipe, SortDataPipe, WenetRawDatasetSource, WenetTarShardDatasetSource) -from wenet.dataset.processor import (DynamicBatchWindow, decode_wav, padding, - parse_json, compute_fbank, - detect_language, detect_task) +from wenet.dataset.processor import (DynamicBatchWindow, decode_wav, + feats_length_fn, padding, parse_json, + compute_fbank, detect_language, + detect_task) @pytest.mark.parametrize("data_list", [ @@ -106,7 +107,8 @@ def test_dynamic_batch_datapipe(data_list): max_frames_in_batch = 10000 dataset = dataset.dynamic_batch( window_class=DynamicBatchWindow(max_frames_in_batch), - wrapper_class=padding) + wrapper_class=padding, + elem_size_fn=feats_length_fn) dataloader = torch.utils.data.DataLoader(dataset, batch_size=None, diff --git a/wenet/AudioLLM/audiollm_model.py b/wenet/AudioLLM/audiollm_model.py new file mode 100644 index 0000000000..3208d46069 --- /dev/null +++ b/wenet/AudioLLM/audiollm_model.py @@ -0,0 +1,312 @@ +from typing import Dict, List, Optional, Union +import torch +from wenet.LLM.sampler import sampler +from wenet.transformer.encoder import TransformerEncoder +from wenet.AudioLLM.bottleneck import ConvLinearBottleNeck +from wenet.LLM.decoder import DecoderOnly +from wenet.utils.common import IGNORE_ID, th_accuracy +from wenet.utils.mask import make_pad_mask, subsequent_mask + + +class AudioLLM(torch.nn.Module): + def __init__( + self, + vocab_size: int, + encoder: TransformerEncoder, + decoder: DecoderOnly, + special_tokens: dict, + tie_word_embedding: bool = False, + linear_bias: bool = False, + ignore_id: int = IGNORE_ID, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + bottleneck_type: str = "conv-linear", + freeze_encoder: bool = True, + freeze_llm_embed: bool = True, + freeze_decoder: bool = True, + **kwargs, + ) -> None: + super().__init__() + del special_tokens + self.encoder = encoder + self.decoder = decoder + self.embed = torch.nn.Embedding(vocab_size, decoder.hidden_size) + self.out = torch.nn.Linear(decoder.hidden_size, + vocab_size, + bias=linear_bias) + if bottleneck_type == "conv-linear": + self.bottleneck = ConvLinearBottleNeck(encoder.output_size(), decoder.hidden_size, **kwargs) + self.vocab_size = vocab_size + self.criterion_att = torch.nn.CrossEntropyLoss(ignore_index=ignore_id, + reduction='sum' if length_normalized_loss else 'mean', + label_smoothing=lsm_weight) + self.tie_word_embedding = tie_word_embedding + self.ignore_id = ignore_id + + self.freeze_encoder = freeze_encoder + if freeze_encoder: + self.freeze_parameters(self.encoder) + self.encoder.eval() + self.freeze_decoder = freeze_decoder + if freeze_decoder: + self.freeze_parameters(self.decoder) + self.decoder.eval() + self.freeze_llm_embed = freeze_llm_embed + if freeze_llm_embed: + self.freeze_parameters(self.embed) + self.freeze_parameters(self.out) + + def train(self, mode: bool = True): + self.bottleneck.train(mode) + if not self.freeze_encoder: + self.encoder.train(mode) + if not self.freeze_decoder: + self.encoder.train(mode) + + def freeze_parameters(self, moudle: torch.nn.Module): + for _, param in moudle.named_parameters(): + param.requires_grad = False + + def extract_audio_features(self, audio, audio_lengths): + output, masks = self.encoder(audio, audio_lengths) + output, sub_lengths = self.bottleneck(output, masks.sum(-1)) + return output, sub_lengths + + @torch.jit.unused + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + """ Forward for training + """ + prefix_tokens = batch['prefix_tokens'].to(device) + audio_feats = batch['audio_feats'].to(device) + suffix_tokens = batch['suffix_tokens'].to(device) + prefix_target = batch['prefix_target'].to(device) + suffix_target = batch['suffix_target'].to(device) + prefix_tokens_lengths = batch['prefix_tokens_lengths'].to(device) + audio_feats_lengths = batch['audio_feats_lengths'].to(device) + suffix_tokens_lengths = batch['suffix_tokens_lengths'].to(device) + + audio_embeds, audio_lengths = self.extract_audio_features(audio_feats, audio_feats_lengths) + + prefix_tokens_embeds = self.embed(prefix_tokens) + suffix_tokens_embeds = self.embed(suffix_tokens) + + # | prefix_embeds | audio_embeds | suffix_embeds | paddings | + b, c = prefix_tokens_embeds.size(0), prefix_tokens_embeds.size(2) + prefix_t = prefix_tokens_embeds.size(1) + audio_t = audio_embeds.size(1) + suffix_t = suffix_tokens_embeds.size(1) + inputs_lengths = prefix_t + audio_t + suffix_t + input_embeds = torch.ones([b, inputs_lengths, c], device=device) + targets = torch.ones([b, inputs_lengths], dtype=torch.long, device=device) + for i in range(b): + index = 0 + input_embeds[i, :prefix_tokens_lengths[i]] = prefix_tokens_embeds[i, :prefix_tokens_lengths[i]] + targets[i, :prefix_tokens_lengths[i]-1] = prefix_target[i, :prefix_tokens_lengths[i]-1] + + index += prefix_tokens_lengths[i] + input_embeds[i, index:index + audio_lengths[i]] = audio_embeds[i, :audio_lengths[i]] + targets[i, index-1:index + audio_lengths[i]-1] = self.ignore_id + + index += audio_lengths[i] + input_embeds[i, index:index + suffix_tokens_lengths[i]] = suffix_tokens_embeds[i, :suffix_tokens_lengths[i]] + targets[i, index-1:index + suffix_tokens_lengths[i]] = suffix_target[i, :suffix_tokens_lengths[i]+1] + + index += suffix_tokens_lengths[i] + input_embeds[i, index:] = torch.cat([prefix_tokens_embeds[i, prefix_tokens_lengths[i]:], + audio_embeds[i, audio_lengths[i]:], + suffix_tokens_embeds[i, suffix_tokens_lengths[i]:]], dim=0) + targets[i,index:] = self.ignore_id + + mask = ~make_pad_mask(audio_lengths + prefix_tokens_lengths + suffix_tokens_lengths, + max_len=inputs_lengths, + pad_type="right").unsqueeze( + 1) # (B,1,L) + + causal_mask = subsequent_mask( + mask.size(-1), device=mask.device).unsqueeze(0) # (1,L,L) + att_mask = causal_mask & mask # (B, L, L) + + decoder_out = self.out(self.decoder(input_embeds, + att_mask)[0]) # (B, L, vocab_size) + + loss = self.criterion_att(decoder_out.view(-1, self.vocab_size), + targets.view(-1)) + acc = th_accuracy(decoder_out.view(-1, self.vocab_size), + targets, + ignore_label=self.ignore_id) + + return { + "loss": loss, + "ppl": torch.exp(loss.detach()), + "th_accuracy": acc + } + + def tie_or_clone_weights(self, jit_mode: bool): + if not self.tie_word_embedding: + return + if jit_mode: + self.out.weight = torch.nn.Parameter(self.embed.weight.clone()) + else: + self.out.weight = self.embed.weight + # TODO(Mddct): whether to deal bias for other llm model + + @torch.jit.unused + @torch.inference_mode() + def generate( + self, + batch: dict, + device: torch.device, + stop_tokens: List[int], + dtype: torch.dtype = torch.float32, + output_len: int = 100, + temperature: Union[float, None] = 0.95, + top_p: float = 1.0, + top_k: int = 100, + ) -> List[List[int]]: + """Generates responses for given prompts using Gemma model.""" + # If a single prompt is provided, treat it as a batch of 1. + + prefix_tokens = batch['prefix_tokens'].to(device) + audio_feats = batch['audio_feats'].to(device) + suffix_tokens = batch['suffix_tokens'].to(device) + prefix_tokens_lengths = batch['prefix_tokens_lengths'].to(device) + audio_feats_lengths = batch['audio_feats_lengths'].to(device) + suffix_tokens_lengths = batch['suffix_tokens_lengths'].to(device) + + audio_embeds, audio_lengths = self.extract_audio_features(audio_feats, audio_feats_lengths) + + prefix_tokens_embeds = self.embed(prefix_tokens) + suffix_tokens_embeds = self.embed(suffix_tokens) + + b, c = prefix_tokens_embeds.size(0), prefix_tokens_embeds.size(2) + input_embeds_list = [] + token_ids_list = [] + for i in range(b): + input_embeds = [] + token_ids = [] + input_embeds.append(prefix_tokens_embeds[i, :prefix_tokens_lengths[i]]) + token_ids.append(prefix_tokens[i, :prefix_tokens_lengths[i]]) + input_embeds.append(audio_embeds[i, :audio_lengths[i]]) + token_ids.append(torch.full((1, audio_lengths[i]), + IGNORE_ID, + dtype=torch.int64, + device=device).squeeze(0)) + input_embeds.append(suffix_tokens_embeds[i, :suffix_tokens_lengths[i]]) + token_ids.append(suffix_tokens[i, :suffix_tokens_lengths[i]]) + input_embeds = torch.cat(input_embeds, dim=0) + token_ids = torch.cat(token_ids, dim=0) + input_embeds_list.append(input_embeds) + token_ids_list.append(token_ids) + + min_prompt_len = min(p.shape[0] for p in token_ids_list) + max_prompt_len = max(p.shape[0] for p in token_ids_list) + max_seq_len = max_prompt_len + output_len + assert max_seq_len <= self.decoder.pos_enc.max_len + + # build KV caches + kv_caches = [] + for _ in range(len(self.decoder.decoders)): + size = (b, 0, self.decoder.n_kv_head, + self.decoder.head_dim) + k_cache = torch.zeros(size=size, dtype=dtype, device=device) + v_cache = torch.zeros(size=size, dtype=dtype, device=device) + kv_caches.append((k_cache, v_cache)) + + # prepare inputs + token_ids_tensor = torch.full((b, max_seq_len), + IGNORE_ID, + dtype=torch.int64, + device=device) + input_embeds_tensor = torch.zeros((b, min_prompt_len, c), + dtype=dtype, + device=device) + # right padding + for i, (embeds, tokens) in enumerate(zip(input_embeds_list, token_ids_list)): + token_ids_tensor[i, :len(tokens)] = tokens + input_embeds_tensor[i, :min_prompt_len] = embeds[:min_prompt_len] + + prompt_mask_tensor = ~make_pad_mask(audio_lengths + prefix_tokens_lengths + suffix_tokens_lengths, + max_len=max_seq_len) + input_positions_tensor = torch.arange(0, + min_prompt_len, + dtype=torch.int64).to(device) + mask_tensor = torch.ones((1, 1, max_seq_len, max_seq_len), + dtype=torch.bool) + mask_tensor = torch.tril(mask_tensor).to(device) + curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor) + att_mask = curr_mask_tensor.squeeze( + 1)[:, :min_prompt_len, :min_prompt_len] + output_positions_tensor = torch.LongTensor([min_prompt_len - 1 + ]).to(device) + temperatures_tensor = None if not temperature else torch.FloatTensor( + [temperature] * b).to(device) + top_ps_tensor = torch.FloatTensor([top_p] * b).to(device) + top_ks_tensor = torch.LongTensor([top_k] * b).to(device) + output_index = torch.tensor(min_prompt_len, + dtype=torch.int64).to(device) + + offset = torch.tensor([0] * b).to(device) + input_offset = offset + + stop_tokens_tensor = torch.tensor(stop_tokens, device=device) + # Prefill up to min_prompt_len tokens, then treat other prefill as + # decode and ignore output. + for i in range(max_seq_len - min_prompt_len): + decoder_out, kv_caches, = self.decoder( + input_embeds_tensor, + att_mask, + input_offset, + kv_caches, + ) + decoder_out = self.out(decoder_out) + decoder_out = decoder_out.index_select(1, output_positions_tensor) + next_token_ids = sampler( + decoder_out, + temperatures_tensor, + top_ps_tensor, + top_ks_tensor, + ) + curr_prompt_mask = prompt_mask_tensor.index_select( + 1, output_index).squeeze(dim=1) + curr_token_ids = token_ids_tensor.index_select( + 1, output_index).squeeze(dim=1) + output_token_ids = torch.where(curr_prompt_mask, curr_token_ids, + next_token_ids).unsqueeze(dim=1) + token_ids_tensor.index_copy_(1, output_index, output_token_ids) + + input_token_ids_tensor = output_token_ids + input_embeds_tensor = self.embed(input_token_ids_tensor) + + input_positions_tensor = output_index.unsqueeze(dim=-1) + curr_mask_tensor = mask_tensor.index_select( + 2, input_positions_tensor) + att_mask = curr_mask_tensor.squeeze(1)[:, :output_index + + 1, :output_index + 1] + + output_positions_tensor = torch.tensor( + 0, dtype=torch.int64).to(device) + input_offset = offset + output_index.unsqueeze(-1) + output_index = output_index + 1 + + if all(torch.isin(next_token_ids, stop_tokens_tensor)): + break + + token_ids = token_ids_tensor.tolist() + results = [] + for i, tokens in enumerate(token_ids): + trimmed_output = tokens[len(token_ids_list[i] + ):len(token_ids_list[i]) + output_len] + for stop_token in stop_tokens: + try: + eos_index = trimmed_output.index(stop_token) + trimmed_output = trimmed_output[:eos_index] + break + except Exception: + continue + results.append(trimmed_output) + + return results \ No newline at end of file diff --git a/wenet/AudioLLM/bottleneck.py b/wenet/AudioLLM/bottleneck.py new file mode 100644 index 0000000000..0c2f7ee5c9 --- /dev/null +++ b/wenet/AudioLLM/bottleneck.py @@ -0,0 +1,76 @@ +import torch +import torch.nn as nn +from typing import List + +class Conv1dSubsampler(nn.Module): + """Convolutional subsampler: a stack of 1D convolution (along temporal + dimension) followed by non-linear activation via gated linear units + (https://arxiv.org/abs/1911.08460) + Args: + in_channels (int): the number of input channels + mid_channels (int): the number of intermediate channels + out_channels (int): the number of output channels + kernel_sizes (List[int]): the kernel size for each convolutional layer + """ + + def __init__( + self, + in_dim: int, + mid_dim: int, + out_dim: int, + kernel_sizes: List[int] = (3, 3), + ): + super(Conv1dSubsampler, self).__init__() + self.n_layers = len(kernel_sizes) + self.conv_layers = nn.ModuleList( + nn.Conv1d( + in_dim if i == 0 else mid_dim // 2, + mid_dim if i < self.n_layers - 1 else out_dim * 2, + k, + stride=2, + padding=k // 2, + ) + for i, k in enumerate(kernel_sizes) + ) + + def get_out_seq_lens_tensor(self, in_seq_lens_tensor): + out = in_seq_lens_tensor.clone() + for _ in range(self.n_layers): + out = ((out.float() - 1) / 2 + 1).floor().long() + return out.squeeze(-1) + + def forward(self, src_tokens, src_lengths): + bsz, in_seq_len, _ = src_tokens.size() # B x T x (C x D) + x = src_tokens.transpose(1, 2).contiguous() # -> B x (C x D) x T + for conv in self.conv_layers: + x = conv(x) + x = nn.functional.glu(x, dim=1) + _, _, out_seq_len = x.size() + x = x.transpose(1, 2).contiguous() # -> B x T x (C x D) + return x, self.get_out_seq_lens_tensor(src_lengths) + +class ConvLinearBottleNeck(nn.Module): + def __init__( + self, + encoder_dim: int, + decoder_dim: int, + bottleneck_mid_dim: int, + conv_kernel_sizes: List[int] = (3, 3) + ): + super(ConvLinearBottleNeck, self).__init__() + + self.subsampling = Conv1dSubsampler(encoder_dim, 2 * encoder_dim, decoder_dim, conv_kernel_sizes) + + self.activation = nn.GELU() + self.fc1 = nn.Linear(decoder_dim, bottleneck_mid_dim, bias=False) + self.fc2 = nn.Linear(bottleneck_mid_dim, decoder_dim, bias=False) + + self.speech_ln = torch.nn.LayerNorm(decoder_dim) + + def forward(self, x, x_lengths): + x, out_lengths = self.subsampling(x, x_lengths) + residual = x + x = self.fc1(x) + x = self.activation(x) + x = self.fc2(x) + return self.speech_ln(residual + x), out_lengths \ No newline at end of file diff --git a/wenet/AudioLLM/template.py b/wenet/AudioLLM/template.py new file mode 100644 index 0000000000..ac302e1880 --- /dev/null +++ b/wenet/AudioLLM/template.py @@ -0,0 +1,43 @@ +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class Template: + # one turn :{system_format}{user_text_format}{user_audio_format}{assistant_format} + # multi turns: + # {system_format}{user_format}{assistant_format}{user_format}{assistant_format}... + system: Optional[str] + + prefix_user: str + suffix_user: str + assistant: str + + bos: str + eos: str + + +audio_gemma = Template( + '', + 'user\n{content}\n', + '\nmodel\n', + '{content}\n', + '', + '', +) + +audio_llama3 = Template( + '<|start_header_id|>system<|end_header_id|>\n\n{content}<|eot_id|>', + '<|start_header_id|>user<|end_header_id|>\n\n{content}\n\n', + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n', + '{content}<|eot_id|>', + '<|begin_of_text|>', + '<|end_of_text|>', +) + + + +WENET_LLM_Template = { + "audio_gemma": audio_gemma, + 'audio_llama3': audio_llama3, +} diff --git a/wenet/LLM/causal_model.py b/wenet/LLM/causal_model.py new file mode 100644 index 0000000000..b9192ce15d --- /dev/null +++ b/wenet/LLM/causal_model.py @@ -0,0 +1,208 @@ +from typing import Dict, List, Optional, Union +import torch +from wenet.LLM.decoder import DecoderOnly +from wenet.LLM.sampler import sampler +from wenet.transformer.label_smoothing_loss import LabelSmoothingLoss +from wenet.utils.common import IGNORE_ID, th_accuracy +from wenet.utils.mask import make_pad_mask, subsequent_mask + + +class CausalLM(torch.nn.Module): + + def __init__( + self, + vocab_size: int, + decoder: DecoderOnly, + special_tokens: dict, + tie_word_embedding: bool = False, + linear_bias: bool = False, + ignore_id: int = IGNORE_ID, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + ) -> None: + super().__init__() + del special_tokens + + self.embed = torch.nn.Embedding(vocab_size, decoder.hidden_size) + self.out = torch.nn.Linear(decoder.hidden_size, + vocab_size, + bias=linear_bias) + + self.decoder = decoder + self.vocab_size = vocab_size + self.criterion_att = LabelSmoothingLoss( + size=vocab_size, + padding_idx=ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + self.tie_word_embedding = tie_word_embedding + self.ignore_id = ignore_id + + @torch.jit.unused + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + """ Forward for training + """ + text = batch['feats'].to(device) + target = batch['target'].to(device) + text_length = batch['feats_lengths'].to(device) + + mask = ~make_pad_mask(text_length, max_len=text.size(1)).unsqueeze( + 1) # (B,1,L) + causal_mask = subsequent_mask( + mask.size(-1), device=mask.device).unsqueeze(0) # (1,L,L) + att_mask = causal_mask & mask # (B, L, L) + + embeding = self.embed(text) + decoder_out = self.out(self.decoder(embeding, + att_mask)[0]) # (B, L, vocab_size) + loss = self.criterion_att(decoder_out, target) + acc = th_accuracy(decoder_out.view(-1, self.vocab_size), + target, + ignore_label=self.ignore_id) + + return { + "loss": loss, + "ppl": torch.exp(loss.detach()), + "th_accuracy": acc + } + + def tie_or_clone_weights(self, jit_mode: bool): + if not self.tie_word_embedding: + return + if jit_mode: + self.out.weight = torch.nn.Parameter(self.embed.weight.clone()) + else: + self.out.weight = self.embed.weight + # TODO(Mddct): whether to deal bias for other llm model + + @torch.jit.unused + @torch.inference_mode() + def generate( + self, + prompts_tokens: List[List[int]], + device: torch.device, + stop_tokens: List[int], + dtype: torch.dtype = torch.float32, + output_len: int = 100, + temperature: Union[float, None] = 0.95, + top_p: float = 1.0, + top_k: int = 100, + ) -> List[List[int]]: + """Generates responses for given prompts using Gemma model.""" + # If a single prompt is provided, treat it as a batch of 1. + batch_size = len(prompts_tokens) + min_prompt_len = min(len(p) for p in prompts_tokens) + max_prompt_len = max(len(p) for p in prompts_tokens) + max_seq_len = max_prompt_len + output_len + assert max_seq_len <= self.decoder.pos_enc.max_len + + # build KV caches + kv_caches = [] + for _ in range(len(self.decoder.decoders)): + size = (batch_size, 0, self.decoder.n_kv_head, + self.decoder.head_dim) + k_cache = torch.zeros(size=size, dtype=dtype, device=device) + v_cache = torch.zeros(size=size, dtype=dtype, device=device) + kv_caches.append((k_cache, v_cache)) + + # prepare inputs + token_ids_tensor = torch.full((batch_size, max_seq_len), + IGNORE_ID, + dtype=torch.int64, + device=device) + input_token_ids_tensor = torch.full((batch_size, min_prompt_len), + IGNORE_ID, + dtype=torch.int64, + device=device) + # right padding + for i, p in enumerate(prompts_tokens): + token_ids_tensor[i, :len(p)] = torch.tensor(p) + input_token_ids_tensor[i, :min_prompt_len] = torch.tensor( + p[:min_prompt_len]) + + prompt_mask_tensor = token_ids_tensor != IGNORE_ID + input_positions_tensor = torch.arange(0, + min_prompt_len, + dtype=torch.int64).to(device) + mask_tensor = torch.ones((1, 1, max_seq_len, max_seq_len), + dtype=torch.bool) + mask_tensor = torch.tril(mask_tensor).to(device) + curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor) + att_mask = curr_mask_tensor.squeeze( + 1)[:, :min_prompt_len, :min_prompt_len] + output_positions_tensor = torch.LongTensor([min_prompt_len - 1 + ]).to(device) + temperatures_tensor = None if not temperature else torch.FloatTensor( + [temperature] * batch_size).to(device) + top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device) + top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device) + output_index = torch.tensor(min_prompt_len, + dtype=torch.int64).to(device) + + input_token_embeding = self.embed(input_token_ids_tensor) + offset = torch.tensor([0] * len(prompts_tokens)).to(device) + input_offset = offset + + stop_tokens_tensor = torch.tensor(stop_tokens, device=device) + # Prefill up to min_prompt_len tokens, then treat other prefill as + # decode and ignore output. + for i in range(max_seq_len - min_prompt_len): + decoder_out, kv_caches, = self.decoder( + input_token_embeding, + att_mask, + input_offset, + kv_caches, + ) + decoder_out = self.out(decoder_out) + decoder_out = decoder_out.index_select(1, output_positions_tensor) + next_token_ids = sampler( + decoder_out, + temperatures_tensor, + top_ps_tensor, + top_ks_tensor, + ) + curr_prompt_mask = prompt_mask_tensor.index_select( + 1, output_index).squeeze(dim=1) + curr_token_ids = token_ids_tensor.index_select( + 1, output_index).squeeze(dim=1) + output_token_ids = torch.where(curr_prompt_mask, curr_token_ids, + next_token_ids).unsqueeze(dim=1) + token_ids_tensor.index_copy_(1, output_index, output_token_ids) + + input_token_ids_tensor = output_token_ids + input_token_embeding = self.embed(input_token_ids_tensor) + + input_positions_tensor = output_index.unsqueeze(dim=-1) + curr_mask_tensor = mask_tensor.index_select( + 2, input_positions_tensor) + att_mask = curr_mask_tensor.squeeze(1)[:, :output_index + + 1, :output_index + 1] + + output_positions_tensor = torch.tensor( + 0, dtype=torch.int64).to(device) + input_offset = offset + output_index.unsqueeze(-1) + output_index = output_index + 1 + + if all(torch.isin(next_token_ids, stop_tokens_tensor)): + break + + token_ids = token_ids_tensor.tolist() + results = [] + for i, tokens in enumerate(token_ids): + trimmed_output = tokens[len(prompts_tokens[i] + ):len(prompts_tokens[i]) + output_len] + for stop_token in stop_tokens: + try: + eos_index = trimmed_output.index(stop_token) + trimmed_output = trimmed_output[:eos_index] + break + except Exception: + continue + results.append(trimmed_output) + + return results diff --git a/wenet/LLM/decoder.py b/wenet/LLM/decoder.py new file mode 100644 index 0000000000..e3331ad944 --- /dev/null +++ b/wenet/LLM/decoder.py @@ -0,0 +1,161 @@ +from functools import partial +from typing import List, Optional, Tuple, Union +import torch +import torch.utils.checkpoint as ckpt +from wenet.transformer.attention import T_CACHE + +from wenet.transformer.encoder_layer import TransformerEncoderLayer +from wenet.utils.class_utils import (WENET_ACTIVATION_CLASSES, + WENET_ATTENTION_CLASSES, + WENET_EMB_CLASSES, WENET_MLP_CLASSES, + WENET_NORM_CLASSES) +from wenet.utils.common import mask_to_bias + + +class DecoderOnly(torch.nn.Module): + + def __init__( + self, + n_kv_head: int, + head_dim: int, + hidden_size: int, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + normalize_before: bool = True, + query_bias: bool = False, + key_bias: bool = False, + value_bias: bool = False, + mlp_bias: bool = False, + activation_type: str = "gelu", + gelu_approximate: Union[str, None] = None, + max_position_embeding: int = 8192, + mlp_type: str = 'gated', + layer_norm_type: str = 'rms_norm', + norm_eps: float = 1e-5, + rms_norm_offset: bool = True, + selfattention_layer_type: str = "rope_abs_selfattn", + use_sdpa: bool = False, + gradient_checkpointing: bool = False, + rope_theta: float = 10000.0, + rope_style: str = 'google', + scale_embed: bool = True, + ) -> None: + super().__init__() + + assert selfattention_layer_type in ['rope_abs_selfattn'] + self.pos_enc = WENET_EMB_CLASSES["rope_pos"]( + hidden_size, + head_dim, + max_len=max_position_embeding, + dropout_rate=positional_dropout_rate, + rope_theta=rope_theta, + scale=scale_embed) + if activation_type == "gelu" and gelu_approximate is not None: + activation = WENET_ACTIVATION_CLASSES['gelu']( + approximate=gelu_approximate) + else: + activation = WENET_ACTIVATION_CLASSES[activation_type]() + + mlp_class = WENET_MLP_CLASSES[mlp_type] + self.num_blocks = num_blocks + # TODO: support lora & refactor lora + self.decoders = torch.nn.ModuleList([ + TransformerEncoderLayer( + hidden_size, + WENET_ATTENTION_CLASSES[selfattention_layer_type]( + attention_heads, + hidden_size, + attention_dropout_rate, + query_bias, + key_bias, + value_bias, + use_sdpa, + n_kv_head, + head_dim, + style=rope_style), + mlp_class(hidden_size, linear_units, dropout_rate, activation, + mlp_bias), + dropout_rate, + normalize_before, + layer_norm_type=layer_norm_type, + norm_eps=norm_eps, + rms_norm_offset=rms_norm_offset, + ) for _ in range(self.num_blocks) + ]) + self.pre_norm = normalize_before + self.final_norm: Optional[torch.nn.Module] = None + if self.pre_norm: + norm_class = WENET_NORM_CLASSES[layer_norm_type] + if layer_norm_type == "rms_norm": + norm_class = partial( + norm_class, + add_unit_offset=rms_norm_offset, + ) + self.final_norm = norm_class(hidden_size, eps=norm_eps) + + self.n_kv_head = n_kv_head + self.head_dim = head_dim + self._hidden_size = hidden_size + self.use_sdpa = use_sdpa + self.gradient_checkpointing = gradient_checkpointing + + def forward( + self, + input: torch.Tensor, + att_mask: torch.Tensor, + input_position: Union[int, torch.Tensor] = 0, + kv_caches: Optional[List[T_CACHE]] = None, + ) -> Tuple[torch.Tensor, Union[List[T_CACHE], None]]: + xs, pos_emb = self.pos_enc(input, offset=input_position) + if self.use_sdpa: + att_mask = mask_to_bias(att_mask, xs.dtype) + + if self.gradient_checkpointing and self.training: + xs = self.forward_layers_checkpointed(xs, att_mask, pos_emb) + else: + xs, kv_caches = self.forward_layers(xs, att_mask, pos_emb, + kv_caches) + if self.pre_norm and self.final_norm is not None: + xs = self.final_norm(xs) + return xs, kv_caches + + def forward_layers( + self, + xs: torch.Tensor, + att_mask: torch.Tensor, + pos_emb: torch.Tensor, + kv_caches: Optional[List[T_CACHE]] = None, + ) -> Tuple[torch.Tensor, Union[List[T_CACHE], None]]: + if self.training or kv_caches is None: + for (i, layer) in enumerate(self.decoders): + xs, _, _, _ = layer(xs, att_mask, pos_emb) + new_kv_caches = kv_caches + else: + assert kv_caches is not None + new_kv_caches = [] + for (i, layer) in enumerate(self.decoders): + xs, _, new_kv_cache, _ = layer(xs, + att_mask, + pos_emb, + att_cache=(kv_caches[i][0], + kv_caches[i][1])) + new_kv_caches.append(new_kv_cache) + + return xs, new_kv_caches + + @torch.jit.ignore(drop=True) + def forward_layers_checkpointed(self, xs: torch.Tensor, + att_mask: torch.Tensor, + pos_emb: torch.Tensor) -> torch.Tensor: + for layer in self.decoders: + xs, _, _, _ = ckpt.checkpoint(layer.__call__, xs, att_mask, + pos_emb) + return xs + + @property + def hidden_size(self): + return self._hidden_size diff --git a/wenet/LLM/sampler.py b/wenet/LLM/sampler.py new file mode 100644 index 0000000000..19f0d5cdaf --- /dev/null +++ b/wenet/LLM/sampler.py @@ -0,0 +1,43 @@ +from typing import Union +import torch + + +# modified from https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L26 +@torch.no_grad() +def sampler( + logits: torch.Tensor, + temperatures: Union[torch.Tensor, None], + top_ps: torch.Tensor, + top_ks: torch.Tensor, +) -> torch.Tensor: + assert logits.size(1) == 1 + logits = logits.squeeze(1) # (batch_size, vocab_size) + if temperatures is None: + return torch.argmax(logits, dim=-1).squeeze(dim=-1) + + # Apply temperature scaling. + logits.div_(temperatures.unsqueeze(dim=1)) + + # Calculate probabilities with softmax. + probs = torch.softmax(logits, dim=-1, dtype=torch.float) + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + + # Apply top-p, top-k. + probs_sum = torch.cumsum(probs_sort, dim=-1) + top_ps_mask = (probs_sum - probs_sort) > top_ps.unsqueeze(dim=1) + probs_sort = torch.where(top_ps_mask, 0, probs_sort) + + top_ks_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) + top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1) + top_ks_mask = top_ks_mask >= top_ks.unsqueeze(dim=1) + probs_sort = torch.where(top_ks_mask, 0, probs_sort) + + # Re-normalization. + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + probs = torch.gather(probs_sort, + dim=-1, + index=torch.argsort(probs_idx, dim=-1)) + + next_token_ids = torch.multinomial(probs, num_samples=1, + replacement=True).squeeze(dim=-1) + return next_token_ids diff --git a/wenet/LLM/script/config.py b/wenet/LLM/script/config.py new file mode 100644 index 0000000000..c37959af45 --- /dev/null +++ b/wenet/LLM/script/config.py @@ -0,0 +1,205 @@ +import dataclasses +from typing import Dict, Optional, Union + +import yaml + + +# https://github.com/google/gemma_pytorch/blob/main/gemma/config.py#L32 +@dataclasses.dataclass +class Config: + vocab_size: int = 256000 + # The maximum sequence length that this model might ever be used with. + max_position_embeddings: int = 8192 + # The number of blocks in the model. + num_hidden_layers: int = 28 + # The number of attention heads used in the attention layers of the model. + num_attention_heads: int = 16 + # The number of key-value heads for implementing attention. + num_key_value_heads: int = 16 + # The hidden size of the model. + hidden_size: int = 3072 + # The dimension of the MLP representations. + intermediate_size: int = 24576 + # The number of head dimensions. + head_dim: int = 256 + # The epsilon used by the rms normalization layers. + rms_norm_eps: float = 1e-6 + # tope theta + rope_theta: float = 500000.0 + # rope style: google or llama + rope_style: str = 'google' + # rms_norm offset + rms_norm_offset: bool = True + # activation type + activation_type: str = 'gelu' + # gelu approximate + gelu_approximate: Union[str, None] = None + # The dtype of the weights. + dtype: str = 'bfloat16' + + # scale embed + scale_embed: bool = True + + def to_wenet_config(self) -> Dict: + configs = {} + configs['max_position_embeding'] = self.max_position_embeddings + configs['num_blocks'] = self.num_hidden_layers + configs['attention_heads'] = self.num_attention_heads + configs['n_kv_head'] = self.num_key_value_heads + configs['head_dim'] = self.head_dim + configs['hidden_size'] = self.hidden_size + configs['linear_units'] = self.intermediate_size + configs['norm_eps'] = self.rms_norm_eps + configs['rope_theta'] = self.rope_theta + configs['activation_type'] = self.activation_type + configs['gelu_approximate'] = self.gelu_approximate + configs['rope_style'] = self.rope_style + configs['rms_norm_offset'] = self.rms_norm_offset + configs['scale_embed'] = self.scale_embed + return configs + + +def wenet_llm_tokenizer_conf(config: Config, tokenizer_path: str, + special_tokens: Dict) -> Dict: + configs = {} + configs['tokenizer'] = 'huggingface' + configs['tokenizer_conf'] = {} + configs['tokenizer_conf']['model'] = tokenizer_path + configs['tokenizer_conf']['special_tokens'] = special_tokens + return configs + + +def wenet_llm_dataset_and_train_conf(config: Config, + template: str = 'gemma') -> Dict: + configs = {} + configs['dataset'] = "llm" + configs['dataset_conf'] = {} + configs['dataset_conf']['filter_conf'] = {} + configs['dataset_conf']['filter_conf'][ + 'token_max_length'] = config.max_position_embeddings + configs['dataset_conf']['filter_conf']['token_min_length'] = 1 + configs['dataset_conf']['shuffle'] = True + configs['dataset_conf']['shuffle_conf'] = {} + configs['dataset_conf']['shuffle_conf']['shuffle_size'] = 1500 + configs['dataset_conf']['shuffle_list'] = True + configs['dataset_conf']['shuffle_list_conf'] = {} + configs['dataset_conf']['shuffle_list_conf']['shuffle_size'] = 15000 + configs['dataset_conf']['sort'] = True + configs['dataset_conf']['sort_conf'] = {} + configs['dataset_conf']['sort_conf']['sort_size'] = 500 + configs['dataset_conf']['batch_conf'] = {} + configs['dataset_conf']['batch_conf']['batch_type'] = 'dynamic' + configs['dataset_conf']['batch_conf']['max_frames_in_batch'] = 12000 + + configs['dataset_conf']['data_style'] = 'sft' + configs['dataset_conf']['data_style_conf'] = {} + configs['dataset_conf']['data_style_conf']['add_bos'] = True + configs['dataset_conf']['data_style_conf']['add_eos'] = True + configs['dataset_conf']['data_style_conf']['template'] = template + configs['dataset_conf']['shift'] = True + + configs['grad_clip'] = 5 + configs['accum_grad'] = 4 + configs['max_epoch'] = 100 + configs['log_interval'] = 100 + configs['save_interval'] = 3000 + + configs['optim'] = "adam" + configs['optim_conf'] = {} + configs['optim_conf']['lr'] = 0.0005 + configs['scheduler'] = "warmuplr" + configs['scheduler_conf'] = {} + configs['scheduler_conf']['warmup_steps'] = 12000 + return configs + + +def wenet_decoderonly_conf(config: Config): + configs = {} + configs['decoder'] = 'decoder_only' + configs['decoder_conf'] = config.to_wenet_config() + configs['decoder_conf']['dropout_rate'] = 0.0 + configs['decoder_conf']['attention_dropout_rate'] = 0.0 + configs['decoder_conf']['positional_dropout_rate'] = 0.0 + configs['decoder_conf']['gradient_checkpointing'] = True + configs['decoder_conf']['normalize_before'] = True + configs['decoder_conf']['use_sdpa'] = True + return configs + + +def wenet_model_conf(config: Config, tie_word_embedding: bool = True): + configs = {} + configs['output_dim'] = config.vocab_size + configs['model'] = "causal_lm" + configs['model_conf'] = {} + configs['model_conf']['linear_bias'] = False + configs['model_conf']['tie_word_embedding'] = tie_word_embedding + configs['model_conf']['lsm_weight'] = 0.1 + configs['model_conf']['length_normalized_loss'] = False + return configs + + +def convert_to_wenet_yaml(config: Config, + wenet_yaml_path: str, + tokenizer_path, + template: str = 'gemma', + tie_word_embedding: bool = True, + special_tokens: Optional[Dict] = None): + configs = {} + configs.update( + wenet_llm_tokenizer_conf(config, tokenizer_path, special_tokens)) + configs.update(wenet_decoderonly_conf(config)) + configs.update( + wenet_model_conf(config, tie_word_embedding=tie_word_embedding)) + configs.update(wenet_llm_dataset_and_train_conf(config, template=template)) + + with open(wenet_yaml_path, '+w') as f: + f.write(yaml.dump(configs)) + f.flush() + + print(configs) + + +def gemma_config_for_7b() -> Config: + return Config(rope_theta=10000.0, gelu_approximate='tanh') + + +def gemma_config_for_2b() -> Config: + return Config(num_hidden_layers=18, + num_attention_heads=8, + num_key_value_heads=1, + hidden_size=2048, + intermediate_size=16384, + rope_theta=10000.0, + gelu_approximate='tanh') + + +def llama3_config_for_8b() -> Config: + return Config(vocab_size=128256, + num_hidden_layers=32, + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, + head_dim=128, + intermediate_size=14336, + rms_norm_eps=1e-5, + rope_theta=500000.0, + activation_type='swish', + rms_norm_offset=False, + rope_style='llama', + scale_embed=False) + + +def llama3_config_for_70b() -> Config: + return Config(vocab_size=128256, + num_hidden_layers=80, + hidden_size=8192, + head_dim=128, + num_attention_heads=64, + num_key_value_heads=8, + intermediate_size=28672, + rms_norm_eps=1e-5, + rope_theta=500000.0, + activation_type='swish', + rms_norm_offset=False, + rope_style='llama', + scale_embed=False) diff --git a/wenet/LLM/script/convert_main.py b/wenet/LLM/script/convert_main.py new file mode 100644 index 0000000000..30bca315af --- /dev/null +++ b/wenet/LLM/script/convert_main.py @@ -0,0 +1,86 @@ +import argparse + +import os + +import torch + +from wenet.LLM.script.config import (convert_to_wenet_yaml, + gemma_config_for_2b, gemma_config_for_7b, + llama3_config_for_70b, + llama3_config_for_8b) +from wenet.LLM.script.gemma_config import (convert_to_wenet_state_dict as + gemma_convert_ckpt_fn, + gemma_special_tokens) +from wenet.LLM.script.llama3_config import (convert_to_wenet_state_dict as + llama3_convert_ckpt_fn, + llama3_special_tokens) + + +def get_args(): + parser = argparse.ArgumentParser(description='load and convert llm ckpt') + parser.add_argument('--ckpt', + required=True, + help='llama3: https://llama.meta.com/llama-downloads/ \ + \ngemma: https://www.kaggle.com/models/google/gemma/frameworks/pyTorch' + ) + parser.add_argument('--model_size', type=str, required=True) + parser.add_argument('--model_name', type=str, required=True) + parser.add_argument('--output_dir', + default='.', + help='output file in wenet\'s style') + args = parser.parse_args() + return args + + +MODEL = { + "gemma": { + "2b": (gemma_config_for_2b(), 'google/gemma-2b'), + "7b": (gemma_config_for_7b(), 'google/gemma-7b'), + "ckpt_fn": gemma_convert_ckpt_fn, + 'tie_word_embeding': True, + 'special_tokens_fn': gemma_special_tokens, + }, + "llama3": { + "8b": (llama3_config_for_8b(), 'meta-llama/Meta-Llama-3-8B'), + "70b": (llama3_config_for_70b(), 'meta-llama/Meta-Llama-3-70B'), + "ckpt_fn": llama3_convert_ckpt_fn, + 'tie_word_embeding': False, + 'special_tokens_fn': llama3_special_tokens, + }, +} + + +def main(): + args = get_args() + args.jit = False + model_size = args.model_size + model_name = args.model_name + assert model_name in MODEL.keys() + all(model_size in size.keys() for size in MODEL.values()) + config = MODEL[model_name][model_size][0] + args.tokenizer = MODEL[model_name][model_size][1] + + os.makedirs(args.output_dir, exist_ok=True) + + checkpoint = torch.load(args.ckpt, map_location="cpu") + if model_name == 'gemma': + checkpoint = checkpoint["model_state_dict"] + wenet_ckpt_path = os.path.join(args.output_dir, + 'wenet_' + os.path.basename(args.ckpt)) + wenet_ckpt_path = os.path.splitext(wenet_ckpt_path)[0] + ".pt" + convert_fn = MODEL[model_name]['ckpt_fn'] + convert_fn(checkpoint, wenet_ckpt_path, config) + + wenet_yaml_path = os.path.join(args.output_dir, 'train.yaml') + convert_to_wenet_yaml( + config, + wenet_yaml_path, + args.tokenizer, + template=model_name, + tie_word_embedding=MODEL[model_name]['tie_word_embeding'], + special_tokens=MODEL[model_name]['special_tokens_fn'](args.tokenizer, + config)) + + +if __name__ == '__main__': + main() diff --git a/wenet/LLM/script/gemma_config.py b/wenet/LLM/script/gemma_config.py new file mode 100644 index 0000000000..0ff7362747 --- /dev/null +++ b/wenet/LLM/script/gemma_config.py @@ -0,0 +1,83 @@ +import torch + +from wenet.LLM.script.config import Config +from wenet.text.hugging_face_tokenizer import HuggingFaceTokenizer + + +def convert_to_wenet_state_dict(gemma_state_dict, wenet_state_dict_path, + config: Config): + + print("==============start CKPT Conversion =========================") + wenet_state_dict = {} + for name in gemma_state_dict.keys(): + old_name = name + # embed + name = name.replace('embedder.weight', 'embed.weight') + + # layers to decoders + name = name.replace('model.layers', 'decoder.decoders') + + if 'self_attn.qkv_proj' in name: + # att weight + i_layer = name.split('.')[2] + layer_prefix = 'decoder.decoders.' + i_layer + linear_q_name = layer_prefix + '.self_attn.linear_q.weight' + linear_k_name = layer_prefix + '.self_attn.linear_k.weight' + linear_v_name = layer_prefix + '.self_attn.linear_v.weight' + + start = 0 + offset = config.num_attention_heads * config.head_dim + linear_q_value = gemma_state_dict[old_name][start:offset, :] + start = offset + offset = offset + config.head_dim * config.num_key_value_heads + linear_k_value = gemma_state_dict[old_name][start:offset, :] + start = offset + linear_v_value = gemma_state_dict[old_name][start:, :] + wenet_state_dict[linear_q_name] = linear_q_value + wenet_state_dict[linear_k_name] = linear_k_value + wenet_state_dict[linear_v_name] = linear_v_value + elif name == 'freqs_cis': + # rope position embeding + name = 'decoder.pos_enc.pe' + pe = torch.view_as_real(gemma_state_dict[old_name].unsqueeze(0)) + wenet_state_dict[name] = pe + else: + # att out dim + name = name.replace('self_attn.o_proj', 'self_attn.linear_out') + + # mlp + name = name.replace('mlp.gate_proj', 'feed_forward.gate') + name = name.replace('mlp.up_proj', 'feed_forward.w_1') + name = name.replace('mlp.down_proj', 'feed_forward.w_2') + + # pre ln (rms norm) + name = name.replace('input_layernorm', 'norm1') + # before mlp ln: (rms norm) + name = name.replace('post_attention_layernorm', 'norm2') + # final norm + name = name.replace('model.norm.weight', + 'decoder.final_norm.weight') + + wenet_state_dict[name] = gemma_state_dict[old_name] + # NOTE(Mddct): tie weight + wenet_state_dict['out.weight'] = wenet_state_dict['embed.weight'] + print("Saving {} ckpt to {}...".format(config.dtype, + wenet_state_dict_path)) + torch.save(wenet_state_dict, wenet_state_dict_path) + print( + "DONE\n===================- End CKPT Conversion ====================\n" + ) + + +def gemma_special_tokens(tokenizer_path, config: Config): + tokenizer = HuggingFaceTokenizer(tokenizer_path) + assert config.vocab_size == tokenizer.vocab_size() + special_tokens = {} + bos = tokenizer.tokens2ids([""])[0] + eos = tokenizer.tokens2ids([""])[0] + unk = tokenizer.tokens2ids([""])[0] + special_tokens = {} + special_tokens[''] = bos + special_tokens[''] = eos + special_tokens[''] = unk + return special_tokens diff --git a/wenet/LLM/script/llama3_config.py b/wenet/LLM/script/llama3_config.py new file mode 100644 index 0000000000..074f14a147 --- /dev/null +++ b/wenet/LLM/script/llama3_config.py @@ -0,0 +1,74 @@ +from typing import Dict +import torch +from wenet.LLM.script.config import Config + +from wenet.text.hugging_face_tokenizer import HuggingFaceTokenizer + + +def llama3_special_tokens(tokenizer_path, config: Config) -> Dict: + tokenizer = HuggingFaceTokenizer(tokenizer_path) + assert config.vocab_size == tokenizer.vocab_size() + # "<|reserved_special_token_0|>", + # "<|reserved_special_token_1|>", + # "<|reserved_special_token_2|>", + # "<|reserved_special_token_3|>", + shi = tokenizer.tokens2ids(["<|start_header_id|>"])[0] + ehi = tokenizer.tokens2ids(["<|end_header_id|>"])[0] + bos = tokenizer.tokens2ids(["<|begin_of_text|>"])[0] + eos = tokenizer.tokens2ids(["<|end_of_text|>"])[0] + eoti = tokenizer.tokens2ids(["<|eot_id|>"])[0] + special_tokens = {} + special_tokens['<|begin_of_text|>'] = bos + special_tokens['<|end_of_text|>'] = eos + special_tokens['<|eot_id|>'] = eoti + special_tokens['<|start_header_id|>'] = shi + special_tokens['<|end_header_id|>'] = ehi + return special_tokens + + +def convert_to_wenet_state_dict(Llama3_state_dict, wenet_state_dict_path, + config: Config): + + wenet_state_dict = {} + + print("==============start CKPT Conversion =========================") + conformer_state_dict = Llama3_state_dict + wenet_state_dict = {} + for name in conformer_state_dict.keys(): + old_name = name + # embed + name = name.replace('tok_embeddings.weight', 'embed.weight') + # output + name = name.replace('output.weight', 'out.weight') + # layers to decoders + name = name.replace('layers', 'decoder.decoders') + if 'attention' in name: + # pre ln (rms norm) + name = name.replace('attention_norm.weight', 'norm1.weight') + # att weight + name = name.replace('.attention.wq.weight', + '.self_attn.linear_q.weight') + name = name.replace('.attention.wk.weight', + '.self_attn.linear_k.weight') + name = name.replace('.attention.wv.weight', + '.self_attn.linear_v.weight') + # att out dim + name = name.replace('attention.wo', 'self_attn.linear_out') + else: + # mlp + name = name.replace('feed_forward.w1', 'feed_forward.gate') + name = name.replace('feed_forward.w3', 'feed_forward.w_1') + name = name.replace('feed_forward.w2', 'feed_forward.w_2') + + # before mlp ln: (rms norm) + name = name.replace('ffn_norm', 'norm2') + wenet_state_dict[name] = conformer_state_dict[old_name] + # final norm weight + wenet_state_dict['decoder.final_norm.weight'] = conformer_state_dict[ + 'norm.weight'] + print("Saving {} ckpt to {}...".format(config.dtype, + wenet_state_dict_path)) + torch.save(wenet_state_dict, wenet_state_dict_path) + print( + "DONE\n===================- End CKPT Conversion ====================\n" + ) diff --git a/wenet/LLM/template.py b/wenet/LLM/template.py new file mode 100644 index 0000000000..b0e0a95418 --- /dev/null +++ b/wenet/LLM/template.py @@ -0,0 +1,36 @@ +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class Template: + # one turn :{system_format}{user_format}{assistant_format} + # multi turns: + # {system_format}{user_format}{assistant_format}{user_format}{assistant_format}... + system: Optional[str] + user: str + assistant: str + + bos: str + eos: str + + +gemma = Template( + '', + 'user\n{content}\nmodel\n', + '{content}\n', + '', + '', +) + +llama3 = Template( + '<|start_header_id|>system<|end_header_id|>\n\n{content}<|eot_id|>', + '<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n', + '{content}<|eot_id|>', + '<|begin_of_text|>', + '<|end_of_text|>', +) +WENET_LLM_Template = { + "gemma": gemma, + 'llama3': llama3, +} diff --git a/wenet/bin/audiollm_recognize.py b/wenet/bin/audiollm_recognize.py new file mode 100644 index 0000000000..f0880c07ab --- /dev/null +++ b/wenet/bin/audiollm_recognize.py @@ -0,0 +1,178 @@ +# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen, Di Wu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import copy +import logging +import os + +import torch +import yaml +from torch.utils.data import DataLoader + +from wenet.dataset.audiollm_dataset import Dataset +from wenet.utils.config import override_config +from wenet.utils.init_model import init_model +from wenet.utils.init_tokenizer import init_tokenizer + +def get_args(): + parser = argparse.ArgumentParser(description='recognize with your model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--test_data', required=True, help='test data file') + parser.add_argument('--data_type', + default='raw', + choices=['raw', 'shard'], + help='train and cv data type') + parser.add_argument('--gpu', + type=int, + default=-1, + help='gpu id for this rank, -1 for cpu') + parser.add_argument('--dtype', + type=str, + default='fp32', + choices=['fp16', 'fp32', 'bf16'], + help='model\'s dtype') + parser.add_argument('--num_workers', + default=0, + type=int, + help='num of subprocess workers for reading') + parser.add_argument('--checkpoint', required=True, help='checkpoint model') + parser.add_argument('--beam_size', + type=int, + default=10, + help='beam size for search') + parser.add_argument('--output_len', + type=int, + default=100, + help='output length') + parser.add_argument('--result_dir', required=True, help='asr result file') + parser.add_argument('--batch_size', + type=int, + default=16, + help='batch size') + parser.add_argument('--override_config', + action='append', + default=[], + help="override yaml config") + parser.add_argument('--temperature', + default=0.95, + type=float, + help='temperature') + parser.add_argument('--top_p', + default=1.0, + type=float, + help='top_p') + parser.add_argument('--top_k', + default=100, + type=int, + help='top_k') + parser.add_argument('--use_lora', + type=bool, + default=False, + help='''Whether to use lora for biasing''') + args = parser.parse_args() + print(args) + return args + + +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) + + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + if len(args.override_config) > 0: + configs = override_config(configs, args.override_config) + + test_conf = copy.deepcopy(configs['dataset_conf']) + + test_conf['filter_conf']['token_max_length'] = 102400 + test_conf['filter_conf']['token_min_length'] = 0 + test_conf['speed_perturb'] = False + test_conf['spec_aug'] = False + test_conf['spec_sub'] = False + test_conf['spec_trim'] = False + test_conf['shuffle'] = False + test_conf['sort'] = False + test_conf['cycle'] = 1 + test_conf['list_shuffle'] = False + if 'fbank_conf' in test_conf: + test_conf['fbank_conf']['dither'] = 0.0 + elif 'mfcc_conf' in test_conf: + test_conf['mfcc_conf']['dither'] = 0.0 + test_conf['batch_conf']['batch_type'] = "static" + test_conf['batch_conf']['batch_size'] = args.batch_size + + tokenizer = init_tokenizer(configs) + test_dataset = Dataset(args.data_type, + args.test_data, + tokenizer, + test_conf, + partition=False, + train=False) + + test_data_loader = DataLoader(test_dataset, + batch_size=None, + num_workers=args.num_workers) + + # Init asr model from configs + args.jit = False + model, configs = init_model(args, configs) + + use_cuda = args.gpu >= 0 and torch.cuda.is_available() + device = torch.device('cuda' if use_cuda else 'cpu') + model = model.to(device) + model.eval() + dtype = torch.float32 + if args.dtype == 'fp16': + dtype = torch.float16 + elif args.dtype == 'bf16': + dtype = torch.bfloat16 + logging.info("compute dtype is {}".format(dtype)) + + dir_name = os.path.join(args.result_dir, + f"temp{args.temperature}_topk{args.top_k}_topp{args.top_p}") + os.makedirs(dir_name, exist_ok=True) + file_name = os.path.join(dir_name, 'text') + file = open(file_name, 'w') + + stop_tokens = tokenizer.tokens2ids(["<|eot_id|>", "<|end_of_text|>"]) + + with torch.cuda.amp.autocast(enabled=True, + dtype=dtype, + cache_enabled=False): + with torch.no_grad(): + for batch_idx, batch in enumerate(test_data_loader): + keys = batch["keys"] + results = model.generate( + batch, + device, + stop_tokens, + dtype, + args.output_len, + args.temperature, + args.top_p, + args.top_k) + for i, (key, tokens) in enumerate(zip(keys, results)): + line = '{} {}'.format(key, tokenizer.detokenize(tokens)[0]) + logging.info('{}'.format(line)) + file.write(line + '\n') + file.close() + +if __name__ == '__main__': + main() diff --git a/wenet/dataset/audiollm_dataset.py b/wenet/dataset/audiollm_dataset.py new file mode 100644 index 0000000000..6bf047b3fe --- /dev/null +++ b/wenet/dataset/audiollm_dataset.py @@ -0,0 +1,159 @@ +from functools import partial +import sys +from wenet.AudioLLM.template import WENET_LLM_Template +from wenet.dataset.datapipes import (WenetRawDatasetSource) +from wenet.dataset import (processor, audiollm_processor) +from wenet.text.base_tokenizer import BaseTokenizer +from wenet.utils.file_utils import read_symbol_table +from wenet.text.hugging_face_tokenizer import HuggingFaceTokenizer + + +def Dataset(data_type, + data_list_file, + tokenizer: BaseTokenizer, + conf=None, + partition=True, + train=True): + """ Construct dataset from arguments + + We have two shuffle stage in the Dataset. The first is global + shuffle at shards tar/raw file level. The second is global shuffle + at training samples level. + + Args: + data_type(str): raw/shard + tokenizer (BaseTokenizer or None): tokenizer to tokenize + partition(bool): whether to do data partition in terms of rank + """ + assert conf is not None + assert data_type in ['raw', 'shard'] + # cycle dataset + cycle = conf.get('cycle', 1) + # stage1 shuffle: source + list_shuffle = conf.get('list_shuffle', True) + list_shuffle_size = sys.maxsize + if list_shuffle: + list_shuffle_conf = conf.get('list_shuffle_conf', {}) + list_shuffle_size = list_shuffle_conf.get('shuffle_size', + list_shuffle_size) + if data_type == 'raw': + dataset = WenetRawDatasetSource(data_list_file, + partition=partition, + shuffle=list_shuffle, + shuffle_size=list_shuffle_size, + cycle=cycle) + dataset = dataset.map(processor.parse_json) + else: + raise NotImplementedError('only support jsonl for now') + + dataset = dataset.map_ignore_error(processor.decode_wav) + + singal_channel_conf = conf.get('singal_channel_conf', {}) + dataset = dataset.map( + partial(processor.singal_channel, **singal_channel_conf)) + + + speaker_conf = conf.get('speaker_conf', None) + if speaker_conf is not None: + assert 'speaker_table_path' in speaker_conf + speaker_table = read_symbol_table(speaker_conf['speaker_table_path']) + dataset = dataset.map( + partial(processor.parse_speaker, speaker_dict=speaker_table)) + if tokenizer is not None: + dataset = dataset.map(partial(processor.tokenize, tokenizer=tokenizer)) + audio_filter_conf = conf.get('audio_filter_conf', {}) + dataset = dataset.filter(partial(processor.filter, **audio_filter_conf)) + + resample_conf = conf.get('resample_conf', {}) + dataset = dataset.map(partial(processor.resample, **resample_conf)) + + speed_perturb = conf.get('speed_perturb', False) + if speed_perturb: + dataset = dataset.map(partial(processor.speed_perturb)) + + feats_type = conf.get('feats_type', 'fbank') + assert feats_type in ['fbank', 'mfcc', 'log_mel_spectrogram'] + if feats_type == 'fbank': + fbank_conf = conf.get('fbank_conf', {}) + dataset = dataset.map(partial(processor.compute_fbank, **fbank_conf)) + elif feats_type == 'mfcc': + mfcc_conf = conf.get('mfcc_conf', {}) + dataset = dataset.map(partial(processor.compute_mfcc, **mfcc_conf)) + elif feats_type == 'log_mel_spectrogram': + log_mel_spectrogram_conf = conf.get('log_mel_spectrogram_conf', {}) + dataset = dataset.map( + partial(processor.compute_log_mel_spectrogram, + **log_mel_spectrogram_conf)) + spec_aug = conf.get('spec_aug', True) + spec_sub = conf.get('spec_sub', False) + spec_trim = conf.get('spec_trim', False) + if spec_aug: + spec_aug_conf = conf.get('spec_aug_conf', {}) + dataset = dataset.map(partial(processor.spec_aug, **spec_aug_conf)) + if spec_sub: + spec_sub_conf = conf.get('spec_sub_conf', {}) + dataset = dataset.map(partial(processor.spec_sub, **spec_sub_conf)) + if spec_trim: + spec_trim_conf = conf.get('spec_trim_conf', {}) + dataset = dataset.map(partial(processor.spec_trim, **spec_trim_conf)) + + # TODO: DPO etc + assert isinstance(tokenizer, HuggingFaceTokenizer) + style_conf = conf.get('data_style_conf', {}) + template = WENET_LLM_Template[style_conf.get('template', 'audio_llama3')] + + dataset = dataset.map( + partial( + audiollm_processor.parse_audiosft, + tokenizer=tokenizer, + template=template, + train=train, + add_bos=style_conf.get('add_bos', True), + add_eos=style_conf.get('add_eos', True), + )) + + shuffle = conf.get('shuffle', True) + if shuffle: + shuffle_conf = conf.get('shuffle_conf', {}) + dataset = dataset.shuffle(buffer_size=shuffle_conf['shuffle_size']) + + sort = conf.get('sort', True) + if sort: + sort_conf = conf.get('sort_conf', {}) + dataset = dataset.sort(buffer_size=sort_conf['sort_size'], + key_func=audiollm_processor.sort_by_input) + shift = conf.get('shift', True) + if shift: + dataset = dataset.map(audiollm_processor.shift) + + filter_conf = conf.get('filter_conf', {}) + dataset = dataset.filter(partial(audiollm_processor.filter, **filter_conf)) + + batch_conf = conf.get('batch_conf', {}) + batch_type = batch_conf.get('batch_type', 'static') + assert batch_type in ['static', 'bucket', 'dynamic'] + if batch_type == 'static': + assert 'batch_size' in batch_conf + batch_size = batch_conf.get('batch_size', 16) + dataset = dataset.batch( + batch_size, + wrapper_class=audiollm_processor.padding, + ) + elif batch_type == 'bucket': + assert 'bucket_boundaries' in batch_conf + assert 'bucket_batch_sizes' in batch_conf + dataset = dataset.bucket_by_sequence_length( + audiollm_processor.input_length_fn, + batch_conf['bucket_boundaries'], + batch_conf['bucket_batch_sizes'], + wrapper_class=audiollm_processor.padding, + ) + else: + max_tokens_in_batch = batch_conf.get('max_tokens_in_batch', 50000) + dataset = dataset.dynamic_batch( + processor.DynamicBatchWindow(max_tokens_in_batch), + wrapper_class=audiollm_processor.padding, + elem_size_fn=audiollm_processor.input_length_fn, + ) + + return dataset diff --git a/wenet/dataset/audiollm_processor.py b/wenet/dataset/audiollm_processor.py new file mode 100644 index 0000000000..1f7b36005d --- /dev/null +++ b/wenet/dataset/audiollm_processor.py @@ -0,0 +1,182 @@ +from typing import Dict, List + +import torch +from torch.nn.utils.rnn import pad_sequence +from wenet.AudioLLM.template import Template +from wenet.text.hugging_face_tokenizer import HuggingFaceTokenizer +from wenet.utils.common import IGNORE_ID + +def parse_audiosft(sample, + tokenizer: HuggingFaceTokenizer, + template: Template, + train: bool = True, + add_bos: bool = True, + add_eos: bool = True): + """Paser sft json line to tensor + + Args: + sample: + { + 'system': 'you are a helpful ...', + 'text': '...', + 'wav': '...', + } + + Returns: + {input_ids, output_ids} + """ + chat_pattern = template + prefix_input_ids = [] + prefix_output_ids = [] + suffix_input_ids = [] + suffix_output_ids = [] + system_text = sample.get('system', 'you are a helpful ASR agent') + if chat_pattern.system is not None: + system_text = chat_pattern.system.format(content=system_text) + if add_bos: + system_text = template.bos + system_text + _, system_text_ids = tokenizer.tokenize(system_text) + prefix_input_ids += system_text_ids + prefix_output_ids += [IGNORE_ID] * len(system_text_ids) + + human_text = sample.get('human_text', 'please help me transcribe this audio in english') + human_text = chat_pattern.prefix_user.format(content=human_text) + _, human_ids = tokenizer.tokenize(human_text) + prefix_input_ids += human_ids + prefix_output_ids += [IGNORE_ID] * len(human_ids) + + _, suffix_ids = tokenizer.tokenize(chat_pattern.suffix_user) + suffix_input_ids += suffix_ids + suffix_output_ids += [IGNORE_ID] * len(suffix_ids) + if train: + assistant = sample['txt'] + assistant = chat_pattern.assistant.format(content=assistant) + _, assistant_ids = tokenizer.tokenize(assistant) + suffix_input_ids += assistant_ids + suffix_output_ids += assistant_ids + + if add_eos: + eos_id = tokenizer.tokens2ids([template.eos]) + suffix_input_ids += eos_id + suffix_output_ids += eos_id + + assert len(prefix_input_ids) == len(prefix_output_ids) + assert len(suffix_input_ids) == len(suffix_output_ids) + + sample['prefix_input_ids'] = torch.tensor(prefix_input_ids) + sample['prefix_output_ids'] = torch.tensor(prefix_output_ids) + sample['suffix_input_ids'] = torch.tensor(suffix_input_ids) + sample['suffix_output_ids'] = torch.tensor(suffix_output_ids) + return sample + + +def shift(sample): + prefix_output_ids = sample['prefix_output_ids'] + suffix_input_ids = sample['suffix_input_ids'] + + sample['prefix_output_ids'] = prefix_output_ids[1:] + sample['suffix_input_ids'] = suffix_input_ids[:-1] + return sample + + +def filter(sample, token_max_length: int = 8190, token_min_length=1): + """ Filter sample according to feature and label length + Inplace operation. + + Args:: + sample: {input_ids, output_ids} + token_max_length: drop utterance which is greater than + token_max_length, especially when use char unit for + english modeling + token_min_length: drop utterance which is + less than token_max_length + + Returns: + bool: True to keep, False to filter + """ + total_lens = sample['prefix_input_ids'].size(0) + \ + sample['feat'].size(0) + \ + sample['suffix_input_ids'].size(0) + if total_lens < token_min_length: + return False + if total_lens > token_max_length: + return False + return True + + +def sort_by_input(sample): + total_lens = sample['prefix_input_ids'].size(0) + \ + sample['feat'].size(0) + \ + sample['suffix_input_ids'].size(0) + return total_lens + +def input_length_fn(sample) -> int: + total_lens = sample['prefix_input_ids'].size(0) + \ + sample['feat'].size(0) + \ + sample['suffix_input_ids'].size(0) + return total_lens + +def padding(data: List[Dict]): + """ Padding the data into training data + + Args: + data: List[{input_ids, output_ids} + + Returns: + Tuple(feats, labels) + """ + sample = data + + total_lens = torch.tensor([x['prefix_input_ids'].size(0) + + x['feat'].size(0) + + x['suffix_input_ids'].size(0) + for x in sample],dtype=torch.int32) + + order = torch.argsort(total_lens, descending=True) + sorted_keys = [sample[i]['key'] for i in order] + prefix_tokens_lengths = torch.tensor( + [sample[i]['prefix_input_ids'].size(0) for i in order], dtype=torch.int32) + sorted_prefix_tokens = [sample[i]['prefix_input_ids'] for i in order] + sorted_prefix_labels = [sample[i]['prefix_output_ids'] for i in order] + padded_prefix_tokens = pad_sequence(sorted_prefix_tokens, + batch_first=True, + padding_value=0) + padding_prefix_labels = pad_sequence(sorted_prefix_labels, + batch_first=True, + padding_value=IGNORE_ID) + audio_feats_lengths = torch.tensor( + [sample[i]['feat'].size(0) for i in order], dtype=torch.int32) + sorted_audio_feats = [sample[i]['feat'] for i in order] + padded_audio_feats = pad_sequence(sorted_audio_feats, + batch_first=True, + padding_value=0) + suffix_tokens_lengths = torch.tensor( + [sample[i]['suffix_input_ids'].size(0) for i in order], dtype=torch.int32) + sorted_suffix_tokens = [sample[i]['suffix_input_ids'] for i in order] + sorted_suffix_labels = [sample[i]['suffix_output_ids'] for i in order] + padded_suffix_tokens = pad_sequence(sorted_suffix_tokens, + batch_first=True, + padding_value=0) + padding_suffix_labels = pad_sequence(sorted_suffix_labels, + batch_first=True, + padding_value=IGNORE_ID) + sorted_labels = [ + torch.tensor(sample[i]['label'], dtype=torch.int64) for i in order + ] + padding_labels = pad_sequence(sorted_labels, + batch_first=True, + padding_value=-1) + batch = { + 'keys': sorted_keys, + 'prefix_tokens': padded_prefix_tokens, + 'audio_feats': padded_audio_feats, + 'suffix_tokens': padded_suffix_tokens, + "prefix_target": padding_prefix_labels, + "suffix_target": padding_suffix_labels, + "prefix_tokens_lengths": prefix_tokens_lengths, + "audio_feats_lengths": audio_feats_lengths, + "suffix_tokens_lengths": suffix_tokens_lengths, + "target_lengths": prefix_tokens_lengths + audio_feats_lengths + suffix_tokens_lengths, + "label": padding_labels, + } + return batch diff --git a/wenet/dataset/datapipes.py b/wenet/dataset/datapipes.py index 6d89ab5522..0d2fdca38e 100644 --- a/wenet/dataset/datapipes.py +++ b/wenet/dataset/datapipes.py @@ -184,21 +184,24 @@ def __iter__(self): @functional_datapipe("dynamic_batch") class DynamicBatchDataPipe(IterDataPipe): - def __init__(self, dataset: IterDataPipe, window_class, - wrapper_class) -> None: + def __init__(self, dataset: IterDataPipe, window_class, wrapper_class, + elem_size_fn) -> None: _check_unpickable_fn(window_class) _check_unpickable_fn(wrapper_class) + _check_unpickable_fn(elem_size_fn) super().__init__() self.dp = dataset assert window_class is not None assert wrapper_class is not None + self.elem_size_fn = elem_size_fn self.window_class = window_class self._buffer = [] self._wrappr_class = wrapper_class def __iter__(self): for elem in self.dp: - if not self.window_class(elem, len(self._buffer)): + if not self.window_class(self.elem_size_fn(elem), len( + self._buffer)): self._buffer.append(elem) else: if len(self._buffer) > 0: diff --git a/wenet/dataset/dataset.py b/wenet/dataset/dataset.py index 95a3eafa97..d88ae8ad0c 100644 --- a/wenet/dataset/dataset.py +++ b/wenet/dataset/dataset.py @@ -150,6 +150,7 @@ def Dataset(data_type, dataset = dataset.dynamic_batch( processor.DynamicBatchWindow(max_frames_in_batch), wrapper_class=processor.padding, + elem_size_fn=processor.feats_length_fn, ) return dataset diff --git a/wenet/dataset/llm_dataset.py b/wenet/dataset/llm_dataset.py new file mode 100644 index 0000000000..dd5f323d59 --- /dev/null +++ b/wenet/dataset/llm_dataset.py @@ -0,0 +1,116 @@ +from functools import partial +import sys +from wenet.LLM.template import WENET_LLM_Template +from wenet.dataset.datapipes import (WenetRawDatasetSource) +from wenet.dataset import (processor, llm_processor) +from wenet.text.base_tokenizer import BaseTokenizer +from wenet.text.hugging_face_tokenizer import HuggingFaceTokenizer + + +def Dataset(data_type, + data_list_file, + tokenizer: BaseTokenizer, + conf=None, + partition=True): + """ Construct dataset from arguments + + We have two shuffle stage in the Dataset. The first is global + shuffle at shards tar/raw file level. The second is global shuffle + at training samples level. + + Args: + data_type(str): raw/shard + tokenizer (BaseTokenizer or None): tokenizer to tokenize + partition(bool): whether to do data partition in terms of rank + """ + assert conf is not None + assert data_type in ['raw', 'shard'] + # cycle dataset + cycle = conf.get('cycle', 1) + # stage1 shuffle: source + list_shuffle = conf.get('list_shuffle', True) + list_shuffle_size = sys.maxsize + if list_shuffle: + list_shuffle_conf = conf.get('list_shuffle_conf', {}) + list_shuffle_size = list_shuffle_conf.get('shuffle_size', + list_shuffle_size) + if data_type == 'raw': + dataset = WenetRawDatasetSource(data_list_file, + partition=partition, + shuffle=list_shuffle, + shuffle_size=list_shuffle_size, + cycle=cycle) + dataset = dataset.map(processor.parse_json) + + else: + raise NotImplementedError('only support jsonl for now') + + # TODO: DPO etc + data_style = conf.get('style', 'sft') + assert data_style in ['pretrain', 'sft'] + assert isinstance(tokenizer, HuggingFaceTokenizer) + style_conf = conf.get('data_style_conf', {}) + template = WENET_LLM_Template[style_conf.get('template', 'gemma')] + if data_style == 'sft': + dataset = dataset.map( + partial( + llm_processor.parse_sft, + tokenizer=tokenizer, + template=template, + add_bos=style_conf.get('add_bos', True), + add_eos=style_conf.get('add_eos', True), + )) + else: + dataset = dataset.map( + partial( + llm_processor.parse_pretrain, + tokenizer=tokenizer, + template=template, + add_bos=style_conf.get('add_bos', True), + add_eos=style_conf.get('add_eos', True), + )) + shuffle = conf.get('shuffle', True) + if shuffle: + shuffle_conf = conf.get('shuffle_conf', {}) + dataset = dataset.shuffle(buffer_size=shuffle_conf['shuffle_size']) + + sort = conf.get('sort', True) + if sort: + sort_conf = conf.get('sort_conf', {}) + dataset = dataset.sort(buffer_size=sort_conf['sort_size'], + key_func=llm_processor.sort_by_input) + shift = conf.get('shift', True) + if shift: + dataset = dataset.map(llm_processor.shift) + + filter_conf = conf.get('filter_conf', {}) + dataset = dataset.filter(partial(llm_processor.filter, **filter_conf)) + + batch_conf = conf.get('batch_conf', {}) + batch_type = batch_conf.get('batch_type', 'static') + assert batch_type in ['static', 'bucket', 'dynamic'] + if batch_type == 'static': + assert 'batch_size' in batch_conf + batch_size = batch_conf.get('batch_size', 16) + dataset = dataset.batch( + batch_size, + wrapper_class=llm_processor.padding, + ) + elif batch_type == 'bucket': + assert 'bucket_boundaries' in batch_conf + assert 'bucket_batch_sizes' in batch_conf + dataset = dataset.bucket_by_sequence_length( + llm_processor.input_length_fn, + batch_conf['bucket_boundaries'], + batch_conf['bucket_batch_sizes'], + wrapper_class=llm_processor.padding, + ) + else: + max_tokens_in_batch = batch_conf.get('max_tokens_in_batch', 50000) + dataset = dataset.dynamic_batch( + processor.DynamicBatchWindow(max_tokens_in_batch), + wrapper_class=llm_processor.padding, + elem_size_fn=llm_processor.input_length_fn, + ) + + return dataset diff --git a/wenet/dataset/llm_processor.py b/wenet/dataset/llm_processor.py new file mode 100644 index 0000000000..f88bdcdc5e --- /dev/null +++ b/wenet/dataset/llm_processor.py @@ -0,0 +1,173 @@ +from typing import Dict, List + +import torch +from torch.nn.utils.rnn import pad_sequence +from wenet.LLM.template import Template +from wenet.text.hugging_face_tokenizer import HuggingFaceTokenizer +from wenet.utils.common import IGNORE_ID + + +def parse_sft( + sample, + tokenizer: HuggingFaceTokenizer, + template: Template, + add_bos: bool = True, + add_eos: bool = True, +): + """Paser sft json line to tensor + + Args: + sample: + { + 'system': 'you are a helpful ...', + "conversation": [{ + 'human': '...', + 'assistant': '...' + }] + } + + Returns: + {input_ids, output_ids} + """ + chat_pattern = template + input_ids = [] + output_ids = [] + system_text = sample.get('system', '') + if chat_pattern.system is not None: + system_text = chat_pattern.system.format(content=system_text) + if add_bos: + system_text = template.bos + system_text + _, system_text_ids = tokenizer.tokenize(system_text) + input_ids += system_text_ids + output_ids += [IGNORE_ID] * len(system_text_ids) + conversations = sample['conversation'] + assert isinstance(conversations, List) + for conversation in conversations: + human = conversation['human'] + human = chat_pattern.user.format(content=human) + _, human_ids = tokenizer.tokenize(human) + input_ids += human_ids + output_ids += [IGNORE_ID] * len(human_ids) + if 'assistant' in conversation: + assistant = conversation['assistant'] + assistant = chat_pattern.assistant.format(content=assistant) + _, assistant_ids = tokenizer.tokenize(assistant) + input_ids += assistant_ids + output_ids += assistant_ids + + if add_eos: + eos_id = tokenizer.tokens2ids([template.eos]) + input_ids += eos_id + output_ids += eos_id + + assert len(input_ids) == len(output_ids) + return { + 'input_ids': torch.tensor(input_ids), + 'output_ids': torch.tensor(output_ids), + } + + +def parse_pretrain(sample, + tokenizer: HuggingFaceTokenizer, + template: Template, + add_bos: bool = True, + add_eos: bool = False): + """ Parse text from json line + + Args: + sample: str, str is a json line has txt + + Returns: + {input_ids, output_ids} + """ + assert 'text' in sample + text = sample['text'] + _, input_ids = tokenizer.tokenize(text) + if add_bos: + input_ids = [template.bos] + input_ids + if add_eos: + input_ids = input_ids + [template.eos] + + return { + 'input_ids': torch.tensor(input_ids), + 'output_ids': torch.tensor(input_ids), + } + + +def shift(sample): + input_ids = sample['input_ids'] + output_ids = sample['output_ids'] + + sample['input_ids'] = input_ids[:-1] + sample['output_ids'] = output_ids[1:] + return sample + + +def filter(sample, token_max_length: int = 8190, token_min_length=1): + """ Filter sample according to feature and label length + Inplace operation. + + Args:: + sample: {input_ids, output_ids} + token_max_length: drop utterance which is greater than + token_max_length, especially when use char unit for + english modeling + token_min_length: drop utterance which is + less than token_max_length + + Returns: + bool: True to keep, False to filter + """ + assert 'input_ids' in sample + assert 'output_ids' in sample + assert isinstance(sample['input_ids'], torch.Tensor) + assert isinstance(sample['output_ids'], torch.Tensor) + if sample['input_ids'].size(0) < token_min_length: + return False + if sample['input_ids'].size(0) > token_max_length: + return False + return True + + +def sort_by_input(sample): + assert 'input_ids' in sample + assert isinstance(sample['input_ids'], torch.Tensor) + return sample['input_ids'].size(0) + + +def input_length_fn(sample) -> int: + assert 'input_ids' in sample + return sample['input_ids'].size(0) + + +def padding(data: List[Dict]): + """ Padding the data into training data + + Args: + data: List[{input_ids, output_ids} + + Returns: + Tuple(feats, labels) + """ + sample = data + feats_length = torch.tensor([x['input_ids'].size(0) for x in sample], + dtype=torch.int32) + order = torch.argsort(feats_length, descending=True) + feats_lengths = torch.tensor( + [sample[i]['input_ids'].size(0) for i in order], dtype=torch.int32) + sorted_feats = [sample[i]['input_ids'] for i in order] + sorted_labels = [sample[i]['output_ids'] for i in order] + padded_feats = pad_sequence(sorted_feats, + batch_first=True, + padding_value=0) + padding_labels = pad_sequence(sorted_labels, + batch_first=True, + padding_value=IGNORE_ID) + + batch = { + 'feats': padded_feats, + "target": padding_labels, + "feats_lengths": feats_lengths, + "target_lengths": feats_lengths, + } + return batch diff --git a/wenet/dataset/processor.py b/wenet/dataset/processor.py index 5131a13ac6..77ddd4bd7f 100644 --- a/wenet/dataset/processor.py +++ b/wenet/dataset/processor.py @@ -570,11 +570,8 @@ def __init__(self, max_frames_in_batch=12000): self.longest_frames = 0 self.max_frames_in_batch = max_frames_in_batch - def __call__(self, sample, buffer_size): - assert isinstance(sample, dict) - assert 'feat' in sample - assert isinstance(sample['feat'], torch.Tensor) - new_sample_frames = sample['feat'].size(0) + def __call__(self, elem_size, buffer_size): + new_sample_frames = elem_size self.longest_frames = max(self.longest_frames, new_sample_frames) frames_after_padding = self.longest_frames * (buffer_size + 1) if frames_after_padding > self.max_frames_in_batch: diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py index 4099947b8f..fb5e6c6920 100644 --- a/wenet/transformer/asr_model.py +++ b/wenet/transformer/asr_model.py @@ -245,6 +245,9 @@ def ctc_logprobs(self, return ctc_probs + def tie_or_clone_weights(self, jit_mode: bool = True): + self.decoder.tie_or_clone_weights(jit_mode) + def decode( self, methods: List[str], diff --git a/wenet/transformer/attention.py b/wenet/transformer/attention.py index c9d8f07b4b..75633bf309 100644 --- a/wenet/transformer/attention.py +++ b/wenet/transformer/attention.py @@ -21,7 +21,7 @@ import torch from torch import nn -from wenet.utils.rope_utils import llama_apply_rotary_emb +from wenet.utils.rope_utils import WENET_APPLY_ROTARY_EMB T_CACHE = Tuple[torch.Tensor, torch.Tensor] @@ -80,7 +80,10 @@ def __init__(self, self.use_sdpa = use_sdpa self.dropout_rate = dropout_rate - def _forward_linearx(self, name: str, x: torch.Tensor) -> torch.Tensor: + def _forward_linearx(self, + name: str, + x: torch.Tensor, + head_first: bool = True) -> torch.Tensor: assert x.ndim >= 3 if name == 'query': x = self.linear_q(x) @@ -98,7 +101,9 @@ def _forward_linearx(self, name: str, x: torch.Tensor) -> torch.Tensor: # split last dim x = x.view(x_shape) - x = x.transpose(-3, -2) # (batch, ..., head or head_kv, time, d_k) + if head_first: + x = x.transpose(-3, + -2) # (batch, ..., head or head_kv, time, d_k) return x def forward_qkv( @@ -172,9 +177,15 @@ def forward_attention( return self.linear_out(x) # (batch, ..., time1, d_model) def _update_kv_and_cache( - self, k: torch.Tensor, v: torch.Tensor, - cache: T_CACHE) -> Tuple[torch.Tensor, torch.Tensor, T_CACHE]: + self, + k: torch.Tensor, + v: torch.Tensor, + cache: T_CACHE, + head_first: bool = True + ) -> Tuple[torch.Tensor, torch.Tensor, T_CACHE]: new_cache = cache + seq_axis = -2 if head_first else -3 + head_axis = -3 if head_first else -2 if not self.training: # NOTE(xcsong): # when export onnx model, for 1st chunk, we feed @@ -194,9 +205,9 @@ def _update_kv_and_cache( # >>> torch.equal(d[0], d[1]) # True key_cache, value_cache = cache if key_cache.size(0) > 0: - k = torch.cat([key_cache, k], dim=2) + k = torch.cat([key_cache, k], dim=seq_axis) if value_cache.size(0) > 0: - v = torch.cat([value_cache, v], dim=2) + v = torch.cat([value_cache, v], dim=seq_axis) # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's # non-trivial to calculate `next_cache_start` here. # new_cache = torch.cat((k, v), dim=-1) if not self.training else cache @@ -206,12 +217,12 @@ def _update_kv_and_cache( k = torch.repeat_interleave( k, self.h // self.h_kv, - dim=-3, + dim=head_axis, ) v = torch.repeat_interleave( v, self.h // self.h_kv, - dim=-3, + dim=-head_axis, ) return k, v, new_cache @@ -578,9 +589,11 @@ def __init__(self, value_bias: bool = True, use_sdpa: bool = False, n_kv_head: Optional[int] = None, - head_dim: Optional[int] = None): + head_dim: Optional[int] = None, + style='google'): super().__init__(n_head, n_feat, dropout_rate, query_bias, key_bias, value_bias, use_sdpa, n_kv_head, head_dim) + self.style = style def forward( self, @@ -621,14 +634,21 @@ def forward( and `head * d_k == size` """ - q, k, v = self.forward_qkv(query, key, value) + q = self._forward_linearx('query', query, head_first=False) + k = self._forward_linearx('key', key, head_first=False) + v = self._forward_linearx('value', value, head_first=False) # NOTE(Mddct): In order to make the code easier to read, # these two lines are not placed in MultiHeadedAttention. - q = llama_apply_rotary_emb(q, pos_emb) - k = llama_apply_rotary_emb(k, pos_emb) - # see above - k, v, new_cache = self._update_kv_and_cache(k, v, cache) - + q = WENET_APPLY_ROTARY_EMB[self.style](q, pos_emb) + k = WENET_APPLY_ROTARY_EMB[self.style](k, pos_emb) + + k, v, new_cache = self._update_kv_and_cache(k, + v, + cache, + head_first=False) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) if not self.use_sdpa: scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) return self.forward_attention(v, scores, mask), new_cache diff --git a/wenet/transformer/decoder.py b/wenet/transformer/decoder.py index ba31edffc7..0c4fab62af 100644 --- a/wenet/transformer/decoder.py +++ b/wenet/transformer/decoder.py @@ -286,6 +286,8 @@ def tie_or_clone_weights(self, jit_mode: bool = True): rank = int(os.environ.get('RANK', 0)) if not self.use_output_layer: return + if not self.tie_word_embedding: + return if jit_mode: if rank == 0: logging.info("clone emb.weight to output.weight") diff --git a/wenet/transformer/embedding.py b/wenet/transformer/embedding.py index db8a41333b..a884b8d53d 100644 --- a/wenet/transformer/embedding.py +++ b/wenet/transformer/embedding.py @@ -205,13 +205,15 @@ def __init__(self, head_dim: int, dropout_rate: float, max_len: int = 1500, - rope_theta=10000.0): + rope_theta=10000.0, + scale: bool = True): super().__init__(d_model, dropout_rate=dropout_rate, max_len=max_len) delattr(self, 'pe') self.max_len = max_len * 2 pe = precompute_freqs_cis(head_dim, self.max_len, rope_theta) - self.register_buffer("pe", torch.view_as_real(pe.unsqueeze(0))) + self.register_buffer("pe", pe.unsqueeze(0)) self.dropout_rate = dropout_rate + self.scale = scale def forward( self, @@ -220,10 +222,10 @@ def forward( torch.Tensor] = 0) -> Tuple[torch.Tensor, torch.Tensor]: pos_emb = self.position_encoding(offset, x.size(1), True) - pos_emb = pos_emb.unsqueeze(1) # [1, 1, seq, head_dim//2] + pos_emb = pos_emb.unsqueeze(2) # [1, 1, seq, head_dim//2] # NOTE(Mddct): some model don't scale - # TODO(Mddct): fix - x = x * self.xscale + if self.scale: + x = x * self.xscale return self.dropout(x), pos_emb def position_encoding(self, @@ -231,7 +233,7 @@ def position_encoding(self, size: int, apply_dropout: bool = True) -> torch.Tensor: - pe = torch.view_as_complex(self.pe) + pe = self.pe if isinstance(offset, int): assert offset + size <= self.max_len pos_emb = pe[:, offset:offset + size] diff --git a/wenet/transformer/encoder_layer.py b/wenet/transformer/encoder_layer.py index bf2cdb2c2c..068d3f99b6 100644 --- a/wenet/transformer/encoder_layer.py +++ b/wenet/transformer/encoder_layer.py @@ -15,6 +15,7 @@ # Modified from ESPnet(https://github.com/espnet/espnet) """Encoder self-attention layer definition.""" +from functools import partial from typing import Optional, Tuple import torch @@ -49,14 +50,21 @@ def __init__( normalize_before: bool = True, layer_norm_type: str = 'layer_norm', norm_eps: float = 1e-5, + rms_norm_offset: bool = True, ): """Construct an EncoderLayer object.""" super().__init__() self.self_attn = self_attn self.feed_forward = feed_forward assert layer_norm_type in ['layer_norm', 'rms_norm'] - self.norm1 = WENET_NORM_CLASSES[layer_norm_type](size, eps=norm_eps) - self.norm2 = WENET_NORM_CLASSES[layer_norm_type](size, eps=norm_eps) + norm_class = WENET_NORM_CLASSES[layer_norm_type] + if layer_norm_type == "rms_norm": + norm_class = partial( + norm_class, + add_unit_offset=rms_norm_offset, + ) + self.norm1 = norm_class(size, eps=norm_eps) + self.norm2 = norm_class(size, eps=norm_eps) self.dropout = nn.Dropout(dropout_rate) self.size = size self.normalize_before = normalize_before diff --git a/wenet/transformer/norm.py b/wenet/transformer/norm.py index 2c3756f13f..8039228630 100644 --- a/wenet/transformer/norm.py +++ b/wenet/transformer/norm.py @@ -9,14 +9,19 @@ def __init__( self, dim: int, eps: float = 1e-6, + add_unit_offset: bool = True, ): super().__init__() self.eps = eps self.weight = torch.nn.Parameter(torch.ones(dim)) + self.add_unit_offset = add_unit_offset def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): x = self._norm(x.float()).type_as(x) - return x * self.weight + if self.add_unit_offset: + return x * (1 + self.weight) + else: + return x * self.weight diff --git a/wenet/utils/fsdp_utils.py b/wenet/utils/fsdp_utils.py index 33871f6f0c..77ca195953 100644 --- a/wenet/utils/fsdp_utils.py +++ b/wenet/utils/fsdp_utils.py @@ -5,6 +5,7 @@ from torch.distributed.fsdp.wrap import (lambda_auto_wrap_policy, transformer_auto_wrap_policy) +from wenet.LLM.decoder import DecoderOnly from wenet.branchformer.encoder_layer import BranchformerEncoderLayer from wenet.e_branchformer.encoder_layer import EBranchformerEncoderLayer from wenet.efficient_conformer.encoder_layer import StrideConformerEncoderLayer @@ -91,6 +92,8 @@ def check_gradient_checkpoint(model): if model.decoder.gradient_checkpointing: model.decoder.gradient_checkpointing = False ckpt_laye_types += list(WENET_DECODER_LAYERS_CLASSES.values()) + if isinstance(model.decoder, DecoderOnly): + ckpt_laye_types += [DecoderOnly] return tuple(ckpt_laye_types) diff --git a/wenet/utils/init_dataset.py b/wenet/utils/init_dataset.py new file mode 100644 index 0000000000..37950c5d81 --- /dev/null +++ b/wenet/utils/init_dataset.py @@ -0,0 +1,24 @@ +from typing import Optional +from wenet.dataset.audiollm_dataset import Dataset as AudioLLMDataset +from wenet.dataset.dataset import Dataset as ASRDatast +from wenet.dataset.llm_dataset import Dataset as LLMDataset +from wenet.text.base_tokenizer import BaseTokenizer + + +def init_dataset(data_type, + data_list_file, + conf, + tokenizer: Optional[BaseTokenizer] = None, + partition=True, + dataset_type: str = 'asr'): + assert dataset_type in ['asr', 'llm', 'audio_llm'] + if dataset_type == 'asr': + return ASRDatast(data_type, data_list_file, tokenizer, conf, partition) + elif dataset_type == 'audio_llm': + assert tokenizer is not None + return AudioLLMDataset(data_type, data_list_file, tokenizer, conf, + partition) + else: + assert tokenizer is not None + return LLMDataset(data_type, data_list_file, tokenizer, conf, + partition) diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index ce8c12eeaf..b209e6a827 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -20,6 +20,9 @@ from wenet.paraformer.cif import Cif from wenet.paraformer.layers import SanmDecoder, SanmEncoder from wenet.paraformer.paraformer import Paraformer, Predictor +from wenet.AudioLLM.audiollm_model import AudioLLM +from wenet.LLM.causal_model import CausalLM +from wenet.LLM.decoder import DecoderOnly from wenet.transducer.joint import TransducerJoint from wenet.transducer.predictor import (ConvPredictor, EmbeddingPredictor, RNNPredictor) @@ -84,11 +87,12 @@ "k2_model": K2Model, "transducer": Transducer, 'paraformer': Paraformer, + 'causal_llm': CausalLM, + 'audio_llm': AudioLLM, } -def init_model(args, configs): - +def init_speech_model(args, configs): # TODO(xcsong): Forcefully read the 'cmvn' attribute. if configs.get('cmvn', None) == 'global_cmvn': mean, istd = load_cmvn(configs['cmvn_conf']['cmvn_file'], @@ -168,6 +172,85 @@ def init_model(args, configs): special_tokens=configs.get('tokenizer_conf', {}).get('special_tokens', None), **configs['model_conf']) + return model, configs + + +def init_causal_llm(configs): + vocab_size = configs['output_dim'] + assert configs['decoder'] == 'decoder_only' + assert configs['model'] == 'causal_lm' + decoder_only = DecoderOnly(**configs['decoder_conf']) + + model = CausalLM( + vocab_size, + decoder_only, + **configs['model_conf'], + special_tokens=configs.get('tokenizer_conf', + {}).get('special_tokens', None), + ) + return model, configs + +def init_audio_llm(args, configs): + assert configs['decoder'] == 'decoder_only' + assert configs['model'] == 'audio_llm' + if configs.get('cmvn', None) == 'global_cmvn': + mean, istd = load_cmvn(configs['cmvn_conf']['cmvn_file'], + configs['cmvn_conf']['is_json_cmvn']) + global_cmvn = GlobalCMVN( + torch.from_numpy(mean).float(), + torch.from_numpy(istd).float()) + else: + global_cmvn = None + + input_dim = configs['input_dim'] + vocab_size = configs['output_dim'] + + encoder_type = configs.get('encoder', 'conformer') + encoder = WENET_ENCODER_CLASSES[encoder_type]( + input_dim, + global_cmvn=global_cmvn, + **configs['encoder_conf'], + **configs['encoder_conf']['efficient_conf'] + if 'efficient_conf' in configs['encoder_conf'] else {}) + + decoder = DecoderOnly(**configs['decoder_conf']) + + model = AudioLLM( + vocab_size, + encoder, + decoder, + **configs['model_conf'], + special_tokens=configs.get('tokenizer_conf', + {}).get('special_tokens', None), + ) + + if hasattr(args, 'pretrain_encoder') and args.pretrain_encoder is not None: + encoder_state_dict = torch.load(args.pretrain_encoder, map_location='cpu') + new_encoder_state_dict = {} + for key in encoder_state_dict.keys(): + if 'encoder.' in key: + new_encoder_state_dict[key] = encoder_state_dict[key] + model.load_state_dict(new_encoder_state_dict, strict=False) + if int(os.environ.get('RANK', 0)) == 0: + print("load pretrained encoder from {}".format(args.pretrain_encoder)) + + if hasattr(args, 'pretrain_decoder') and args.pretrain_decoder is not None: + decoder_state_dict = torch.load(args.pretrain_decoder, map_location='cpu') + model.load_state_dict(decoder_state_dict, strict=False) + if int(os.environ.get('RANK', 0)) == 0: + print("load pretrained decoder from {}".format(args.pretrain_decoder)) + + return model, configs + +def init_model(args, configs): + + model_type = configs.get('model', 'asr_model') + if model_type == 'causal_lm': + model, configs = init_causal_llm(configs) + elif model_type == 'audio_llm': + model, configs = init_audio_llm(args, configs) + else: + model, configs = init_speech_model(args, configs) # If specify checkpoint, load some info from checkpoint if hasattr(args, 'checkpoint') and args.checkpoint is not None: @@ -178,16 +261,16 @@ def init_model(args, configs): infos = {} configs["init_infos"] = infos + # Trye to tie some weights + if hasattr(model, 'tie_or_clone_weights'): + if not hasattr(args, 'jit'): + args.jit = True # i.e. export onnx/jit/ipex + model.tie_or_clone_weights(args.jit) + if hasattr(args, 'only_optimize_lora') and args.only_optimize_lora: mark_only_lora_as_trainable(model, bias='lora_only') if int(os.environ.get('RANK', 0)) == 0: print(configs) - # Tie emb.weight to decoder.output_layer.weight - if model.decoder.tie_word_embedding: - if not hasattr(args, 'jit'): - args.jit = True # i.e. export onnx/jit/ipex - model.decoder.tie_or_clone_weights(jit_mode=args.jit) - return model, configs diff --git a/wenet/utils/init_tokenizer.py b/wenet/utils/init_tokenizer.py index c0c2ce7d77..9f42f058a3 100644 --- a/wenet/utils/init_tokenizer.py +++ b/wenet/utils/init_tokenizer.py @@ -18,6 +18,7 @@ from wenet.text.base_tokenizer import BaseTokenizer from wenet.text.bpe_tokenizer import BpeTokenizer from wenet.text.char_tokenizer import CharTokenizer +from wenet.text.hugging_face_tokenizer import HuggingFaceTokenizer from wenet.text.paraformer_tokenizer import ParaformerTokenizer from wenet.text.whisper_tokenizer import WhisperTokenizer @@ -47,6 +48,9 @@ def init_tokenizer(configs) -> BaseTokenizer: tokenizer = ParaformerTokenizer( symbol_table=configs['tokenizer_conf']['symbol_table_path'], seg_dict=configs['tokenizer_conf']['seg_dict_path']) + elif tokenizer_type == 'huggingface': + tokenizer = HuggingFaceTokenizer( + model=configs['tokenizer_conf']['model']) else: raise NotImplementedError logging.info("use {} tokenizer".format(configs["tokenizer"])) diff --git a/wenet/utils/mask.py b/wenet/utils/mask.py index 0480fb4f6a..c151d0425e 100644 --- a/wenet/utils/mask.py +++ b/wenet/utils/mask.py @@ -197,13 +197,14 @@ def add_optional_chunk_mask(xs: torch.Tensor, return chunk_masks -def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: +def make_pad_mask(lengths: torch.Tensor, max_len: int = 0, pad_type: str = "right") -> torch.Tensor: """Make mask tensor containing indices of padded part. See description of make_non_pad_mask. Args: lengths (torch.Tensor): Batch of lengths (B,). + type (string): Choice of [right, left] Returns: torch.Tensor: Mask tensor containing indices of padded part. @@ -222,7 +223,10 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: device=lengths.device) seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) seq_length_expand = lengths.unsqueeze(-1) - mask = seq_range_expand >= seq_length_expand + if pad_type == "right": + mask = seq_range_expand >= seq_length_expand + elif pad_type == "left": + mask = seq_range_expand.flip(dims=[-1]) >= seq_length_expand return mask diff --git a/wenet/utils/rope_utils.py b/wenet/utils/rope_utils.py index e80bf9ace7..54f13c47b8 100644 --- a/wenet/utils/rope_utils.py +++ b/wenet/utils/rope_utils.py @@ -31,3 +31,9 @@ def llama_apply_rotary_emb(x: torch.Tensor, x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) x_out = torch.view_as_real(x_ * freqs_cis).flatten(3) return x_out.type_as(x) + + +WENET_APPLY_ROTARY_EMB = { + 'google': google_apply_rotary_emb, + 'llama': llama_apply_rotary_emb, +} diff --git a/wenet/utils/train_utils.py b/wenet/utils/train_utils.py index fee435dc58..b6ca93522e 100644 --- a/wenet/utils/train_utils.py +++ b/wenet/utils/train_utils.py @@ -39,7 +39,6 @@ estimate_zero3_model_states_mem_needs_all_live) from deepspeed.utils.zero_to_fp32 import ( convert_zero_checkpoint_to_fp32_state_dict) -from wenet.dataset.dataset import Dataset from wenet.utils.checkpoint import save_checkpoint from wenet.utils.common import (StepTimer, get_nested_attribute, lrs_to_str, tensor_to_scalar) @@ -48,11 +47,14 @@ wenet_fsdp_wrap_policy) from wenet.utils.scheduler import WarmupLR, NoamHoldAnnealing from wenet.utils.ctc_utils import get_blank_id +from wenet.utils.init_dataset import init_dataset def add_model_args(parser): parser.add_argument('--config', required=True, help='config file') parser.add_argument('--model_dir', required=True, help='save model dir') + parser.add_argument('--pretrain_encoder', help='pretrain encoder model') + parser.add_argument('--pretrain_decoder', help='pretrain decoder model') parser.add_argument('--checkpoint', help='checkpoint model') parser.add_argument('--tensorboard_dir', default='tensorboard', @@ -281,21 +283,23 @@ def check_modify_and_save_config(args, configs, symbol_table): configs['encoder_conf']['lora_rank'] = args.lora_rank configs['encoder_conf']['lora_alpha'] = args.lora_alpha configs['encoder_conf']['lora_dropout'] = args.lora_dropout - - if 'input_dim' not in configs: - if 'fbank_conf' in configs['dataset_conf']: - input_dim = configs['dataset_conf']['fbank_conf']['num_mel_bins'] - elif 'log_mel_spectrogram_conf' in configs['dataset_conf']: - input_dim = configs['dataset_conf']['log_mel_spectrogram_conf'][ - 'num_mel_bins'] + if configs["model"] == 'asr_model': + if 'input_dim' not in configs: + if 'fbank_conf' in configs['dataset_conf']: + input_dim = configs['dataset_conf']['fbank_conf'][ + 'num_mel_bins'] + elif 'log_mel_spectrogram_conf' in configs['dataset_conf']: + input_dim = configs['dataset_conf'][ + 'log_mel_spectrogram_conf']['num_mel_bins'] + else: + input_dim = configs['dataset_conf']['mfcc_conf'][ + 'num_mel_bins'] else: - input_dim = configs['dataset_conf']['mfcc_conf']['num_mel_bins'] - else: - input_dim = configs['input_dim'] + input_dim = configs['input_dim'] - configs, _ = get_blank_id(configs, symbol_table) + configs, _ = get_blank_id(configs, symbol_table) - configs['input_dim'] = input_dim + configs['input_dim'] = input_dim configs['output_dim'] = configs['vocab_size'] configs['train_engine'] = args.train_engine @@ -335,13 +339,18 @@ def init_dataset_and_dataloader(args, configs, tokenizer, seed=777): cv_conf['list_shuffle'] = False configs['vocab_size'] = tokenizer.vocab_size() - train_dataset = Dataset(args.data_type, args.train_data, tokenizer, - train_conf, True) - cv_dataset = Dataset(args.data_type, - args.cv_data, - tokenizer, - cv_conf, - partition=False) + train_dataset = init_dataset(args.data_type, + args.train_data, + train_conf, + tokenizer, + True, + dataset_type=configs['dataset']) + cv_dataset = init_dataset(args.data_type, + args.cv_data, + cv_conf, + tokenizer, + partition=False, + dataset_type=configs['dataset']) # NOTE(xcsong): Why we prefer persistent_workers=True ? # https://discuss.pytorch.org/t/what-are-the-dis-advantages-of-persistent-workers/102110 @@ -790,8 +799,8 @@ def log_per_step(writer, info_dict, timer: Optional[StepTimer] = None): steps_per_second = timer.steps_per_second(timer_step) log_str += 'steps/sec {:.1f}| '.format(steps_per_second) log_str += 'Batch {}/{} loss {:.6f} '.format( - epoch, - batch_idx + 1 if 'save_interval' not in info_dict else step + 1, + epoch, batch_idx + 1 if 'save_interval' not in info_dict else + (step + 1) * accum_grad, tensor_to_scalar(loss_dict['loss']) * accum_grad) for name, value in loss_dict.items(): if name != 'loss' and value is not None: