Skip to content

Commit

Permalink
add batched inference using whisperX with faster-whisper fallback for…
Browse files Browse the repository at this point in the history
… unsupported languages
  • Loading branch information
MahmoudAshraf97 committed Dec 4, 2023
1 parent 6320541 commit d26814d
Show file tree
Hide file tree
Showing 5 changed files with 305 additions and 75 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,16 @@ If your system has enough VRAM (>=10GB), you can use `diarize_parallel.py` inste
- `--no-stem`: Disables source separation
- `--whisper-model`: The model to be used for ASR, default is `medium.en`
- `--suppress_numerals`: Transcribes numbers in their pronounced letters instead of digits, improves alignment accuracy
- `--device`: Choose which device to use, defaults to "cuda" if available
- `--language`: Manually select language, useful if language detection failed
- `--batch-size`: Batch size for batched inference, reduce if you run out of memory, set to 0 for non-batched inference

## Known Limitations
- Overlapping speakers are yet to be addressed, a possible approach would be to separate the audio file and isolate only one speaker, then feed it into the pipeline but this will need much more computation
- There might be some errors, please raise an issue if you encounter any.

## Future Improvements
- Implement a maximum length per sentence for SRT
- Improve Batch Processing

## Acknowledgements
Special Thanks for [@adamjonas](https://github.com/adamjonas) for supporting this project
Expand Down
198 changes: 154 additions & 44 deletions Whisper_Transcription_+_NeMo_Diarization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@
},
"outputs": [],
"source": [
"!pip install git+https://github.com/m-bain/whisperX.git@49e0130e4e0c0d99d60715d76e65a71826a97109\n",
"!pip install --no-build-isolation nemo_toolkit[asr]==1.20.0\n",
"!pip install faster-whisper==0.9.0\n",
"!pip install git+https://github.com/m-bain/whisperX.git@a5dca2cc65b1a37f32a347e574b2c56af3a7434a\n",
"!pip install --no-build-isolation nemo_toolkit[asr]==1.21.0\n",
"!pip install git+https://github.com/facebookresearch/demucs#egg=demucs\n",
"!pip install deepmultilingualpunctuation\n",
"!pip install wget pydub\n",
Expand Down Expand Up @@ -62,7 +61,8 @@
"import re\n",
"import logging\n",
"import nltk\n",
"from whisperx.alignment import DEFAULT_ALIGN_MODELS_HF, DEFAULT_ALIGN_MODELS_TORCH"
"from whisperx.alignment import DEFAULT_ALIGN_MODELS_HF, DEFAULT_ALIGN_MODELS_TORCH\n",
"from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE"
]
},
{
Expand Down Expand Up @@ -101,6 +101,10 @@
" DEFAULT_ALIGN_MODELS_HF.keys()\n",
")\n",
"\n",
"whisper_langs = sorted(LANGUAGES.keys()) + sorted(\n",
" [k.title() for k in TO_LANGUAGE_CODE.keys()]\n",
")\n",
"\n",
"\n",
"def create_config(output_dir):\n",
" DOMAIN_TYPE = \"telephonic\" # Can be meeting, telephonic, or general based on domain type of the audio file\n",
Expand Down Expand Up @@ -130,9 +134,7 @@
"\n",
" pretrained_vad = \"vad_multilingual_marblenet\"\n",
" pretrained_speaker_model = \"titanet_large\"\n",
"\n",
" config.num_workers = 0 # Workaround for multiprocessing hanging with ipython issue\n",
"\n",
" config.diarizer.manifest_filepath = os.path.join(data_dir, \"input_manifest.json\")\n",
" config.diarizer.out_dir = (\n",
" output_dir # Directory to store intermediate files and prediction outputs\n",
Expand Down Expand Up @@ -315,16 +317,19 @@
"\n",
"def get_speaker_aware_transcript(sentences_speaker_mapping, f):\n",
" previous_speaker = sentences_speaker_mapping[0][\"speaker\"]\n",
" text = sentences_speaker_mapping[0][\"text\"]\n",
" for sentence_dict in sentences_speaker_mapping[1:]:\n",
" sp = sentence_dict[\"speaker\"]\n",
" f.write(f\"{previous_speaker}: \")\n",
"\n",
" for sentence_dict in sentences_speaker_mapping:\n",
" speaker = sentence_dict[\"speaker\"]\n",
" sentence = sentence_dict[\"text\"]\n",
" if sp != previous_speaker:\n",
" f.write(f\"{previous_speaker}: {text}\\n\\n\")\n",
" text = sentence\n",
" previous_speaker = sp\n",
" else:\n",
" text += \" \" + sentence\n",
"\n",
" # If this speaker doesn't match the previous one, start a new paragraph\n",
" if speaker != previous_speaker:\n",
" f.write(f\"\\n\\n{speaker}: \")\n",
" previous_speaker = speaker\n",
"\n",
" # No matter what, write the current sentence\n",
" f.write(sentence + \" \")\n",
"\n",
"\n",
"def format_timestamp(\n",
Expand Down Expand Up @@ -428,7 +433,100 @@
" # remove directory and all its content\n",
" shutil.rmtree(path)\n",
" else:\n",
" raise ValueError(\"Path {} is not a file or dir.\".format(path))"
" raise ValueError(\"Path {} is not a file or dir.\".format(path))\n",
"\n",
"\n",
"def process_language_arg(language: str, model_name: str):\n",
" \"\"\"\n",
" Process the language argument to make sure it's valid and convert language names to language codes.\n",
" \"\"\"\n",
" if language is not None:\n",
" language = language.lower()\n",
" if language not in LANGUAGES:\n",
" if language in TO_LANGUAGE_CODE:\n",
" language = TO_LANGUAGE_CODE[language]\n",
" else:\n",
" raise ValueError(f\"Unsupported language: {language}\")\n",
"\n",
" if model_name.endswith(\".en\") and language != \"en\":\n",
" if language is not None:\n",
" logging.warning(\n",
" f\"{model_name} is an English-only model but received '{language}'; using English instead.\"\n",
" )\n",
" language = \"en\"\n",
" return language\n",
"\n",
"\n",
"def transcribe(\n",
" audio_file: str,\n",
" language: str,\n",
" model_name: str,\n",
" compute_dtype: str,\n",
" suppress_numerals: bool,\n",
" device: str,\n",
"):\n",
" from faster_whisper import WhisperModel\n",
" from helpers import find_numeral_symbol_tokens, wav2vec2_langs\n",
"\n",
" # Faster Whisper non-batched\n",
" # Run on GPU with FP16\n",
" whisper_model = WhisperModel(model_name, device=device, compute_type=compute_dtype)\n",
"\n",
" # or run on GPU with INT8\n",
" # model = WhisperModel(model_size, device=\"cuda\", compute_type=\"int8_float16\")\n",
" # or run on CPU with INT8\n",
" # model = WhisperModel(model_size, device=\"cpu\", compute_type=\"int8\")\n",
"\n",
" if suppress_numerals:\n",
" numeral_symbol_tokens = find_numeral_symbol_tokens(whisper_model.hf_tokenizer)\n",
" else:\n",
" numeral_symbol_tokens = None\n",
"\n",
" if language is not None and language in wav2vec2_langs:\n",
" word_timestamps = False\n",
" else:\n",
" word_timestamps = True\n",
"\n",
" segments, info = whisper_model.transcribe(\n",
" audio_file,\n",
" language=language,\n",
" beam_size=5,\n",
" word_timestamps=word_timestamps, # TODO: disable this if the language is supported by wav2vec2\n",
" suppress_tokens=numeral_symbol_tokens,\n",
" vad_filter=True,\n",
" )\n",
" whisper_results = []\n",
" for segment in segments:\n",
" whisper_results.append(segment._asdict())\n",
" # clear gpu vram\n",
" del whisper_model\n",
" torch.cuda.empty_cache()\n",
" return whisper_results, language\n",
"\n",
"\n",
"def transcribe_batched(\n",
" audio_file: str,\n",
" language: str,\n",
" batch_size: int,\n",
" model_name: str,\n",
" compute_dtype: str,\n",
" suppress_numerals: bool,\n",
" device: str,\n",
"):\n",
" import whisperx\n",
"\n",
" # Faster Whisper batched\n",
" whisper_model = whisperx.load_model(\n",
" model_name,\n",
" device,\n",
" compute_type=compute_dtype,\n",
" asr_options={\"suppress_numerals\": suppress_numerals},\n",
" )\n",
" audio = whisperx.load_audio(audio_file)\n",
" result = whisper_model.transcribe(audio, language=language, batch_size=batch_size)\n",
" del whisper_model\n",
" torch.cuda.empty_cache()\n",
" return result[\"segments\"], result[\"language\"]\n"
]
},
{
Expand All @@ -455,11 +553,17 @@
"# Whether to enable music removal from speech, helps increase diarization quality but uses alot of ram\n",
"enable_stemming = True\n",
"\n",
"# (choose from 'tiny.en', 'tiny', 'base.en', 'base', 'small.en', 'small', 'medium.en', 'medium', 'large-v1', 'large-v2', 'large')\n",
"# (choose from 'tiny.en', 'tiny', 'base.en', 'base', 'small.en', 'small', 'medium.en', 'medium', 'large-v1', 'large-v2', 'large-v3', 'large')\n",
"whisper_model_name = \"large-v2\"\n",
"\n",
"# replaces numerical digits with their pronounciation, increases diarization accuracy\n",
"suppress_numerals = True"
"suppress_numerals = True\n",
"\n",
"batch_size = 8\n",
"\n",
"language = None # autodetect language\n",
"\n",
"device =\"cuda\" if torch.cuda.is_available() else \"cpu\""
]
},
{
Expand Down Expand Up @@ -543,31 +647,31 @@
},
"outputs": [],
"source": [
"# Run on GPU with FP16\n",
"whisper_model = WhisperModel(whisper_model_name, device=\"cuda\", compute_type=\"float16\")\n",
"\n",
"compute_type = \"float16\"\n",
"# or run on GPU with INT8\n",
"# model = WhisperModel(model_size, device=\"cuda\", compute_type=\"int8_float16\")\n",
"# compute_type = \"int8_float16\"\n",
"# or run on CPU with INT8\n",
"# model = WhisperModel(model_size, device=\"cpu\", compute_type=\"int8\")\n",
"if suppress_numerals:\n",
" numeral_symbol_tokens = find_numeral_symbol_tokens(whisper_model.hf_tokenizer)\n",
"# compute_type = \"int8\"\n",
"\n",
"if batch_size != 0:\n",
" whisper_results, language = transcribe_batched(\n",
" vocal_target,\n",
" language,\n",
" batch_size,\n",
" whisper_model_name,\n",
" compute_type,\n",
" suppress_numerals,\n",
" device,\n",
" )\n",
"else:\n",
" numeral_symbol_tokens = None\n",
"\n",
"segments, info = whisper_model.transcribe(\n",
" vocal_target,\n",
" beam_size=5,\n",
" word_timestamps=True,\n",
" suppress_tokens=numeral_symbol_tokens,\n",
" vad_filter=True,\n",
")\n",
"whisper_results = []\n",
"for segment in segments:\n",
" whisper_results.append(segment._asdict())\n",
"# clear gpu vram\n",
"del whisper_model\n",
"torch.cuda.empty_cache()"
" whisper_results, language = transcribe(\n",
" vocal_target,\n",
" language,\n",
" whisper_model_name,\n",
" compute_type,\n",
" suppress_numerals,\n",
" device,\n",
" )"
]
},
{
Expand All @@ -592,10 +696,10 @@
"metadata": {},
"outputs": [],
"source": [
"if info.language in wav2vec2_langs:\n",
"if language in wav2vec2_langs:\n",
" device = \"cuda\"\n",
" alignment_model, metadata = whisperx.load_align_model(\n",
" language_code=info.language, device=device\n",
" language_code=language, device=device\n",
" )\n",
" result_aligned = whisperx.align(\n",
" whisper_results, alignment_model, metadata, vocal_target, device\n",
Expand All @@ -606,6 +710,12 @@
" del alignment_model\n",
" torch.cuda.empty_cache()\n",
"else:\n",
" assert (\n",
" batch_size == 0 # TODO: add a better check for word timestamps existence\n",
" ), (\n",
" f\"Unsupported language: {language}, use --batch_size to 0\"\n",
" \" to generate word timestamps using whisper directly and fix this error.\"\n",
" )\n",
" word_timestamps = []\n",
" for segment in whisper_results:\n",
" for word in segment[\"words\"]:\n",
Expand Down Expand Up @@ -725,7 +835,7 @@
},
"outputs": [],
"source": [
"if info.language in punct_model_langs:\n",
"if language in punct_model_langs:\n",
" # restoring punctuation in the transcript to help realign the sentences\n",
" punct_model = PunctuationModel(model=\"kredor/punctuate-all\")\n",
"\n",
Expand Down Expand Up @@ -754,7 +864,7 @@
" wsm = get_realigned_ws_mapping_with_punctuation(wsm)\n",
"else:\n",
" print(\n",
" f'Punctuation restoration is not available for {info.language} language.'\n",
" f'Punctuation restoration is not available for {language} language.'\n",
" )\n",
"\n",
"ssm = get_sentences_speaker_mapping(wsm, speaker_ts)"
Expand Down
Loading

0 comments on commit d26814d

Please sign in to comment.