diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index 1094672ed..047c3a5eb 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -446,6 +446,7 @@ def segmentation_loss( if self.weigh_by_cardinality else None ) + seg_loss = nll_loss( permutated_prediction, torch.argmax(target, dim=-1), @@ -548,6 +549,12 @@ def training_step(self, batch, batch_idx: int): warm_up_right = round(self.warm_up[1] / self.duration * num_frames) weight[:, num_frames - warm_up_right :] = 0.0 + latency = 0.1 # will be a parameter of the task (in s) + delay = int(np.floor(num_frames * latency / self.duration)) # round down + + prediction = prediction[:, delay:, :] + target = target[:, :num_frames-delay, :] + if self.specifications.powerset: multilabel = self.model.powerset.to_multilabel(prediction) permutated_target, _ = permutate(multilabel, target) diff --git a/pyannote/audio/torchmetrics/audio/diarization_error_rate.py b/pyannote/audio/torchmetrics/audio/diarization_error_rate.py index 70cdc052f..fe1e8752d 100644 --- a/pyannote/audio/torchmetrics/audio/diarization_error_rate.py +++ b/pyannote/audio/torchmetrics/audio/diarization_error_rate.py @@ -49,10 +49,11 @@ class DiarizationErrorRate(Metric): higher_is_better = False is_differentiable = False - def __init__(self, threshold: float = 0.5): + def __init__(self, threshold: float = 0.5, per_frame: bool = False): super().__init__() self.threshold = threshold + self.per_frame = per_frame self.add_state("false_alarm", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state( @@ -85,14 +86,17 @@ def update( speech_total : torch.Tensor Diarization error rate components accumulated over the whole batch. """ - - false_alarm, missed_detection, speaker_confusion, speech_total = _der_update( + if self.per_frame: + self.false_alarm, self.missed_detection, self.speaker_confusion, self.speech_total = _der_update(preds, target, + per_frame = self.per_frame, threshold=self.threshold) + else: + false_alarm, missed_detection, speaker_confusion, speech_total = _der_update( preds, target, threshold=self.threshold - ) - self.false_alarm += false_alarm - self.missed_detection += missed_detection - self.speaker_confusion += speaker_confusion - self.speech_total += speech_total + ) + self.false_alarm += false_alarm + self.missed_detection += missed_detection + self.speaker_confusion += speaker_confusion + self.speech_total += speech_total def compute(self): return _der_compute( @@ -100,6 +104,7 @@ def compute(self): self.missed_detection, self.speaker_confusion, self.speech_total, + self.per_frame ) diff --git a/pyannote/audio/torchmetrics/functional/audio/diarization_error_rate.py b/pyannote/audio/torchmetrics/functional/audio/diarization_error_rate.py index 9502a527e..840e9df98 100644 --- a/pyannote/audio/torchmetrics/functional/audio/diarization_error_rate.py +++ b/pyannote/audio/torchmetrics/functional/audio/diarization_error_rate.py @@ -32,6 +32,7 @@ def _der_update( preds: torch.Tensor, target: torch.Tensor, + per_frame: bool = False, threshold: Union[torch.Tensor, float] = 0.5, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Compute components of diarization error rate @@ -53,7 +54,6 @@ def _der_update( speech_total : torch.Tensor Diarization error rate components accumulated over the whole batch. """ - # make threshold a (num_thresholds,) tensor scalar_threshold = isinstance(threshold, Number) if scalar_threshold: @@ -86,6 +86,9 @@ def _der_update( speaker_confusion = torch.sum((hypothesis != target) * hypothesis, 1) - false_alarm # (batch_size, num_frames, num_thresholds) + + if per_frame: + return torch.sum(false_alarm, 0)[:,0], torch.sum(missed_detection, 0)[:,0], torch.sum(speaker_confusion, 0)[:,0], 1.0 * torch.sum(target) false_alarm = torch.sum(torch.sum(false_alarm, 1), 0) missed_detection = torch.sum(torch.sum(missed_detection, 1), 0) @@ -107,6 +110,7 @@ def _der_compute( missed_detection: torch.Tensor, speaker_confusion: torch.Tensor, speech_total: torch.Tensor, + per_frame: bool = False, ) -> torch.Tensor: """Compute diarization error rate from its components @@ -123,7 +127,8 @@ def _der_compute( der : (num_thresholds, )-shaped torch.Tensor Diarization error rate. """ - + if per_frame: + return false_alarm, missed_detection, speaker_confusion, speech_total return (false_alarm + missed_detection + speaker_confusion) / (speech_total + 1e-8) diff --git a/pyannote/audio/utils/loss.py b/pyannote/audio/utils/loss.py index 2c55b26f3..c6f8a3db7 100644 --- a/pyannote/audio/utils/loss.py +++ b/pyannote/audio/utils/loss.py @@ -155,7 +155,7 @@ def nll_loss( num_classes = prediction.shape[2] losses = F.nll_loss( - prediction.view(-1, num_classes), + prediction.reshape(-1, num_classes), # (batch_size x num_frames, num_classes) target.view(-1), # (batch_size x num_frames, ) diff --git a/tutorials/Notebook_test.ipynb b/tutorials/Notebook_test.ipynb new file mode 100644 index 000000000..79204b1f0 --- /dev/null +++ b/tutorials/Notebook_test.ipynb @@ -0,0 +1,735 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Protocol" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "'AMI.SpeakerDiarization.only_words' found in /home/brahou/pyannote-audio/tutorials/AMI-diarization-setup/pyannote/database.yml does not define the 'scope' of speaker labels (file, database, or global). Setting it to 'file'.\n", + "'AMI.SpeakerDiarization.word_and_vocalsounds' found in /home/brahou/pyannote-audio/tutorials/AMI-diarization-setup/pyannote/database.yml does not define the 'scope' of speaker labels (file, database, or global). Setting it to 'file'.\n" + ] + } + ], + "source": [ + "from pyannote.core import notebook, Segment\n", + "notebook.reset()\n", + "from pyannote.database import registry\n", + "registry.load_database(\"AMI-diarization-setup/pyannote/database.yml\")\n", + "from pyannote.database import FileFinder\n", + "protocol = registry.get_protocol('AMI.SpeakerDiarization.only_words', \n", + " preprocessors={\"audio\": FileFinder()})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/brahou/pyannote-audio/pyannote/audio/core/io.py:43: UserWarning: torchaudio._backend.set_audio_backend has been deprecated. With dispatcher enabled, this function is no-op. You can remove the function call.\n", + " torchaudio.set_audio_backend(\"soundfile\")\n", + "/home/brahou/anaconda3/lib/python3.11/site-packages/torch_audiomentations/utils/io.py:27: UserWarning: torchaudio._backend.set_audio_backend has been deprecated. With dispatcher enabled, this function is no-op. You can remove the function call.\n", + " torchaudio.set_audio_backend(\"soundfile\")\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Protocol AMI.SpeakerDiarization.only_words does not precompute the output of torchaudio.info(): adding a 'torchaudio.info' preprocessor for you to speed up dataloaders. See pyannote.database documentation on how to do that yourself.\n" + ] + } + ], + "source": [ + "from pyannote.audio.tasks import SpeakerDiarization\n", + "diarization_task = SpeakerDiarization(protocol, duration=5, batch_size=256, max_speakers_per_frame = 2,\n", + " max_speakers_per_chunk = 3, num_workers = 0)\n", + "\n", + "from pyannote.audio.models.segmentation import PyanNet\n", + "model = PyanNet(task=diarization_task, lstm = {\n", + " \"hidden_size\": 128,\n", + " \"num_layers\": 2,\n", + " \"bidirectional\": True,\n", + " \"monolithic\": True,\n", + " \"dropout\": 0.0,\n", + " })\n", + "#If you infer without training\n", + "model.setup(stage = 'fit')\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": {}, + "outputs": [], + "source": [ + "#batches = iter(model.train_dataloader())" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": {}, + "outputs": [], + "source": [ + "#%timeit next(batches)" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "#import pytorch_lightning as pl\n", + "#trainer = pl.Trainer(max_epochs=1,accelerator='gpu', devices=[0])\n", + "#trainer.fit(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "from pyannote.audio import Model\n", + "pretrained_model = Model.from_pretrained(\"pyannote/segmentation-3.0\", use_auth_token=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "#import torch\n", + "#model.load_state_dict(torch.load('model.pt'), strict = False)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Testing" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "from pyannote.audio import Inference\n", + "from pyannote.audio.torchmetrics import DiarizationErrorRate\n", + "import numpy as np\n", + "import torch\n", + "from rich.progress import Progress\n", + "\n", + "\n", + "def test_torchmetrics(model, files, latency):\n", + " with Progress() as progress:\n", + " main_task = progress.add_task(protocol.name, total=len(files))\n", + " file_task = progress.add_task(\"Processing\", total=1.0)\n", + " \n", + " def progress_hook(completed: int = None, total: int = None):\n", + " progress.update(file_task, completed=completed / total) \n", + "\n", + " inference = Inference(model)\n", + " metric = DiarizationErrorRate()\n", + " num_frames = model.example_output.num_frames\n", + " delay = int(np.floor(latency / model.example_output.frames.duration)) # round down\n", + " print(f\"delay : {delay:d} frame(s)\")\n", + "\n", + " # initialize error list\n", + " error=[]\n", + " for file in files:\n", + " progress.update(file_task, description=file[\"uri\"])\n", + " # calculate inference for current file\n", + " window_inference = inference(file)\n", + " hypothesis = torch.from_numpy(window_inference.data)\n", + "\n", + " # discretize reference annotation\n", + " annotation = file[\"annotation\"]\n", + " sliding_window = window_inference.sliding_window\n", + " support = Segment(sliding_window[0].start, sliding_window[hypothesis.size(0) - 1].end)\n", + " resolution = sliding_window.duration / num_frames\n", + " discretization = annotation.discretize(support, resolution=resolution)\n", + " max_num_speaker = len(annotation.labels())\n", + " reference = torch.zeros((hypothesis.size(0),hypothesis.size(1),max_num_speaker))\n", + " for i in range(hypothesis.size(0)):\n", + " reference_window = discretization.crop(sliding_window[i], mode=\"center\")\n", + " reference[i] = torch.from_numpy(np.array(reference_window.data))[:num_frames]\n", + "\n", + " # pad the hypothesis and permute the inputs (torchmetrics takes (num_chunks,num_speakers,num_frames))\n", + " if reference.size(2) > hypothesis.size(2):\n", + " hypothesis = torch.nn.functional.pad(hypothesis, (0, 1, 0, 0, 0, 0))\n", + " hypothesis = hypothesis.permute(0,2,1)\n", + " reference = reference.permute(0,2,1)\n", + " print(f\"hypothesis of size : {hypothesis[:,:,delay:].size()}\")\n", + " print(f\"reference of size : {reference[:,:,:num_frames-delay].size()}\")\n", + "\n", + " error.append(metric(hypothesis[:,:,delay:], reference[:,:,:num_frames-delay]))\n", + " print(f\"current error list : {error}\")\n", + " progress.advance(main_task)\n", + " return np.array(error).mean()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "from pyannote.audio import Inference\n", + "from pyannote.audio.torchmetrics import DiarizationErrorRate\n", + "import numpy as np\n", + "import torch\n", + "from rich.progress import Progress\n", + "\n", + "def test_torchmetrics_per_frame(model, files, latency):\n", + " \n", + " inference = Inference(model)\n", + " metric = DiarizationErrorRate(per_frame = True)\n", + " num_frames = model.example_output.num_frames\n", + " delay = int(np.floor(latency / model.example_output.frames.duration)) # round down\n", + " print(f\"delay : {delay:d} frame(s)\")\n", + " \n", + " # initialize error lists\n", + " error=[]\n", + " false_alarm_error=[]\n", + " missed_detection_error=[]\n", + " speaker_confusion_error=[]\n", + " \n", + " for file in files:\n", + " # calculate inference for current file\n", + " window_inference = inference(file)\n", + " hypothesis = torch.from_numpy(window_inference.data)\n", + " \n", + " # discretize reference annotation\n", + " annotation = file[\"annotation\"]\n", + " sliding_window = window_inference.sliding_window\n", + " support = Segment(sliding_window[0].start, sliding_window[hypothesis.size(0) - 1].end)\n", + " resolution = sliding_window.duration / num_frames\n", + " discretization = annotation.discretize(support, resolution=resolution)\n", + " max_num_speaker = len(annotation.labels())\n", + " reference = torch.zeros((hypothesis.size(0),hypothesis.size(1),max_num_speaker))\n", + " for i in range(hypothesis.size(0)):\n", + " reference_window = discretization.crop(sliding_window[i], mode=\"center\")\n", + " reference[i] = torch.from_numpy(np.array(reference_window.data))[:num_frames]\n", + " \n", + " # permute the inputs (torchmetrics takes (num_chunks,num_speakers,num_frames))\n", + " hypothesis = torch.nn.functional.pad(hypothesis, (0, 1, 0, 0, 0, 0))\n", + " hypothesis = hypothesis.permute(0,2,1)\n", + " reference = reference.permute(0,2,1)\n", + " print(f\"hypothesis of size : {hypothesis[:,:,delay:].size()}\")\n", + " print(f\"reference of size : {reference[:,:,:num_frames-delay].size()}\")\n", + " \n", + " # calculate the metrics\n", + " false_alarm, missed_detection, speaker_confusion, total_speech = metric(\n", + " hypothesis[:,:,delay:], reference[:,:,:num_frames-delay])\n", + " print(f\"DER for this file = {(false_alarm.sum() + missed_detection.sum() + speaker_confusion.sum()) / (total_speech + 1e-8)}\")\n", + " false_alarm_per_frame = false_alarm * num_frames / (total_speech + 1e-8)\n", + " missed_detection_per_frame = missed_detection * num_frames / (total_speech + 1e-8)\n", + " speaker_confusion_per_frame = speaker_confusion * num_frames / (total_speech + 1e-8)\n", + " der_per_frame = false_alarm_per_frame + missed_detection_per_frame + speaker_confusion_per_frame\n", + " \n", + " # append the list of errors with current file errors\n", + " error.append(np.array(der_per_frame))\n", + " false_alarm_error.append(np.array(false_alarm_per_frame))\n", + " missed_detection_error.append(np.array(missed_detection_per_frame))\n", + " speaker_confusion_error.append(np.array(speaker_confusion_per_frame))\n", + " \n", + " # calculate mean of all test files \n", + " error=np.array(error).mean(axis = 0)\n", + " false_alarm_error = np.array(false_alarm_error).mean(axis = 0)\n", + " missed_detection_error = np.array(missed_detection_error).mean(axis = 0)\n", + " speaker_confusion_error = np.array(speaker_confusion_error).mean(axis = 0)\n", + " \n", + " return error, false_alarm_error, missed_detection_error, speaker_confusion_error" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "bbe7edda52ac4d1b9355fe13cf44820b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
delay : 0 frame(s)\n",
+       "
\n" + ], + "text/plain": [ + "delay : 0 frame(s)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
hypothesis of size : torch.Size([1669, 4, 293])\n",
+       "
\n" + ], + "text/plain": [ + "hypothesis of size : torch.Size([1669, 4, 293])\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
reference of size : torch.Size([1669, 4, 293])\n",
+       "
\n" + ], + "text/plain": [ + "reference of size : torch.Size([1669, 4, 293])\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
current error list : [tensor(0.2593)]\n",
+       "
\n" + ], + "text/plain": [ + "current error list : [tensor(0.2593)]\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Local DER = 25.9%\n" + ] + } + ], + "source": [ + "files = list(getattr(protocol, \"test\")())\n", + "latency = 0\n", + "der = test_torchmetrics(model, files, latency)\n", + "print(f\"Local DER = {abs(der) * 100:.1f}%\")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "delay : 0 frame(s)\n", + "hypothesis of size : torch.Size([1669, 4, 293])\n", + "reference of size : torch.Size([1669, 4, 293])\n", + "DER for this file = 0.25927138328552246\n", + "hypothesis of size : torch.Size([4096, 4, 293])\n", + "reference of size : torch.Size([4096, 4, 293])\n", + "DER for this file = 0.15371479094028473\n", + "hypothesis of size : torch.Size([3633, 4, 293])\n", + "reference of size : torch.Size([3633, 4, 293])\n", + "DER for this file = 0.14838285744190216\n", + "hypothesis of size : torch.Size([3880, 4, 293])\n", + "reference of size : torch.Size([3880, 4, 293])\n", + "DER for this file = 0.24103473126888275\n", + "hypothesis of size : torch.Size([2090, 4, 293])\n", + "reference of size : torch.Size([2090, 4, 293])\n", + "DER for this file = 0.26910367608070374\n", + "hypothesis of size : torch.Size([4682, 4, 293])\n", + "reference of size : torch.Size([4682, 4, 293])\n", + "DER for this file = 0.19512279331684113\n", + "hypothesis of size : torch.Size([4660, 4, 293])\n", + "reference of size : torch.Size([4660, 4, 293])\n", + "DER for this file = 0.196478009223938\n", + "hypothesis of size : torch.Size([4436, 4, 293])\n", + "reference of size : torch.Size([4436, 4, 293])\n", + "DER for this file = 0.2791493833065033\n", + "hypothesis of size : torch.Size([3003, 4, 293])\n", + "reference of size : torch.Size([3003, 4, 293])\n", + "DER for this file = 0.21955639123916626\n", + "hypothesis of size : torch.Size([4412, 4, 293])\n", + "reference of size : torch.Size([4412, 4, 293])\n", + "DER for this file = 0.1819673478603363\n", + "hypothesis of size : torch.Size([5131, 4, 293])\n", + "reference of size : torch.Size([5131, 4, 293])\n", + "DER for this file = 0.20070508122444153\n", + "hypothesis of size : torch.Size([5228, 4, 293])\n", + "reference of size : torch.Size([5228, 4, 293])\n", + "DER for this file = 0.3011073172092438\n", + "hypothesis of size : torch.Size([4277, 4, 293])\n", + "reference of size : torch.Size([4277, 4, 293])\n", + "DER for this file = 0.32932376861572266\n", + "hypothesis of size : torch.Size([3565, 4, 293])\n", + "reference of size : torch.Size([3565, 4, 293])\n", + "DER for this file = 0.30412614345550537\n", + "hypothesis of size : torch.Size([5936, 4, 293])\n", + "reference of size : torch.Size([5936, 3, 293])\n", + "DER for this file = 0.2846376597881317\n", + "hypothesis of size : torch.Size([4411, 4, 293])\n", + "reference of size : torch.Size([4411, 4, 293])\n", + "DER for this file = 0.3492230176925659\n" + ] + }, + { + "data": { + "text/plain": [ + "array([0.33141226, 0.32828283, 0.3254712 , 0.32385397, 0.32102078,\n", + " 0.31781587, 0.31483144, 0.3118828 , 0.3095508 , 0.30769417,\n", + " 0.30492172, 0.30279967, 0.30185664, 0.3003403 , 0.29950818,\n", + " 0.2974546 , 0.29643768, 0.2952635 , 0.2943591 , 0.2934955 ,\n", + " 0.29226917, 0.29182842, 0.28979072, 0.28865463, 0.28763038,\n", + " 0.28585577, 0.2846649 , 0.2832142 , 0.28189656, 0.28114805,\n", + " 0.27901095, 0.27710798, 0.27613223, 0.27480465, 0.2739618 ,\n", + " 0.27200747, 0.2700129 , 0.26821354, 0.26670614, 0.2651845 ,\n", + " 0.2638062 , 0.26239705, 0.26102602, 0.25953332, 0.2591208 ,\n", + " 0.25828105, 0.2570774 , 0.2559183 , 0.25584063, 0.25441644,\n", + " 0.25364774, 0.2519333 , 0.2509564 , 0.25014085, 0.24833299,\n", + " 0.24697796, 0.24619237, 0.24567632, 0.24573608, 0.24532034,\n", + " 0.24451478, 0.24479821, 0.24377097, 0.24390468, 0.24280521,\n", + " 0.2421332 , 0.24140316, 0.24052833, 0.23971869, 0.23828493,\n", + " 0.23733024, 0.23671696, 0.23572822, 0.23477092, 0.23483257,\n", + " 0.23470238, 0.23381706, 0.23352832, 0.23248194, 0.23244005,\n", + " 0.23191895, 0.23056792, 0.23067129, 0.22970879, 0.22897042,\n", + " 0.22844109, 0.22814372, 0.2281976 , 0.22832845, 0.22790605,\n", + " 0.22724718, 0.22683671, 0.22714718, 0.22620751, 0.2261574 ,\n", + " 0.22560719, 0.22536096, 0.22467043, 0.22339903, 0.22351626,\n", + " 0.22290963, 0.2221011 , 0.221811 , 0.22127423, 0.22140059,\n", + " 0.22096576, 0.22094062, 0.21971132, 0.21941712, 0.21886246,\n", + " 0.21794213, 0.21799375, 0.21791717, 0.21776289, 0.21730742,\n", + " 0.21689726, 0.21737412, 0.21717627, 0.21763906, 0.21722978,\n", + " 0.21712679, 0.2172139 , 0.21662505, 0.2162188 , 0.21584177,\n", + " 0.21599226, 0.2156428 , 0.2144337 , 0.21447644, 0.21455246,\n", + " 0.21390632, 0.21364939, 0.21381284, 0.21428956, 0.21387157,\n", + " 0.21350878, 0.21355487, 0.2128605 , 0.21283644, 0.21259557,\n", + " 0.21243428, 0.21266854, 0.21199487, 0.21181758, 0.21200636,\n", + " 0.2116846 , 0.21187645, 0.21240513, 0.2130516 , 0.21303566,\n", + " 0.21281463, 0.21330939, 0.21278341, 0.21276967, 0.21349609,\n", + " 0.2140504 , 0.2140244 , 0.21345386, 0.21307206, 0.21299285,\n", + " 0.21363075, 0.21332927, 0.2138904 , 0.21424842, 0.21428452,\n", + " 0.2149067 , 0.21462873, 0.21488793, 0.21478772, 0.21447697,\n", + " 0.21527193, 0.2145139 , 0.2143711 , 0.21494722, 0.21588585,\n", + " 0.21638153, 0.2170032 , 0.21727371, 0.21760772, 0.21724796,\n", + " 0.2179234 , 0.21791705, 0.21777347, 0.2189023 , 0.220375 ,\n", + " 0.22009227, 0.21989672, 0.2195742 , 0.22005339, 0.22005624,\n", + " 0.2200972 , 0.22042753, 0.22141808, 0.22161743, 0.22222513,\n", + " 0.22201827, 0.22299126, 0.22290714, 0.22261173, 0.22326928,\n", + " 0.22358622, 0.22328378, 0.22379616, 0.22481851, 0.2259741 ,\n", + " 0.22653295, 0.22694038, 0.22720665, 0.22725053, 0.22805731,\n", + " 0.22841386, 0.22888723, 0.22907071, 0.23003177, 0.23025626,\n", + " 0.23097122, 0.23100737, 0.23167296, 0.23189557, 0.23287718,\n", + " 0.23242202, 0.23350279, 0.23474258, 0.23566872, 0.23595695,\n", + " 0.2362484 , 0.23680899, 0.23641714, 0.23713154, 0.23838924,\n", + " 0.23718446, 0.2376921 , 0.2394906 , 0.24025609, 0.24090442,\n", + " 0.24237992, 0.2427249 , 0.24297021, 0.24407573, 0.24501283,\n", + " 0.24509661, 0.24581507, 0.24699542, 0.24790806, 0.24876504,\n", + " 0.24916984, 0.24964936, 0.25082186, 0.25161406, 0.25240988,\n", + " 0.25344747, 0.2545621 , 0.2552997 , 0.2570173 , 0.25769967,\n", + " 0.2583722 , 0.25973082, 0.26096007, 0.26157868, 0.2620245 ,\n", + " 0.26331863, 0.2649801 , 0.26681465, 0.2678213 , 0.2693782 ,\n", + " 0.27037668, 0.27147996, 0.27291638, 0.27446246, 0.2756156 ,\n", + " 0.27669254, 0.27759477, 0.27955914, 0.2804497 , 0.2808544 ,\n", + " 0.2821845 , 0.28436345, 0.2855499 , 0.28686047, 0.28862983,\n", + " 0.29075706, 0.29199818, 0.2939653 , 0.2955644 , 0.2973317 ,\n", + " 0.29890612, 0.3008567 , 0.30361935, 0.30643216, 0.30874175,\n", + " 0.3112345 , 0.31372193, 0.3135783 ], dtype=float32)" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "files = list(getattr(protocol, \"test\")())\n", + "latency = 0\n", + "error_per_frame, false_alarm, missed_detection, speaker_confusion = test_torchmetrics_per_frame(model, files, latency)\n", + "error_per_frame" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "fig, ax = plt.subplots()\n", + "\n", + "frames = np.arange(error_per_frame.shape[0]) \n", + "\n", + "# Plot the stacked area plot\n", + "ax.plot(frames, error_per_frame, color=\"black\", label=\"Error rate\")\n", + "\n", + "ax.stackplot(frames, false_alarm, missed_detection, speaker_confusion, \n", + " labels=[\"Speaker confusion\", \"Missed detection\", \"False alarm\"])\n", + "\n", + "# Set axis labels and a title\n", + "ax.set_xlabel(\"Frames\")\n", + "ax.set_ylabel(\"Error rate\")\n", + "ax.set_title(\"Proportion of confusion, missed detection, and false alarm in the error rate (per frame)\")\n", + "ax.legend(loc='upper center')\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.setup(stage = 'fit')\n", + "from pyannote.audio.pipelines import SpeakerDiarization\n", + "pipeline = SpeakerDiarization(model).instantiate({\n", + " \"segmentation\": {\n", + " \"min_duration_off\": 0.0,\n", + " },\n", + " \"clustering\": {\n", + " \"method\": \"centroid\",\n", + " \"min_cluster_size\": 2,\n", + " \"threshold\": 0.01,\n", + " },\n", + "})\n", + "files = list(getattr(protocol, \"test\")())\n", + "\n", + "from pyannote.audio.utils.preview import listen\n", + "listen(\"test_short.wav\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#from huggingface_hub import notebook_login\n", + "#notebook_login()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pyannote.audio import Pipeline\n", + "pretrained_pipeline = Pipeline.from_pretrained(\"pyannote/speaker-diarization\", use_auth_token=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "annotation = pipeline(\"test_short.wav\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "list(annotation.itertracks())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Optimizing pipeline hyper-parameters" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "W can try to optimize the hyper-parameters (that we chose manually above) on the validation set to get better performance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# to make things faster, we run the inference once and for all... \n", + "validation_files = list(protocol.development())\n", + "for file in validation_files:\n", + " file['osd'] = inference(file)\n", + "# ... and tell the pipeline to load OSD scores directly from files\n", + "pipeline = OverlappedSpeechDetectionPipeline(scores=\"osd\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pyannote.pipeline import Optimizer\n", + "optimizer = Optimizer(pipeline)\n", + "optimizer.tune(validation_files, n_iterations=200, show_progress=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "There you go: better hyper-parameters that should lead to better results!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "optimized_pipeline = OverlappedSpeechDetectionPipeline(scores=inference).instantiate(optimizer.best_params)\n", + "optimized_pipeline(test_file).get_timeline()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pyannote.audio import Inference\n", + "from pyannote.audio.utils.metric import DiscreteDiarizationErrorRate\n", + "from pyannote.audio.utils.signal import binarize\n", + "from rich.progress import Progress\n", + "\n", + "#model.setup(stage = 'fit')\n", + "\n", + "def test_discrete(model, files): \n", + " inference = Inference(model)\n", + " metric = DiscreteDiarizationErrorRate()\n", + " for file in files[0:1]:\n", + " reference = file[\"annotation\"]\n", + " hypothesis = binarize(inference(file))\n", + " uem = file[\"annotated\"]\n", + " _ = metric(reference, hypothesis, uem=uem)\n", + " return metric\n", + " \n", + "files = list(getattr(protocol, \"train\")())\n", + "print(f\"Local DER = {abs(test_discrete(model, files)) * 100:.1f}%\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}