feat: add NaN capture and replay for fine-tuning#826
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 1d330952ed
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
|
/review |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 71b3fb4715
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
e20e8aa to
e18239b
Compare
532d2ba to
50b9143
Compare
50b9143 to
220a02a
Compare
liopeer
left a comment
There was a problem hiding this comment.
LGTM! Adressing the comments is optional.
40f6516 to
90ca5d7
Compare
…producing-bad-batch
What has changed and why?
NaNCapturedebug tool for fine-tuning (train_*task APIs): when aNaN/Inf is detected in parameter gradients (scanned after gradient
accumulation, before
clip_gradients/optimizer.step), it capturesreproducible state and halts training.
out_dir/debug/nan_capture/rank{R}/nan_capture.ptholding the model statedict, the
TrainModelclass path + init kwargs (for reconstruction), thestep's microbatches, and torch/CUDA RNG state. The standard
checkpoints/last.ckptis not touched, soresume_interruptedisunaffected.
load_nan_capture(dir)+NaNCaptureState.replay()for zero-setupreplay: reconstructs the model, restores RNG, and re-runs the triggering
forward+backward (mirrors the training loop; stops before the optimizer step)
to reproduce the NaN deterministically in a notebook/REPL.
_commands/train_task.pyalongside the existing underflow/overflowmonitor from feat: integrate HF DebugUnderflowOverflow into fine-tuning #814. Enable with
debug_args={"nancapture": {"enabled": True}}.Replay is debug-only — the training loop carries no replay flag.
Closes TRN-2256.
How has it been tested?
pytest tests/_debug/test_nan_capture.py→ 11/11 pass (config, monitor gradscan + capture payload, buffer clone/detach/reset, replay roundtrip +
reproduction).
make format→ clean (no unintended changes).mypyon the 4 changed files → clean.Did you update CHANGELOG.md?
Did you update the documentation?