diff --git a/lhotse/bin/modes/workflows.py b/lhotse/bin/modes/workflows.py index 76fb3217b..4179b37c9 100644 --- a/lhotse/bin/modes/workflows.py +++ b/lhotse/bin/modes/workflows.py @@ -172,6 +172,55 @@ def align_with_torchaudio( writer.write(cut, flush=True) +@workflows.command() +@click.argument( + "in_cuts", type=click.Path(exists=True, dir_okay=False, allow_dash=True) +) +@click.argument("out_cuts", type=click.Path(allow_dash=True)) +@click.option( + "-d", "--device", default="cpu", help="Device on which to run the inference." +) +@click.option( + "--num-speakers", + type=int, + default=None, + help="Number of clusters to use for speaker diarization. Will use threshold if not provided.", +) +@click.option( + "--threshold", + type=float, + default=None, + help="Threshold for speaker diarization. Will use num-speakers if not provided.", +) +def diarize_segments_with_speechbrain( + in_cuts: str, + out_cuts: str, + device: str = "cpu", + num_speakers: Optional[int] = None, + threshold: Optional[float] = None, +): + """ + This workflow uses SpeechBrain's pretrained speaker embedding model to compute speaker embeddings + for each cut in the CutSet. The cuts for the same recording are then clustered using + agglomerative hierarchical clustering, and the resulting cluster indices are used to create new cuts + with the speaker labels. + + Please refer to https://huggingface.co/speechbrain/spkrec-xvect-voxceleb for more details + about the speaker embedding extractor. + """ + from lhotse.workflows import diarize_segments_with_speechbrain + + assert exactly_one_not_null( + num_speakers, threshold + ), "Exactly one of --num-speakers and --threshold must be provided." + + cuts = load_manifest_lazy_or_eager(in_cuts) + cuts_with_spk_id = diarize_segments_with_speechbrain( + cuts, device=device, num_speakers=num_speakers, threshold=threshold + ) + cuts_with_spk_id.to_file(out_cuts) + + @workflows.command() @click.argument( "in_cuts", type=click.Path(exists=True, dir_okay=False, allow_dash=True) @@ -203,8 +252,6 @@ def align_with_torchaudio( show_default=True, ) # Options used with the "conversational" method - - @click.option( "--same-spk-pause", type=float, diff --git a/lhotse/workflows/__init__.py b/lhotse/workflows/__init__.py index e14b0f64e..4f6f71abe 100644 --- a/lhotse/workflows/__init__.py +++ b/lhotse/workflows/__init__.py @@ -1,3 +1,4 @@ +from .diarization import diarize_segments_with_speechbrain from .forced_alignment import align_with_torchaudio from .meeting_simulation import * from .whisper import annotate_with_whisper diff --git a/lhotse/workflows/diarization.py b/lhotse/workflows/diarization.py new file mode 100644 index 000000000..c149efa20 --- /dev/null +++ b/lhotse/workflows/diarization.py @@ -0,0 +1,117 @@ +import logging +import shutil +import tempfile + +import numpy as np +import torch +from attr import frozen +from cytoolz.itertoolz import groupby +from tqdm import tqdm + +from lhotse import CutSet, Recording +from lhotse.utils import fastcopy, is_module_available + +logging.basicConfig( + format="%(asctime)s,%(msecs)d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", + datefmt="%Y-%m-%d:%H:%M:%S", + level=logging.INFO, +) +logger = logging.getLogger(__name__) + + +def diarize_segments_with_speechbrain( + cuts: CutSet, + device: str = "cpu", + num_speakers: int = None, + threshold: float = 0.5, +) -> CutSet: + """ + This workflow uses SpeechBrain's pretrained speaker embedding model to compute speaker embeddings + for each cut in the CutSet. The cuts for the same recording are then clustered using + agglomerative hierarchical clustering, and the resulting cluster indices are used to create new cuts + with the speaker labels. + + Please refer to https://huggingface.co/speechbrain/spkrec-xvect-voxceleb for more details + about the speaker embedding extractor. + + :param manifest: a ``CutSet`` object. + :param device: Where to run the inference (cpu, cuda, etc.). + :param num_speakers: Number of speakers to cluster the cuts into. If not specified, we will use + the threshold parameter to determine the number of speakers. + :param threshold: The threshold for agglomerative clustering. + :return: a new ``CutSet`` with speaker labels. + """ + assert is_module_available("speechbrain"), ( + "This function expects SpeechBrain to be installed. " + "You can install it via 'pip install speechbrain' " + ) + + assert is_module_available("sklearn"), ( + "This function expects scikit-learn to be installed. " + "You can install it via 'pip install scikit-learn' " + ) + + from sklearn.cluster import AgglomerativeClustering + from speechbrain.pretrained import EncoderClassifier + + threshold = None if num_speakers is not None else threshold + dirpath = tempfile.mkdtemp() + + recordings, _, _ = cuts.decompose(dirpath, verbose=True) + recordings = recordings.to_eager() + recording_ids = frozenset(recordings.ids) + + logging.info("Saving cut recordings temporarily to disk...") + cuts_ = [] + for cut in tqdm(cuts): + save_path = f"{dirpath}/{cut.recording_id}.wav" + _ = cut.save_audio(save_path) + cuts_.append(fastcopy(cut, recording=Recording.from_file(save_path))) + + cuts_ = CutSet.from_cuts(cuts_).trim_to_supervisions(keep_overlapping=False) + + # Load the pretrained model + model = EncoderClassifier.from_hparams( + source="speechbrain/spkrec-xvect-voxceleb", + savedir="pretrained_models/spkrec-xvect-voxceleb", + run_opts={"device": device}, + ) + + out_cuts = [] + + for recording_id in tqdm(recording_ids, total=len(recording_ids)): + logging.info(f"Processing recording {recording_id}...") + embeddings = [] + reco_cuts = cuts_.filter(lambda c: c.recording_id == recording_id) + num_cuts = len(frozenset(reco_cuts.ids)) + if num_cuts == 0: + continue + for cut in tqdm(reco_cuts, total=num_cuts): + audio = torch.from_numpy(cut.load_audio()) + embedding = model.encode_batch(audio).cpu().numpy() + embeddings.append(embedding.squeeze()) + + embeddings = np.vstack(embeddings) + clusterer = AgglomerativeClustering( + n_clusters=num_speakers, + affinity="euclidean", + linkage="ward", + distance_threshold=threshold, + ) + clusterer.fit(embeddings) + + # Assign the cluster indices to the cuts + for cut, cluster_idx in zip(reco_cuts, clusterer.labels_): + sup = fastcopy(cut.supervisions[0], speaker=f"spk{cluster_idx}") + out_cuts.append( + fastcopy( + cut, + recording=recordings[cut.recording_id], + supervisions=[sup], + ) + ) + + # Remove the temporary directory + shutil.rmtree(dirpath) + + return CutSet.from_cuts(out_cuts) diff --git a/lhotse/workflows/whisper.py b/lhotse/workflows/whisper.py index c5fef029b..69651c8a7 100644 --- a/lhotse/workflows/whisper.py +++ b/lhotse/workflows/whisper.py @@ -2,6 +2,7 @@ from typing import Any, Generator, List, Optional, Union import torch +from tqdm import tqdm from lhotse import ( CutSet, @@ -87,7 +88,7 @@ def _annotate_recordings( model = whisper.load_model(model_name, device=device, download_root=download_root) - for recording in recordings: + for recording in tqdm(recordings): if recording.num_channels > 1: logging.warning( f"Skipping recording '{recording.id}'. It has {recording.num_channels} channels, " @@ -141,7 +142,7 @@ def _annotate_cuts( model = whisper.load_model(model_name, device=device, download_root=download_root) - for cut in cuts: + for cut in tqdm(cuts): if cut.num_channels > 1: logging.warning( f"Skipping cut '{cut.id}'. It has {cut.num_channels} channels, "