native loader: detect and handle torchao quantization (closes #41, #42)#48
Merged
Conversation
forgather.ml.sharded_checkpoint.load_checkpoint() now detects torchao-quantized checkpoints (via config.json's quantization_config block, falling back to a state_dict scan of the first shard) and installs the matching quantized linear modules on the constructed model before load_state_dict runs. Every tool that loads through the native checkpoint loader inherits the support: - forgather eval test ... -M <quantized_dir> - forgather inf server -m <quantized_dir> --from-checkpoint - Trainer resume_from_checkpoint No CLI flag, no marker file: the recipe is recovered from config.json's TorchAoConfig.quant_type (written by finalize) or from the saved tensor subclass attributes (covers checkpoints that lack the config hint). New module src/forgather/ml/quantization_detect.py centralises: - detect_torchao_quantization() — two-tier detection - install_torchao_quantization() — prepare→convert + safe-globals registration so torch.load(weights_only=True) accepts the subclasses Mechanics on the loader side: - when quantization is detected, force assign=True (Tensor.copy_ does not handle quantized-to-quantized copies cleanly) - override the load device to the module's existing device so the assigned tensors don't migrate the model to the caller-passed staging device (e.g. cpu when running on cuda:0) Reverts PR #46's eval-side shim that detected the same config and forced HF's from_pretrained() path. The shim violated the project's "-c CHECKPOINT_PATH means native loader" contract; fix is on the native loader instead. Test coverage: - tests/test_quantization_detect.py (11 unit tests, all recipes) - tests/integration/specs/tiny_llama_inference_quantized.yaml + FinalizeSpec hook in tests/integration/test_inference.py Verified end-to-end on the chinchilla baseline: - bf16 eval_loss=1.367 (unchanged regression) - quantized eval_loss=1.392 via native loader (matches PR #46's HF loader result within noise: 1.392) - inference server -c <quantized_dir> serves coherent completions at 61.9 tok/s vs bf16 379.9 tok/s on a 4.43M model (slowdown documented in inference README — quantization is memory-bound; throughput wins appear at larger scale) Closes #41, #42. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… test - Drop the `weights_only=False` pickle fallback in `_peek_first_shard`. Register torchao safe-globals eagerly before the load so `weights_only=True` keeps PyTorch's safe-default posture on .bin shards. The previous fallback allowed arbitrary pickled code on any shard whose contents happened to not load under the strict allowlist. - Document the silent device override in `load_checkpoint`'s docstring and note that tied weights are restored post-load via the trainer's retie step (eval/inference paths don't grad-update tied weights, so the missing retie is behaviourally invisible). - Guard `_module_device` against meta-device modules: skip meta and defer to the caller's fallback, otherwise `assign=True` would silently produce a fully-meta model. - Document the "canonical Forgather recipe" contract on `_base_config_from_tensor` so future readers know the state_dict reverse-lookup coerces non-default packing formats / mapping types to defaults. Config.json path preserves them; state_dict path is best-effort for hand-crafted torchao configs. - Widen the unknown-subclass ValueError with a float8-specific hint and point at restoring the config.json block as the cheaper alternative to re-finalizing. - Add `test_forward_reverse_roundtrip` (parametrized over runnable recipes) asserting `_base_config_from_tensor` recovers the same fields `recipe_to_base_config` produced. Locks down drift between the forward and reverse maps. - Fix the stale "routes through HF `from_pretrained()`" comment in the Quick Start code block of `qat-training.md` left over from PR #46. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This was referenced May 15, 2026
Owner
Author
Review pass completeThree review agents ran in parallel — loader correctness (Rev #1), detection-logic semantics (Rev #2), docs+tests (Rev #3). One real blocker each, plus a useful set of nits. Applied in commit 8c6ca14Security / correctness:
Detection:
Docs:
Filed as follow-up
Skipped with reasoning
Verification after fixes
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Closes #41 (reopened) and #42. Teaches Forgather's native checkpoint loader (`src/forgather/ml/sharded_checkpoint.py:load_checkpoint`) to detect and handle torchao-quantized checkpoints. Every tool that loads through the native loader inherits the support — `forgather eval`, `forgather inf server -c`, trainer resume — with no CLI flag and no marker file.
This supersedes PR #46's approach. PR #46 detected quantization at the eval entry point and forced HF's `from_pretrained()` path, which violated the project's `-c CHECKPOINT_PATH` → native-loader contract. This PR reverts that shim and fixes the native loader itself.
What's new
`src/forgather/ml/quantization_detect.py` (new module):
`src/forgather/ml/sharded_checkpoint.py:load_checkpoint`: inserts a single block before `load_state_dict`. When quantization is detected:
`tools/finalize_model/finalize_model.py:_apply_quantize`: refactored to delegate to `install_torchao_quantization`, dropping the duplicate prepare/convert logic.
`scripts/eval_script.py`: reverts PR #46's `_model_is_quantized` + `resolve_checkpoint` autodetect shim. The function is back to its original shape.
Three-way comparison (regression check)
Full Tiny Stories eval on the 4.43M Tiny Llama chinchilla baseline, RTX 3090, recipe `int8-dynamic-act-int4-weight`:
Inference throughput
Quantized inference is ~6× slower than bf16 at batch=1 on this 4.43M model. The slowdown is dequant-overhead-dominated at this scale; quantization wins (memory footprint, longer context, larger batch) appear at larger model sizes. Documented in `tools/inference_server/README.md`.
Tests
Test plan
🤖 Generated with Claude Code