Skip to content

Commit ac0c8bd

Browse files
committed
feat: add version and Python version arguments to CLI
1 parent cd59f21 commit ac0c8bd

File tree

1 file changed

+43
-8
lines changed

1 file changed

+43
-8
lines changed

whisperx/transcribe.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import argparse
22
import gc
33
import os
4+
import sys
45
import warnings
6+
import importlib.metadata
7+
import platform
58

69
import numpy as np
710
import torch
@@ -85,6 +88,8 @@ def cli():
8588
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
8689

8790
parser.add_argument("--print_progress", type=str2bool, default = False, help = "if True, progress will be printed in transcribe() and align() methods.")
91+
parser.add_argument("--version", "-V", action="version", version=f"%(prog)s {importlib.metadata.version('whisperx')}",help="Show whisperx version information and exit")
92+
parser.add_argument("--python-version", "-P", action="version", version=f"Python {platform.python_version()} ({platform.python_implementation()})",help="Show python version information and exit")
8893
# fmt: on
8994

9095
args = parser.parse_args().__dict__
@@ -138,7 +143,9 @@ def cli():
138143
f"{model_name} is an English-only model but received '{args['language']}'; using English instead."
139144
)
140145
args["language"] = "en"
141-
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
146+
align_language = (
147+
args["language"] if args["language"] is not None else "en"
148+
) # default to loading english if not specified
142149

143150
temperature = args.pop("temperature")
144151
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
@@ -174,12 +181,29 @@ def cli():
174181
if args["max_line_count"] and not args["max_line_width"]:
175182
warnings.warn("--max_line_count has no effect without --max_line_width")
176183
writer_args = {arg: args.pop(arg) for arg in word_options}
177-
184+
178185
# Part 1: VAD & ASR Loop
179186
results = []
180187
tmp_results = []
181188
# model = load_model(model_name, device=device, download_root=model_dir)
182-
model = load_model(model_name, device=device, device_index=device_index, download_root=model_dir, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_method=vad_method, vad_options={"chunk_size":chunk_size, "vad_onset": vad_onset, "vad_offset": vad_offset}, task=task, local_files_only=model_cache_only, threads=faster_whisper_threads)
189+
model = load_model(
190+
model_name,
191+
device=device,
192+
device_index=device_index,
193+
download_root=model_dir,
194+
compute_type=compute_type,
195+
language=args["language"],
196+
asr_options=asr_options,
197+
vad_method=vad_method,
198+
vad_options={
199+
"chunk_size": chunk_size,
200+
"vad_onset": vad_onset,
201+
"vad_offset": vad_offset,
202+
},
203+
task=task,
204+
local_files_only=model_cache_only,
205+
threads=faster_whisper_threads,
206+
)
183207

184208
for audio_path in args.pop("audio"):
185209
audio = load_audio(audio_path)
@@ -203,7 +227,9 @@ def cli():
203227
if not no_align:
204228
tmp_results = results
205229
results = []
206-
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
230+
align_model, align_metadata = load_align_model(
231+
align_language, device, model_name=align_model
232+
)
207233
for result, audio_path in tmp_results:
208234
# >> Align
209235
if len(tmp_results) > 1:
@@ -215,8 +241,12 @@ def cli():
215241
if align_model is not None and len(result["segments"]) > 0:
216242
if result.get("language", "en") != align_metadata["language"]:
217243
# load new language
218-
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
219-
align_model, align_metadata = load_align_model(result["language"], device)
244+
print(
245+
f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language..."
246+
)
247+
align_model, align_metadata = load_align_model(
248+
result["language"], device
249+
)
220250
print(">>Performing alignment...")
221251
result: AlignedTranscriptionResult = align(
222252
result["segments"],
@@ -239,19 +269,24 @@ def cli():
239269
# >> Diarize
240270
if diarize:
241271
if hf_token is None:
242-
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
272+
print(
273+
"Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model..."
274+
)
243275
tmp_results = results
244276
print(">>Performing diarization...")
245277
results = []
246278
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
247279
for result, input_audio_path in tmp_results:
248-
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
280+
diarize_segments = diarize_model(
281+
input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers
282+
)
249283
result = assign_word_speakers(diarize_segments, result)
250284
results.append((result, input_audio_path))
251285
# >> Write
252286
for result, audio_path in results:
253287
result["language"] = align_language
254288
writer(result, audio_path, writer_args)
255289

290+
256291
if __name__ == "__main__":
257292
cli()

0 commit comments

Comments
 (0)