Skip to content

Commit fac2270

Browse files
committed
NVIDIA NeMo diarization test module
1 parent b922e5e commit fac2270

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

src/utils/diarize_nemo.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,6 @@
6969
# --- Graceful termination (SIGINT/SIGTERM) ------------------------------------
7070
_SIG_CAUGHT = False
7171

72-
print("[NOTE/WARNING] This script is EXPERIMENTAL. Do NOT leave bug reports on it unless you're willing to fix them yourself.")
73-
7472
def _graceful_exit(signum, frame):
7573
# Map common signals to human-friendly names (best-effort, portable)
7674
name = {getattr(signal, "SIGINT", 2): "SIGINT",
@@ -1060,10 +1058,10 @@ def _get_diarizer_class(msdd_enabled: bool):
10601058
# elif not args.no_asr:
10611059
# logging.warning("[PIPELINE] ASR model was provided, but it will be ignored because the clustering-only diarizer (no MSDD) does not support transcription.")
10621060

1063-
# Re-create manifest path (safe to do again here)
1064-
manifest_fp = create_manifest(
1065-
processed_audio_path, args.num_speakers, args.output_dir, args.workdir, args.allow_outside_workdir
1066-
)
1061+
# # Re-create manifest path (safe to do again here)
1062+
# manifest_fp = create_manifest(
1063+
# processed_audio_path, args.num_speakers, args.output_dir, args.workdir, args.allow_outside_workdir
1064+
# )
10671065

10681066
# Build a dict that matches NeMo’s expected structure 1:1
10691067
cfg_dict = {
@@ -1135,7 +1133,9 @@ def _get_diarizer_class(msdd_enabled: bool):
11351133
if not args.no_asr and 'asr_model' in model_paths:
11361134
cfg_dict['diarizer']['asr'] = {'model_path': model_paths['asr_model']}
11371135
elif not args.no_asr:
1138-
logging.warning("[PIPELINE] ASR model was provided, but it will be ignored because the clustering-only diarizer (no MSDD) does not support transcription.")
1136+
# logging.warning("[PIPELINE] ASR model was provided, but it will be ignored because the clustering-only diarizer (no MSDD) does not support transcription.")
1137+
logging.critical("FATAL: Transcription requires MSDD diarizer. Use --msdd-model <model> or add --no-asr.")
1138+
sys.exit(EXIT_CODE_CLI_ERROR)
11391139

11401140
# Convert to OmegaConf
11411141
cfg = om.create(cfg_dict)
@@ -1166,6 +1166,7 @@ def _get_diarizer_class(msdd_enabled: bool):
11661166
# except Exception:
11671167
# pass
11681168

1169+
# --- merge user diarizer config ONCE ---
11691170
if args.diarizer_config:
11701171
try:
11711172
override_cfg = om.load(args.diarizer_config)
@@ -1178,17 +1179,14 @@ def _get_diarizer_class(msdd_enabled: bool):
11781179
output_str, success = "", False
11791180
try:
11801181
logging.info("[PIPELINE] Starting NeMo diarization/transcription job...")
1182+
1183+
# Apply last so user YAML can’t re-enable overlap smoothing by accident
1184+
_force_no_overlap(cfg)
1185+
1186+
# Pick diarizer class once
11811187
DiarizerClass, diarizer_label = _get_diarizer_class(msdd_enabled=not msdd_disabled)
11821188
logging.info(f"[PIPELINE] Using diarizer class: {diarizer_label}")
11831189

1184-
if args.diarizer_config:
1185-
override_cfg = om.load(args.diarizer_config)
1186-
cfg = om.merge(cfg, override_cfg)
1187-
logging.info(f"[PIPELINE] Merged diarizer config from {args.diarizer_config}")
1188-
1189-
_force_no_overlap(cfg)
1190-
DiarizerClass, diarizer_label = _get_diarizer_class(msdd_enabled=not msdd_disabled)
1191-
11921190
dm = DiarizerClass(cfg=cfg)
11931191

11941192
# Some NeMo builds are Lightning modules; .to() may or may not exist—be liberal.
@@ -1346,6 +1344,9 @@ def _assert_transformers_version():
13461344
class CustomHelpFormatter(argparse.RawDescriptionHelpFormatter, argparse.ArgumentDefaultsHelpFormatter): pass
13471345

13481346
if __name__ == "__main__":
1347+
# warning
1348+
print("[NOTE/WARNING] This script is EXPERIMENTAL. Do NOT leave bug reports on it unless you're willing to fix them yourself.", file=sys.stderr)
1349+
13491350
parser = argparse.ArgumentParser(description=SCRIPT_BANNER, formatter_class=CustomHelpFormatter)
13501351
parser.add_argument("audio_filepath", nargs='?', default=None, help="Path to the audio file to process.")
13511352
parser.add_argument("--workdir", default="./data", help="Root working directory for all runs and artifacts.")

0 commit comments

Comments
 (0)