6060COLLAPSE_CENT = 0.22 # collapse to 1 if min centroid distance <= this
6161MIN_CLUSTER_SIZE = 5 # clusters smaller than this are absorbed
6262
63+ # Changepoint recognition
64+ CP_ENTER = 0.28 # enter speech-change state if cosine jump >= this
65+ CP_EXIT = 0.22 # exit back below this (hysteresis)
66+ MIN_REGION_SEC = 1.2 # min region length (seconds) after change-point merge
67+ USE_CHANGEPOINTS = True # default ON
68+
6369# ---------- Logging ----------
6470logging .basicConfig (level = logging .INFO , format = "%(asctime)s - %(levelname)s - %(message)s" )
6571
@@ -211,6 +217,86 @@ def enforce_min_run(labels: np.ndarray, min_run: int = 3) -> np.ndarray:
211217 i = j
212218 return y
213219
220+ # --- Changepoint recognition helpers ---
221+ def _regionize_by_changepoints (X_win , ts_win , enter = CP_ENTER , exit = CP_EXIT , min_region_sec = MIN_REGION_SEC ):
222+ """
223+ X_win : (N, d) unit-norm window embeddings
224+ ts_win: list[(t0, t1)] for each window
225+ returns: regions = list[(start_idx, end_idx_inclusive)], ts_regions = list[(t0, t1)]
226+ """
227+ X = _unit_norm (np .asarray (X_win , dtype = np .float32 ))
228+ N = len (X )
229+ if N == 0 :
230+ return [], []
231+ if N == 1 :
232+ return [(0 , 0 )], [ts_win [0 ]]
233+
234+ # cosine distance between consecutive windows
235+ d = np .maximum (0.0 , (cosine_distances (X [:- 1 ], X [1 :]).diagonal ()))
236+ # hysteresis thresholding
237+ on = False
238+ cuts = [0 ]
239+ for i , val in enumerate (d , start = 1 ):
240+ if not on and val >= float (enter ):
241+ on = True
242+ cuts .append (i )
243+ elif on and val <= float (exit ):
244+ on = False
245+ cuts .append (i )
246+ if cuts [- 1 ] != N :
247+ cuts .append (N )
248+
249+ # build regions from cuts, enforce min duration
250+ regions = []
251+ ts_regions = []
252+ cur_s = cuts [0 ]
253+ for cur_e in cuts [1 :]:
254+ s , e = cur_s , cur_e - 1
255+ # grow region until long enough
256+ t0 , _ = ts_win [s ]
257+ _ , t1 = ts_win [e ]
258+ if (t1 - t0 ) >= float (min_region_sec ):
259+ regions .append ((s , e ))
260+ ts_regions .append ((t0 , t1 ))
261+ cur_s = cur_e
262+ else :
263+ # too short; defer merging with next chunk
264+ continue
265+ # tail (if left)
266+ if cur_s < N :
267+ s , e = cur_s , N - 1
268+ t0 , _ = ts_win [s ]
269+ _ , t1 = ts_win [e ]
270+ if regions and (t1 - t0 ) < float (min_region_sec ):
271+ # absorb into previous region
272+ ps , pe = regions [- 1 ]
273+ regions [- 1 ] = (ps , e )
274+ pt0 , _ = ts_regions [- 1 ]
275+ ts_regions [- 1 ] = (pt0 , t1 )
276+ else :
277+ regions .append ((s , e ))
278+ ts_regions .append ((t0 , t1 ))
279+
280+ # guarantee at least one region
281+ if not regions :
282+ regions = [(0 , N - 1 )]
283+ ts_regions = [(ts_win [0 ][0 ], ts_win [- 1 ][1 ])]
284+ return regions , ts_regions
285+
286+ def _mean_embs_by_regions (X_win , regions ):
287+ if not regions :
288+ return np .zeros ((0 , X_win .shape [1 ]), dtype = np .float32 )
289+ means = []
290+ for s , e in regions :
291+ means .append (X_win [s :e + 1 ].mean (axis = 0 ))
292+ return _unit_norm (np .vstack (means ).astype (np .float32 ))
293+
294+ def _expand_region_labels_to_windows (regions , lab_regions , N ):
295+ y = np .zeros (N , dtype = int )
296+ for (s , e ), lab in zip (regions , lab_regions ):
297+ y [s :e + 1 ] = int (lab )
298+ return y
299+
214300# ---------- Clustering helpers ----------
215301def single_speaker_guard (embeddings , guard_q90 = GUARD_Q90 ):
216302 """
@@ -457,11 +543,15 @@ def format_timestamp(seconds):
457543 return f"{ hours :02d} :{ minutes :02d} :{ secs :02d} "
458544 return f"{ minutes :02d} :{ secs :02d} "
459545
460- def format_output_text (speaker_transcripts ):
461- return "\n " .join (
462- f"=== { format_timestamp (seg ['start' ])} ({ seg ['speaker' ]} ) ===\n { seg ['text' ]} \n "
463- for seg in speaker_transcripts
464- )
546+ def format_output_text (speaker_transcripts , include_end = True ):
547+ lines = []
548+ for seg in speaker_transcripts :
549+ if include_end :
550+ head = f"=== { format_timestamp (seg ['start' ])} -{ format_timestamp (seg ['end' ])} ({ seg ['speaker' ]} ) ==="
551+ else :
552+ head = f"=== { format_timestamp (seg ['start' ])} ({ seg ['speaker' ]} ) ==="
553+ lines .append (f"{ head } \n { seg ['text' ]} \n " )
554+ return "\n " .join (lines )
465555
466556def _fmt_srt_time (t ):
467557 ms = int (round (t * 1000 ))
@@ -522,6 +612,7 @@ def main(
522612 collapse_sil = COLLAPSE_SIL ,
523613 collapse_centroid = COLLAPSE_CENT ,
524614 seed = 1337 ,
615+ include_span = True , # NEW: default show start–end in .txt headers
525616):
526617 # Resolve device
527618 dev = resolve_device (device )
@@ -561,11 +652,25 @@ def main(
561652
562653 logging .info ("Computing embeddings (%d windows)..." , len (segments ))
563654 encoder = VoiceEncoder (device = torch .device (dev ))
564- embeddings = get_embeddings (segments , encoder )
655+ win_embs = get_embeddings (segments , encoder )
656+ X_win = _unit_norm (win_embs )
657+
658+ # --- CHANGE-POINT REGIONIZATION (default ON) ---
659+ if USE_CHANGEPOINTS :
660+ regions , ts_regions = _regionize_by_changepoints (
661+ X_win , timestamps , enter = CP_ENTER , exit = CP_EXIT , min_region_sec = MIN_REGION_SEC
662+ )
663+ X_reg = _mean_embs_by_regions (X_win , regions )
664+ feat_for_cluster = X_reg
665+ logging .info ("Regionized %d windows -> %d regions" , len (X_win ), len (regions ))
666+ else :
667+ regions = [(i , i ) for i in range (len (X_win ))]
668+ ts_regions = timestamps
669+ feat_for_cluster = X_win
565670
566671 logging .info ("Clustering (method=%s%s)..." , method , f", force_n={ force_n } " if force_n else "" )
567- labels = pick_labels (
568- embeddings ,
672+ lab_regions = pick_labels (
673+ feat_for_cluster ,
569674 method = method ,
570675 min_speakers = min_speakers ,
571676 max_speakers = max_speakers ,
@@ -579,26 +684,29 @@ def main(
579684 min_cluster_size = MIN_CLUSTER_SIZE ,
580685 )
581686
582- # Temporal post-processing: smooth, then enforce min-run
687+ # Temporal post-processing on REGION labels
688+ labels = lab_regions
583689 if labels .size and smoothing_window and smoothing_window > 1 :
584- logging .info ("Smoothing labels (window=%d)..." , smoothing_window )
585690 labels = smooth_labels (labels , smoothing_window )
586691 if labels .size and min_run and min_run > 1 :
587692 labels = enforce_min_run (labels , min_run = min_run )
588693
589694 logging .info ("Transcribing with Whisper (%s) on %s..." , whisper_model , dev )
590695 transcript_segments = transcribe_audio (audio_filepath , model_name = whisper_model , language = language , device = dev )
591696
592- logging .info ("Assigning speaker labels to transcript segments..." )
593- speaker_transcripts = assign_speakers_to_transcripts (transcript_segments , labels , timestamps )
697+ # Use region timestamps for alignment
698+ diar_ts = ts_regions
699+ speaker_transcripts = assign_speakers_to_transcripts (transcript_segments , labels , diar_ts )
594700
595701 if merge_consecutive :
596702 logging .info ("Merging consecutive segments by same speaker (max_gap=%.2fs)..." , max_gap_merge )
597703 speaker_transcripts = merge_consecutive_speaker_segments (speaker_transcripts , max_gap = max_gap_merge )
598704
599- # Emit
705+ # Emit (txt / srt / vtt)
706+ txt = format_output_text (speaker_transcripts , include_end = include_span )
707+
600708 if not output_filepath :
601- print (format_output_text ( speaker_transcripts ) )
709+ print (txt )
602710 else :
603711 low = output_filepath .lower ()
604712 if low .endswith (".srt" ):
@@ -607,7 +715,7 @@ def main(
607715 write_vtt (speaker_transcripts , output_filepath )
608716 else :
609717 with open (output_filepath , "w" , encoding = "utf-8" ) as f :
610- f .write (format_output_text ( speaker_transcripts ) )
718+ f .write (txt )
611719 logging .info ("Saved to %s" , output_filepath )
612720
613721if __name__ == "__main__" :
@@ -628,6 +736,14 @@ def main(
628736
629737 # VAD + windows
630738 ap .add_argument ("--no-vad" , action = "store_true" , help = "Disable WebRTC VAD gating" )
739+ # span flags (default ON)
740+ group = ap .add_mutually_exclusive_group ()
741+ group .add_argument ("--span" , dest = "span" , action = "store_true" ,
742+ help = "Include end time in .txt headers (default)" )
743+ group .add_argument ("--no-span" , dest = "span" , action = "store_false" ,
744+ help = "Hide end time in .txt headers" )
745+ ap .set_defaults (span = True )
746+
631747 ap .add_argument ("--window" , type = float , default = WINDOW_SIZE , help = "Window size (sec)" )
632748 ap .add_argument ("--hop" , type = float , default = HOP_SIZE , help = "Hop size (sec)" )
633749 ap .add_argument ("--vad-frame-ms" , type = int , default = VAD_FRAME_MS , help = "VAD frame size (10/20/30 ms)" )
@@ -680,6 +796,7 @@ def main(
680796 collapse_sil = args .collapse_sil ,
681797 collapse_centroid = args .collapse_centroid ,
682798 seed = args .seed ,
799+ include_span = args .span , # <— default True; can be disabled with --no-span
683800 )
684801
685802# # /// ALT: defaults to 1; conservative.
0 commit comments