Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,9 @@ package.json
package-lock.json

node_modules

.vscode
.venv
__pycache__

.env
43 changes: 43 additions & 0 deletions diarization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import os
import warnings
from typing import Callable, TYPE_CHECKING

import torch


with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=SyntaxWarning)
from pyannote.audio import Pipeline

if TYPE_CHECKING:
from pyannote.core import Annotation


def diarization_pipeline(
audio_file: str | os.PathLike,
device: str | torch.device | None = None,
duration_seconds: int | None = None,
hook: Callable | None = None,
) -> Annotation:
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(device)
pipeline.to(device)

print(f"Processing {audio_file} on {device}")
input_ = audio_file
if duration_seconds is not None:
import torchaudio
waveform, sample_rate = torchaudio.load(audio_file)
print(f"Trimming waveform to {duration_seconds} seconds")
max_samples = sample_rate * duration_seconds
waveform = waveform[:, :max_samples]
input_ = {"waveform": waveform, "sample_rate": sample_rate}
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", category=UserWarning, message=".*MPEG_LAYER_III.*"
)
output = pipeline(input_, hook=hook)
return output
11 changes: 11 additions & 0 deletions infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import requests

from diarization import diarization_pipeline

# Maximum data size: 200MB
MAX_PAYLOAD_SIZE = 200 * 1024 * 1024

Expand Down Expand Up @@ -69,6 +71,7 @@ def transcribe(job):
engine = job['input'].get('engine', 'faster-whisper')
model_name = job['input'].get('model', 'large-v2')
is_streaming = job['input'].get('streaming', False)
enable_diarization = job["input"].get("diarization", False)

if not datatype:
yield { "error" : "datatype field not provided. Should be 'blob' or 'url'." }
Expand Down Expand Up @@ -96,6 +99,14 @@ def transcribe(job):
return

stream_gen = transcribe_core(engine, model_name, audio_file)

if enable_diarization:
print("Running speaker diarization...")
try:
_ = diarization_pipeline(audio_file, device=device) # todo - use outputs
print("Diarization completed successfully")
except Exception as e:
print(f"Diarization failed: {e}")

if is_streaming:
for entry in stream_gen:
Expand Down