Skip to content

Commit a51ae7a

Browse files
authored
feat: add centralized logging to replace ad-hoc print statements (#1254)
* feat: add logging utility functions * feat: add logging setup and log level argument to CLI * feat: integrate logging across modules
1 parent 3b1b9a8 commit a51ae7a

File tree

9 files changed

+145
-20
lines changed

9 files changed

+145
-20
lines changed

whisperx/__init__.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,29 @@ def load_audio(*args, **kwargs):
2929
def assign_word_speakers(*args, **kwargs):
3030
diarize = _lazy_import("diarize")
3131
return diarize.assign_word_speakers(*args, **kwargs)
32+
33+
34+
def setup_logging(*args, **kwargs):
35+
"""
36+
Configure logging for WhisperX.
37+
38+
Args:
39+
level: Logging level (debug, info, warning, error, critical). Default: warning
40+
log_file: Optional path to log file. If None, logs only to console.
41+
"""
42+
logging_module = _lazy_import("log_utils")
43+
return logging_module.setup_logging(*args, **kwargs)
44+
45+
46+
def get_logger(*args, **kwargs):
47+
"""
48+
Get a logger instance for the given module.
49+
50+
Args:
51+
name: Logger name (typically __name__ from calling module)
52+
53+
Returns:
54+
Logger instance configured with WhisperX settings
55+
"""
56+
logging_module = _lazy_import("log_utils")
57+
return logging_module.get_logger(*args, **kwargs)

whisperx/__main__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from whisperx.utils import (LANGUAGES, TO_LANGUAGE_CODE, optional_float,
88
optional_int, str2bool)
9+
from whisperx.log_utils import setup_logging
910

1011

1112
def cli():
@@ -23,6 +24,7 @@ def cli():
2324
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
2425
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json", "aud"], help="format of the output file; if not specified, all available formats will be produced")
2526
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
27+
parser.add_argument("--log-level", type=str, default=None, choices=["debug", "info", "warning", "error", "critical"], help="logging level (overrides --verbose if set)")
2628

2729
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
2830
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
@@ -80,6 +82,16 @@ def cli():
8082

8183
args = parser.parse_args().__dict__
8284

85+
log_level = args.get("log_level")
86+
verbose = args.get("verbose")
87+
88+
if log_level is not None:
89+
setup_logging(level=log_level)
90+
elif verbose:
91+
setup_logging(level="info")
92+
else:
93+
setup_logging(level="warning")
94+
8395
from whisperx.transcribe import transcribe_task
8496

8597
transcribe_task(args, parser)

whisperx/alignment.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
)
2525
import nltk
2626
from nltk.data import load as nltk_load
27+
from whisperx.log_utils import get_logger
28+
29+
logger = get_logger(__name__)
2730

2831
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
2932

@@ -81,8 +84,9 @@ def load_align_model(language_code: str, device: str, model_name: Optional[str]
8184
elif language_code in DEFAULT_ALIGN_MODELS_HF:
8285
model_name = DEFAULT_ALIGN_MODELS_HF[language_code]
8386
else:
84-
print(f"There is no default alignment model set for this language ({language_code}).\
85-
Please find a wav2vec2.0 model finetuned on this language in https://huggingface.co/models, then pass the model name in --align_model [MODEL_NAME]")
87+
logger.error(f"No default alignment model for language: {language_code}. "
88+
f"Please find a wav2vec2.0 model finetuned on this language at https://huggingface.co/models, "
89+
f"then pass the model name via --align_model [MODEL_NAME]")
8690
raise ValueError(f"No default align-model for language: {language_code}")
8791

8892
if model_name in torchaudio.pipelines.__all__:
@@ -223,12 +227,12 @@ def align(
223227

224228
# check we can align
225229
if len(segment_data[sdx]["clean_char"]) == 0:
226-
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
230+
logger.warning(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original')
227231
aligned_segments.append(aligned_seg)
228232
continue
229233

230234
if t1 >= MAX_DURATION:
231-
print(f'Failed to align segment ("{segment["text"]}"): original start time longer than audio duration, skipping...')
235+
logger.warning(f'Failed to align segment ("{segment["text"]}"): original start time longer than audio duration, skipping')
232236
aligned_segments.append(aligned_seg)
233237
continue
234238

@@ -270,7 +274,7 @@ def align(
270274
path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2)
271275

272276
if path is None:
273-
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
277+
logger.warning(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original')
274278
aligned_segments.append(aligned_seg)
275279
continue
276280

whisperx/asr.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
1515
from whisperx.schema import SingleSegment, TranscriptionResult
1616
from whisperx.vads import Vad, Silero, Pyannote
17+
from whisperx.log_utils import get_logger
18+
19+
logger = get_logger(__name__)
1720

1821

1922
def find_numeral_symbol_tokens(tokenizer):
@@ -247,7 +250,7 @@ def data(audio, segments):
247250
if self.suppress_numerals:
248251
previous_suppress_tokens = self.options.suppress_tokens
249252
numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer)
250-
print(f"Suppressing numeral and symbol tokens")
253+
logger.info("Suppressing numeral and symbol tokens")
251254
new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens
252255
new_suppressed_tokens = list(set(new_suppressed_tokens))
253256
self.options = replace(self.options, suppress_tokens=new_suppressed_tokens)
@@ -285,7 +288,7 @@ def data(audio, segments):
285288

286289
def detect_language(self, audio: np.ndarray) -> str:
287290
if audio.shape[0] < N_SAMPLES:
288-
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
291+
logger.warning("Audio is shorter than 30s, language detection may be inaccurate")
289292
model_n_mels = self.model.feat_kwargs.get("feature_size")
290293
segment = log_mel_spectrogram(audio[: N_SAMPLES],
291294
n_mels=model_n_mels if model_n_mels is not None else 80,
@@ -294,7 +297,7 @@ def detect_language(self, audio: np.ndarray) -> str:
294297
results = self.model.model.detect_language(encoder_output)
295298
language_token, language_probability = results[0][0]
296299
language = language_token[2:-2]
297-
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
300+
logger.info(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio")
298301
return language
299302

300303

@@ -344,7 +347,7 @@ def load_model(
344347
if language is not None:
345348
tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
346349
else:
347-
print("No language specified, language will be first be detected for each audio file (increases inference time).")
350+
logger.info("No language specified, language will be detected for each audio file (increases inference time)")
348351
tokenizer = None
349352

350353
default_asr_options = {

whisperx/diarize.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
from whisperx.audio import load_audio, SAMPLE_RATE
88
from whisperx.schema import TranscriptionResult, AlignedTranscriptionResult
9+
from whisperx.log_utils import get_logger
10+
11+
logger = get_logger(__name__)
912

1013

1114
class DiarizationPipeline:
@@ -18,6 +21,7 @@ def __init__(
1821
if isinstance(device, str):
1922
device = torch.device(device)
2023
model_config = model_name or "pyannote/speaker-diarization-3.1"
24+
logger.info(f"Loading diarization model: {model_config}")
2125
self.model = Pipeline.from_pretrained(model_config, use_auth_token=use_auth_token).to(device)
2226

2327
def __call__(

whisperx/log_utils.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import logging
2+
import sys
3+
from typing import Optional
4+
5+
_LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
6+
_DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
7+
8+
9+
def setup_logging(
10+
level: str = "info",
11+
log_file: Optional[str] = None,
12+
) -> None:
13+
"""
14+
Configure logging for WhisperX.
15+
16+
Args:
17+
level: Logging level (debug, info, warning, error, critical). Default: info
18+
log_file: Optional path to log file. If None, logs only to console.
19+
"""
20+
logger = logging.getLogger("whisperx")
21+
22+
logger.handlers.clear()
23+
24+
try:
25+
log_level = getattr(logging, level.upper())
26+
except AttributeError:
27+
log_level = logging.WARNING
28+
logger.setLevel(log_level)
29+
30+
formatter = logging.Formatter(_LOG_FORMAT, datefmt=_DATE_FORMAT)
31+
32+
console_handler = logging.StreamHandler(sys.stdout)
33+
console_handler.setLevel(log_level)
34+
console_handler.setFormatter(formatter)
35+
36+
logger.addHandler(console_handler)
37+
38+
if log_file:
39+
try:
40+
file_handler = logging.FileHandler(log_file)
41+
file_handler.setLevel(log_level)
42+
file_handler.setFormatter(formatter)
43+
logger.addHandler(file_handler)
44+
except (OSError) as e:
45+
logger.warning(f"Failed to create log file '{log_file}': {e}")
46+
logger.warning("Continuing with console logging only")
47+
48+
# Don't propagate to root logger to avoid duplicate messages
49+
logger.propagate = False
50+
51+
52+
def get_logger(name: str) -> logging.Logger:
53+
"""
54+
Get a logger instance for the given module.
55+
56+
Args:
57+
name: Logger name (typically __name__ from calling module)
58+
59+
Returns:
60+
Logger instance configured with WhisperX settings
61+
"""
62+
whisperx_logger = logging.getLogger("whisperx")
63+
if not whisperx_logger.handlers:
64+
setup_logging()
65+
66+
logger_name = "whisperx" if name == "__main__" else name
67+
return logging.getLogger(logger_name)

whisperx/transcribe.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from whisperx.diarize import DiarizationPipeline, assign_word_speakers
1313
from whisperx.schema import AlignedTranscriptionResult, TranscriptionResult
1414
from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE, get_writer
15+
from whisperx.log_utils import get_logger
16+
17+
logger = get_logger(__name__)
1518

1619

1720
def transcribe_task(args: dict, parser: argparse.ArgumentParser):
@@ -142,7 +145,7 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
142145
for audio_path in args.pop("audio"):
143146
audio = load_audio(audio_path)
144147
# >> VAD & ASR
145-
print(">>Performing transcription...")
148+
logger.info("Performing transcription...")
146149
result: TranscriptionResult = model.transcribe(
147150
audio,
148151
batch_size=batch_size,
@@ -175,13 +178,13 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
175178
if align_model is not None and len(result["segments"]) > 0:
176179
if result.get("language", "en") != align_metadata["language"]:
177180
# load new language
178-
print(
181+
logger.info(
179182
f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language..."
180183
)
181184
align_model, align_metadata = load_align_model(
182185
result["language"], device
183186
)
184-
print(">>Performing alignment...")
187+
logger.info("Performing alignment...")
185188
result: AlignedTranscriptionResult = align(
186189
result["segments"],
187190
align_model,
@@ -203,12 +206,12 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
203206
# >> Diarize
204207
if diarize:
205208
if hf_token is None:
206-
print(
207-
"Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model..."
209+
logger.warning(
210+
"No --hf_token provided, needs to be saved in environment variable, otherwise will throw error loading diarization model"
208211
)
209212
tmp_results = results
210-
print(">>Performing diarization...")
211-
print(">>Using model:", diarize_model_name)
213+
logger.info("Performing diarization...")
214+
logger.info(f"Using model: {diarize_model_name}")
212215
results = []
213216
diarize_model = DiarizationPipeline(model_name=diarize_model_name, use_auth_token=hf_token, device=device)
214217
for result, input_audio_path in tmp_results:

whisperx/vads/pyannote.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313

1414
from whisperx.diarize import Segment as SegmentX
1515
from whisperx.vads.vad import Vad
16+
from whisperx.log_utils import get_logger
17+
18+
logger = get_logger(__name__)
1619

1720

1821
def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
@@ -232,7 +235,7 @@ def apply(self, file: AudioFile, hook: Optional[Callable] = None) -> Annotation:
232235
class Pyannote(Vad):
233236

234237
def __init__(self, device, use_auth_token=None, model_fp=None, **kwargs):
235-
print(">>Performing voice activity detection using Pyannote...")
238+
logger.info("Performing voice activity detection using Pyannote...")
236239
super().__init__(kwargs['vad_onset'])
237240
self.vad_pipeline = load_vad_model(device, use_auth_token=use_auth_token, model_fp=model_fp)
238241

@@ -257,7 +260,7 @@ def merge_chunks(segments,
257260
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
258261

259262
if len(segments_list) == 0:
260-
print("No active speech found in audio")
263+
logger.warning("No active speech found in audio")
261264
return []
262265
assert segments_list, "segments_list is empty."
263266
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)

whisperx/vads/silero.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,17 @@
88

99
from whisperx.diarize import Segment as SegmentX
1010
from whisperx.vads.vad import Vad
11+
from whisperx.log_utils import get_logger
12+
13+
logger = get_logger(__name__)
1114

1215
AudioFile = Union[Text, Path, IOBase, Mapping]
1316

1417

1518
class Silero(Vad):
1619
# check again default values
1720
def __init__(self, **kwargs):
18-
print(">>Performing voice activity detection using Silero...")
21+
logger.info("Performing voice activity detection using Silero...")
1922
super().__init__(kwargs['vad_onset'])
2023

2124
self.vad_onset = kwargs['vad_onset']
@@ -60,7 +63,7 @@ def merge_chunks(segments_list,
6063
):
6164
assert chunk_size > 0
6265
if len(segments_list) == 0:
63-
print("No active speech found in audio")
66+
logger.warning("No active speech found in audio")
6467
return []
6568
assert segments_list, "segments_list is empty."
6669
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)

0 commit comments

Comments
 (0)