Skip to content

Commit 35d904c

Browse files
committed
fix: add type checks for diarization results in transcribe_task function
1 parent 7019370 commit 35d904c

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

whisperx/transcribe.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import warnings
55

66
import numpy as np
7+
import pandas as pd
78
import torch
89

910
from whisperx.alignment import align, load_align_model
@@ -218,15 +219,23 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
218219
diarize_model = DiarizationPipeline(model_name=diarize_model_name, use_auth_token=hf_token, device=device)
219220
for result, input_audio_path in tmp_results:
220221
diarize_result = diarize_model(
221-
input_audio_path,
222-
min_speakers=min_speakers,
223-
max_speakers=max_speakers,
222+
input_audio_path,
223+
min_speakers=min_speakers,
224+
max_speakers=max_speakers,
224225
return_embeddings=return_speaker_embeddings
225226
)
226227

227228
if return_speaker_embeddings:
229+
if not isinstance(diarize_result, tuple):
230+
raise TypeError(
231+
f"Expected tuple when return_embeddings=True, got {type(diarize_result).__name__}"
232+
)
228233
diarize_segments, speaker_embeddings = diarize_result
229234
else:
235+
if not isinstance(diarize_result, pd.DataFrame):
236+
raise TypeError(
237+
f"Expected DataFrame when return_embeddings=False, got {type(diarize_result).__name__}"
238+
)
230239
diarize_segments = diarize_result
231240
speaker_embeddings = None
232241

0 commit comments

Comments
 (0)