diff --git a/CHANGELOG.md b/CHANGELOG.md index 8ef09b72b..6ff016ee0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,13 @@ this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ### Added +- Add `NaNCapture` for fine-tuning debugging: when a NaN/Inf is detected in parameter + gradients during training, save a self-contained capture (model state dict + + TrainModel class/init kwargs + the step's microbatches + RNG state) to + `out_dir/debug/nan_capture/rank{R}/nan_capture.pt` and halt training. Replay via + `lightly_train._debug.nan_capture.load_nan_capture(dir).replay()` to deterministically + reproduce the failure in a notebook/REPL. Enable with + `debug_args.nancapture.enabled=True`. - Add ONNX and TensorRT export for depth estimation models via the `export_onnx` and `export_tensorrt` methods of `DepthAnythingDepthEstimation`. diff --git a/src/lightly_train/_commands/train_task.py b/src/lightly_train/_commands/train_task.py index 53b7ff30b..f492d0894 100644 --- a/src/lightly_train/_commands/train_task.py +++ b/src/lightly_train/_commands/train_task.py @@ -58,6 +58,7 @@ ) from lightly_train._debug import debug_args, underflow_overflow from lightly_train._debug.debug_args import DebugArgs +from lightly_train._debug.nan_capture import NaNCaptureMonitor from lightly_train._debug.underflow_overflow import UnderflowOverflowMonitor from lightly_train._events import tracker from lightly_train._loggers.task_logger_args import TaskLoggerArgs @@ -227,11 +228,12 @@ def train_image_classification( Set to 1 to explicitly disable gradient accumulation. debug_args: Debug configuration dict. `None` disables debugging; keys - configure individual debug tools. The currently supported key is - ``underflow_overflow``, which enables detection of inf/nan - activations and weights via forward hooks on all model modules. - Reports are written to per-rank log files in ``out/debug/``. - Cannot be combined with + configure individual debug tools. Supported keys are + ``underflow_overflow`` (detect inf/nan activations and weights via + forward hooks on all model modules) and ``nancapture`` (capture a + replayable bad-batch snapshot when parameter gradients become + non-finite). Reports/captures are written under ``out/debug/``. + ``underflow_overflow`` cannot be combined with ``torch_compile_args={"disable": False}``. """ kwargs = {**locals()} @@ -381,11 +383,12 @@ def train_image_classification_multihead( Set to 1 to explicitly disable gradient accumulation. debug_args: Debug configuration dict. `None` disables debugging; keys - configure individual debug tools. The currently supported key is - ``underflow_overflow``, which enables detection of inf/nan - activations and weights via forward hooks on all model modules. - Reports are written to per-rank log files in ``out/debug/``. - Cannot be combined with + configure individual debug tools. Supported keys are + ``underflow_overflow`` (detect inf/nan activations and weights via + forward hooks on all model modules) and ``nancapture`` (capture a + replayable bad-batch snapshot when parameter gradients become + non-finite). Reports/captures are written under ``out/debug/``. + ``underflow_overflow`` cannot be combined with ``torch_compile_args={"disable": False}``. """ kwargs = {**locals()} @@ -562,11 +565,12 @@ def train_instance_segmentation( Set to 1 to explicitly disable gradient accumulation. debug_args: Debug configuration dict. `None` disables debugging; keys - configure individual debug tools. The currently supported key is - ``underflow_overflow``, which enables detection of inf/nan - activations and weights via forward hooks on all model modules. - Reports are written to per-rank log files in ``out/debug/``. - Cannot be combined with + configure individual debug tools. Supported keys are + ``underflow_overflow`` (detect inf/nan activations and weights via + forward hooks on all model modules) and ``nancapture`` (capture a + replayable bad-batch snapshot when parameter gradients become + non-finite). Reports/captures are written under ``out/debug/``. + ``underflow_overflow`` cannot be combined with ``torch_compile_args={"disable": False}``. """ tracker.track_training_started( @@ -730,11 +734,12 @@ def train_object_detection( Set to 1 to explicitly disable gradient accumulation. debug_args: Debug configuration dict. `None` disables debugging; keys - configure individual debug tools. The currently supported key is - ``underflow_overflow``, which enables detection of inf/nan - activations and weights via forward hooks on all model modules. - Reports are written to per-rank log files in ``out/debug/``. - Cannot be combined with + configure individual debug tools. Supported keys are + ``underflow_overflow`` (detect inf/nan activations and weights via + forward hooks on all model modules) and ``nancapture`` (capture a + replayable bad-batch snapshot when parameter gradients become + non-finite). Reports/captures are written under ``out/debug/``. + ``underflow_overflow`` cannot be combined with ``torch_compile_args={"disable": False}``. """ tracker.track_training_started( @@ -899,11 +904,12 @@ def train_panoptic_segmentation( Set to 1 to explicitly disable gradient accumulation. debug_args: Debug configuration dict. `None` disables debugging; keys - configure individual debug tools. The currently supported key is - ``underflow_overflow``, which enables detection of inf/nan - activations and weights via forward hooks on all model modules. - Reports are written to per-rank log files in ``out/debug/``. - Cannot be combined with + configure individual debug tools. Supported keys are + ``underflow_overflow`` (detect inf/nan activations and weights via + forward hooks on all model modules) and ``nancapture`` (capture a + replayable bad-batch snapshot when parameter gradients become + non-finite). Reports/captures are written under ``out/debug/``. + ``underflow_overflow`` cannot be combined with ``torch_compile_args={"disable": False}``. """ tracker.track_training_started( @@ -1067,11 +1073,12 @@ def train_semantic_segmentation( Set to 1 to explicitly disable gradient accumulation. debug_args: Debug configuration dict. `None` disables debugging; keys - configure individual debug tools. The currently supported key is - ``underflow_overflow``, which enables detection of inf/nan - activations and weights via forward hooks on all model modules. - Reports are written to per-rank log files in ``out/debug/``. - Cannot be combined with + configure individual debug tools. Supported keys are + ``underflow_overflow`` (detect inf/nan activations and weights via + forward hooks on all model modules) and ``nancapture`` (capture a + replayable bad-batch snapshot when parameter gradients become + non-finite). Reports/captures are written under ``out/debug/``. + ``underflow_overflow`` cannot be combined with ``torch_compile_args={"disable": False}``. """ tracker.track_training_started( @@ -1200,11 +1207,12 @@ def train_semantic_segmentation_multihead( Set to 1 to explicitly disable gradient accumulation. debug_args: Debug configuration dict. `None` disables debugging; keys - configure individual debug tools. The currently supported key is - ``underflow_overflow``, which enables detection of inf/nan - activations and weights via forward hooks on all model modules. - Reports are written to per-rank log files in ``out/debug/``. - Cannot be combined with + configure individual debug tools. Supported keys are + ``underflow_overflow`` (detect inf/nan activations and weights via + forward hooks on all model modules) and ``nancapture`` (capture a + replayable bad-batch snapshot when parameter gradients become + non-finite). Reports/captures are written under ``out/debug/``. + ``underflow_overflow`` cannot be combined with ``torch_compile_args={"disable": False}``. """ tracker.track_training_started( @@ -1577,7 +1585,21 @@ def _train_task_from_config(config: TrainTaskConfig) -> None: else: monitor_ctx = contextlib.nullcontext(None) - with monitor_ctx as underflow_overflow_monitor: + nancapture_ctx: contextlib.AbstractContextManager[NaNCaptureMonitor | None] + if config.debug_args.is_nancapture_enabled(): + assert config.debug_args.nancapture is not None + nancapture_ctx = NaNCaptureMonitor( + train_model=train_model, + train_model_init_kwargs=train_model_init_kwargs, + debug_args=config.debug_args.nancapture, + out_dir=out_dir, + global_rank=fabric.global_rank, + gradient_accumulation_steps=config.gradient_accumulation_steps, + ) + else: + nancapture_ctx = contextlib.nullcontext(None) + + with monitor_ctx as underflow_overflow_monitor, nancapture_ctx as nancapture_monitor: logger.info( f"Resolved Args: {helpers.pretty_format_args(args=config.model_dump())}" ) @@ -1708,6 +1730,12 @@ def _train_task_from_config(config: TrainTaskConfig) -> None: if underflow_overflow_monitor is not None: underflow_overflow_monitor.set_step(step=step) + # Reset the NaNCapture monitor's per-step buffer and snapshot RNG so + # that, if a NaN is detected at this step, the capture includes the + # exact microbatches and stochastic state needed for replay. + if nancapture_monitor is not None: + nancapture_monitor.begin_step(step=step) + train_transform.set_step(step) train_collate_fn.set_step(step) @@ -1729,6 +1757,11 @@ def _train_task_from_config(config: TrainTaskConfig) -> None: batch = next(infinite_train_dataloader) timer.end_step("train_dataload") + # Clone+detach this microbatch into the NaNCapture buffer so the + # captured state is decoupled from the live autograd graph. + if nancapture_monitor is not None: + nancapture_monitor.collect_batch(batch) + # Type ignore is needed because `train_model` is not recognized as an # instance of `_FabricModule` with fabric.no_backward_sync(train_model, enabled=is_accumulating): # type: ignore[arg-type] @@ -1749,6 +1782,12 @@ def _train_task_from_config(config: TrainTaskConfig) -> None: ) # Optimizer step and scheduler step. + # Scan parameter gradients for NaN/Inf before the optimizer step + # (which would corrupt the model via the bad gradient). On detection, + # the monitor saves the capture and raises NaNDetectedError, halting + # training. + if nancapture_monitor is not None: + nancapture_monitor.check_and_capture(train_model) # clip_gradients returns the total gradient norm before clipping. It is # None for models that do not support gradient norm logging. gradient_norm = train_model.clip_gradients( @@ -1965,8 +2004,8 @@ def _train_task_from_config(config: TrainTaskConfig) -> None: ) train_model.set_train_mode() fabric.barrier() - timer.stop() - logger.info("Training completed.") + timer.stop() + logger.info("Training completed.") class TrainTaskConfig(PydanticConfig): diff --git a/src/lightly_train/_debug/debug_args.py b/src/lightly_train/_debug/debug_args.py index 3d5d7b1df..d648317bc 100644 --- a/src/lightly_train/_debug/debug_args.py +++ b/src/lightly_train/_debug/debug_args.py @@ -70,13 +70,42 @@ def _validate_trace_batch_nums(cls, v: list[int]) -> list[int]: return v +class NaNCaptureArgs(PydanticConfig): + """Arguments for NaN/Inf capture debugging. + + When enabled, the monitor scans parameter gradients for NaN/Inf after each + gradient-accumulation step (before ``clip_gradients``/``optimizer.step``). + On detection, it saves a self-contained capture to + ``out_dir/debug/nan_capture/rank{R}/nan_capture.pt`` holding the model + state, the step's microbatches, RNG state, and metadata — then raises + :class:`NaNDetectedError` to halt training. The capture is reproducibly + loadable via :func:`lightly_train._debug.nan_capture.load_nan_capture` + and ``.replay()``. + """ + + enabled: bool = Field( + default=False, + description=( + "Whether to enable NaN/Inf capture debugging. When True, scans " + "parameter gradients after each accumulated training step and, on " + "first non-finite gradient, saves a replayable capture to " + "out_dir/debug/nan_capture/rank{R}/nan_capture.pt before aborting." + ), + ) + + class DebugArgs(PydanticConfig): underflow_overflow: DebugUnderflowOverflowArgs | None = None + nancapture: NaNCaptureArgs | None = None def is_underflow_overflow_enabled(self) -> bool: """Returns True if underflow/overflow debugging is enabled.""" return self.underflow_overflow is not None and self.underflow_overflow.enabled + def is_nancapture_enabled(self) -> bool: + """Returns True if NaN/Inf capture debugging is enabled.""" + return self.nancapture is not None and self.nancapture.enabled + def get_debug_args(debug_args: dict[str, Any] | DebugArgs | None) -> DebugArgs: """Resolves the debug arguments into a :class:`DebugArgs` instance.""" diff --git a/src/lightly_train/_debug/nan_capture.py b/src/lightly_train/_debug/nan_capture.py new file mode 100644 index 000000000..43f79969f --- /dev/null +++ b/src/lightly_train/_debug/nan_capture.py @@ -0,0 +1,463 @@ +# +# Copyright (c) Lightly AG and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +"""NaN/Inf capture for reproducing bad batches in fine-tuning training. + +Self-contained capture + replay tool for the manual Fabric training loop +(``lightly_train._commands.train_task``). + +**Scheme.** When enabled via ``DebugArgs.nancapture.enabled=True``: + +- On each training step, :class:`NaNCaptureMonitor` clones+detaches each + microbatch into a per-step buffer and snapshots the torch (and CUDA, if + available) RNG state. +- After the gradient-accumulation loop completes and *before* + ``clip_gradients``/``optimizer.step()`` (the point at which a NaN gradient + would corrupt the model via the optimizer step), the monitor scans all + parameter ``.grad`` tensors for NaN/Inf. +- On detection, the monitor writes a single self-contained + ``out_dir/debug/nan_capture/rank{R}/nan_capture.pt`` holding the model + state dict, the TrainModel class path + init kwargs for reconstruction, + the step's microbatches, RNG state, and metadata — then raises + :class:`NaNDetectedError` to halt training. + +The standard ``checkpoints/last.ckpt`` is *not* touched; ``resume_interrupted`` +from ``out_dir`` is unaffected. + +**Replay.** Capture files are reproducibly loadable via +:func:`load_nan_capture` + :meth:`NaNCaptureState.replay`. The replay +reconstructs the TrainModel, restores RNG, and re-runs the triggering +forward+backward sequence — zero-setup (auto-creates a single-device Fabric +if none is passed). Stops before ``clip_gradients``/``optimizer.step`` since +the NaN lives in gradients/activations, and the corruption path is the +optimizer step which the training loop never reached. + +Caveat: replay uses default Fabric precision (``"32-true"``). To reproduce a +mixed-precision failure (e.g. bfloat16 overflow), construct your own Fabric +matching the captured run's precision and pass it: ``cap.replay(fabric=f)``. + +Known limitation: the capture holds the model state **after** the triggering +step's forward/backward. Model **parameters** are unaffected by forward/backward +(they are only changed by the optimizer step, which never ran), so weight-based +reproduction is faithful. But training-mode **buffers** mutated during the +forward pass (e.g. BatchNorm running stats) are captured at their post-batch +values, not pre-batch — so replay may start from slightly stale buffers. This is +accepted for v1; it does not affect dropout/data-driven NaN debugging. If the +suspect path runs through BatchNorm-style buffers, treat replay results with +that caveat. + +**References.** The capture-and-replay scheme (clone microbatches, snapshot RNG, +write the capture before the optimizer step, halt on NaN/Inf) follows Chaim Rand, +"Debugging the Dreaded NaN — Capturing and Reproducing Failures in PyTorch +Training with Lightning", Feb 2025. +https://chaimrand.medium.com/debugging-the-dreaded-nan-ac3f9feac5b2 +""" + +from __future__ import annotations + +import datetime +import importlib +import logging +import os +from collections.abc import Mapping +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import torch +from lightning_fabric import Fabric +from torch import Tensor +from torch.nn import Module + +from lightly_train._debug.debug_args import NaNCaptureArgs + +logger = logging.getLogger(__name__) + +_CAPTURE_FILENAME = "nan_capture.pt" + + +class NaNDetectedError(RuntimeError): + """Raised when a NaN/Inf is detected in parameter gradients. + + The capture (model state, microbatches, RNG, metadata) has already been + written to ``capture_path`` before this exception is raised. To reproduce + the failure:: + + from lightly_train._debug.nan_capture import load_nan_capture + cap = load_nan_capture("") + cap.replay() + """ + + def __init__(self, nan_param_names: list[str], capture_path: Path) -> None: + self.nan_param_names = list(nan_param_names) + self.capture_path = Path(capture_path) + first = self.nan_param_names[0] if self.nan_param_names else "" + super().__init__( + f"NaN/Inf detected in gradients at {len(self.nan_param_names)} " + f"parameter(s). First failing parameter: '{first}'. Capture saved " + f"to '{self.capture_path}'. Reproduce with: " + f"lightly_train._debug.nan_capture.load_nan_capture(" + f"'{self.capture_path}').replay()." + ) + + +@dataclass +class NaNCaptureMetadata: + """Metadata for a captured NaN/Inf occurrence.""" + + step: int + rank: int + timestamp: str + nan_param_names: list[str] + gradient_accumulation_steps: int + train_model_class_path: str + train_model_device: str + + +@dataclass +class NaNReplayResult: + """Result of replaying a captured NaN/Inf. + + Attributes: + step: Training step that was replayed. + results: Per-microbatch ``TaskStepResult`` outputs (in order). + nan_param_names: Parameter names whose ``.grad`` is NaN/Inf after the + replayed accumulation; empty if the NaN did not reproduce. + reproduced: ``True`` iff ``nan_param_names`` is non-empty. + """ + + step: int + results: list[Any] + nan_param_names: list[str] + reproduced: bool + + def raise_if_reproduced(self) -> None: + """Re-raise :class:`NaNDetectedError` if the NaN reproduced.""" + if self.reproduced: + raise NaNDetectedError( + nan_param_names=self.nan_param_names, + capture_path=Path("(replayed capture)"), + ) + + +@dataclass +class NaNCaptureState: + """Loaded NaN capture, ready for replay. + + Construct via :func:`load_nan_capture`. Call :meth:`replay` to reproduce + the captured failure with zero setup. + """ + + step: int + rank: int + gradient_accumulation_steps: int + model: Any # TrainModel-like; replay requires ``training_step``. + batches: list[Any] + rng_state: dict[str, Any] + metadata: NaNCaptureMetadata + + def replay(self, fabric: Fabric | None = None) -> NaNReplayResult: + """Re-run the triggering step with saved microbatches and RNG. + + Zero-setup: if ``fabric`` is None, creates a single-device Fabric + (``accelerator="auto"``, ``devices=1``) — no ``launch()`` since replay + runs in the main process. Mirrors the training loop exactly: restores + RNG, runs ``training_step`` + ``fabric.backward(loss / grad_accum)`` + over saved microbatches (with ``fabric.no_backward_sync`` on all but + the last microbatch), then scans grads. Stops before + ``clip_gradients``/``optimizer.step`` (the corruption path that the + training loop never reached). + + Caveat: auto-Fabric uses default precision (``"32-true"``). To + reproduce a mixed-precision failure, pass a Fabric matching the + captured run's precision. + """ + if fabric is None: + fabric = Fabric(accelerator="auto", devices=1) + + _restore_rng(self.rng_state) + + self.model.train() + model = fabric.setup_module(self.model) + + device = next(model.parameters()).device + grad_accum = self.gradient_accumulation_steps + + results: list[Any] = [] + for acc_step in range(grad_accum): + is_accumulating = acc_step < grad_accum - 1 + batch = _batch_to_device(self.batches[acc_step], device) + with fabric.no_backward_sync(model, enabled=is_accumulating): + result = model.training_step(fabric=fabric, batch=batch, step=self.step) + fabric.backward(result.loss / grad_accum) + results.append(result) + + nan_names = _scan_grads_nan(model) + return NaNReplayResult( + step=self.step, + results=results, + nan_param_names=nan_names, + reproduced=bool(nan_names), + ) + + +class NaNCaptureMonitor: + """In-training monitor that captures training state on NaN/Inf detection. + + Wired into the fine-tuning training loop after gradient accumulation, + before ``clip_gradients``/``optimizer.step()``. All hooks are no-ops when + :attr:`enabled` is False. + + Lifecycle:: + + with NaNCaptureMonitor(...) as monitor: + for step in range(steps): + monitor.begin_step(step) + for acc_step in range(grad_accum): + batch = next(dataloader) + monitor.collect_batch(batch) + # ... training_step + backward ... + monitor.check_and_capture(train_model) # raises on NaN + """ + + def __init__( + self, + train_model: Module, + train_model_init_kwargs: Mapping[str, object], + debug_args: NaNCaptureArgs, + out_dir: Path, + global_rank: int, + gradient_accumulation_steps: int, + ) -> None: + self._train_model = train_model + self._train_model_class_path = _get_model_class_path(train_model) + self._train_model_init_kwargs = dict(train_model_init_kwargs) + self._out_dir = Path(out_dir) + self._rank = global_rank + self._grad_accum = gradient_accumulation_steps + self._enabled = bool(debug_args.enabled) + self._microbatches: list[Any] = [] + self._step: int | None = None + self._rng_state: dict[str, Any] = {} + + @property + def enabled(self) -> bool: + return self._enabled + + def __enter__(self) -> NaNCaptureMonitor: + return self + + def __exit__(self, *exc: object) -> None: + self.close() + + def close(self) -> None: + """Release per-step buffered state. Idempotent.""" + self._microbatches = [] + self._rng_state = {} + self._step = None + + def begin_step(self, step: int) -> None: + if not self._enabled: + return + self._step = step + self._microbatches = [] + self._rng_state = {"torch": torch.get_rng_state()} + if torch.cuda.is_available(): + device = next(self._train_model.parameters()).device + if device.type == "cuda": + self._rng_state["cuda"] = torch.cuda.get_rng_state(device) + + def collect_batch(self, batch: Any) -> None: + if not self._enabled: + return + # Clone + detach + move to CPU immediately so the per-step buffer lives + # on host RAM (not accelerator) and is decoupled from the live autograd + # graph. Keeps debug-time accelerator memory bounded to the forward + # pass instead of holding the whole microbatch set until save. + self._microbatches.append(_clone_batch_to_cpu(batch)) + + def check_and_capture(self, train_model: Module) -> None: + if not self._enabled: + return + nan_names = _scan_grads_nan(train_model) + if not nan_names: + return + capture_dir = self._out_dir / "debug" / "nan_capture" / f"rank{self._rank}" + capture_dir.mkdir(parents=True, exist_ok=True) + capture_path = capture_dir / _CAPTURE_FILENAME + + state_dict_cpu = _state_dict_to_cpu(train_model.state_dict()) + microbatches_cpu = list(self._microbatches) + device_str = str(next(train_model.parameters()).device) + metadata = NaNCaptureMetadata( + step=self._step if self._step is not None else -1, + rank=self._rank, + timestamp=datetime.datetime.now(datetime.timezone.utc).isoformat(), + nan_param_names=nan_names, + gradient_accumulation_steps=self._grad_accum, + train_model_class_path=self._train_model_class_path, + train_model_device=device_str, + ) + + payload = { + "train_model_state_dict": state_dict_cpu, + "train_model_class_path": self._train_model_class_path, + "train_model_init_kwargs": self._train_model_init_kwargs, + "microbatches": microbatches_cpu, + "rng_state": self._rng_state, + "metadata": metadata, + } + # Atomic write: save to a temp file then os.replace onto the final + # path so a crash/disk-full/interruption can never leave a partially + # written nan_capture.pt (which would be worse than no capture). + tmp_path = capture_path.with_suffix(".pt.tmp") + try: + torch.save(payload, tmp_path) + os.replace(tmp_path, capture_path) + finally: + if tmp_path.exists(): + tmp_path.unlink(missing_ok=True) + logger.error( + f"NaN/Inf captured at step {self._step} on rank {self._rank}: " + f"{len(nan_names)} NaN/Inf parameter(s). Capture saved to " + f"'{capture_path}'." + ) + raise NaNDetectedError(nan_param_names=nan_names, capture_path=capture_path) + + +def load_nan_capture(capture_dir: Path, device: str = "cpu") -> NaNCaptureState: + """Load a NaN capture from ``capture_dir`` for replay. + + ``capture_dir`` is the rank directory (e.g. + ``out_dir/debug/nan_capture/rank0``). Reconstructs the TrainModel from + the captured class path + init kwargs (with ``load_weights=False`` since + the captured state dict supplies the weights), loads the saved raw TrainModel + state dict strictly, restores the saved microbatches and RNG state. Does + not run anything — call :meth:`NaNCaptureState.replay` to reproduce the + failure. + """ + capture_dir = Path(capture_dir) + capture_path = capture_dir / _CAPTURE_FILENAME + if not capture_path.is_file(): + raise FileNotFoundError(f"NaN capture file not found at '{capture_path}'.") + + payload = torch.load(capture_path, map_location=device, weights_only=False) + state_dict = payload["train_model_state_dict"] + class_path = payload["train_model_class_path"] + init_kwargs = dict(payload["train_model_init_kwargs"]) + microbatches = payload["microbatches"] + rng_state = payload["rng_state"] + metadata = payload["metadata"] + + module_path, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_path) + train_model_cls = getattr(module, class_name) + # The captured state dict supplies weights; don't re-download the backbone. + init_kwargs["load_weights"] = False + model = train_model_cls(**init_kwargs) + # A capture is meant to be a self-contained, faithful TrainModel snapshot: + # replay must restore exactly the raw state that produced the bad gradients. + # Do not use task-specific checkpoint/export loaders here; some prefer EMA + # weights when present, while the live model weights produced the NaN. + model.load_state_dict(state_dict, strict=True) + logger.info( + "Loaded NaN capture from '%s' (step %s, rank %s).", + capture_path, + metadata.step, + metadata.rank, + ) + model.to(device) + + return NaNCaptureState( + step=metadata.step, + rank=metadata.rank, + gradient_accumulation_steps=metadata.gradient_accumulation_steps, + model=model, + batches=microbatches, + rng_state=rng_state, + metadata=metadata, + ) + + +# ---- helpers ---------------------------------------------------------------- + + +def _get_model_class_path(model: Module) -> str: + """Return the import path for the real TrainModel behind Fabric wrappers.""" + unwrapped = _unwrap_train_model(model) + return f"{unwrapped.__class__.__module__}.{unwrapped.__class__.__qualname__}" + + +def _unwrap_train_model(model: Module) -> Module: + """Return the original TrainModel for Lightning Fabric-wrapped modules.""" + if model.__class__.__name__ == "_FabricModule": + unwrapped = getattr(model, "module", model) + if isinstance(unwrapped, Module): + return unwrapped + return model + + +def _clone_batch_to_cpu(batch: Any) -> Any: + """Clone + detach + move to CPU all tensors in ``batch``. + + Recurses dicts/lists/tuples. Cloning+detaching breaks autograd links with + the live forward pass; moving to CPU keeps the per-step microbatch buffer + in host RAM rather than holding accelerator memory until the capture save. + """ + if isinstance(batch, torch.Tensor): + return batch.detach().clone().cpu() + if isinstance(batch, dict): + return {k: _clone_batch_to_cpu(v) for k, v in batch.items()} + if isinstance(batch, list): + return [_clone_batch_to_cpu(v) for v in batch] + if isinstance(batch, tuple): + return tuple(_clone_batch_to_cpu(v) for v in batch) + return batch + + +def _batch_to_device(batch: Any, device: torch.device) -> Any: + if isinstance(batch, torch.Tensor): + return batch.to(device) + if isinstance(batch, dict): + return {k: _batch_to_device(v, device) for k, v in batch.items()} + if isinstance(batch, list): + return [_batch_to_device(v, device) for v in batch] + if isinstance(batch, tuple): + return tuple(_batch_to_device(v, device) for v in batch) + return batch + + +def _state_dict_to_cpu( + state_dict: Mapping[str, Any], +) -> dict[str, Any]: + return {k: (v.cpu() if isinstance(v, Tensor) else v) for k, v in state_dict.items()} + + +def _scan_grads_nan(model: Module) -> list[str]: + """Return names of parameters whose ``.grad`` contains NaN or Inf. + + Strips the ``_forward_module.`` prefix that Lightning Fabric's + ``_FabricModule`` adds to ``named_parameters()``, so reported names + match the user's model layout (e.g. ``lin.weight`` rather than + ``_forward_module.lin.weight``). + """ + nan_names: list[str] = [] + for name, p in model.named_parameters(): + if p.grad is None: + continue + if not torch.isfinite(p.grad).all(): + clean = name + if clean.startswith("_forward_module."): + clean = clean[len("_forward_module.") :] + nan_names.append(clean) + return nan_names + + +def _restore_rng(rng_state: dict[str, Any]) -> None: + torch.set_rng_state(rng_state["torch"]) + if "cuda" in rng_state and torch.cuda.is_available(): + torch.cuda.set_rng_state(rng_state["cuda"]) diff --git a/tests/_debug/test_debug_args.py b/tests/_debug/test_debug_args.py index 5a720e02d..93ee5f3cc 100644 --- a/tests/_debug/test_debug_args.py +++ b/tests/_debug/test_debug_args.py @@ -9,7 +9,7 @@ import pytest -from lightly_train._debug.debug_args import get_debug_args +from lightly_train._debug.debug_args import DebugArgs, NaNCaptureArgs, get_debug_args from lightly_train.errors import ConfigValidationError @@ -33,3 +33,25 @@ def test_get_debug_args__rejects_unknown_keys() -> None: def test_get_debug_args__trace_batch_nums_rejects_negative() -> None: with pytest.raises(ConfigValidationError, match="non-negative"): get_debug_args({"underflow_overflow": {"trace_batch_nums": [-1]}}) + + +def test_nancapture_args__default_disabled() -> None: + args = NaNCaptureArgs() + assert args.enabled is False + + +def test_debug_args__is_nancapture_enabled() -> None: + assert DebugArgs().is_nancapture_enabled() is False + assert DebugArgs(nancapture=NaNCaptureArgs()).is_nancapture_enabled() is False + assert ( + DebugArgs(nancapture=NaNCaptureArgs(enabled=True)).is_nancapture_enabled() + is True + ) + + +def test_get_debug_args__parses_nancapture() -> None: + debug_args = get_debug_args({"nancapture": {"enabled": True}}) + assert isinstance(debug_args, DebugArgs) + assert debug_args.nancapture is not None + assert debug_args.nancapture.enabled is True + assert debug_args.is_nancapture_enabled() is True diff --git a/tests/_debug/test_nan_capture.py b/tests/_debug/test_nan_capture.py new file mode 100644 index 000000000..224aed468 --- /dev/null +++ b/tests/_debug/test_nan_capture.py @@ -0,0 +1,397 @@ +# +# Copyright (c) Lightly AG and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import pytest +import torch +import torch.nn as nn +from lightning_fabric import Fabric + +from lightly_train._debug.debug_args import NaNCaptureArgs +from lightly_train._debug.nan_capture import ( + NaNCaptureMonitor, + NaNCaptureState, + NaNDetectedError, + NaNReplayResult, + load_nan_capture, +) + +# --------------------------------------------------------------------------- +# Toy models +# --------------------------------------------------------------------------- + + +class _MonitorToyModel(nn.Module): + """Tiny model used to exercise the monitor's grad scan + buffer.""" + + def __init__(self) -> None: + super().__init__() + self.lin = nn.Linear(4, 4) + + +class _ReplayToyModel(nn.Module): + """Tiny model used to exercise the replay path (load + forward/backward). + + Implements the ``training_step`` method that replay relies on. Kept + self-contained so the test does not depend on a full task pipeline. + """ + + def __init__(self, **_kwargs: object) -> None: + super().__init__() + self.lin = nn.Linear(4, 4) + + def training_step( + self, fabric: object, batch: dict[str, torch.Tensor], step: int + ) -> _ReplayResult: + out = self.lin(batch["x"]) + loss = (out * out).sum() + return _ReplayResult(loss=loss) + + def load_train_state_dict(self, *_args: object, **_kwargs: object) -> Any: + raise AssertionError("NaNCapture replay must load the raw TrainModel state") + + +class _ReplayResult: + """Minimal stand-in for ``TaskStepResult`` carrying only what replay needs.""" + + def __init__(self, loss: torch.Tensor) -> None: + self.loss = loss + + +# --------------------------------------------------------------------------- +# Monitor +# --------------------------------------------------------------------------- + + +def _make_monitor( + model: nn.Module, tmp_path: Path, **overrides: object +) -> NaNCaptureMonitor: + debug_args = NaNCaptureArgs(enabled=True, **overrides) # type: ignore[arg-type] + return NaNCaptureMonitor( + train_model=model, + train_model_init_kwargs={}, + debug_args=debug_args, + out_dir=tmp_path, + global_rank=0, + gradient_accumulation_steps=1, + ) + + +class TestNaNCaptureMonitor: + def test_disabled_monitor_is_noop(self, tmp_path: Path) -> None: + model = _MonitorToyModel() + debug_args = NaNCaptureArgs(enabled=False) + monitor = NaNCaptureMonitor( + train_model=model, + train_model_init_kwargs={}, + debug_args=debug_args, + out_dir=tmp_path, + global_rank=0, + gradient_accumulation_steps=1, + ) + assert monitor.enabled is False + monitor.begin_step(0) + monitor.collect_batch({"x": torch.randn(2, 4)}) + # No NaN set — should be a no-op even if enabled were True; but disabled + # short-circuits regardless. + monitor.check_and_capture(model) + assert not (tmp_path / "debug" / "nan_capture").exists() + + def test_collect_batch__clones_detaches_and_moves_to_cpu( + self, tmp_path: Path + ) -> None: + monitor = _make_monitor(_MonitorToyModel(), tmp_path) + monitor.begin_step(0) + live = torch.randn(2, 4, requires_grad=True) + monitor.collect_batch({"x": live}) + stored = monitor._microbatches[0]["x"] # type: ignore[index] + assert stored is not live + assert stored.grad_fn is None + # Collected immediately to host RAM so the buffer does not hold + # accelerator memory until save. + assert stored.device.type == "cpu" + # Mutating the stored clone (detached, so in-place is allowed) must + # not affect the live tensor — proves the buffer holds a true clone. + stored.fill_(1.0) + assert not torch.equal(stored, live) + assert live.requires_grad # live tensor untouched + + def test_begin_step__resets_buffer(self, tmp_path: Path) -> None: + model = _MonitorToyModel() + monitor = _make_monitor(model, tmp_path) + monitor.begin_step(0) + monitor.collect_batch({"x": torch.randn(2, 4)}) + monitor.collect_batch({"x": torch.randn(2, 4)}) + # Trigger NaN at step 0; capture must contain step 0 microbatches only. + model.lin.weight.grad = torch.full_like(model.lin.weight, float("nan")) + with pytest.raises(NaNDetectedError): + monitor.check_and_capture(model) + cap0 = torch.load( + tmp_path / "debug" / "nan_capture" / "rank0" / "nan_capture.pt", + weights_only=False, + ) + assert len(cap0["microbatches"]) == 2 + + # Now step 1: buffer must be empty until collect_batch populates it again. + monitor.begin_step(1) + assert monitor._microbatches == [] + monitor.collect_batch({"x": torch.randn(2, 4)}) + model.lin.weight.grad = torch.full_like(model.lin.weight, float("nan")) + with pytest.raises(NaNDetectedError): + monitor.check_and_capture(model) + cap1 = torch.load( + tmp_path / "debug" / "nan_capture" / "rank0" / "nan_capture.pt", + weights_only=False, + ) + assert len(cap1["microbatches"]) == 1 + + def test_check_and_capture__clean_does_not_raise_or_save( + self, tmp_path: Path + ) -> None: + model = _MonitorToyModel() + monitor = _make_monitor(model, tmp_path) + monitor.begin_step(0) + monitor.collect_batch({"x": torch.randn(2, 4)}) + # No backward called → no .grad set → no NaN. + monitor.check_and_capture(model) + assert not (tmp_path / "debug" / "nan_capture" / "rank0").exists() + + @pytest.mark.parametrize("bad_value", [float("nan"), float("inf"), float("-inf")]) + def test_check_and_capture__detects_nonfinite_and_writes_capture( + self, tmp_path: Path, bad_value: float + ) -> None: + model = _MonitorToyModel() + monitor = _make_monitor(model, tmp_path) + monitor.begin_step(7) + monitor.collect_batch({"x": torch.randn(2, 4)}) + nan_param = "lin.weight" + model.lin.weight.grad = torch.full_like(model.lin.weight, bad_value) + + with pytest.raises(NaNDetectedError) as excinfo: + monitor.check_and_capture(model) + + capture_path = tmp_path / "debug" / "nan_capture" / "rank0" / "nan_capture.pt" + assert capture_path.exists() + assert excinfo.value.capture_path == capture_path + assert nan_param in excinfo.value.nan_param_names + + payload = torch.load(capture_path, weights_only=False) + assert set(payload.keys()) == { + "train_model_state_dict", + "train_model_class_path", + "train_model_init_kwargs", + "microbatches", + "rng_state", + "metadata", + } + assert payload["metadata"].step == 7 + assert payload["metadata"].rank == 0 + assert payload["metadata"].gradient_accumulation_steps == 1 + assert nan_param in payload["metadata"].nan_param_names + assert len(payload["microbatches"]) == 1 + # Saved state dict holds clean weights (grad NaN does not touch weights). + assert torch.equal( + payload["train_model_state_dict"]["lin.weight"], + model.lin.weight.detach().cpu(), + ) + assert "torch" in payload["rng_state"] + + +# --------------------------------------------------------------------------- +# Replay +# --------------------------------------------------------------------------- + + +def _save_capture( + model: nn.Module, + batches: list[dict[str, torch.Tensor]], + tmp_path: Path, + rank: int = 0, + grad_accum: int = 1, + step: int = 42, + *, + induce_nan_grad: bool = True, +) -> Path: + """Run the monitor through one step that triggers a capture, return dir.""" + monitor = NaNCaptureMonitor( + train_model=model, + train_model_init_kwargs={}, + debug_args=NaNCaptureArgs(enabled=True), + out_dir=tmp_path, + global_rank=rank, + gradient_accumulation_steps=grad_accum, + ) + monitor.begin_step(step) + for b in batches: + monitor.collect_batch(b) + if induce_nan_grad: + # Set NaN on the first parameter to trigger capture. + first_param = next(model.parameters()) + first_param.grad = torch.full_like(first_param, float("nan")) + with pytest.raises(NaNDetectedError): + monitor.check_and_capture(model) + else: + monitor.check_and_capture(model) + return tmp_path / "debug" / "nan_capture" / f"rank{rank}" + + +class TestNaNCaptureReplay: + def test_load_reconstructs_model_with_saved_weights(self, tmp_path: Path) -> None: + original = _ReplayToyModel() + # Touch a parameter so state_dict has well-defined values. + with torch.no_grad(): + original.lin.weight.fill_(0.5) + original.lin.bias.fill_(-0.25) + saved_weight = original.lin.weight.detach().clone() + saved_bias = original.lin.bias.detach().clone() + + batches = [{"x": torch.randn(2, 4)} for _ in range(2)] + capture_dir = _save_capture(original, batches, tmp_path, grad_accum=2) + + cap = load_nan_capture(capture_dir, device="cpu") + assert isinstance(cap, NaNCaptureState) + assert cap.step == 42 + assert cap.rank == 0 + assert cap.gradient_accumulation_steps == 2 + assert len(cap.batches) == 2 + assert torch.equal(cap.model.lin.weight.detach().cpu(), saved_weight) + assert torch.equal(cap.model.lin.bias.detach().cpu(), saved_bias) + + def test_replay_runs_forward_backward_without_nan(self, tmp_path: Path) -> None: + original = _ReplayToyModel() + batches = [{"x": torch.randn(2, 4)}, {"x": torch.randn(2, 4)}] + capture_dir = _save_capture(original, batches, tmp_path, grad_accum=2) + + cap = load_nan_capture(capture_dir, device="cpu") + result = cap.replay() + + assert isinstance(result, NaNReplayResult) + assert result.step == 42 + assert result.reproduced is False + assert result.nan_param_names == [] + assert len(result.results) == 2 + + def test_replay_reproduces_when_forward_produces_nan(self, tmp_path: Path) -> None: + original = _ReplayToyModel() + batches = [{"x": torch.randn(2, 4)}] + capture_dir = _save_capture(original, batches, tmp_path, grad_accum=1) + + cap = load_nan_capture(capture_dir, device="cpu") + # Corrupt the loaded model's weight so forward output is non-finite → + # loss non-finite → backward yields non-finite grads → NaN reproduces. + with torch.no_grad(): + cap.model.lin.weight.fill_(float("nan")) + + result = cap.replay() + + assert result.reproduced is True + assert len(result.nan_param_names) >= 1 + assert "lin.weight" in result.nan_param_names + with pytest.raises(NaNDetectedError): + result.raise_if_reproduced() + + +# --------------------------------------------------------------------------- +# Integration: the production path the toy tests missed +# --------------------------------------------------------------------------- +# The unit toys above (a) never go through `fabric.setup` and (b) take **kwargs +# and ignore them, so they cannot catch two production bugs this feature had: +# 1. recording the Fabric wrapper class instead of the real TrainModel class, +# 2. dumping pydantic config objects to dicts (real constructors need objects). +# The integration test below wraps via fabric.setup_module and uses a ctor that +# *requires* a config object (rejects plain dicts), exercising both paths. + + +@dataclass(frozen=True) +class _ReplayConfig: + """Config object the integration model ctor requires (not a dict).""" + + scale: float = 1.0 + + +class _ReplayFabricToyModel(nn.Module): + """Replay model whose ctor *requires* a config object (rejects dicts).""" + + def __init__(self, config: _ReplayConfig, **_kwargs: object) -> None: + if not isinstance(config, _ReplayConfig): + raise TypeError( + f"config must be a _ReplayConfig object, got {type(config)!r}" + ) + super().__init__() + self.config = config + self.lin = nn.Linear(4, 4) + + def training_step( + self, fabric: object, batch: dict[str, torch.Tensor], step: int + ) -> _ReplayResult: + out = self.lin(batch["x"]) * self.config.scale + loss = (out * out).sum() + return _ReplayResult(loss=loss) + + def load_train_state_dict(self, *_args: object, **_kwargs: object) -> Any: + raise AssertionError("NaNCapture replay must load the raw TrainModel state") + + +class TestNaNCaptureIntegration: + def test_capture_load_replay_through_fabric_setup(self, tmp_path: Path) -> None: + # Construct, then wrap exactly like the training loop does + # (fabric.setup), so the recorded class is _FabricModule — proving the + # monitor unwraps it to the real model class for capture. + underlying = _ReplayFabricToyModel(config=_ReplayConfig(scale=2.0)) + with torch.no_grad(): + underlying.lin.weight.fill_(0.5) + underlying.lin.bias.fill_(-0.25) + saved_weight = underlying.lin.weight.detach().clone() + saved_bias = underlying.lin.bias.detach().clone() + + fabric = Fabric(accelerator="cpu", devices=1) + wrapped = fabric.setup_module(underlying) + assert wrapped.__class__.__name__ == "_FabricModule" # sanity: wrapped + + monitor = NaNCaptureMonitor( + train_model=wrapped, + # Pass the config OBJECT (real ctors need it, not a dumped dict). + train_model_init_kwargs={"config": _ReplayConfig(scale=2.0)}, + debug_args=NaNCaptureArgs(enabled=True), + out_dir=tmp_path, + global_rank=0, + gradient_accumulation_steps=1, + ) + monitor.begin_step(99) + monitor.collect_batch({"x": torch.randn(2, 4)}) + wrapped.lin.weight.grad = torch.full_like( # type: ignore[attr-defined] + wrapped.lin.weight, + float("nan"), # type: ignore[attr-defined] + ) + with pytest.raises(NaNDetectedError): + monitor.check_and_capture(wrapped) + + capture_dir = tmp_path / "debug" / "nan_capture" / "rank0" + cap = load_nan_capture(capture_dir, device="cpu") + + # Blocker 1 regression guard: reconstructed model is the real class, + # not _FabricModule. + assert isinstance(cap.model, _ReplayFabricToyModel) + assert cap.model.config.scale == 2.0 # config object survived + # Blocker 2 regression guard: weights loaded from the (Fabric-wrapped) + # capture match the underlying model's, via clean state_dict keys. + assert torch.equal(cap.model.lin.weight.detach().cpu(), saved_weight) + assert torch.equal(cap.model.lin.bias.detach().cpu(), saved_bias) + assert cap.step == 99 + + # Replay runs forward+backward through the reconstructed (unwrapped) + # model and reports a clean (non-reproduced) result. + result = cap.replay() + assert isinstance(result, NaNReplayResult) + assert result.step == 99 + assert result.reproduced is False + assert len(result.results) == 1