Skip to content

native loader: detect and handle torchao quantization (closes #41, #42)#48

Merged
jdinalt merged 2 commits into
devfrom
feature/native-loader-quantization
May 15, 2026
Merged

native loader: detect and handle torchao quantization (closes #41, #42)#48
jdinalt merged 2 commits into
devfrom
feature/native-loader-quantization

Conversation

@jdinalt
Copy link
Copy Markdown
Owner

@jdinalt jdinalt commented May 15, 2026

Summary

Closes #41 (reopened) and #42. Teaches Forgather's native checkpoint loader (`src/forgather/ml/sharded_checkpoint.py:load_checkpoint`) to detect and handle torchao-quantized checkpoints. Every tool that loads through the native loader inherits the support — `forgather eval`, `forgather inf server -c`, trainer resume — with no CLI flag and no marker file.

This supersedes PR #46's approach. PR #46 detected quantization at the eval entry point and forced HF's `from_pretrained()` path, which violated the project's `-c CHECKPOINT_PATH` → native-loader contract. This PR reverts that shim and fixes the native loader itself.

What's new

`src/forgather/ml/quantization_detect.py` (new module):

  • `detect_torchao_quantization(model_dir=, state_dict=)` — two-tier detection:
    1. Parse `<model_dir>/config.json`'s `quantization_config` block through HF's `TorchAoConfig.from_dict()` (fast path; written by `forgather finalize --quantize`).
    2. Scan the saved state_dict for torchao tensor subclasses and reverse-lookup the base config from `IntxUnpackedToInt8Tensor` / `Int4Tensor` attributes. Covers checkpoints that lack the config hint.
  • `install_torchao_quantization(module, base_config)` — runs torchao's `prepare` → `convert` on the constructed module and registers torchao tensor subclasses as PyTorch safe-globals (so subsequent `torch.load(weights_only=True)` accepts them).

`src/forgather/ml/sharded_checkpoint.py:load_checkpoint`: inserts a single block before `load_state_dict`. When quantization is detected:

  • Installs quantized linear modules on the already-constructed module.
  • Forces `assign=True` (`Tensor.copy_` doesn't handle quantized-to-quantized copies cleanly).
  • Overrides the load device to the module's existing device — `assign=True` rebinds, so if we let the caller's staging device (often CPU in the trainer) become the map_location, the whole model migrates off GPU.

`tools/finalize_model/finalize_model.py:_apply_quantize`: refactored to delegate to `install_torchao_quantization`, dropping the duplicate prepare/convert logic.

`scripts/eval_script.py`: reverts PR #46's `_model_is_quantized` + `resolve_checkpoint` autodetect shim. The function is back to its original shape.

Three-way comparison (regression check)

Full Tiny Stories eval on the 4.43M Tiny Llama chinchilla baseline, RTX 3090, recipe `int8-dynamic-act-int4-weight`:

Path eval_loss Notes
bf16 baseline 1.3672 Unchanged (regression check).
Quantized (native loader, this PR) 1.3916 Same recipe, same artifact.
Quantized (HF loader, PR #46) 1.3917 Within noise; same numbers, different load path.

Inference throughput

Variant tok/s
bf16 baseline 379.9
`int8-dynamic-act-int4-weight` 61.9

Quantized inference is ~6× slower than bf16 at batch=1 on this 4.43M model. The slowdown is dequant-overhead-dominated at this scale; quantization wins (memory footprint, longer context, larger batch) appear at larger model sizes. Documented in `tools/inference_server/README.md`.

Tests

  • Unit: `tests/test_quantization_detect.py` — 11 tests covering config-based detection, state_dict reverse-lookup for the two v1 recipes that work on this hardware (`int8-dynamic-act-int4-weight`, `int4-weight-only`), the `float8` config-only path, walk-up-from-checkpoint-subdir behavior, plain bf16 → None, unknown subclass → ValueError, and install-swaps-linear correctness. All passing.
  • Integration: new `tests/integration/specs/tiny_llama_inference_quantized.yaml` plus a `FinalizeSpec` hook in `tests/integration/test_inference.py` that runs `forgather finalize --quantize` between train and serve. Full train→finalize→serve→perplexity path.

Test plan

  • Unit tests pass (11/11).
  • `forgather eval test tinystories -M <quantized_dir>` works through native loader (no `--no-checkpoint` flag, no autodetect override log). eval_loss matches HF-loader baseline.
  • bf16 regression: `forgather eval test tinystories -M <bf16_dir>` unchanged.
  • `forgather inf server -m <quantized_dir> --from-checkpoint` serves coherent completions.
  • state_dict-only detection: deleting `quantization_config` from config.json and re-loading still detects via the shard scan.
  • Integration test `tiny_llama_inference_quantized` passes (user-side / CI; takes ~10 min).
  • Webui inference modal submit (user-side check; unchanged code, but worth verifying nothing regressed).

🤖 Generated with Claude Code

jdinalt and others added 2 commits May 15, 2026 06:03
forgather.ml.sharded_checkpoint.load_checkpoint() now detects
torchao-quantized checkpoints (via config.json's quantization_config
block, falling back to a state_dict scan of the first shard) and
installs the matching quantized linear modules on the constructed
model before load_state_dict runs. Every tool that loads through the
native checkpoint loader inherits the support:

- forgather eval test ... -M <quantized_dir>
- forgather inf server -m <quantized_dir> --from-checkpoint
- Trainer resume_from_checkpoint

No CLI flag, no marker file: the recipe is recovered from
config.json's TorchAoConfig.quant_type (written by finalize) or from
the saved tensor subclass attributes (covers checkpoints that lack the
config hint).

New module src/forgather/ml/quantization_detect.py centralises:
- detect_torchao_quantization() — two-tier detection
- install_torchao_quantization() — prepare→convert + safe-globals
  registration so torch.load(weights_only=True) accepts the subclasses

Mechanics on the loader side:
- when quantization is detected, force assign=True (Tensor.copy_ does
  not handle quantized-to-quantized copies cleanly)
- override the load device to the module's existing device so the
  assigned tensors don't migrate the model to the caller-passed
  staging device (e.g. cpu when running on cuda:0)

Reverts PR #46's eval-side shim that detected the same config and
forced HF's from_pretrained() path. The shim violated the project's
"-c CHECKPOINT_PATH means native loader" contract; fix is on the
native loader instead.

Test coverage:
- tests/test_quantization_detect.py (11 unit tests, all recipes)
- tests/integration/specs/tiny_llama_inference_quantized.yaml +
  FinalizeSpec hook in tests/integration/test_inference.py

Verified end-to-end on the chinchilla baseline:
- bf16 eval_loss=1.367 (unchanged regression)
- quantized eval_loss=1.392 via native loader (matches PR #46's HF
  loader result within noise: 1.392)
- inference server -c <quantized_dir> serves coherent completions at
  61.9 tok/s vs bf16 379.9 tok/s on a 4.43M model (slowdown documented
  in inference README — quantization is memory-bound; throughput wins
  appear at larger scale)

Closes #41, #42.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… test

- Drop the `weights_only=False` pickle fallback in `_peek_first_shard`.
  Register torchao safe-globals eagerly before the load so
  `weights_only=True` keeps PyTorch's safe-default posture on .bin
  shards. The previous fallback allowed arbitrary pickled code on any
  shard whose contents happened to not load under the strict allowlist.
- Document the silent device override in `load_checkpoint`'s docstring
  and note that tied weights are restored post-load via the trainer's
  retie step (eval/inference paths don't grad-update tied weights, so
  the missing retie is behaviourally invisible).
- Guard `_module_device` against meta-device modules: skip meta and
  defer to the caller's fallback, otherwise `assign=True` would silently
  produce a fully-meta model.
- Document the "canonical Forgather recipe" contract on
  `_base_config_from_tensor` so future readers know the state_dict
  reverse-lookup coerces non-default packing formats / mapping types
  to defaults. Config.json path preserves them; state_dict path is
  best-effort for hand-crafted torchao configs.
- Widen the unknown-subclass ValueError with a float8-specific hint and
  point at restoring the config.json block as the cheaper alternative
  to re-finalizing.
- Add `test_forward_reverse_roundtrip` (parametrized over runnable
  recipes) asserting `_base_config_from_tensor` recovers the same
  fields `recipe_to_base_config` produced. Locks down drift between
  the forward and reverse maps.
- Fix the stale "routes through HF `from_pretrained()`" comment in the
  Quick Start code block of `qat-training.md` left over from PR #46.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@jdinalt
Copy link
Copy Markdown
Owner Author

jdinalt commented May 15, 2026

Review pass complete

Three review agents ran in parallel — loader correctness (Rev #1), detection-logic semantics (Rev #2), docs+tests (Rev #3). One real blocker each, plus a useful set of nits.

Applied in commit 8c6ca14

Security / correctness:

Detection:

  • Rev Add claude GitHub actions 1771648604662 #2 strong recommend: new `test_forward_reverse_roundtrip` parametrized over runnable recipes, asserting `_base_config_from_tensor(saved(forward(recipe)))` matches `forward(recipe)` for the user-facing fields. Locks down drift between `qat_recipes.recipe_to_base_config` and `quantization_detect._base_config_from_tensor`.
  • Rev Add claude GitHub actions 1771648604662 #2 nit 1: documented the "assumes canonical Forgather recipe" contract on `_base_config_from_tensor`. Non-default packing formats / mapping types are silently coerced on the state_dict path; config.json path preserves them.
  • Rev Add claude GitHub actions 1771648604662 #2 nits 3+5 (combined): widened the unknown-subclass `ValueError` to suggest restoring the `config.json` block (cheaper than re-finalizing) and added a float8-specific hint when the unknown subclass is a Float8 variant.

Docs:

Filed as follow-up

Skipped with reasoning

Verification after fixes

  • 13/13 unit tests pass.
  • Quantized eval through native loader: eval_loss=1.391616 (unchanged from pre-fix run; `weights_only=True` works end-to-end via eager safe-globals registration).

@jdinalt jdinalt merged commit f3ebdfc into dev May 15, 2026
1 check passed
@jdinalt jdinalt deleted the feature/native-loader-quantization branch May 15, 2026 06:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant