Skip to content

eval/finalize: load torchao-quantized models without a flag#46

Merged
jdinalt merged 2 commits into
devfrom
feature/eval-quantized-models
May 15, 2026
Merged

eval/finalize: load torchao-quantized models without a flag#46
jdinalt merged 2 commits into
devfrom
feature/eval-quantized-models

Conversation

@jdinalt
Copy link
Copy Markdown
Owner

@jdinalt jdinalt commented May 15, 2026

Summary

Closes #41. `forgather eval` can now load and evaluate `forgather finalize --quantize`-produced artifacts with no extra CLI flag and no marker file. Issue #41 originally proposed both; investigation showed HF's existing config-based autodetection makes them unnecessary.

Two minimal changes:

  1. Finalize writes a `quantization_config` block into the saved `config.json` (lines added after `_apply_quantize`). On reload, HF `from_pretrained()` instantiates a `TorchAoHfQuantizer`, pre-processes the model with the right quantized linear modules, then runs `load_state_dict` — so torchao tensor subclasses land in slots that know how to hold them.

  2. Eval (`scripts/eval_script.py:resolve_checkpoint`) reads `config.json` at load time. If `quantization_config` is present, it forces the `from_pretrained()` path regardless of `--checkpoint` / `--no-checkpoint` flags. Without this, eval's default checkpoint-resume path uses `from_config()` + `load_state_dict()` which has no quantizer hook and fails with `'Parameter' object has no attribute 'tensor_data_names'` on quantized subclasses.

Bf16 models are unaffected — autodetect is a no-op when the config block is absent.

Three-way comparison

Run on the 4.43M Tiny Llama from PR #39 verification (RTX 3090, full Tiny Stories eval, recipe `int8-dynamic-act-int4-weight`):

Model eval_loss perplexity Δ vs bf16
bf16 baseline 1.3656 3.918
PTQ on bf16 baseline 1.3917 4.022 +0.0262
QAT-trained + converted 1.3905 4.017 +0.0249

QAT shaves ~0.0013 eval-loss off PTQ at this scale — real but small. Doc is honest about it.

Same mechanism unblocks issue #42 (inference server already uses `from_pretrained()`; once the artifact carries `quantization_config` it picks up the optimized path automatically — #42 stays open for the webui-control + docs work).

Files changed

  • `tools/finalize_model/finalize_model.py` — attach `config.quantization_config = TorchAoConfig(...)` after `_apply_quantize`.
  • `scripts/eval_script.py` — `resolve_checkpoint()` reads config.json, forces `from_pretrained()` path on quantized models.
  • `docs/trainers/qat-training.md` — new "Evaluating Quantized Models" section with autodetect mechanism + three-way comparison table.
  • `docs/guides/finalize-model.md` — note on the new `quantization_config` write behavior.
  • `docs/guides/evaluating-models.md` — note in "Model loading" about quantized-model autodetect.

Test plan

  • Smoke: `forgather eval test tinystories -M <quantized_dir>` returns finite eval_loss (no `--no-checkpoint` needed).
  • Autodetect log fires for quantized models, not for bf16 ones.
  • Regression: `forgather eval test tinystories -M <bf16_dir>` unchanged.
  • `config.json` after finalize contains a `quant_method: torchao` block with the right `quant_type`.
  • Full three-way comparison (numbers above).
  • Webui eval submit (user-side check, not on this branch).

🤖 Generated with Claude Code

jdinalt and others added 2 commits May 15, 2026 02:51
Finalize now writes a `quantization_config` block into the saved
config.json so HF `from_pretrained()` runs the `TorchAoHfQuantizer`
pre-process path on reload — the quantized linear modules are
installed before `load_state_dict`, so the saved tensor subclasses
land in slots that know how to hold them.

Eval autodetects the block at load time and forces the
`from_pretrained()` path regardless of `--checkpoint` /
`--no-checkpoint`. The default checkpoint-resume path uses
`from_config()` + `load_state_dict()` which has no quantizer hook and
fails with "'Parameter' object has no attribute 'tensor_data_names'"
on torchao quantized tensor subclasses.

No new CLI flag and no marker file — issue #41 originally proposed
both, but HF's existing config-based autodetection makes them
unnecessary. Same mechanism benefits the inference server (issue #42)
for free.

Three-way comparison verified on 4.43M Tiny Llama (RTX 3090, full
Tiny Stories eval, recipe `int8-dynamic-act-int4-weight`):
- bf16 baseline: eval_loss=1.366
- PTQ on bf16:   eval_loss=1.392 (+0.026)
- QAT converted: eval_loss=1.391 (+0.025; -0.001 vs PTQ)

Numbers and the loader autodetection mechanism documented in
docs/trainers/qat-training.md, with cross-links from
docs/guides/finalize-model.md and docs/guides/evaluating-models.md.

Closes #41.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…c polish

- Cache `resolve_checkpoint()` result on `args` so the autodetect log
  fires once per run instead of twice.
- Move the quantization probe ahead of the `--checkpoint` short-circuit:
  explicit `--checkpoint PATH` on a quantized model previously slid
  through to the failing `from_config()` path and crashed with the
  tensor_data_names error. Now it logs a WARNING naming the overridden
  path and routes through `from_pretrained()`.
- Factor out `_model_is_quantized()` helper.
- Dry-run logs that it would write a `quantization_config` block.
- Docs: clarify that autodetect overrides both checkpoint flags (not
  just "checkpoint resume"); s/perplexity premium/eval-loss premium/
  in the three-way comparison; forward-reference the measured Δ from
  the Choosing-a-Recipe table to the Three-Way Comparison section.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@jdinalt
Copy link
Copy Markdown
Owner Author

jdinalt commented May 15, 2026

Review pass complete

Two review agents ran in parallel. Both said no blockers.

Applied in commit be638c5

  • Autodetect cache — `resolve_checkpoint()` now caches its result on `args`, so the "Detected quantization_config..." log fires once per run instead of twice. Both reviewers flagged the duplicate.
  • Explicit `--checkpoint PATH` override warning — previously, `--checkpoint ` on a quantized model short-circuited before the autodetect check and crashed with the `tensor_data_names` error. Moved the quantization probe ahead of the checkpoint branch; explicit paths now log a WARNING naming the overridden path and route through `from_pretrained()`. (Reviewer Add comprehensive unit tests for forgather modules #1 nit 1.)
  • Dry-run log — `--quantize --dry-run` now logs "Would write 'quantization_config' block to config.json". (Reviewer Add comprehensive unit tests for forgather modules #1 nit 2.)
  • Doc clarifications:
    • Log message + prose changed from "instead of checkpoint resume" to "ignoring checkpoint flags" — accurate phrasing for the override behavior. (Both reviewers.)
    • s/perplexity premium/eval-loss premium/ in the Three-Way Comparison section — the +0.025 number is eval-loss, not perplexity. (Reviewer Add claude GitHub actions 1771648604662 #2 nit 4.)
    • Forward reference from the "Choosing a Recipe" table to the measured 0.0013 Δ in the Three-Way Comparison section, so the qualitative claim and the empirical datapoint don't read as contradictory. (Reviewer Add claude GitHub actions 1771648604662 #2 nit 1.)

Also factored out a tiny `_model_is_quantized()` helper while refactoring `resolve_checkpoint()` — keeps the function readable.

Filed as follow-up

Skipped with reasoning

@jdinalt jdinalt merged commit 7041259 into dev May 15, 2026
1 check passed
@jdinalt jdinalt deleted the feature/eval-quantized-models branch May 15, 2026 03:04
jdinalt added a commit that referenced this pull request May 15, 2026
… test

- Drop the `weights_only=False` pickle fallback in `_peek_first_shard`.
  Register torchao safe-globals eagerly before the load so
  `weights_only=True` keeps PyTorch's safe-default posture on .bin
  shards. The previous fallback allowed arbitrary pickled code on any
  shard whose contents happened to not load under the strict allowlist.
- Document the silent device override in `load_checkpoint`'s docstring
  and note that tied weights are restored post-load via the trainer's
  retie step (eval/inference paths don't grad-update tied weights, so
  the missing retie is behaviourally invisible).
- Guard `_module_device` against meta-device modules: skip meta and
  defer to the caller's fallback, otherwise `assign=True` would silently
  produce a fully-meta model.
- Document the "canonical Forgather recipe" contract on
  `_base_config_from_tensor` so future readers know the state_dict
  reverse-lookup coerces non-default packing formats / mapping types
  to defaults. Config.json path preserves them; state_dict path is
  best-effort for hand-crafted torchao configs.
- Widen the unknown-subclass ValueError with a float8-specific hint and
  point at restoring the config.json block as the cheaper alternative
  to re-finalizing.
- Add `test_forward_reverse_roundtrip` (parametrized over runnable
  recipes) asserting `_base_config_from_tensor` recovers the same
  fields `recipe_to_base_config` produced. Locks down drift between
  the forward and reverse maps.
- Fix the stale "routes through HF `from_pretrained()`" comment in the
  Quick Start code block of `qat-training.md` left over from PR #46.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
jdinalt added a commit that referenced this pull request May 15, 2026
… (#48)

* native loader: detect and handle torchao quantization

forgather.ml.sharded_checkpoint.load_checkpoint() now detects
torchao-quantized checkpoints (via config.json's quantization_config
block, falling back to a state_dict scan of the first shard) and
installs the matching quantized linear modules on the constructed
model before load_state_dict runs. Every tool that loads through the
native checkpoint loader inherits the support:

- forgather eval test ... -M <quantized_dir>
- forgather inf server -m <quantized_dir> --from-checkpoint
- Trainer resume_from_checkpoint

No CLI flag, no marker file: the recipe is recovered from
config.json's TorchAoConfig.quant_type (written by finalize) or from
the saved tensor subclass attributes (covers checkpoints that lack the
config hint).

New module src/forgather/ml/quantization_detect.py centralises:
- detect_torchao_quantization() — two-tier detection
- install_torchao_quantization() — prepare→convert + safe-globals
  registration so torch.load(weights_only=True) accepts the subclasses

Mechanics on the loader side:
- when quantization is detected, force assign=True (Tensor.copy_ does
  not handle quantized-to-quantized copies cleanly)
- override the load device to the module's existing device so the
  assigned tensors don't migrate the model to the caller-passed
  staging device (e.g. cpu when running on cuda:0)

Reverts PR #46's eval-side shim that detected the same config and
forced HF's from_pretrained() path. The shim violated the project's
"-c CHECKPOINT_PATH means native loader" contract; fix is on the
native loader instead.

Test coverage:
- tests/test_quantization_detect.py (11 unit tests, all recipes)
- tests/integration/specs/tiny_llama_inference_quantized.yaml +
  FinalizeSpec hook in tests/integration/test_inference.py

Verified end-to-end on the chinchilla baseline:
- bf16 eval_loss=1.367 (unchanged regression)
- quantized eval_loss=1.392 via native loader (matches PR #46's HF
  loader result within noise: 1.392)
- inference server -c <quantized_dir> serves coherent completions at
  61.9 tok/s vs bf16 379.9 tok/s on a 4.43M model (slowdown documented
  in inference README — quantization is memory-bound; throughput wins
  appear at larger scale)

Closes #41, #42.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* review fixes: eager safe-globals, doc clarifications, forward/reverse test

- Drop the `weights_only=False` pickle fallback in `_peek_first_shard`.
  Register torchao safe-globals eagerly before the load so
  `weights_only=True` keeps PyTorch's safe-default posture on .bin
  shards. The previous fallback allowed arbitrary pickled code on any
  shard whose contents happened to not load under the strict allowlist.
- Document the silent device override in `load_checkpoint`'s docstring
  and note that tied weights are restored post-load via the trainer's
  retie step (eval/inference paths don't grad-update tied weights, so
  the missing retie is behaviourally invisible).
- Guard `_module_device` against meta-device modules: skip meta and
  defer to the caller's fallback, otherwise `assign=True` would silently
  produce a fully-meta model.
- Document the "canonical Forgather recipe" contract on
  `_base_config_from_tensor` so future readers know the state_dict
  reverse-lookup coerces non-default packing formats / mapping types
  to defaults. Config.json path preserves them; state_dict path is
  best-effort for hand-crafted torchao configs.
- Widen the unknown-subclass ValueError with a float8-specific hint and
  point at restoring the config.json block as the cheaper alternative
  to re-finalizing.
- Add `test_forward_reverse_roundtrip` (parametrized over runnable
  recipes) asserting `_base_config_from_tensor` recovers the same
  fields `recipe_to_base_config` produced. Locks down drift between
  the forward and reverse maps.
- Fix the stale "routes through HF `from_pretrained()`" comment in the
  Quick Start code block of `qat-training.md` left over from PR #46.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant