Skip to content

torchao QAT support: --qat-recipe (trainer) + --qat-convert (finalize)#39

Merged
jdinalt merged 3 commits into
devfrom
feature/qat-training
May 14, 2026
Merged

torchao QAT support: --qat-recipe (trainer) + --qat-convert (finalize)#39
jdinalt merged 3 commits into
devfrom
feature/qat-training

Conversation

@jdinalt
Copy link
Copy Markdown
Owner

@jdinalt jdinalt commented May 14, 2026

Summary

Adds torchao Quantization-Aware Training to Forgather, split across the two natural integration points so train and deploy concerns stay separated:

  • Trainer (prepare) — new --qat-recipe argument. When set, the trainer runs quantize_(model, QATConfig(base_config, step=\"prepare\")) to install FakeQuantizedLinear modules. Forward simulates the target low-bit precision while backward stays in full precision. Mutually exclusive with --fp8-recipe.
  • Finalize (convert) — new --qat-convert argument on forgather finalize, also exposed in the webui FinalizeModal. Runs quantize_(model, QATConfig(base_config, step=\"convert\")) against the loaded checkpoint, producing a deployable low-bit artifact. No-op (with warning) on non-QAT-trained models so the new flag is purely additive.

Both endpoints resolve their recipe string through a shared src/forgather/ml/qat_recipes.py so prepare and convert always agree on the underlying torchao base config.

Recipes

Recipe Activations Weights torchao base config
int8-dynamic-act-int4-weight (default) int8 per-token dynamic int4 per-group (gs=32) Int8DynamicActivationInt4WeightConfig
int4-weight-only full precision int4 per-group (gs=128) Int4WeightOnlyConfig
float8-dynamic-act-float8-weight float8 per-row float8 per-row Float8DynamicActivationFloat8WeightConfig
float8-dynamic-act-int4-weight float8 per-row int4 per-group Float8DynamicActivationInt4WeightConfig

Save format constraint

forgather finalize --qat-convert always writes .bin. The --safetensors flag is silently disabled with a warning when both are set, because torchao's quantized tensor subclasses (Int8DynActInt4WeightLinear etc.) wrap multiple inner tensors and don't expose a single .storage().data_ptr() — which is what the safetensors writer requires. Until torchao ships explicit safetensors serialisation hooks, .bin is the working save format.

Files

  • src/forgather/ml/qat_recipes.py (new) — shared recipe→config map
  • src/forgather/ml/trainer/base_trainer.pyqat_recipe TrainingArgument + validator + three-way mutex with fp8_recipe
  • src/forgather/ml/trainer/trainer.py_apply_qat_training, wired in _prepare_model
  • templatelib/examples/projects/lm_training_project.yaml--qat-recipe dynamic arg + trainer_args pass-through + comment-table entry
  • src/forgather/cli/finalize.py--qat-convert flag + job_params marshalling
  • tools/finalize_model/finalize_model.py_apply_qat_convert helper, new pipeline step between dry-run check and write_finalized_checkpoint, auto-fallback to .bin
  • tools/forgather_server/{finalize_ops,launcher,scheduler}.py — pass-through plumbing
  • tools/forgather_server/webui/src/components/FinalizeModal.tsx — "QAT Convert" dropdown in Output section
  • docs/trainers/qat-training.md (new) — recipe table, prepare/convert flow, save-format constraint
  • docs/guides/finalize-model.md — new Quantization subsection

Verification done locally (2× RTX 3090, sm_86)

  • forgather -t v2.yaml train --qat-recipe int8-dynamic-act-int4-weight --max-steps 30: prepare converted 29/29 nn.Linear, training ran clean (14.5 steps/sec).
  • Mutex check: --fp8-recipe tensorwise --qat-recipe int4-weight-only raises ValueError: fp8_recipe and qat_recipe are mutually exclusive at construction.
  • End-to-end Python harness (small Llama via LlamaConfig/LlamaForCausalLM): prepare → 20-step train (loss finite, backward clean) → convert (15 FakeQuantizedLinear → 0) → forward on converted model (finite logits) → save via forgather.ml.sharded_checkpoint.save_checkpoint (.bin path) succeeds.
  • Safetensors save against converted model fails as expected (torchao subclass issue). Finalize now warns and forces .bin.

Test plan

  • Train ≥200 steps on hal9000 (RTX 4090) with --qat-recipe int8-dynamic-act-int4-weight and confirm loss trajectory tracks bf16 within roughly +0.05 at step 200+
  • Run forgather finalize --source <run> --dest <out> --qat-convert int8-dynamic-act-int4-weight against that checkpoint; confirm the produced .bin shards contain quantized linear keys and AutoModelForCausalLM.from_pretrained(<out>, trust_remote_code=True) can do a forward pass (after re-applying the convert config to the freshly-loaded module)
  • Submit a finalize job via the webui's FinalizeModal with the QAT Convert dropdown set to int8-dynamic-act-int4-weight; confirm the resulting job's stdout shows --qat-convert int8-dynamic-act-int4-weight was passed through
  • Run finalize with --qat-convert against a model that was not QAT-trained and confirm the warning fires and a normal (non-quantized) artifact is produced

🤖 Generated with Claude Code

jdinalt and others added 2 commits May 14, 2026 20:09
…t in finalize)

Adds a two-phase QAT workflow:

  1. Train-time: --qat-recipe inserts FakeQuantizedLinear via
     torchao.quantize_(model, QATConfig(base, step="prepare")). Forward
     simulates the target low-bit precision; backward stays in full
     precision. Mutually exclusive with --fp8-recipe.

  2. Deploy-time: forgather finalize --qat-convert <recipe> swaps every
     FakeQuantizedLinear for the real low-bit quantized linear op. Reuses
     the existing finalize pipeline (model already loaded + reshape-aware).
     No-op with a warning on non-QAT-trained models.

Recipes (str: int8-dynamic-act-int4-weight, int4-weight-only,
float8-dynamic-act-float8-weight, float8-dynamic-act-int4-weight) map to
torchao base configs via the shared src/forgather/ml/qat_recipes.py so
prepare and convert resolve through the same function.

Webui FinalizeModal grows a QAT Convert dropdown wired into job_params;
the server scheduler/launcher/finalize_ops chain passes --qat-convert
through to the subprocess.

Saves of QAT-converted models are forced to .bin: torchao's quantized
tensor subclasses lack a single storage().data_ptr(), which safetensors
requires. finalize warns and overrides --safetensors when --qat-convert
is set.

Docs: new docs/trainers/qat-training.md, finalize-model.md gets a
Quantization subsection.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…pe, single-list mutex, template-helper for choices

Addresses reviewer feedback on PR #39:

- qat_recipes.py:int8-dynamic-act-int4-weight now resolves to
  Int8DynamicActivationIntxWeightConfig (Intx, version=2) instead of the
  deprecated Int8DynamicActivationInt4WeightConfig (pytorch/ao#2752).
  Same int8-per-token-dyn / int4-per-group-symmetric semantics.

- Drop float8-dynamic-act-int4-weight from QAT_RECIPES. torchao's underlying
  kernel asserts preshuffled int4 packing which is Hopper-only (SM90+,
  FBGEMM); it would crash at finalize on any non-Hopper GPU. Comment in
  qat_recipes.py points at the runtime-capability gate to add later.

- BaseTrainingArguments mutex refactored from pairwise checks to a single
  _LINEAR_SWAP_RECIPES tuple + len(active) > 1 check, so adding a third
  swap-style recipe is one tuple-entry mechanical edit.

- Expose QAT_RECIPES as a `qat_recipes` Jinja global in preprocess.py;
  lm_training_project.yaml's --qat-recipe choices: now renders via
  `{{ qat_recipes() | toyaml }}` instead of hand-copying the four strings.
  Drift between the Python source of truth and the template is now
  impossible.

- Defensive comment at trainer.py _prepare_model: explain why the
  if-chain is sequential (not elif) — if the mutex is ever relaxed the
  second swap will surface 0/N converted rather than silently shipping a
  single-recipe model.

- Finalize doc example no longer shows --safetensors together with
  --qat-convert (silent fallback would contradict the visible flag);
  cross-reference comments added in qat_recipes.py and FinalizeModal.tsx
  for the remaining TSX/Python duplication.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@jdinalt
Copy link
Copy Markdown
Owner Author

jdinalt commented May 14, 2026

Reviewer feedback applied in 0cd8da2. Summary of changes:

From the trainer-side review:

  • Switched off the deprecated config. int8-dynamic-act-int4-weight now resolves to Int8DynamicActivationIntxWeightConfig(weight_dtype=torch.int4, weight_granularity=PerGroup(group_size=32)) — the v2 successor to the deprecated Int8DynamicActivationInt4WeightConfig. The UserWarning (Migrating from AffineQuantizedTensor + Layouts to new structure of tensor subclasses pytorch/ao#2752) is gone from training output. Verified prepare still converts 29/29 Linear layers.
  • Dropped float8-dynamic-act-int4-weight from QAT_RECIPES. Its kernel hardcodes preshuffled int4 packing (Hopper-only, FBGEMM); it would have crashed at finalize on every non-Hopper GPU we own. A comment in qat_recipes.py flags where the runtime capability gate will live when we add it back.
  • Mutex refactored to single-list. New _LINEAR_SWAP_RECIPES = ("fp8_recipe", "qat_recipe") + len(active) > 1 check in __post_init__. Adding mx_recipe (or anything else) is a one-tuple-entry edit.
  • Template-helper for choices:. Exposed QAT_RECIPES as a qat_recipes Jinja global in preprocess.py; the template now renders choices: {{ qat_recipes() | toyaml }} instead of hand-copying the strings. Verified --qat-recipe --help shows the three live recipes.
  • Defensive comment in trainer.py:_prepare_model explaining why the if-chain is sequential (not elif) so a future relaxed mutex still surfaces a clear 0/N converted log instead of silently producing a single-recipe model.

From the finalize-side review:

  • Doc example fix. docs/guides/finalize-model.md no longer shows --safetensors alongside --qat-convert (the visible flag would contradict the silent fallback). The fallback is now explicitly documented next to the example.
  • Cross-reference comments in qat_recipes.py and FinalizeModal.tsx for the remaining TSX/Python duplication of the recipe list.

Webui rebuilt; full suite of changes verified locally (template renders, prepare smoke-tests clean with the new Intx config, mutex error message correct, dropped recipe is rejected).

…hilla numbers

Two changes against the previous QAT commits:

1. **Fix: --qat-convert now actually quantizes**.

   Forgather's sharded checkpoint saver returns the plain float state_dict
   from FakeQuantizedLinear modules (their inner scale/zp state is not
   persistent), so by the time finalize loads the model the FQ modules
   are gone. The previous _apply_qat_convert gated on FQ-count > 0 and
   was a silent no-op for every QAT-trained checkpoint produced by the
   trainer.

   Fix: drop the FQ-count gate. Always run prepare-then-convert on the
   loaded model — prepare re-installs FakeQuantizedLinear on the loaded
   float weights, convert computes scales/zps from those (QAT-trained)
   weight statistics and swaps to the real low-bit linear ops. Verified
   end-to-end against the chinchilla checkpoint: 29/29 nn.Linear →
   29/29 IntxUnpackedToInt8Tensor on disk; final artifact is 6.6 MiB vs
   17.7 MiB float32 baseline.

   Side effect: running --qat-convert against a non-QAT model now also
   produces a quantized artifact (functionally PTQ). Tracked separately
   in #40 for a proper --ptq-quantize flag + UX split.

2. **Docs cross-links and chinchilla trajectory numbers**.

   - docs/README.md now lists qat-training.md alongside fp8-training.md
     under Training.
   - docs/trainers/fp8-training.md gets a See Also section pointing at
     qat-training.md (mutually-exclusive recipe siblings).
   - docs/guides/finalize-model.md cross-links back to qat-training.md.
   - docs/trainers/trainer_options.md documents qat_recipe alongside
     fp8_recipe in the trainer-args reference.
   - docs/trainers/qat-training.md gets the full 1-chinchilla Tiny Llama
     comparison: bf16 AMP baseline eval_loss=1.3601 vs QAT
     int8-dyn-act-int4-wt eval_loss=1.3774 (Δ +0.0173) at step 5140 on
     wopr (2× RTX 3090). Wall-clock: 197 s vs 329 s (QAT ~1.67×).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@jdinalt
Copy link
Copy Markdown
Owner Author

jdinalt commented May 14, 2026

Verification update from local runs on wopr (2× RTX 3090):

Headline comparison — full 1-chinchilla Tiny Llama (5140 steps, 82.6M tokens, same seed):

Run Final eval_loss Final train loss Wall time Steps/sec
bf16 AMP baseline 1.3601 1.3352 197 s 26.1
QAT int8-act-int4-wt 1.3774 (Δ = +0.0173) 1.3534 329 s 15.6

QAT pays a ~+0.017 eval-loss premium for the int8-act / int4-weight quantization noise and ~1.67× wall-clock overhead. Full trajectory table is in docs/trainers/qat-training.md.

Real bug surfaced during finalize testing (now fixed in 3cfd50e): --qat-convert was a silent no-op for every QAT-trained checkpoint produced by the trainer. Forgather's sharded saver returns plain float weights from FakeQuantizedLinear modules (their inner state isn't persistent), so finalize was loading a plain-Linear model and the old FQ-count gate was rejecting the convert step. Fixed by dropping the gate and always running prepare → convert on the loaded float weights. Verified end-to-end: 29/29 nn.Linear → 29 IntxUnpackedToInt8Tensor weights on disk, artifact size 6.6 MiB vs 17.7 MiB float32. Side effect noted on issue #40 (it's now functionally PTQ-capable too).

Test status:

Round-trip load of the quantized artifact via AutoModelForCausalLM.from_pretrained hits torchao's subclass-dispatch limitations (IntxUnpackedToInt8Tensor doesn't yet support view/transpose in the load path). That's exactly what issues #41 (eval) and #42 (inference server) will address — they need the same marker-file + load-helper plumbing.

Three follow-up issues filed:

Together #41+#42 unblock the three-way AMP / PTQ / QAT comparison on Tiny Stories that this PR's foundation makes possible.

@jdinalt jdinalt merged commit bdf9c17 into dev May 14, 2026
1 check passed
@jdinalt jdinalt deleted the feature/qat-training branch May 14, 2026 21:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant