torchao QAT support: --qat-recipe (trainer) + --qat-convert (finalize)#39
Conversation
…t in finalize)
Adds a two-phase QAT workflow:
1. Train-time: --qat-recipe inserts FakeQuantizedLinear via
torchao.quantize_(model, QATConfig(base, step="prepare")). Forward
simulates the target low-bit precision; backward stays in full
precision. Mutually exclusive with --fp8-recipe.
2. Deploy-time: forgather finalize --qat-convert <recipe> swaps every
FakeQuantizedLinear for the real low-bit quantized linear op. Reuses
the existing finalize pipeline (model already loaded + reshape-aware).
No-op with a warning on non-QAT-trained models.
Recipes (str: int8-dynamic-act-int4-weight, int4-weight-only,
float8-dynamic-act-float8-weight, float8-dynamic-act-int4-weight) map to
torchao base configs via the shared src/forgather/ml/qat_recipes.py so
prepare and convert resolve through the same function.
Webui FinalizeModal grows a QAT Convert dropdown wired into job_params;
the server scheduler/launcher/finalize_ops chain passes --qat-convert
through to the subprocess.
Saves of QAT-converted models are forced to .bin: torchao's quantized
tensor subclasses lack a single storage().data_ptr(), which safetensors
requires. finalize warns and overrides --safetensors when --qat-convert
is set.
Docs: new docs/trainers/qat-training.md, finalize-model.md gets a
Quantization subsection.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…pe, single-list mutex, template-helper for choices Addresses reviewer feedback on PR #39: - qat_recipes.py:int8-dynamic-act-int4-weight now resolves to Int8DynamicActivationIntxWeightConfig (Intx, version=2) instead of the deprecated Int8DynamicActivationInt4WeightConfig (pytorch/ao#2752). Same int8-per-token-dyn / int4-per-group-symmetric semantics. - Drop float8-dynamic-act-int4-weight from QAT_RECIPES. torchao's underlying kernel asserts preshuffled int4 packing which is Hopper-only (SM90+, FBGEMM); it would crash at finalize on any non-Hopper GPU. Comment in qat_recipes.py points at the runtime-capability gate to add later. - BaseTrainingArguments mutex refactored from pairwise checks to a single _LINEAR_SWAP_RECIPES tuple + len(active) > 1 check, so adding a third swap-style recipe is one tuple-entry mechanical edit. - Expose QAT_RECIPES as a `qat_recipes` Jinja global in preprocess.py; lm_training_project.yaml's --qat-recipe choices: now renders via `{{ qat_recipes() | toyaml }}` instead of hand-copying the four strings. Drift between the Python source of truth and the template is now impossible. - Defensive comment at trainer.py _prepare_model: explain why the if-chain is sequential (not elif) — if the mutex is ever relaxed the second swap will surface 0/N converted rather than silently shipping a single-recipe model. - Finalize doc example no longer shows --safetensors together with --qat-convert (silent fallback would contradict the visible flag); cross-reference comments added in qat_recipes.py and FinalizeModal.tsx for the remaining TSX/Python duplication. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Reviewer feedback applied in From the trainer-side review:
From the finalize-side review:
Webui rebuilt; full suite of changes verified locally (template renders, prepare smoke-tests clean with the new Intx config, mutex error message correct, dropped recipe is rejected). |
…hilla numbers Two changes against the previous QAT commits: 1. **Fix: --qat-convert now actually quantizes**. Forgather's sharded checkpoint saver returns the plain float state_dict from FakeQuantizedLinear modules (their inner scale/zp state is not persistent), so by the time finalize loads the model the FQ modules are gone. The previous _apply_qat_convert gated on FQ-count > 0 and was a silent no-op for every QAT-trained checkpoint produced by the trainer. Fix: drop the FQ-count gate. Always run prepare-then-convert on the loaded model — prepare re-installs FakeQuantizedLinear on the loaded float weights, convert computes scales/zps from those (QAT-trained) weight statistics and swaps to the real low-bit linear ops. Verified end-to-end against the chinchilla checkpoint: 29/29 nn.Linear → 29/29 IntxUnpackedToInt8Tensor on disk; final artifact is 6.6 MiB vs 17.7 MiB float32 baseline. Side effect: running --qat-convert against a non-QAT model now also produces a quantized artifact (functionally PTQ). Tracked separately in #40 for a proper --ptq-quantize flag + UX split. 2. **Docs cross-links and chinchilla trajectory numbers**. - docs/README.md now lists qat-training.md alongside fp8-training.md under Training. - docs/trainers/fp8-training.md gets a See Also section pointing at qat-training.md (mutually-exclusive recipe siblings). - docs/guides/finalize-model.md cross-links back to qat-training.md. - docs/trainers/trainer_options.md documents qat_recipe alongside fp8_recipe in the trainer-args reference. - docs/trainers/qat-training.md gets the full 1-chinchilla Tiny Llama comparison: bf16 AMP baseline eval_loss=1.3601 vs QAT int8-dyn-act-int4-wt eval_loss=1.3774 (Δ +0.0173) at step 5140 on wopr (2× RTX 3090). Wall-clock: 197 s vs 329 s (QAT ~1.67×). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Verification update from local runs on wopr (2× RTX 3090): Headline comparison — full 1-chinchilla Tiny Llama (5140 steps, 82.6M tokens, same seed):
QAT pays a ~+0.017 eval-loss premium for the int8-act / int4-weight quantization noise and ~1.67× wall-clock overhead. Full trajectory table is in Real bug surfaced during finalize testing (now fixed in Test status:
Round-trip load of the quantized artifact via Three follow-up issues filed:
Together #41+#42 unblock the three-way AMP / PTQ / QAT comparison on Tiny Stories that this PR's foundation makes possible. |
Summary
Adds torchao Quantization-Aware Training to Forgather, split across the two natural integration points so train and deploy concerns stay separated:
--qat-recipeargument. When set, the trainer runsquantize_(model, QATConfig(base_config, step=\"prepare\"))to installFakeQuantizedLinearmodules. Forward simulates the target low-bit precision while backward stays in full precision. Mutually exclusive with--fp8-recipe.--qat-convertargument onforgather finalize, also exposed in the webui FinalizeModal. Runsquantize_(model, QATConfig(base_config, step=\"convert\"))against the loaded checkpoint, producing a deployable low-bit artifact. No-op (with warning) on non-QAT-trained models so the new flag is purely additive.Both endpoints resolve their recipe string through a shared
src/forgather/ml/qat_recipes.pyso prepare and convert always agree on the underlying torchao base config.Recipes
int8-dynamic-act-int4-weight(default)Int8DynamicActivationInt4WeightConfigint4-weight-onlyInt4WeightOnlyConfigfloat8-dynamic-act-float8-weightFloat8DynamicActivationFloat8WeightConfigfloat8-dynamic-act-int4-weightFloat8DynamicActivationInt4WeightConfigSave format constraint
forgather finalize --qat-convertalways writes.bin. The--safetensorsflag is silently disabled with a warning when both are set, because torchao's quantized tensor subclasses (Int8DynActInt4WeightLinearetc.) wrap multiple inner tensors and don't expose a single.storage().data_ptr()— which is what the safetensors writer requires. Until torchao ships explicit safetensors serialisation hooks,.binis the working save format.Files
src/forgather/ml/qat_recipes.py(new) — shared recipe→config mapsrc/forgather/ml/trainer/base_trainer.py—qat_recipeTrainingArgument + validator + three-way mutex withfp8_recipesrc/forgather/ml/trainer/trainer.py—_apply_qat_training, wired in_prepare_modeltemplatelib/examples/projects/lm_training_project.yaml—--qat-recipedynamic arg + trainer_args pass-through + comment-table entrysrc/forgather/cli/finalize.py—--qat-convertflag + job_params marshallingtools/finalize_model/finalize_model.py—_apply_qat_converthelper, new pipeline step between dry-run check andwrite_finalized_checkpoint, auto-fallback to.bintools/forgather_server/{finalize_ops,launcher,scheduler}.py— pass-through plumbingtools/forgather_server/webui/src/components/FinalizeModal.tsx— "QAT Convert" dropdown in Output sectiondocs/trainers/qat-training.md(new) — recipe table, prepare/convert flow, save-format constraintdocs/guides/finalize-model.md— new Quantization subsectionVerification done locally (2× RTX 3090, sm_86)
forgather -t v2.yaml train --qat-recipe int8-dynamic-act-int4-weight --max-steps 30: prepare converted 29/29 nn.Linear, training ran clean (14.5 steps/sec).--fp8-recipe tensorwise --qat-recipe int4-weight-onlyraisesValueError: fp8_recipe and qat_recipe are mutually exclusiveat construction.LlamaConfig/LlamaForCausalLM): prepare → 20-step train (loss finite, backward clean) → convert (15 FakeQuantizedLinear → 0) → forward on converted model (finite logits) → save viaforgather.ml.sharded_checkpoint.save_checkpoint(.binpath) succeeds..bin.Test plan
--qat-recipe int8-dynamic-act-int4-weightand confirm loss trajectory tracks bf16 within roughly +0.05 at step 200+forgather finalize --source <run> --dest <out> --qat-convert int8-dynamic-act-int4-weightagainst that checkpoint; confirm the produced.binshards contain quantized linear keys andAutoModelForCausalLM.from_pretrained(<out>, trust_remote_code=True)can do a forward pass (after re-applying the convert config to the freshly-loaded module)int8-dynamic-act-int4-weight; confirm the resulting job's stdout shows--qat-convert int8-dynamic-act-int4-weightwas passed through--qat-convertagainst a model that was not QAT-trained and confirm the warning fires and a normal (non-quantized) artifact is produced🤖 Generated with Claude Code