eval/finalize: load torchao-quantized models without a flag#46
Merged
Conversation
Finalize now writes a `quantization_config` block into the saved config.json so HF `from_pretrained()` runs the `TorchAoHfQuantizer` pre-process path on reload — the quantized linear modules are installed before `load_state_dict`, so the saved tensor subclasses land in slots that know how to hold them. Eval autodetects the block at load time and forces the `from_pretrained()` path regardless of `--checkpoint` / `--no-checkpoint`. The default checkpoint-resume path uses `from_config()` + `load_state_dict()` which has no quantizer hook and fails with "'Parameter' object has no attribute 'tensor_data_names'" on torchao quantized tensor subclasses. No new CLI flag and no marker file — issue #41 originally proposed both, but HF's existing config-based autodetection makes them unnecessary. Same mechanism benefits the inference server (issue #42) for free. Three-way comparison verified on 4.43M Tiny Llama (RTX 3090, full Tiny Stories eval, recipe `int8-dynamic-act-int4-weight`): - bf16 baseline: eval_loss=1.366 - PTQ on bf16: eval_loss=1.392 (+0.026) - QAT converted: eval_loss=1.391 (+0.025; -0.001 vs PTQ) Numbers and the loader autodetection mechanism documented in docs/trainers/qat-training.md, with cross-links from docs/guides/finalize-model.md and docs/guides/evaluating-models.md. Closes #41. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…c polish - Cache `resolve_checkpoint()` result on `args` so the autodetect log fires once per run instead of twice. - Move the quantization probe ahead of the `--checkpoint` short-circuit: explicit `--checkpoint PATH` on a quantized model previously slid through to the failing `from_config()` path and crashed with the tensor_data_names error. Now it logs a WARNING naming the overridden path and routes through `from_pretrained()`. - Factor out `_model_is_quantized()` helper. - Dry-run logs that it would write a `quantization_config` block. - Docs: clarify that autodetect overrides both checkpoint flags (not just "checkpoint resume"); s/perplexity premium/eval-loss premium/ in the three-way comparison; forward-reference the measured Δ from the Choosing-a-Recipe table to the Three-Way Comparison section. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Owner
Author
Review pass completeTwo review agents ran in parallel. Both said no blockers. Applied in commit be638c5
Also factored out a tiny `_model_is_quantized()` helper while refactoring `resolve_checkpoint()` — keeps the function readable. Filed as follow-up
Skipped with reasoning
|
This was referenced May 15, 2026
jdinalt
added a commit
that referenced
this pull request
May 15, 2026
… test - Drop the `weights_only=False` pickle fallback in `_peek_first_shard`. Register torchao safe-globals eagerly before the load so `weights_only=True` keeps PyTorch's safe-default posture on .bin shards. The previous fallback allowed arbitrary pickled code on any shard whose contents happened to not load under the strict allowlist. - Document the silent device override in `load_checkpoint`'s docstring and note that tied weights are restored post-load via the trainer's retie step (eval/inference paths don't grad-update tied weights, so the missing retie is behaviourally invisible). - Guard `_module_device` against meta-device modules: skip meta and defer to the caller's fallback, otherwise `assign=True` would silently produce a fully-meta model. - Document the "canonical Forgather recipe" contract on `_base_config_from_tensor` so future readers know the state_dict reverse-lookup coerces non-default packing formats / mapping types to defaults. Config.json path preserves them; state_dict path is best-effort for hand-crafted torchao configs. - Widen the unknown-subclass ValueError with a float8-specific hint and point at restoring the config.json block as the cheaper alternative to re-finalizing. - Add `test_forward_reverse_roundtrip` (parametrized over runnable recipes) asserting `_base_config_from_tensor` recovers the same fields `recipe_to_base_config` produced. Locks down drift between the forward and reverse maps. - Fix the stale "routes through HF `from_pretrained()`" comment in the Quick Start code block of `qat-training.md` left over from PR #46. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
jdinalt
added a commit
that referenced
this pull request
May 15, 2026
… (#48) * native loader: detect and handle torchao quantization forgather.ml.sharded_checkpoint.load_checkpoint() now detects torchao-quantized checkpoints (via config.json's quantization_config block, falling back to a state_dict scan of the first shard) and installs the matching quantized linear modules on the constructed model before load_state_dict runs. Every tool that loads through the native checkpoint loader inherits the support: - forgather eval test ... -M <quantized_dir> - forgather inf server -m <quantized_dir> --from-checkpoint - Trainer resume_from_checkpoint No CLI flag, no marker file: the recipe is recovered from config.json's TorchAoConfig.quant_type (written by finalize) or from the saved tensor subclass attributes (covers checkpoints that lack the config hint). New module src/forgather/ml/quantization_detect.py centralises: - detect_torchao_quantization() — two-tier detection - install_torchao_quantization() — prepare→convert + safe-globals registration so torch.load(weights_only=True) accepts the subclasses Mechanics on the loader side: - when quantization is detected, force assign=True (Tensor.copy_ does not handle quantized-to-quantized copies cleanly) - override the load device to the module's existing device so the assigned tensors don't migrate the model to the caller-passed staging device (e.g. cpu when running on cuda:0) Reverts PR #46's eval-side shim that detected the same config and forced HF's from_pretrained() path. The shim violated the project's "-c CHECKPOINT_PATH means native loader" contract; fix is on the native loader instead. Test coverage: - tests/test_quantization_detect.py (11 unit tests, all recipes) - tests/integration/specs/tiny_llama_inference_quantized.yaml + FinalizeSpec hook in tests/integration/test_inference.py Verified end-to-end on the chinchilla baseline: - bf16 eval_loss=1.367 (unchanged regression) - quantized eval_loss=1.392 via native loader (matches PR #46's HF loader result within noise: 1.392) - inference server -c <quantized_dir> serves coherent completions at 61.9 tok/s vs bf16 379.9 tok/s on a 4.43M model (slowdown documented in inference README — quantization is memory-bound; throughput wins appear at larger scale) Closes #41, #42. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * review fixes: eager safe-globals, doc clarifications, forward/reverse test - Drop the `weights_only=False` pickle fallback in `_peek_first_shard`. Register torchao safe-globals eagerly before the load so `weights_only=True` keeps PyTorch's safe-default posture on .bin shards. The previous fallback allowed arbitrary pickled code on any shard whose contents happened to not load under the strict allowlist. - Document the silent device override in `load_checkpoint`'s docstring and note that tied weights are restored post-load via the trainer's retie step (eval/inference paths don't grad-update tied weights, so the missing retie is behaviourally invisible). - Guard `_module_device` against meta-device modules: skip meta and defer to the caller's fallback, otherwise `assign=True` would silently produce a fully-meta model. - Document the "canonical Forgather recipe" contract on `_base_config_from_tensor` so future readers know the state_dict reverse-lookup coerces non-default packing formats / mapping types to defaults. Config.json path preserves them; state_dict path is best-effort for hand-crafted torchao configs. - Widen the unknown-subclass ValueError with a float8-specific hint and point at restoring the config.json block as the cheaper alternative to re-finalizing. - Add `test_forward_reverse_roundtrip` (parametrized over runnable recipes) asserting `_base_config_from_tensor` recovers the same fields `recipe_to_base_config` produced. Locks down drift between the forward and reverse maps. - Fix the stale "routes through HF `from_pretrained()`" comment in the Quick Start code block of `qat-training.md` left over from PR #46. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Closes #41. `forgather eval` can now load and evaluate `forgather finalize --quantize`-produced artifacts with no extra CLI flag and no marker file. Issue #41 originally proposed both; investigation showed HF's existing config-based autodetection makes them unnecessary.
Two minimal changes:
Finalize writes a `quantization_config` block into the saved `config.json` (lines added after `_apply_quantize`). On reload, HF `from_pretrained()` instantiates a `TorchAoHfQuantizer`, pre-processes the model with the right quantized linear modules, then runs `load_state_dict` — so torchao tensor subclasses land in slots that know how to hold them.
Eval (`scripts/eval_script.py:resolve_checkpoint`) reads `config.json` at load time. If `quantization_config` is present, it forces the `from_pretrained()` path regardless of `--checkpoint` / `--no-checkpoint` flags. Without this, eval's default checkpoint-resume path uses `from_config()` + `load_state_dict()` which has no quantizer hook and fails with `'Parameter' object has no attribute 'tensor_data_names'` on quantized subclasses.
Bf16 models are unaffected — autodetect is a no-op when the config block is absent.
Three-way comparison
Run on the 4.43M Tiny Llama from PR #39 verification (RTX 3090, full Tiny Stories eval, recipe `int8-dynamic-act-int4-weight`):
QAT shaves ~0.0013 eval-loss off PTQ at this scale — real but small. Doc is honest about it.
Same mechanism unblocks issue #42 (inference server already uses `from_pretrained()`; once the artifact carries `quantization_config` it picks up the optimized path automatically — #42 stays open for the webui-control + docs work).
Files changed
Test plan
🤖 Generated with Claude Code