Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions whisperx/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from transformers import Pipeline
from transformers.pipelines.pt_utils import PipelineIterator

from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram, resample_audio
from .vad import load_vad_model, merge_chunks
from .types import TranscriptionResult, SingleSegment

Expand Down Expand Up @@ -258,10 +258,16 @@ def stack(items):
return final_iterator

def transcribe(
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None, chunk_size=30, print_progress = False, combined_progress=False
self, audio: Union[str, np.ndarray, torch.Tensor], sample_rate: int = SAMPLE_RATE, batch_size=None, num_workers=0, language=None, task=None, chunk_size=30, print_progress = False, combined_progress=False
) -> TranscriptionResult:
if isinstance(audio, str):
audio = load_audio(audio)
elif isinstance(audio, (np.ndarray, torch.Tensor)):
if sample_rate != SAMPLE_RATE:
audio = resample_audio(audio, sample_rate)
sample_rate = SAMPLE_RATE
else:
print("The 'sample_rate' argument is set to 16000 (16 kHz) by default. Audio will not be resampled")

def data(audio, segments):
for seg in segments:
Expand All @@ -270,7 +276,10 @@ def data(audio, segments):
# print(f2-f1)
yield {'inputs': audio[f1:f2]}

vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
vad_segments = self.vad_model(
{"waveform": torch.from_numpy(audio).unsqueeze(0) if type(audio) == np.ndarray else audio.unsqueeze(0),
"sample_rate": SAMPLE_RATE}
)
vad_segments = merge_chunks(
vad_segments,
chunk_size,
Expand Down
54 changes: 54 additions & 0 deletions whisperx/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
import numpy as np
import torch
import torch.nn.functional as F
from torchaudio.functional import resample

from .utils import exact_div

# hard-coded audio hyperparameters
SAMPLE_RATE = 16000
ALPHA = 0.5
N_FFT = 400
N_MELS = 80
HOP_LENGTH = 160
Expand Down Expand Up @@ -53,6 +55,58 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0


def resample_audio(audio: Union[torch.Tensor, np.ndarray ], sample_rate: int) -> torch.Tensor:
"""
Resample audio: np.ndarray to 16 kHz

Parameters
----------
audio: Union[np.ndarray, torch.Tensor]
The data to be resampled, 1D (mono) or 2D (stereo). This parameter can accept either a NumPy array (np.ndarray) or a PyTorch tensor (torch.Tensor) containing audio data. The audio data should be of type float32, float64, int16, or int32.
audio: np.ndarray[ float32 | float64 | int16 | int32 ]
The data to be resampled, 1D(mono) or 2D(stereo)

sample_rate: int
The sample rate of audio

Returns
-------
A torch Tensor 1D containing the audio waveform, in float32 dtype.
"""
if type(audio) != torch.Tensor:
audio = torch.from_numpy(audio)

if audio.dtype not in (torch.float32, torch.float64, torch.int16, torch.int32):
raise ValueError(f"Audio type must be one of [float32, float64, int16, int32], not {audio.dtype}")

audio_dtype = audio.dtype

if audio.ndim == 2: #Stereo
if audio.shape[0] == 2 or audio.shape[1] == 2:
if audio.shape[1] == 2: #SciPy | Soundfile
audio = torch.transpose(audio, 0, 1)

# Convert to mono
# MIX = A * (1 - ALPHA) + B * ALPHA
audio = (audio[0] * ALPHA + audio[1] * ALPHA)

elif audio.shape[0] != 1:
raise ValueError(f"Invalid audio shape ({audio.shape}). Audio must be provided as: \n"
"([channel, time], [time]) for mono audio, \n"
"([channel, time], [time, channel]) for stereo audio")

elif audio.ndim != 1:
raise ValueError(f"Audio ndim must be 1D(mono) or 2D(stereo)")


if audio_dtype in (torch.int16, torch.int32):
audio = audio.to(torch.float32) / (32768.0 if audio_dtype == torch.int16 else 2147483648.0)
elif audio_dtype == torch.float64:
audio = audio.to(torch.float32)

return resample(audio, sample_rate, SAMPLE_RATE).flatten()


def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
"""
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
Expand Down
2 changes: 1 addition & 1 deletion whisperx/vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def merge_chunks(
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))

if len(segments_list) == 0:
print("No active speech found in audio")
print("No active speech found in audio.")
return []
# assert segments_list, "segments_list is empty."
# Make sur the starting point is the start of the segment.
Expand Down