Skip to content
This repository was archived by the owner on Mar 25, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import os,sys,time
import argparse
from simple_diarizer.diarizer import Diarizer
import pprint

parser = argparse.ArgumentParser(
description="Speaker diarization",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,

)
parser.add_argument(dest='audio_name', type=str, help="Input audio file")
parser.add_argument(dest='outputfile', nargs="?", default=None, help="Optional output file")
parser.add_argument("--number_of_speakers", dest='number_of_speaker', default=None, type=int, help="Number of speakers (if known)")
parser.add_argument("--max_speakers", dest='max_speakers', default=25, type=int, help="Maximum number of speakers (if number of speaker is unknown)")
parser.add_argument("--embed_model", dest='embed_model', default="ecapa", type=str, help="Name of embedding")
parser.add_argument("--cluster_method", dest='cluster_method', default="nme-sc", type=str, help="Clustering method")
args = parser.parse_args()

diar = Diarizer(
embed_model=args.embed_model, # 'xvec' and 'ecapa' supported
cluster_method=args.cluster_method # 'ahc' 'sc' and 'nme-sc' supported
)

WAV_FILE=args.audio_name
num_speakers=args.number_of_speaker if args.number_of_speaker != "None" else None
max_spk= args.max_speakers
output_file=args.outputfile

t0 = time.time()

segments = diar.diarize(WAV_FILE, num_speakers=num_speakers,max_speakers=max_spk,outfile=output_file)

print("Time used for processing:", time.time() - t0)

if not output_file:

json = {}
_segments = []
_speakers = {}
seg_id = 1
spk_i = 1
spk_i_dict = {}

for seg in segments:

segment = {}
segment["seg_id"] = seg_id

if seg['label'] not in spk_i_dict.keys():
spk_i_dict[seg['label']] = spk_i
spk_i += 1

spk_id = "spk" + str(spk_i_dict[seg['label']])
segment["spk_id"] = spk_id
segment["seg_begin"] = round(seg['start'])
segment["seg_end"] = round(seg['end'])

if spk_id not in _speakers:
_speakers[spk_id] = {}
_speakers[spk_id]["spk_id"] = spk_id
_speakers[spk_id]["duration"] = seg['end']-seg['start']
_speakers[spk_id]["nbr_seg"] = 1
else:
_speakers[spk_id]["duration"] += seg['end']-seg['start']
_speakers[spk_id]["nbr_seg"] += 1

_segments.append(segment)
seg_id += 1

for spkstat in _speakers.values():
spkstat["duration"] = round(spkstat["duration"])

json["speakers"] = list(_speakers.values())
json["segments"] = _segments

pprint.pprint(json)

6 changes: 4 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
ipython>=7.9.0
matplotlib>=3.5.1
# ipython>=7.9.0
# matplotlib>=3.5.1
pandas>=1.3.5
scikit-learn>=1.0.2
speechbrain>=0.5.11
torchaudio>=0.10.1
onnxruntime>=1.14.0
scipy<=1.8.1 # newer version can provoke segmentation faults
2 changes: 1 addition & 1 deletion simple_diarizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
import os

__version__ = os.getenv("GITHUB_REF_NAME", "latest")
__version__ = os.getenv("GITHUB_REF_NAME", "1.0.2")
54 changes: 48 additions & 6 deletions simple_diarizer/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from scipy.ndimage import gaussian_filter
from sklearn.cluster import AgglomerativeClustering, KMeans, SpectralClustering
from sklearn.metrics import pairwise_distances

from .spectral_clustering import NME_SpectralClustering

def similarity_matrix(embeds, metric="cosine"):
return pairwise_distances(embeds, metric=metric)
Expand Down Expand Up @@ -43,9 +43,7 @@ def cluster_AHC(embeds, n_clusters=None, threshold=None, metric="cosine", **kwar
# A lot of these methods are lifted from
# https://github.com/wq2012/SpectralCluster
##########################################


def cluster_SC(embeds, n_clusters=None, threshold=None, enhance_sim=True, **kwargs):
def cluster_SC(embeds, n_clusters=None, max_speakers= None, threshold=None, enhance_sim=True, **kwargs):
"""
Cluster embeds using Spectral Clustering
"""
Expand All @@ -59,7 +57,7 @@ def cluster_SC(embeds, n_clusters=None, threshold=None, enhance_sim=True, **kwar
if n_clusters is None:
(eigenvalues, eigenvectors) = compute_sorted_eigenvectors(S)
# Get number of clusters.
k = compute_number_of_clusters(eigenvalues, 100, threshold)
k = compute_number_of_clusters(eigenvalues, max_speakers, threshold)

# Get spectral embeddings.
spectral_embeddings = eigenvectors[:, :k]
Expand All @@ -82,6 +80,25 @@ def cluster_SC(embeds, n_clusters=None, threshold=None, enhance_sim=True, **kwar
return cluster_model.fit_predict(S)


def cluster_NME_SC(embeds, n_clusters=None, max_speakers= None, threshold=None, enhance_sim=True, **kwargs):
"""
Cluster embeds using NME-Spectral Clustering

if n_clusters is None:
assert threshold, "If num_clusters is not defined, threshold must be defined"
"""

S = cos_similarity(embeds)

labels = NME_SpectralClustering(
S,
num_clusters=n_clusters,
max_num_clusters=max_speakers
)

return labels


def diagonal_fill(A):
"""
Sets the diagonal elemnts of the matrix to the max of each row
Expand Down Expand Up @@ -134,7 +151,7 @@ def row_max_norm(A):
def sim_enhancement(A):
func_order = [
diagonal_fill,
gaussian_blur,

row_threshold_mult,
symmetrization,
diffusion,
Expand All @@ -144,6 +161,31 @@ def sim_enhancement(A):
A = f(A)
return A

def cos_similarity(x):
"""Compute cosine similarity matrix in CPU & memory sensitive way

Args:
x (np.ndarray): embeddings, 2D array, embeddings are in rows

Returns:
np.ndarray: cosine similarity matrix

"""
assert x.ndim == 2, f"x has {x.ndim} dimensions, it must be matrix"
x = x / (np.sqrt(np.sum(np.square(x), axis=1, keepdims=True)) + 1.0e-32)
assert np.allclose(np.ones_like(x[:, 0]), np.sum(np.square(x), axis=1))
max_n_elm = 200000000
step = max(max_n_elm // (x.shape[0] * x.shape[0]), 1)
retval = np.zeros(shape=(x.shape[0], x.shape[0]), dtype=np.float64)
x0 = np.expand_dims(x, 0)
x1 = np.expand_dims(x, 1)
for i in range(0, x.shape[1], step):
product = x0[:, :, i : i + step] * x1[:, :, i : i + step]
retval += np.sum(product, axis=2, keepdims=False)
assert np.all(retval >= -1.0001), retval
assert np.all(retval <= 1.0001), retval
return retval


def compute_affinity_matrix(X):
"""Compute the affinity matrix from data.
Expand Down
69 changes: 42 additions & 27 deletions simple_diarizer/diarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import pandas as pd
import torch
import torchaudio
from speechbrain.pretrained import EncoderClassifier
from speechbrain.inference.speaker import EncoderClassifier
from tqdm.autonotebook import tqdm

from .cluster import cluster_AHC, cluster_SC
from .cluster import cluster_AHC, cluster_SC, cluster_NME_SC
from .utils import check_wav_16khz_mono, convert_wavfile


Expand All @@ -25,12 +25,16 @@ def __init__(
assert cluster_method in [
"ahc",
"sc",
], "Only ahc and sc in the supported clustering options"
"nme-sc",
], "Only ahc,sc and nme-sc in the supported clustering options"

if cluster_method == "ahc":
self.cluster = cluster_AHC
if cluster_method == "sc":
self.cluster = cluster_SC
if cluster_method == "nme-sc":
self.cluster = cluster_NME_SC


self.vad_model, self.get_speech_ts = self.setup_VAD()

Expand All @@ -56,7 +60,7 @@ def __init__(

def setup_VAD(self):
model, utils = torch.hub.load(
repo_or_dir="snakers4/silero-vad", model="silero_vad"
repo_or_dir="snakers4/silero-vad", model="silero_vad", onnx=True
)
# force_reload=True)

Expand Down Expand Up @@ -182,6 +186,7 @@ def diarize(
self,
wav_file,
num_speakers=2,
max_speakers=None,
threshold=None,
silence_tolerance=0.2,
enhance_sim=True,
Expand All @@ -194,6 +199,7 @@ def diarize(
Inputs:
wav_file (path): Path to input audio file
num_speakers (int) or NoneType: Number of speakers to cluster to
max_speakers (int)
threshold (float) or NoneType: Threshold to cluster to if
num_speakers is not defined
silence_tolerance (float): Same speaker segments which are close enough together
Expand Down Expand Up @@ -229,10 +235,10 @@ def diarize(
'cluster_labels': cluster_labels (list): cluster label for each embed in embeds
}

Uses AHC/SC to cluster
Uses AHC/SC/NME-SC to cluster
"""
recname = os.path.splitext(os.path.basename(wav_file))[0]

if check_wav_16khz_mono(wav_file):
signal, fs = torchaudio.load(wav_file)
else:
Expand All @@ -249,25 +255,34 @@ def diarize(
print("Running VAD...")
speech_ts = self.vad(signal[0])
print("Splitting by silence found {} utterances".format(len(speech_ts)))
assert len(speech_ts) >= 1, "Couldn't find any speech during VAD"

print("Extracting embeddings...")
embeds, segments = self.recording_embeds(signal, fs, speech_ts)

print("Clustering to {} speakers...".format(num_speakers))
cluster_labels = self.cluster(
embeds,
n_clusters=num_speakers,
threshold=threshold,
enhance_sim=enhance_sim,
)

print("Cleaning up output...")
cleaned_segments = self.join_segments(cluster_labels, segments)
cleaned_segments = self.make_output_seconds(cleaned_segments, fs)
cleaned_segments = self.join_samespeaker_segments(
cleaned_segments, silence_tolerance=silence_tolerance
)
#assert len(speech_ts) >= 1, "Couldn't find any speech during VAD"

if len(speech_ts) >= 1:
print("Extracting embeddings...")
embeds, segments = self.recording_embeds(signal, fs, speech_ts)

[w,k]=embeds.shape
if w >= 2:
print('Clustering to {} speakers...'.format(num_speakers))
cluster_labels = self.cluster(embeds, n_clusters=num_speakers,max_speakers=max_speakers,
threshold=threshold, enhance_sim=enhance_sim)



cleaned_segments = self.join_segments(cluster_labels, segments)
cleaned_segments = self.make_output_seconds(cleaned_segments, fs)
cleaned_segments = self.join_samespeaker_segments(cleaned_segments,
silence_tolerance=silence_tolerance)


else:
cluster_labels =[ 1]
cleaned_segments = self.join_segments(cluster_labels, segments)
cleaned_segments = self.make_output_seconds(cleaned_segments, fs)

else:
cleaned_segments = []

print("Done!")
if outfile:
self.rttm_output(cleaned_segments, recname, outfile=outfile)
Expand All @@ -281,9 +296,9 @@ def diarize(
"cluster_labels": cluster_labels}

@staticmethod
def rttm_output(segments, recname, outfile=None):
def rttm_output(segments, recname, outfile=None, channel=0):
assert outfile, "Please specify an outfile"
rttm_line = "SPEAKER {} 0 {} {} <NA> <NA> {} <NA> <NA>\n"
rttm_line = "SPEAKER {} "+str(channel)+" {} {} <NA> <NA> {} <NA> <NA>\n"
with open(outfile, "w") as fp:
for seg in segments:
start = seg["start"]
Expand Down
Loading