Skip to content

Commit

Permalink
enable punctuation alignment for unsupported languages
Browse files Browse the repository at this point in the history
  • Loading branch information
MahmoudAshraf97 committed Dec 6, 2023
1 parent d26814d commit a1bad91
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 44 deletions.
16 changes: 7 additions & 9 deletions Whisper_Transcription_+_NeMo_Diarization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@
" result = whisper_model.transcribe(audio, language=language, batch_size=batch_size)\n",
" del whisper_model\n",
" torch.cuda.empty_cache()\n",
" return result[\"segments\"], result[\"language\"]\n"
" return result[\"segments\"], result[\"language\"]"
]
},
{
Expand Down Expand Up @@ -561,9 +561,9 @@
"\n",
"batch_size = 8\n",
"\n",
"language = None # autodetect language\n",
"language = None # autodetect language\n",
"\n",
"device =\"cuda\" if torch.cuda.is_available() else \"cpu\""
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
]
},
{
Expand Down Expand Up @@ -710,9 +710,7 @@
" del alignment_model\n",
" torch.cuda.empty_cache()\n",
"else:\n",
" assert (\n",
" batch_size == 0 # TODO: add a better check for word timestamps existence\n",
" ), (\n",
" assert batch_size == 0, ( # TODO: add a better check for word timestamps existence\n",
" f\"Unsupported language: {language}, use --batch_size to 0\"\n",
" \" to generate word timestamps using whisper directly and fix this error.\"\n",
" )\n",
Expand Down Expand Up @@ -861,12 +859,12 @@
" word = word.rstrip(\".\")\n",
" word_dict[\"word\"] = word\n",
"\n",
" wsm = get_realigned_ws_mapping_with_punctuation(wsm)\n",
"else:\n",
" print(\n",
" f'Punctuation restoration is not available for {language} language.'\n",
" logging.warning(\n",
" f\"Punctuation restoration is not available for {language} language. Using the original punctuation.\"\n",
" )\n",
"\n",
"wsm = get_realigned_ws_mapping_with_punctuation(wsm)\n",
"ssm = get_sentences_speaker_mapping(wsm, speaker_ts)"
]
},
Expand Down
4 changes: 2 additions & 2 deletions diarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,12 @@
word = word.rstrip(".")
word_dict["word"] = word

wsm = get_realigned_ws_mapping_with_punctuation(wsm)
else:
logging.warning(
f"Punctuation restoration is not available for {language} language."
f"Punctuation restoration is not available for {language} language. Using the original punctuation."
)

wsm = get_realigned_ws_mapping_with_punctuation(wsm)
ssm = get_sentences_speaker_mapping(wsm, speaker_ts)

with open(f"{os.path.splitext(args.audio)[0]}.txt", "w", encoding="utf-8-sig") as f:
Expand Down
83 changes: 50 additions & 33 deletions diarize_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,22 @@
help="name of the Whisper model to use",
)

parser.add_argument(
"--batch-size",
type=int,
dest="batch_size",
default=8,
help="Batch size for batched inference, reduce if you run out of memory, set to 0 for non-batched inference",
)

parser.add_argument(
"--language",
type=str,
default=None,
choices=whisper_langs,
help="Language spoken in the audio, specify None to perform language detection",
)

parser.add_argument(
"--device",
dest="device",
Expand Down Expand Up @@ -76,39 +92,34 @@
nemo_process = subprocess.Popen(
["python3", "nemo_process.py", "-a", vocal_target, "--device", args.device],
)
# Run on GPU with FP16
whisper_model = WhisperModel(
args.model_name, device=args.device, compute_type=mtypes[args.device]
)

# or run on GPU with INT8
# model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
# or run on CPU with INT8
# model = WhisperModel(model_size, device="cpu", compute_type="int8")

if args.suppress_numerals:
numeral_symbol_tokens = find_numeral_symbol_tokens(whisper_model.hf_tokenizer)
# Transcribe the audio file
if args.batch_size != 0:
from transcription_helpers import transcribe_batched

whisper_results, language = transcribe_batched(
vocal_target,
args.language,
args.batch_size,
args.model_name,
mtypes[args.device],
args.suppress_numerals,
args.device,
)
else:
numeral_symbol_tokens = None

segments, info = whisper_model.transcribe(
vocal_target,
beam_size=5,
word_timestamps=True,
suppress_tokens=numeral_symbol_tokens,
vad_filter=True,
)
whisper_results = []
for segment in segments:
whisper_results.append(segment._asdict())

# clear gpu vram
del whisper_model
torch.cuda.empty_cache()
from transcription_helpers import transcribe

whisper_results, language = transcribe(
vocal_target,
args.language,
args.model_name,
mtypes[args.device],
args.suppress_numerals,
args.device,
)

if info.language in wav2vec2_langs:
if language in wav2vec2_langs:
alignment_model, metadata = whisperx.load_align_model(
language_code=info.language, device=args.device
language_code=language, device=args.device
)
result_aligned = whisperx.align(
whisper_results, alignment_model, metadata, vocal_target, args.device
Expand All @@ -118,6 +129,12 @@
del alignment_model
torch.cuda.empty_cache()
else:
assert (
args.batch_size == 0 # TODO: add a better check for word timestamps existence
), (
f"Unsupported language: {language}, use --batch_size to 0"
" to generate word timestamps using whisper directly and fix this error."
)
word_timestamps = []
for segment in whisper_results:
for word in segment["words"]:
Expand All @@ -139,7 +156,7 @@

wsm = get_words_speaker_mapping(word_timestamps, speaker_ts, "start")

if info.language in punct_model_langs:
if language in punct_model_langs:
# restoring punctuation in the transcript to help realign the sentences
punct_model = PunctuationModel(model="kredor/punctuate-all")

Expand All @@ -165,12 +182,12 @@
word = word.rstrip(".")
word_dict["word"] = word

wsm = get_realigned_ws_mapping_with_punctuation(wsm)
else:
logging.warning(
f"Punctuation restoration is not available for {info.language} language."
f"Punctuation restoration is not available for {language} language. Using the original punctuation."
)

wsm = get_realigned_ws_mapping_with_punctuation(wsm)
ssm = get_sentences_speaker_mapping(wsm, speaker_ts)

with open(f"{os.path.splitext(args.audio)[0]}.txt", "w", encoding="utf-8-sig") as f:
Expand Down

0 comments on commit a1bad91

Please sign in to comment.