Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Source code and examples: [github.com/jdinalt/forgather](https://github.com/jdin
- **[Training Performance Metrics](trainers/training-performance-metrics.md)** - Token throughput, FLOP tracking, and MFU
- **[DiLoCo](trainers/diloco.md)** - Distributed Local-SGD training across heterogeneous machines on LAN
- **[FP8 Training](trainers/fp8-training.md)** - FP8 training via torchao
- **[QAT Training](trainers/qat-training.md)** - Quantization-aware training via torchao; pair with `forgather finalize --qat-convert` for the deployable low-bit artifact
- **[Checkpointing](checkpointing/README.md)** - Distributed checkpoint system for multi-GPU and multi-node training
- **[Torch Titan Integration](trainers/torchtitan.md)** - Forgather integration with PyTorch's Torch Titan training framework
- **[Adafactor Triton Performance](trainers/adafactor-triton-performance.md)** - Performance analysis for the Triton-optimized Adafactor kernel
Expand Down
22 changes: 22 additions & 0 deletions docs/guides/finalize-model.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,26 @@ the (possibly-updated) tokenizer last.
| `--dtype {bfloat16,float16,float32}` | Cast weights to this dtype before saving. Default: keep the dtype the source checkpoint was saved in. |
| `--device STR` | Device for loading the model during finalize (default `cpu`). |

### Quantization

| Option | Description |
|--------|-------------|
| `--qat-convert RECIPE` | Run the torchao QAT convert step before saving: swap `FakeQuantizedLinear` modules for the real low-bit quantized linear ops described by `RECIPE`. Use the same recipe string that was used at training time (`--qat-recipe`). On models without fake-quantized modules this is a no-op with a warning. See [QAT Training](../trainers/qat-training.md) for the recipe list. |

Example:

```bash
# After training with --qat-recipe int8-dynamic-act-int4-weight, produce the
# deployable quantized artifact:
forgather finalize output_models/qat_run out/qat_int8_int4 \
--qat-convert int8-dynamic-act-int4-weight
```

When `--qat-convert` is set, finalize always writes `.bin`: torchao's
quantized tensor subclasses don't expose a single `.storage().data_ptr()`,
which the safetensors writer requires. If `--safetensors` is passed
alongside `--qat-convert`, it is silently disabled with a warning.

### Misc

| Option | Description |
Expand Down Expand Up @@ -166,3 +186,5 @@ pad_token:
- **[EOS Tokens and `generate()` Stopping Criteria](eos-and-generate-stopping.md)** --
theory of operation: how HF's `generate()` resolves stopping across the
multiple files that carry EOS information.
- **[QAT Training](../trainers/qat-training.md)** -- pair `--qat-convert` here
with `--qat-recipe` at training time to produce a low-bit deployable artifact.
8 changes: 8 additions & 0 deletions docs/trainers/fp8-training.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,11 @@ performance. Update torchao to match your PyTorch version.
**No speedup observed**: Ensure `torch_compile=True` is set. Without compilation, the
overhead of FP8 scale computation and casting can offset the matmul speedup, especially
for small models.

## See Also

- **[QAT Training](qat-training.md)** -- the other torchao Linear-swap recipe. Mutually
exclusive with FP8: QAT inserts `FakeQuantizedLinear` for low-bit deployment, while
FP8 swaps to `Float8Linear` for faster training compute.
- **[Finalizing a Trained Model](../guides/finalize-model.md)** -- post-training packaging.
No FP8-specific options today; the deployable artifact retains the original FP precision.
224 changes: 224 additions & 0 deletions docs/trainers/qat-training.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
# Quantization-Aware Training (QAT)

Forgather supports torchao-style quantization-aware training. At training
time `nn.Linear` modules are wrapped in `FakeQuantizedLinear`, which
simulates the target low-bit precision in the forward pass while the
backward pass stays in full precision. The model learns to be robust to the
quantization noise so that the converted (real low-bit) artifact retains
most of the bf16 accuracy.

QAT is a two-phase workflow:

1. **Prepare** -- done at training time via `--qat-recipe`. Inserts fake
quantizers into the model. Training proceeds normally (the optimizer
updates full-precision master weights; the fake-quant scales/zero-points
are recomputed each step).
2. **Convert** -- done after training via `forgather finalize --qat-convert
<recipe>`. Swaps each `FakeQuantizedLinear` for the real low-bit
quantized linear op described by the recipe, producing a deployable
artifact.

## Requirements

- **GPU**: any CUDA GPU (or CPU). QAT runs in full precision; the fake
quantizers are pure PyTorch math with no hardware gating.
- **torchao**: `>=0.16.0`. Bundled in the Forgather Docker images.

## Quick Start

```bash
# 1. Train with fake quantizers installed
forgather -t config.yaml train --qat-recipe int8-dynamic-act-int4-weight

# 2. After training, produce the deployable quantized artifact
forgather finalize output_models/my_run out/my_run_int8_int4 \
--qat-convert int8-dynamic-act-int4-weight --safetensors
```

The recipe string passed to `--qat-recipe` and `--qat-convert` must be the
**same** -- the convert step needs the matching base config to know what
scales and dtypes to use. Recipe strings are validated against the registry
in `src/forgather/ml/qat_recipes.py`.

QAT is mutually exclusive with `fp8_recipe`. Both transform `nn.Linear`,
so the trainer rejects the combination at startup.

## Recipes

| Recipe | Activations | Weights | torchao base config |
|--------|-------------|---------|---------------------|
| `int8-dynamic-act-int4-weight` | int8 per-token dynamic | int4 per-group (group_size=32) | `Int8DynamicActivationIntxWeightConfig` |
| `int4-weight-only` | full precision | int4 per-group (group_size=128) | `Int4WeightOnlyConfig` |
| `float8-dynamic-act-float8-weight` | float8 per-row dynamic | float8 per-row | `Float8DynamicActivationFloat8WeightConfig` |

`float8-dynamic-act-int4-weight` is *not* exposed in v1 — torchao gates its
underlying kernel to the `preshuffled` int4 packing format which is Hopper-only
(SM90+ / FBGEMM). It will be added back behind a runtime capability check.

Recommended default: `int8-dynamic-act-int4-weight`. It's the most
broadly-validated production path -- the same recipe Meta and NVIDIA use
when shipping QAT'd LLMs for edge inference.

To add or tweak a recipe (e.g. change `group_size`), edit
`src/forgather/ml/qat_recipes.py:recipe_to_base_config`. Both the trainer
and finalize resolve through the same function, so they stay in sync.

## How It Works

At trainer init, when `qat_recipe` is set:

```python
quantize_(model, QATConfig(base_config, step="prepare"))
```

`quantize_` walks the module tree and swaps each `nn.Linear` for a
`FakeQuantizedLinear` instance. On every forward pass:

1. Activations are quantize-then-dequantize through the activation fake
quantizer (if the recipe has one).
2. Weights are quantize-then-dequantize through the weight fake quantizer.
3. The matmul runs in the original (bf16/fp32) dtype on the dequantized
tensors.

In the backward pass nothing about this is special: gradients flow through
the standard linear backward in full precision, and the optimizer updates
the original full-precision weights. The fake quantizers don't have learned
parameters by default -- their scales and zero-points are derived from the
current weight/activation statistics every step.

At finalize, when `--qat-convert <recipe>` is set:

```python
# 1. Re-install fake quantizers on top of the loaded float weights
quantize_(model, QATConfig(base_config, step="prepare"))
# 2. Swap them for the real low-bit quantized linear ops
quantize_(model, QATConfig(base_config, step="convert"))
```

The first call is necessary because Forgather's sharded checkpoint saver
serialises `state_dict()` which returns *float* weights — the
`FakeQuantizedLinear` modules' scale/zero-point inner state is not
persistent. We re-install fake quantizers from the float weights and then
let convert compute the final low-bit weights and scales. The scales the
convert step picks are derived from the QAT-trained weight statistics, so
the QAT training-time accuracy benefit is preserved.

The result is a model whose `nn.Linear` modules are now torchao subclasses
(`Int8DynActInt4WeightLinear`, etc.). Forgather's `save_checkpoint` writes
the resulting state_dict as PyTorch `.bin` (safetensors is incompatible —
see below).

## Loss Trajectory: 1-Chinchilla Tiny Llama

Full-length training run of `examples/tutorials/tiny_llama:v2.yaml` (Tiny
Llama, 4.43M params, ~82.6M training tokens — chinchilla-optimal at
~20 tokens/param), single GPU (RTX 3090, sm_86, wopr), same seed, same
config. The baseline run uses the v2.yaml default precision settings (bf16
AMP via `mixed_precision: "bf16"`); the QAT run adds `--qat-recipe
int8-dynamic-act-int4-weight` on top.

| Eval step | bf16 AMP baseline (eval_loss) | QAT int8-act-int4-wt (eval_loss) | Δ (QAT − baseline) |
|-----------|-------------------------------|----------------------------------|--------------------|
| 642 | 2.0651 | 2.0789 | +0.0138 |
| 1284 | 1.6999 | 1.7142 | +0.0143 |
| 1926 | 1.5658 | 1.5799 | +0.0141 |
| 4494 | 1.3725 | 1.3896 | +0.0171 |
| 5136 | 1.3602 | 1.3776 | +0.0174 |
| **5140 (final)** | **1.3601** | **1.3774** | **+0.0173** |

Final train loss at step 5120 was 1.3352 vs 1.3534 (Δ +0.0182). The two
trajectories track each other from the very first eval through to
completion — QAT pays a stable ~+0.017 eval-loss premium throughout
training rather than a divergent late-training gap, which is the
encouraging signal: the model is learning under the fake-quant noise, not
just accumulating it.

**Wall-clock overhead.** Same GPU, same model, same data:

| Run | Wall time | Steps/sec | Tokens/sec |
|-----|-----------|-----------|------------|
| bf16 AMP baseline | 197 s | 26.1 | 419K |
| QAT int8-act-int4-wt | 329 s | 15.6 | 251K |

QAT is ~1.67× slower than the bf16 baseline (the cost of running the fake
quantizers in pure PyTorch in the forward pass). Whether it pays for
itself depends on what the converted artifact recovers — that comparison
needs `forgather eval` + inference-server support for quantized models
(tracked in #41 and #42).

## Save Format

`forgather finalize --qat-convert` always writes the converted artifact in
PyTorch (`.bin`) format. The `--safetensors` flag is silently disabled with
a warning when both are set: torchao's quantized tensor subclasses
(`Int8DynActInt4WeightLinear`, `Int4Tensor`, etc.) wrap multiple inner
tensors and don't expose a single `.storage().data_ptr()`, which is what
the safetensors writer requires. Until torchao ships explicit safetensors
serialization, `.bin` is the working save format.

The default `.bin` artifact loads cleanly through `torch.load` + the
torchao `quantize_(model, QATConfig(base_config, step="convert"))` re-cast
applied at load time. See the programmatic example below.

## Behavior on Models Without QAT

If you pass `--qat-convert <recipe>` to `forgather finalize` on a model
that wasn't trained with `--qat-recipe`, the same prepare-then-convert
pipeline runs anyway -- which is functionally **post-training
quantization (PTQ)**: the recipe is applied, but the result lacks the
QAT training-time accuracy benefit. The deployable artifact is still
valid and loadable. A future `--ptq-quantize` flag (tracked in #40) will
make that PTQ-on-plain-model intent explicit, but until then
`--qat-convert` is the single entry point for both flows.

## Programmatic Usage

```python
from forgather.ml.trainer import Trainer, TrainingArguments

args = TrainingArguments(
output_dir="output_models/my_qat_run",
qat_recipe="int8-dynamic-act-int4-weight",
# ... other training args
)

trainer = Trainer(
args=args,
model_init=model_factory,
train_dataset=train_dataset,
)
trainer.train()
```

To run convert programmatically:

```python
from torchao.quantization import quantize_
from torchao.quantization.qat import QATConfig
from forgather.ml.qat_recipes import recipe_to_base_config

base_config = recipe_to_base_config("int8-dynamic-act-int4-weight")
quantize_(model, QATConfig(base_config, step="convert"))
model.save_pretrained("out/my_quantized_model", safe_serialization=True)
```

## Out of Scope

The v1 integration intentionally omits a few torchao QAT knobs that aren't
needed for the common case:

- **Auto-convert at training end**: convert is run by `forgather finalize`,
not the trainer. Keeps training and deployment concerns separated.
- **Custom `group_size` / granularity flags on the CLI**: the per-recipe
defaults in `qat_recipes.py` are the standard values. Edit them locally
if you need to experiment.
- **Range learning** (learned per-channel scales): torchao supports it via
`IntxFakeQuantizeConfig(range_learning=True)`, but the v1 recipes leave
it off.

## See Also

- [FP8 Training](fp8-training.md) -- the other torchao Linear-swap recipe;
mutually exclusive with QAT.
- [Finalizing a Trained Model](../guides/finalize-model.md) -- the
`forgather finalize` reference (including `--qat-convert`).
6 changes: 4 additions & 2 deletions docs/trainers/trainer_options.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,13 @@ On any GPU that supports TF32 (Ampere or newer), you usually want
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `mixed_precision` | str \| None | None | `None` / `"no"` disabled, `"bf16"` (no GradScaler), or `"fp16"` (with GradScaler). |
| `fp8_recipe` | str \| None | None | `"tensorwise"`, `"rowwise"`, or `"rowwise_with_gw_hp"`. Converts `nn.Linear` to `Float8Linear` via torchao. Orthogonal to `mixed_precision`. |
| `fp8_recipe` | str \| None | None | `"tensorwise"`, `"rowwise"`, or `"rowwise_with_gw_hp"`. Converts `nn.Linear` to `Float8Linear` via torchao. Orthogonal to `mixed_precision`. Mutually exclusive with `qat_recipe`. |
| `fp8_dim_alignment` | int | 16 | Minimum alignment for FP8 Linear layer dimensions; non-conforming layers are skipped. |
| `qat_recipe` | str \| None | None | `"int8-dynamic-act-int4-weight"`, `"int4-weight-only"`, or `"float8-dynamic-act-float8-weight"`. Installs `FakeQuantizedLinear` via torchao QAT (prepare phase). Run `forgather finalize --qat-convert <recipe>` after training to produce the deployable low-bit artifact. Mutually exclusive with `fp8_recipe`. |

FP8 requires CUDA SM >= 8.9 (RTX 4090, H100, etc.). See
[`fp8-training.md`](fp8-training.md).
[`fp8-training.md`](fp8-training.md). QAT has no hardware gate (runs on any
CUDA GPU or CPU); see [`qat-training.md`](qat-training.md).

---

Expand Down
12 changes: 12 additions & 0 deletions src/forgather/cli/finalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def _enqueue_finalize(args):
p.add_argument("--generation-config", default=None)
p.add_argument("--dtype", default=None)
p.add_argument("--device", default="cpu")
p.add_argument("--qat-convert", default=None)
p.add_argument("--priority", type=int, default=0)
p.add_argument("--server", default=None)
sub = p.parse_args(args.remainder)
Expand Down Expand Up @@ -104,6 +105,17 @@ def _enqueue_finalize(args):
job_params["generation_config"] = sub.generation_config
if sub.dtype:
job_params["dtype"] = sub.dtype
if sub.qat_convert:
from forgather.ml.qat_recipes import QAT_RECIPES

if sub.qat_convert not in QAT_RECIPES:
print(
f"--qat-convert must be one of {QAT_RECIPES}, "
f"got '{sub.qat_convert}'",
file=sys.stderr,
)
raise SystemExit(2)
job_params["qat_convert"] = sub.qat_convert

from .server_client import ServerClient, ServerUnreachable

Expand Down
62 changes: 62 additions & 0 deletions src/forgather/ml/qat_recipes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Shared QAT recipe registry for the trainer (prepare) and finalize (convert).

The same recipe string is supplied at training time via ``qat_recipe`` (which
inserts ``FakeQuantizedLinear`` modules) and at finalize time via
``--qat-convert`` (which swaps them for real low-bit quantized linear ops).
Both call sites resolve the string through :func:`recipe_to_base_config`.
"""

from __future__ import annotations


# Source of truth for QAT recipe names. Consumed by:
# - BaseTrainingArguments validator (src/forgather/ml/trainer/base_trainer.py)
# - the Forgather finalize CLI (src/forgather/cli/finalize.py)
# - finalize_model.py's --qat-convert (tools/finalize_model/finalize_model.py)
# - the lm_training_project.yaml template's --qat-recipe `choices:` list,
# rendered from this tuple via the `qat_recipes` Jinja global
# - tools/forgather_server/webui/src/components/FinalizeModal.tsx
# (TSX duplicate — keep the four strings here and there in sync)
QAT_RECIPES: tuple[str, ...] = (
"int8-dynamic-act-int4-weight",
"int4-weight-only",
"float8-dynamic-act-float8-weight",
)


def recipe_to_base_config(recipe: str):
"""Map a Forgather QAT recipe string to a torchao base config instance.

The returned object is the ``base_config`` argument for
``torchao.quantization.qat.QATConfig(base_config, step=...)``. It must be
the *same* config (same parameters) for both the prepare and convert
phases.
"""
import torch
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationIntxWeightConfig,
)
from torchao.quantization.granularity import PerGroup

if recipe == "int8-dynamic-act-int4-weight":
# Replaces the deprecated Int8DynamicActivationInt4WeightConfig
# (see pytorch/ao#2752). Same semantics: int8 per-token dynamic
# activations, int4 per-group symmetric weights.
return Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4,
weight_granularity=PerGroup(group_size=32),
)
if recipe == "int4-weight-only":
return Int4WeightOnlyConfig(group_size=128)
if recipe == "float8-dynamic-act-float8-weight":
return Float8DynamicActivationFloat8WeightConfig()
# `float8-dynamic-act-int4-weight` is intentionally not exposed in v1:
# torchao's Float8DynamicActivationInt4WeightConfig requires the
# `preshuffled` int4 packing format, which is Hopper-only (SM90+,
# FBGEMM). When we add capability-gated recipe exposure, re-introduce
# it behind a runtime check.
raise ValueError(
f"Unknown QAT recipe: {recipe!r}. Valid recipes: {QAT_RECIPES}"
)
Loading
Loading