diff --git a/pyannote/audio/pipelines/speech_separation.py b/pyannote/audio/pipelines/speech_separation.py index 55f4ef5e7..164062c9f 100644 --- a/pyannote/audio/pipelines/speech_separation.py +++ b/pyannote/audio/pipelines/speech_separation.py @@ -31,12 +31,12 @@ from typing import Callable, Optional, Text, Tuple, Union import numpy as np +from scipy.ndimage import binary_dilation, binary_closing import torch from einops import rearrange from pyannote.core import Annotation, SlidingWindow, SlidingWindowFeature from pyannote.metrics.diarization import GreedyDiarizationErrorRate from pyannote.pipeline.parameter import Categorical, ParamDict, Uniform -from scipy.ndimage import binary_dilation from pyannote.audio import Audio, Inference, Model, Pipeline from pyannote.audio.core.io import AudioFile @@ -163,12 +163,14 @@ def __init__( if self._segmentation.model.specifications[0].powerset: self.segmentation = ParamDict( + min_duration_on=Uniform(0.0, 1.0), min_duration_off=Uniform(0.0, 1.0), ) else: self.segmentation = ParamDict( threshold=Uniform(0.1, 0.9), + min_duration_on=Uniform(0.0, 1.0), min_duration_off=Uniform(0.0, 1.0), ) @@ -602,6 +604,19 @@ def apply( # shape: (num_speakers, ) discrete_diarization.data = discrete_diarization.data[:, active_speakers] num_frames, num_speakers = discrete_diarization.data.shape + + # filter out too short segments + min_frames_on = int( + self._segmentation.model.num_frames( + self.segmentation.min_duration_on * self._audio.sample_rate + ) + ) + + if min_frames_on > 0: + discrete_diarization.data = binary_closing( + discrete_diarization.data, structure=np.array([[True] * min_frames_on]).T + ) + hook("discrete_diarization", discrete_diarization) clustered_separations = self.reconstruct(separations, hard_clusters, count)