Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm

### Added

- Add `NaNCapture` for fine-tuning debugging: when a NaN/Inf is detected in parameter
gradients during training, save a self-contained capture (model state dict +
TrainModel class/init kwargs + the step's microbatches + RNG state) to
`out_dir/debug/nan_capture/rank{R}/nan_capture.pt` and halt training. Replay via
`lightly_train._debug.nan_capture.load_nan_capture(dir).replay()` to deterministically
reproduce the failure in a notebook/REPL. Enable with
`debug_args.nancapture.enabled=True`.
- Add ONNX and TensorRT export for depth estimation models via the `export_onnx` and
`export_tensorrt` methods of `DepthAnythingDepthEstimation`.

Expand Down
115 changes: 77 additions & 38 deletions src/lightly_train/_commands/train_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
)
from lightly_train._debug import debug_args, underflow_overflow
from lightly_train._debug.debug_args import DebugArgs
from lightly_train._debug.nan_capture import NaNCaptureMonitor
from lightly_train._debug.underflow_overflow import UnderflowOverflowMonitor
from lightly_train._events import tracker
from lightly_train._loggers.task_logger_args import TaskLoggerArgs
Expand Down Expand Up @@ -227,11 +228,12 @@ def train_image_classification(
Set to 1 to explicitly disable gradient accumulation.
debug_args:
Debug configuration dict. `None` disables debugging; keys
configure individual debug tools. The currently supported key is
``underflow_overflow``, which enables detection of inf/nan
activations and weights via forward hooks on all model modules.
Reports are written to per-rank log files in ``out/debug/``.
Cannot be combined with
configure individual debug tools. Supported keys are
``underflow_overflow`` (detect inf/nan activations and weights via
forward hooks on all model modules) and ``nancapture`` (capture a
replayable bad-batch snapshot when parameter gradients become
non-finite). Reports/captures are written under ``out/debug/``.
``underflow_overflow`` cannot be combined with
``torch_compile_args={"disable": False}``.
"""
kwargs = {**locals()}
Expand Down Expand Up @@ -381,11 +383,12 @@ def train_image_classification_multihead(
Set to 1 to explicitly disable gradient accumulation.
debug_args:
Debug configuration dict. `None` disables debugging; keys
configure individual debug tools. The currently supported key is
``underflow_overflow``, which enables detection of inf/nan
activations and weights via forward hooks on all model modules.
Reports are written to per-rank log files in ``out/debug/``.
Cannot be combined with
configure individual debug tools. Supported keys are
``underflow_overflow`` (detect inf/nan activations and weights via
forward hooks on all model modules) and ``nancapture`` (capture a
replayable bad-batch snapshot when parameter gradients become
non-finite). Reports/captures are written under ``out/debug/``.
``underflow_overflow`` cannot be combined with
``torch_compile_args={"disable": False}``.
"""
kwargs = {**locals()}
Expand Down Expand Up @@ -562,11 +565,12 @@ def train_instance_segmentation(
Set to 1 to explicitly disable gradient accumulation.
debug_args:
Debug configuration dict. `None` disables debugging; keys
configure individual debug tools. The currently supported key is
``underflow_overflow``, which enables detection of inf/nan
activations and weights via forward hooks on all model modules.
Reports are written to per-rank log files in ``out/debug/``.
Cannot be combined with
configure individual debug tools. Supported keys are
``underflow_overflow`` (detect inf/nan activations and weights via
forward hooks on all model modules) and ``nancapture`` (capture a
replayable bad-batch snapshot when parameter gradients become
non-finite). Reports/captures are written under ``out/debug/``.
``underflow_overflow`` cannot be combined with
``torch_compile_args={"disable": False}``.
"""
tracker.track_training_started(
Expand Down Expand Up @@ -730,11 +734,12 @@ def train_object_detection(
Set to 1 to explicitly disable gradient accumulation.
debug_args:
Debug configuration dict. `None` disables debugging; keys
configure individual debug tools. The currently supported key is
``underflow_overflow``, which enables detection of inf/nan
activations and weights via forward hooks on all model modules.
Reports are written to per-rank log files in ``out/debug/``.
Cannot be combined with
configure individual debug tools. Supported keys are
``underflow_overflow`` (detect inf/nan activations and weights via
forward hooks on all model modules) and ``nancapture`` (capture a
replayable bad-batch snapshot when parameter gradients become
non-finite). Reports/captures are written under ``out/debug/``.
``underflow_overflow`` cannot be combined with
``torch_compile_args={"disable": False}``.
"""
tracker.track_training_started(
Expand Down Expand Up @@ -899,11 +904,12 @@ def train_panoptic_segmentation(
Set to 1 to explicitly disable gradient accumulation.
debug_args:
Debug configuration dict. `None` disables debugging; keys
configure individual debug tools. The currently supported key is
``underflow_overflow``, which enables detection of inf/nan
activations and weights via forward hooks on all model modules.
Reports are written to per-rank log files in ``out/debug/``.
Cannot be combined with
configure individual debug tools. Supported keys are
``underflow_overflow`` (detect inf/nan activations and weights via
forward hooks on all model modules) and ``nancapture`` (capture a
replayable bad-batch snapshot when parameter gradients become
non-finite). Reports/captures are written under ``out/debug/``.
``underflow_overflow`` cannot be combined with
``torch_compile_args={"disable": False}``.
"""
tracker.track_training_started(
Expand Down Expand Up @@ -1067,11 +1073,12 @@ def train_semantic_segmentation(
Set to 1 to explicitly disable gradient accumulation.
debug_args:
Debug configuration dict. `None` disables debugging; keys
configure individual debug tools. The currently supported key is
``underflow_overflow``, which enables detection of inf/nan
activations and weights via forward hooks on all model modules.
Reports are written to per-rank log files in ``out/debug/``.
Cannot be combined with
configure individual debug tools. Supported keys are
``underflow_overflow`` (detect inf/nan activations and weights via
forward hooks on all model modules) and ``nancapture`` (capture a
replayable bad-batch snapshot when parameter gradients become
non-finite). Reports/captures are written under ``out/debug/``.
``underflow_overflow`` cannot be combined with
``torch_compile_args={"disable": False}``.
"""
tracker.track_training_started(
Expand Down Expand Up @@ -1200,11 +1207,12 @@ def train_semantic_segmentation_multihead(
Set to 1 to explicitly disable gradient accumulation.
debug_args:
Debug configuration dict. `None` disables debugging; keys
configure individual debug tools. The currently supported key is
``underflow_overflow``, which enables detection of inf/nan
activations and weights via forward hooks on all model modules.
Reports are written to per-rank log files in ``out/debug/``.
Cannot be combined with
configure individual debug tools. Supported keys are
``underflow_overflow`` (detect inf/nan activations and weights via
forward hooks on all model modules) and ``nancapture`` (capture a
replayable bad-batch snapshot when parameter gradients become
non-finite). Reports/captures are written under ``out/debug/``.
``underflow_overflow`` cannot be combined with
``torch_compile_args={"disable": False}``.
"""
tracker.track_training_started(
Expand Down Expand Up @@ -1577,7 +1585,21 @@ def _train_task_from_config(config: TrainTaskConfig) -> None:
else:
monitor_ctx = contextlib.nullcontext(None)

with monitor_ctx as underflow_overflow_monitor:
nancapture_ctx: contextlib.AbstractContextManager[NaNCaptureMonitor | None]
if config.debug_args.is_nancapture_enabled():
assert config.debug_args.nancapture is not None
nancapture_ctx = NaNCaptureMonitor(
train_model=train_model,
train_model_init_kwargs=train_model_init_kwargs,
debug_args=config.debug_args.nancapture,
out_dir=out_dir,
global_rank=fabric.global_rank,
gradient_accumulation_steps=config.gradient_accumulation_steps,
)
else:
nancapture_ctx = contextlib.nullcontext(None)

with monitor_ctx as underflow_overflow_monitor, nancapture_ctx as nancapture_monitor:
logger.info(
f"Resolved Args: {helpers.pretty_format_args(args=config.model_dump())}"
)
Expand Down Expand Up @@ -1708,6 +1730,12 @@ def _train_task_from_config(config: TrainTaskConfig) -> None:
if underflow_overflow_monitor is not None:
underflow_overflow_monitor.set_step(step=step)

# Reset the NaNCapture monitor's per-step buffer and snapshot RNG so
# that, if a NaN is detected at this step, the capture includes the
# exact microbatches and stochastic state needed for replay.
if nancapture_monitor is not None:
nancapture_monitor.begin_step(step=step)

train_transform.set_step(step)
train_collate_fn.set_step(step)

Expand All @@ -1729,6 +1757,11 @@ def _train_task_from_config(config: TrainTaskConfig) -> None:
batch = next(infinite_train_dataloader)
timer.end_step("train_dataload")

# Clone+detach this microbatch into the NaNCapture buffer so the
# captured state is decoupled from the live autograd graph.
if nancapture_monitor is not None:
nancapture_monitor.collect_batch(batch)

# Type ignore is needed because `train_model` is not recognized as an
# instance of `_FabricModule`
with fabric.no_backward_sync(train_model, enabled=is_accumulating): # type: ignore[arg-type]
Expand All @@ -1749,6 +1782,12 @@ def _train_task_from_config(config: TrainTaskConfig) -> None:
)

# Optimizer step and scheduler step.
# Scan parameter gradients for NaN/Inf before the optimizer step
# (which would corrupt the model via the bad gradient). On detection,
# the monitor saves the capture and raises NaNDetectedError, halting
# training.
if nancapture_monitor is not None:
nancapture_monitor.check_and_capture(train_model)
# clip_gradients returns the total gradient norm before clipping. It is
# None for models that do not support gradient norm logging.
gradient_norm = train_model.clip_gradients(
Expand Down Expand Up @@ -1965,8 +2004,8 @@ def _train_task_from_config(config: TrainTaskConfig) -> None:
)
train_model.set_train_mode()
fabric.barrier()
timer.stop()
logger.info("Training completed.")
timer.stop()
logger.info("Training completed.")


class TrainTaskConfig(PydanticConfig):
Expand Down
29 changes: 29 additions & 0 deletions src/lightly_train/_debug/debug_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,42 @@ def _validate_trace_batch_nums(cls, v: list[int]) -> list[int]:
return v


class NaNCaptureArgs(PydanticConfig):
"""Arguments for NaN/Inf capture debugging.

When enabled, the monitor scans parameter gradients for NaN/Inf after each
gradient-accumulation step (before ``clip_gradients``/``optimizer.step``).
On detection, it saves a self-contained capture to
``out_dir/debug/nan_capture/rank{R}/nan_capture.pt`` holding the model
state, the step's microbatches, RNG state, and metadata — then raises
:class:`NaNDetectedError` to halt training. The capture is reproducibly
loadable via :func:`lightly_train._debug.nan_capture.load_nan_capture`
and ``.replay()``.
"""

enabled: bool = Field(
default=False,
description=(
"Whether to enable NaN/Inf capture debugging. When True, scans "
"parameter gradients after each accumulated training step and, on "
"first non-finite gradient, saves a replayable capture to "
"out_dir/debug/nan_capture/rank{R}/nan_capture.pt before aborting."
),
)


class DebugArgs(PydanticConfig):
underflow_overflow: DebugUnderflowOverflowArgs | None = None
nancapture: NaNCaptureArgs | None = None

def is_underflow_overflow_enabled(self) -> bool:
"""Returns True if underflow/overflow debugging is enabled."""
return self.underflow_overflow is not None and self.underflow_overflow.enabled

def is_nancapture_enabled(self) -> bool:
"""Returns True if NaN/Inf capture debugging is enabled."""
return self.nancapture is not None and self.nancapture.enabled


def get_debug_args(debug_args: dict[str, Any] | DebugArgs | None) -> DebugArgs:
"""Resolves the debug arguments into a :class:`DebugArgs` instance."""
Expand Down
Loading