Skip to content

Commit 41c548c

Browse files
committed
Resemblyzer diarization test module update
1 parent 2dc6ddb commit 41c548c

File tree

1 file changed

+132
-15
lines changed

1 file changed

+132
-15
lines changed

src/utils/diarize_resemblyzer.py

Lines changed: 132 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@
6060
COLLAPSE_CENT = 0.22 # collapse to 1 if min centroid distance <= this
6161
MIN_CLUSTER_SIZE = 5 # clusters smaller than this are absorbed
6262

63+
# Changepoint recognition
64+
CP_ENTER = 0.28 # enter speech-change state if cosine jump >= this
65+
CP_EXIT = 0.22 # exit back below this (hysteresis)
66+
MIN_REGION_SEC = 1.2 # min region length (seconds) after change-point merge
67+
USE_CHANGEPOINTS = True # default ON
68+
6369
# ---------- Logging ----------
6470
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
6571

@@ -211,6 +217,86 @@ def enforce_min_run(labels: np.ndarray, min_run: int = 3) -> np.ndarray:
211217
i = j
212218
return y
213219

220+
# --- Changepoint recognition helpers ---
221+
def _regionize_by_changepoints(X_win, ts_win, enter=CP_ENTER, exit=CP_EXIT, min_region_sec=MIN_REGION_SEC):
222+
"""
223+
X_win : (N, d) unit-norm window embeddings
224+
ts_win: list[(t0, t1)] for each window
225+
returns: regions = list[(start_idx, end_idx_inclusive)], ts_regions = list[(t0, t1)]
226+
"""
227+
X = _unit_norm(np.asarray(X_win, dtype=np.float32))
228+
N = len(X)
229+
if N == 0:
230+
return [], []
231+
if N == 1:
232+
return [(0, 0)], [ts_win[0]]
233+
234+
# cosine distance between consecutive windows
235+
d = np.maximum(0.0, (cosine_distances(X[:-1], X[1:]).diagonal()))
236+
# hysteresis thresholding
237+
on = False
238+
cuts = [0]
239+
for i, val in enumerate(d, start=1):
240+
if not on and val >= float(enter):
241+
on = True
242+
cuts.append(i)
243+
elif on and val <= float(exit):
244+
on = False
245+
cuts.append(i)
246+
if cuts[-1] != N:
247+
cuts.append(N)
248+
249+
# build regions from cuts, enforce min duration
250+
regions = []
251+
ts_regions = []
252+
cur_s = cuts[0]
253+
for cur_e in cuts[1:]:
254+
s, e = cur_s, cur_e - 1
255+
# grow region until long enough
256+
t0, _ = ts_win[s]
257+
_, t1 = ts_win[e]
258+
if (t1 - t0) >= float(min_region_sec):
259+
regions.append((s, e))
260+
ts_regions.append((t0, t1))
261+
cur_s = cur_e
262+
else:
263+
# too short; defer merging with next chunk
264+
continue
265+
# tail (if left)
266+
if cur_s < N:
267+
s, e = cur_s, N - 1
268+
t0, _ = ts_win[s]
269+
_, t1 = ts_win[e]
270+
if regions and (t1 - t0) < float(min_region_sec):
271+
# absorb into previous region
272+
ps, pe = regions[-1]
273+
regions[-1] = (ps, e)
274+
pt0, _ = ts_regions[-1]
275+
ts_regions[-1] = (pt0, t1)
276+
else:
277+
regions.append((s, e))
278+
ts_regions.append((t0, t1))
279+
280+
# guarantee at least one region
281+
if not regions:
282+
regions = [(0, N-1)]
283+
ts_regions = [(ts_win[0][0], ts_win[-1][1])]
284+
return regions, ts_regions
285+
286+
def _mean_embs_by_regions(X_win, regions):
287+
if not regions:
288+
return np.zeros((0, X_win.shape[1]), dtype=np.float32)
289+
means = []
290+
for s, e in regions:
291+
means.append(X_win[s:e+1].mean(axis=0))
292+
return _unit_norm(np.vstack(means).astype(np.float32))
293+
294+
def _expand_region_labels_to_windows(regions, lab_regions, N):
295+
y = np.zeros(N, dtype=int)
296+
for (s, e), lab in zip(regions, lab_regions):
297+
y[s:e+1] = int(lab)
298+
return y
299+
214300
# ---------- Clustering helpers ----------
215301
def single_speaker_guard(embeddings, guard_q90=GUARD_Q90):
216302
"""
@@ -457,11 +543,15 @@ def format_timestamp(seconds):
457543
return f"{hours:02d}:{minutes:02d}:{secs:02d}"
458544
return f"{minutes:02d}:{secs:02d}"
459545

460-
def format_output_text(speaker_transcripts):
461-
return "\n".join(
462-
f"=== {format_timestamp(seg['start'])} ({seg['speaker']}) ===\n{seg['text']}\n"
463-
for seg in speaker_transcripts
464-
)
546+
def format_output_text(speaker_transcripts, include_end=True):
547+
lines = []
548+
for seg in speaker_transcripts:
549+
if include_end:
550+
head = f"=== {format_timestamp(seg['start'])}-{format_timestamp(seg['end'])} ({seg['speaker']}) ==="
551+
else:
552+
head = f"=== {format_timestamp(seg['start'])} ({seg['speaker']}) ==="
553+
lines.append(f"{head}\n{seg['text']}\n")
554+
return "\n".join(lines)
465555

466556
def _fmt_srt_time(t):
467557
ms = int(round(t * 1000))
@@ -522,6 +612,7 @@ def main(
522612
collapse_sil=COLLAPSE_SIL,
523613
collapse_centroid=COLLAPSE_CENT,
524614
seed=1337,
615+
include_span=True, # NEW: default show start–end in .txt headers
525616
):
526617
# Resolve device
527618
dev = resolve_device(device)
@@ -561,11 +652,25 @@ def main(
561652

562653
logging.info("Computing embeddings (%d windows)...", len(segments))
563654
encoder = VoiceEncoder(device=torch.device(dev))
564-
embeddings = get_embeddings(segments, encoder)
655+
win_embs = get_embeddings(segments, encoder)
656+
X_win = _unit_norm(win_embs)
657+
658+
# --- CHANGE-POINT REGIONIZATION (default ON) ---
659+
if USE_CHANGEPOINTS:
660+
regions, ts_regions = _regionize_by_changepoints(
661+
X_win, timestamps, enter=CP_ENTER, exit=CP_EXIT, min_region_sec=MIN_REGION_SEC
662+
)
663+
X_reg = _mean_embs_by_regions(X_win, regions)
664+
feat_for_cluster = X_reg
665+
logging.info("Regionized %d windows -> %d regions", len(X_win), len(regions))
666+
else:
667+
regions = [(i, i) for i in range(len(X_win))]
668+
ts_regions = timestamps
669+
feat_for_cluster = X_win
565670

566671
logging.info("Clustering (method=%s%s)...", method, f", force_n={force_n}" if force_n else "")
567-
labels = pick_labels(
568-
embeddings,
672+
lab_regions = pick_labels(
673+
feat_for_cluster,
569674
method=method,
570675
min_speakers=min_speakers,
571676
max_speakers=max_speakers,
@@ -579,26 +684,29 @@ def main(
579684
min_cluster_size=MIN_CLUSTER_SIZE,
580685
)
581686

582-
# Temporal post-processing: smooth, then enforce min-run
687+
# Temporal post-processing on REGION labels
688+
labels = lab_regions
583689
if labels.size and smoothing_window and smoothing_window > 1:
584-
logging.info("Smoothing labels (window=%d)...", smoothing_window)
585690
labels = smooth_labels(labels, smoothing_window)
586691
if labels.size and min_run and min_run > 1:
587692
labels = enforce_min_run(labels, min_run=min_run)
588693

589694
logging.info("Transcribing with Whisper (%s) on %s...", whisper_model, dev)
590695
transcript_segments = transcribe_audio(audio_filepath, model_name=whisper_model, language=language, device=dev)
591696

592-
logging.info("Assigning speaker labels to transcript segments...")
593-
speaker_transcripts = assign_speakers_to_transcripts(transcript_segments, labels, timestamps)
697+
# Use region timestamps for alignment
698+
diar_ts = ts_regions
699+
speaker_transcripts = assign_speakers_to_transcripts(transcript_segments, labels, diar_ts)
594700

595701
if merge_consecutive:
596702
logging.info("Merging consecutive segments by same speaker (max_gap=%.2fs)...", max_gap_merge)
597703
speaker_transcripts = merge_consecutive_speaker_segments(speaker_transcripts, max_gap=max_gap_merge)
598704

599-
# Emit
705+
# Emit (txt / srt / vtt)
706+
txt = format_output_text(speaker_transcripts, include_end=include_span)
707+
600708
if not output_filepath:
601-
print(format_output_text(speaker_transcripts))
709+
print(txt)
602710
else:
603711
low = output_filepath.lower()
604712
if low.endswith(".srt"):
@@ -607,7 +715,7 @@ def main(
607715
write_vtt(speaker_transcripts, output_filepath)
608716
else:
609717
with open(output_filepath, "w", encoding="utf-8") as f:
610-
f.write(format_output_text(speaker_transcripts))
718+
f.write(txt)
611719
logging.info("Saved to %s", output_filepath)
612720

613721
if __name__ == "__main__":
@@ -628,6 +736,14 @@ def main(
628736

629737
# VAD + windows
630738
ap.add_argument("--no-vad", action="store_true", help="Disable WebRTC VAD gating")
739+
# span flags (default ON)
740+
group = ap.add_mutually_exclusive_group()
741+
group.add_argument("--span", dest="span", action="store_true",
742+
help="Include end time in .txt headers (default)")
743+
group.add_argument("--no-span", dest="span", action="store_false",
744+
help="Hide end time in .txt headers")
745+
ap.set_defaults(span=True)
746+
631747
ap.add_argument("--window", type=float, default=WINDOW_SIZE, help="Window size (sec)")
632748
ap.add_argument("--hop", type=float, default=HOP_SIZE, help="Hop size (sec)")
633749
ap.add_argument("--vad-frame-ms", type=int, default=VAD_FRAME_MS, help="VAD frame size (10/20/30 ms)")
@@ -680,6 +796,7 @@ def main(
680796
collapse_sil=args.collapse_sil,
681797
collapse_centroid=args.collapse_centroid,
682798
seed=args.seed,
799+
include_span=args.span, # <— default True; can be disabled with --no-span
683800
)
684801

685802
# # /// ALT: defaults to 1; conservative.

0 commit comments

Comments
 (0)