diff --git a/docs/README.md b/docs/README.md index d3aa9201..1452103c 100644 --- a/docs/README.md +++ b/docs/README.md @@ -34,6 +34,7 @@ Source code and examples: [github.com/jdinalt/forgather](https://github.com/jdin - **[Training Performance Metrics](trainers/training-performance-metrics.md)** - Token throughput, FLOP tracking, and MFU - **[DiLoCo](trainers/diloco.md)** - Distributed Local-SGD training across heterogeneous machines on LAN - **[FP8 Training](trainers/fp8-training.md)** - FP8 training via torchao +- **[QAT Training](trainers/qat-training.md)** - Quantization-aware training via torchao; pair with `forgather finalize --qat-convert` for the deployable low-bit artifact - **[Checkpointing](checkpointing/README.md)** - Distributed checkpoint system for multi-GPU and multi-node training - **[Torch Titan Integration](trainers/torchtitan.md)** - Forgather integration with PyTorch's Torch Titan training framework - **[Adafactor Triton Performance](trainers/adafactor-triton-performance.md)** - Performance analysis for the Triton-optimized Adafactor kernel diff --git a/docs/guides/finalize-model.md b/docs/guides/finalize-model.md index 2d484b48..dbeefe72 100644 --- a/docs/guides/finalize-model.md +++ b/docs/guides/finalize-model.md @@ -99,6 +99,26 @@ the (possibly-updated) tokenizer last. | `--dtype {bfloat16,float16,float32}` | Cast weights to this dtype before saving. Default: keep the dtype the source checkpoint was saved in. | | `--device STR` | Device for loading the model during finalize (default `cpu`). | +### Quantization + +| Option | Description | +|--------|-------------| +| `--qat-convert RECIPE` | Run the torchao QAT convert step before saving: swap `FakeQuantizedLinear` modules for the real low-bit quantized linear ops described by `RECIPE`. Use the same recipe string that was used at training time (`--qat-recipe`). On models without fake-quantized modules this is a no-op with a warning. See [QAT Training](../trainers/qat-training.md) for the recipe list. | + +Example: + +```bash +# After training with --qat-recipe int8-dynamic-act-int4-weight, produce the +# deployable quantized artifact: +forgather finalize output_models/qat_run out/qat_int8_int4 \ + --qat-convert int8-dynamic-act-int4-weight +``` + +When `--qat-convert` is set, finalize always writes `.bin`: torchao's +quantized tensor subclasses don't expose a single `.storage().data_ptr()`, +which the safetensors writer requires. If `--safetensors` is passed +alongside `--qat-convert`, it is silently disabled with a warning. + ### Misc | Option | Description | @@ -166,3 +186,5 @@ pad_token: - **[EOS Tokens and `generate()` Stopping Criteria](eos-and-generate-stopping.md)** -- theory of operation: how HF's `generate()` resolves stopping across the multiple files that carry EOS information. +- **[QAT Training](../trainers/qat-training.md)** -- pair `--qat-convert` here + with `--qat-recipe` at training time to produce a low-bit deployable artifact. diff --git a/docs/trainers/fp8-training.md b/docs/trainers/fp8-training.md index d30f9602..744a67bc 100644 --- a/docs/trainers/fp8-training.md +++ b/docs/trainers/fp8-training.md @@ -186,3 +186,11 @@ performance. Update torchao to match your PyTorch version. **No speedup observed**: Ensure `torch_compile=True` is set. Without compilation, the overhead of FP8 scale computation and casting can offset the matmul speedup, especially for small models. + +## See Also + +- **[QAT Training](qat-training.md)** -- the other torchao Linear-swap recipe. Mutually + exclusive with FP8: QAT inserts `FakeQuantizedLinear` for low-bit deployment, while + FP8 swaps to `Float8Linear` for faster training compute. +- **[Finalizing a Trained Model](../guides/finalize-model.md)** -- post-training packaging. + No FP8-specific options today; the deployable artifact retains the original FP precision. diff --git a/docs/trainers/qat-training.md b/docs/trainers/qat-training.md new file mode 100644 index 00000000..b12da82f --- /dev/null +++ b/docs/trainers/qat-training.md @@ -0,0 +1,224 @@ +# Quantization-Aware Training (QAT) + +Forgather supports torchao-style quantization-aware training. At training +time `nn.Linear` modules are wrapped in `FakeQuantizedLinear`, which +simulates the target low-bit precision in the forward pass while the +backward pass stays in full precision. The model learns to be robust to the +quantization noise so that the converted (real low-bit) artifact retains +most of the bf16 accuracy. + +QAT is a two-phase workflow: + +1. **Prepare** -- done at training time via `--qat-recipe`. Inserts fake + quantizers into the model. Training proceeds normally (the optimizer + updates full-precision master weights; the fake-quant scales/zero-points + are recomputed each step). +2. **Convert** -- done after training via `forgather finalize --qat-convert + `. Swaps each `FakeQuantizedLinear` for the real low-bit + quantized linear op described by the recipe, producing a deployable + artifact. + +## Requirements + +- **GPU**: any CUDA GPU (or CPU). QAT runs in full precision; the fake + quantizers are pure PyTorch math with no hardware gating. +- **torchao**: `>=0.16.0`. Bundled in the Forgather Docker images. + +## Quick Start + +```bash +# 1. Train with fake quantizers installed +forgather -t config.yaml train --qat-recipe int8-dynamic-act-int4-weight + +# 2. After training, produce the deployable quantized artifact +forgather finalize output_models/my_run out/my_run_int8_int4 \ + --qat-convert int8-dynamic-act-int4-weight --safetensors +``` + +The recipe string passed to `--qat-recipe` and `--qat-convert` must be the +**same** -- the convert step needs the matching base config to know what +scales and dtypes to use. Recipe strings are validated against the registry +in `src/forgather/ml/qat_recipes.py`. + +QAT is mutually exclusive with `fp8_recipe`. Both transform `nn.Linear`, +so the trainer rejects the combination at startup. + +## Recipes + +| Recipe | Activations | Weights | torchao base config | +|--------|-------------|---------|---------------------| +| `int8-dynamic-act-int4-weight` | int8 per-token dynamic | int4 per-group (group_size=32) | `Int8DynamicActivationIntxWeightConfig` | +| `int4-weight-only` | full precision | int4 per-group (group_size=128) | `Int4WeightOnlyConfig` | +| `float8-dynamic-act-float8-weight` | float8 per-row dynamic | float8 per-row | `Float8DynamicActivationFloat8WeightConfig` | + +`float8-dynamic-act-int4-weight` is *not* exposed in v1 — torchao gates its +underlying kernel to the `preshuffled` int4 packing format which is Hopper-only +(SM90+ / FBGEMM). It will be added back behind a runtime capability check. + +Recommended default: `int8-dynamic-act-int4-weight`. It's the most +broadly-validated production path -- the same recipe Meta and NVIDIA use +when shipping QAT'd LLMs for edge inference. + +To add or tweak a recipe (e.g. change `group_size`), edit +`src/forgather/ml/qat_recipes.py:recipe_to_base_config`. Both the trainer +and finalize resolve through the same function, so they stay in sync. + +## How It Works + +At trainer init, when `qat_recipe` is set: + +```python +quantize_(model, QATConfig(base_config, step="prepare")) +``` + +`quantize_` walks the module tree and swaps each `nn.Linear` for a +`FakeQuantizedLinear` instance. On every forward pass: + +1. Activations are quantize-then-dequantize through the activation fake + quantizer (if the recipe has one). +2. Weights are quantize-then-dequantize through the weight fake quantizer. +3. The matmul runs in the original (bf16/fp32) dtype on the dequantized + tensors. + +In the backward pass nothing about this is special: gradients flow through +the standard linear backward in full precision, and the optimizer updates +the original full-precision weights. The fake quantizers don't have learned +parameters by default -- their scales and zero-points are derived from the +current weight/activation statistics every step. + +At finalize, when `--qat-convert ` is set: + +```python +# 1. Re-install fake quantizers on top of the loaded float weights +quantize_(model, QATConfig(base_config, step="prepare")) +# 2. Swap them for the real low-bit quantized linear ops +quantize_(model, QATConfig(base_config, step="convert")) +``` + +The first call is necessary because Forgather's sharded checkpoint saver +serialises `state_dict()` which returns *float* weights — the +`FakeQuantizedLinear` modules' scale/zero-point inner state is not +persistent. We re-install fake quantizers from the float weights and then +let convert compute the final low-bit weights and scales. The scales the +convert step picks are derived from the QAT-trained weight statistics, so +the QAT training-time accuracy benefit is preserved. + +The result is a model whose `nn.Linear` modules are now torchao subclasses +(`Int8DynActInt4WeightLinear`, etc.). Forgather's `save_checkpoint` writes +the resulting state_dict as PyTorch `.bin` (safetensors is incompatible — +see below). + +## Loss Trajectory: 1-Chinchilla Tiny Llama + +Full-length training run of `examples/tutorials/tiny_llama:v2.yaml` (Tiny +Llama, 4.43M params, ~82.6M training tokens — chinchilla-optimal at +~20 tokens/param), single GPU (RTX 3090, sm_86, wopr), same seed, same +config. The baseline run uses the v2.yaml default precision settings (bf16 +AMP via `mixed_precision: "bf16"`); the QAT run adds `--qat-recipe +int8-dynamic-act-int4-weight` on top. + +| Eval step | bf16 AMP baseline (eval_loss) | QAT int8-act-int4-wt (eval_loss) | Δ (QAT − baseline) | +|-----------|-------------------------------|----------------------------------|--------------------| +| 642 | 2.0651 | 2.0789 | +0.0138 | +| 1284 | 1.6999 | 1.7142 | +0.0143 | +| 1926 | 1.5658 | 1.5799 | +0.0141 | +| 4494 | 1.3725 | 1.3896 | +0.0171 | +| 5136 | 1.3602 | 1.3776 | +0.0174 | +| **5140 (final)** | **1.3601** | **1.3774** | **+0.0173** | + +Final train loss at step 5120 was 1.3352 vs 1.3534 (Δ +0.0182). The two +trajectories track each other from the very first eval through to +completion — QAT pays a stable ~+0.017 eval-loss premium throughout +training rather than a divergent late-training gap, which is the +encouraging signal: the model is learning under the fake-quant noise, not +just accumulating it. + +**Wall-clock overhead.** Same GPU, same model, same data: + +| Run | Wall time | Steps/sec | Tokens/sec | +|-----|-----------|-----------|------------| +| bf16 AMP baseline | 197 s | 26.1 | 419K | +| QAT int8-act-int4-wt | 329 s | 15.6 | 251K | + +QAT is ~1.67× slower than the bf16 baseline (the cost of running the fake +quantizers in pure PyTorch in the forward pass). Whether it pays for +itself depends on what the converted artifact recovers — that comparison +needs `forgather eval` + inference-server support for quantized models +(tracked in #41 and #42). + +## Save Format + +`forgather finalize --qat-convert` always writes the converted artifact in +PyTorch (`.bin`) format. The `--safetensors` flag is silently disabled with +a warning when both are set: torchao's quantized tensor subclasses +(`Int8DynActInt4WeightLinear`, `Int4Tensor`, etc.) wrap multiple inner +tensors and don't expose a single `.storage().data_ptr()`, which is what +the safetensors writer requires. Until torchao ships explicit safetensors +serialization, `.bin` is the working save format. + +The default `.bin` artifact loads cleanly through `torch.load` + the +torchao `quantize_(model, QATConfig(base_config, step="convert"))` re-cast +applied at load time. See the programmatic example below. + +## Behavior on Models Without QAT + +If you pass `--qat-convert ` to `forgather finalize` on a model +that wasn't trained with `--qat-recipe`, the same prepare-then-convert +pipeline runs anyway -- which is functionally **post-training +quantization (PTQ)**: the recipe is applied, but the result lacks the +QAT training-time accuracy benefit. The deployable artifact is still +valid and loadable. A future `--ptq-quantize` flag (tracked in #40) will +make that PTQ-on-plain-model intent explicit, but until then +`--qat-convert` is the single entry point for both flows. + +## Programmatic Usage + +```python +from forgather.ml.trainer import Trainer, TrainingArguments + +args = TrainingArguments( + output_dir="output_models/my_qat_run", + qat_recipe="int8-dynamic-act-int4-weight", + # ... other training args +) + +trainer = Trainer( + args=args, + model_init=model_factory, + train_dataset=train_dataset, +) +trainer.train() +``` + +To run convert programmatically: + +```python +from torchao.quantization import quantize_ +from torchao.quantization.qat import QATConfig +from forgather.ml.qat_recipes import recipe_to_base_config + +base_config = recipe_to_base_config("int8-dynamic-act-int4-weight") +quantize_(model, QATConfig(base_config, step="convert")) +model.save_pretrained("out/my_quantized_model", safe_serialization=True) +``` + +## Out of Scope + +The v1 integration intentionally omits a few torchao QAT knobs that aren't +needed for the common case: + +- **Auto-convert at training end**: convert is run by `forgather finalize`, + not the trainer. Keeps training and deployment concerns separated. +- **Custom `group_size` / granularity flags on the CLI**: the per-recipe + defaults in `qat_recipes.py` are the standard values. Edit them locally + if you need to experiment. +- **Range learning** (learned per-channel scales): torchao supports it via + `IntxFakeQuantizeConfig(range_learning=True)`, but the v1 recipes leave + it off. + +## See Also + +- [FP8 Training](fp8-training.md) -- the other torchao Linear-swap recipe; + mutually exclusive with QAT. +- [Finalizing a Trained Model](../guides/finalize-model.md) -- the + `forgather finalize` reference (including `--qat-convert`). diff --git a/docs/trainers/trainer_options.md b/docs/trainers/trainer_options.md index a724ab1a..0f76725b 100644 --- a/docs/trainers/trainer_options.md +++ b/docs/trainers/trainer_options.md @@ -214,11 +214,13 @@ On any GPU that supports TF32 (Ampere or newer), you usually want | Field | Type | Default | Description | |-------|------|---------|-------------| | `mixed_precision` | str \| None | None | `None` / `"no"` disabled, `"bf16"` (no GradScaler), or `"fp16"` (with GradScaler). | -| `fp8_recipe` | str \| None | None | `"tensorwise"`, `"rowwise"`, or `"rowwise_with_gw_hp"`. Converts `nn.Linear` to `Float8Linear` via torchao. Orthogonal to `mixed_precision`. | +| `fp8_recipe` | str \| None | None | `"tensorwise"`, `"rowwise"`, or `"rowwise_with_gw_hp"`. Converts `nn.Linear` to `Float8Linear` via torchao. Orthogonal to `mixed_precision`. Mutually exclusive with `qat_recipe`. | | `fp8_dim_alignment` | int | 16 | Minimum alignment for FP8 Linear layer dimensions; non-conforming layers are skipped. | +| `qat_recipe` | str \| None | None | `"int8-dynamic-act-int4-weight"`, `"int4-weight-only"`, or `"float8-dynamic-act-float8-weight"`. Installs `FakeQuantizedLinear` via torchao QAT (prepare phase). Run `forgather finalize --qat-convert ` after training to produce the deployable low-bit artifact. Mutually exclusive with `fp8_recipe`. | FP8 requires CUDA SM >= 8.9 (RTX 4090, H100, etc.). See -[`fp8-training.md`](fp8-training.md). +[`fp8-training.md`](fp8-training.md). QAT has no hardware gate (runs on any +CUDA GPU or CPU); see [`qat-training.md`](qat-training.md). --- diff --git a/src/forgather/cli/finalize.py b/src/forgather/cli/finalize.py index 02a94f99..daef4f27 100644 --- a/src/forgather/cli/finalize.py +++ b/src/forgather/cli/finalize.py @@ -75,6 +75,7 @@ def _enqueue_finalize(args): p.add_argument("--generation-config", default=None) p.add_argument("--dtype", default=None) p.add_argument("--device", default="cpu") + p.add_argument("--qat-convert", default=None) p.add_argument("--priority", type=int, default=0) p.add_argument("--server", default=None) sub = p.parse_args(args.remainder) @@ -104,6 +105,17 @@ def _enqueue_finalize(args): job_params["generation_config"] = sub.generation_config if sub.dtype: job_params["dtype"] = sub.dtype + if sub.qat_convert: + from forgather.ml.qat_recipes import QAT_RECIPES + + if sub.qat_convert not in QAT_RECIPES: + print( + f"--qat-convert must be one of {QAT_RECIPES}, " + f"got '{sub.qat_convert}'", + file=sys.stderr, + ) + raise SystemExit(2) + job_params["qat_convert"] = sub.qat_convert from .server_client import ServerClient, ServerUnreachable diff --git a/src/forgather/ml/qat_recipes.py b/src/forgather/ml/qat_recipes.py new file mode 100644 index 00000000..c9b46e83 --- /dev/null +++ b/src/forgather/ml/qat_recipes.py @@ -0,0 +1,62 @@ +"""Shared QAT recipe registry for the trainer (prepare) and finalize (convert). + +The same recipe string is supplied at training time via ``qat_recipe`` (which +inserts ``FakeQuantizedLinear`` modules) and at finalize time via +``--qat-convert`` (which swaps them for real low-bit quantized linear ops). +Both call sites resolve the string through :func:`recipe_to_base_config`. +""" + +from __future__ import annotations + + +# Source of truth for QAT recipe names. Consumed by: +# - BaseTrainingArguments validator (src/forgather/ml/trainer/base_trainer.py) +# - the Forgather finalize CLI (src/forgather/cli/finalize.py) +# - finalize_model.py's --qat-convert (tools/finalize_model/finalize_model.py) +# - the lm_training_project.yaml template's --qat-recipe `choices:` list, +# rendered from this tuple via the `qat_recipes` Jinja global +# - tools/forgather_server/webui/src/components/FinalizeModal.tsx +# (TSX duplicate — keep the four strings here and there in sync) +QAT_RECIPES: tuple[str, ...] = ( + "int8-dynamic-act-int4-weight", + "int4-weight-only", + "float8-dynamic-act-float8-weight", +) + + +def recipe_to_base_config(recipe: str): + """Map a Forgather QAT recipe string to a torchao base config instance. + + The returned object is the ``base_config`` argument for + ``torchao.quantization.qat.QATConfig(base_config, step=...)``. It must be + the *same* config (same parameters) for both the prepare and convert + phases. + """ + import torch + from torchao.quantization import ( + Float8DynamicActivationFloat8WeightConfig, + Int4WeightOnlyConfig, + Int8DynamicActivationIntxWeightConfig, + ) + from torchao.quantization.granularity import PerGroup + + if recipe == "int8-dynamic-act-int4-weight": + # Replaces the deprecated Int8DynamicActivationInt4WeightConfig + # (see pytorch/ao#2752). Same semantics: int8 per-token dynamic + # activations, int4 per-group symmetric weights. + return Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=PerGroup(group_size=32), + ) + if recipe == "int4-weight-only": + return Int4WeightOnlyConfig(group_size=128) + if recipe == "float8-dynamic-act-float8-weight": + return Float8DynamicActivationFloat8WeightConfig() + # `float8-dynamic-act-int4-weight` is intentionally not exposed in v1: + # torchao's Float8DynamicActivationInt4WeightConfig requires the + # `preshuffled` int4 packing format, which is Hopper-only (SM90+, + # FBGEMM). When we add capability-gated recipe exposure, re-introduce + # it behind a runtime check. + raise ValueError( + f"Unknown QAT recipe: {recipe!r}. Valid recipes: {QAT_RECIPES}" + ) diff --git a/src/forgather/ml/trainer/base_trainer.py b/src/forgather/ml/trainer/base_trainer.py index 026cdba1..039ebbd2 100644 --- a/src/forgather/ml/trainer/base_trainer.py +++ b/src/forgather/ml/trainer/base_trainer.py @@ -138,6 +138,14 @@ class BaseTrainingArguments(MinimalTrainingArguments): Minimum alignment for FP8 ``Linear`` layer dimensions. Layers whose ``in_features`` or ``out_features`` are not divisible by this value are skipped. Hardware requires 16. Default is ``16``. + qat_recipe : str or None, optional + Quantization-aware training recipe via ``torchao``. Inserts + ``FakeQuantizedLinear`` modules so the forward pass simulates the + target low-bit precision while backward stays in full precision. + After training, run ``forgather finalize --qat-convert `` to + produce the real low-bit deployment artifact. Mutually exclusive with + ``fp8_recipe``. See ``docs/trainers/qat-training.md`` for the recipe + list. Default is ``None``. """ # Default torch dtype for model construction (e.g., "float32", "bfloat16", "float16") @@ -203,6 +211,12 @@ class BaseTrainingArguments(MinimalTrainingArguments): # out_features not divisible by this value are skipped. Hardware requires 16. fp8_dim_alignment: int = 16 + # Quantization-aware training (QAT) via torchao. Inserts FakeQuantizedLinear + # modules in the prepare phase; convert is done post-training via + # `forgather finalize --qat-convert `. Mutually exclusive with fp8_recipe. + # See src/forgather/ml/qat_recipes.py for the recipe table. + qat_recipe: str | None = None + def __post_init__(self): if self.logging_dir is None: self.logging_dir = os.path.join( @@ -238,6 +252,26 @@ def __post_init__(self): f"fp8_recipe must be one of {_FP8_RECIPES}, got '{self.fp8_recipe}'" ) + # Validate qat_recipe + if self.qat_recipe is not None: + from forgather.ml.qat_recipes import QAT_RECIPES + + if self.qat_recipe not in QAT_RECIPES: + raise ValueError( + f"qat_recipe must be one of {QAT_RECIPES}, got '{self.qat_recipe}'" + ) + + # Linear-swap recipes are mutually exclusive (each replaces nn.Linear + # with a different specialised class). Add new recipes here to keep + # the check single-source. + _LINEAR_SWAP_RECIPES = ("fp8_recipe", "qat_recipe") + _active = [name for name in _LINEAR_SWAP_RECIPES if getattr(self, name)] + if len(_active) > 1: + raise ValueError( + f"Linear-swap recipes are mutually exclusive " + f"(each replaces nn.Linear); set at most one. Got: {_active}" + ) + TBaseTrainingArguments = TypeVar("TBaseTrainingArguments", bound=BaseTrainingArguments) diff --git a/src/forgather/ml/trainer/trainer.py b/src/forgather/ml/trainer/trainer.py index 02eaebec..1a379569 100644 --- a/src/forgather/ml/trainer/trainer.py +++ b/src/forgather/ml/trainer/trainer.py @@ -859,8 +859,16 @@ def _prepare_model(self) -> None: case _: raise ValueError("Requires one of: default|meta|device") assert self.model is not None + # Linear-swap recipes (fp8 / qat) are mutually exclusive — see the + # _LINEAR_SWAP_RECIPES check in BaseTrainingArguments.__post_init__. + # The if-chain is sequential rather than elif so a future relaxed + # mutex still surfaces a clear error (the second swap would find + # no nn.Linear left and report 0/N converted) instead of silently + # producing a single-recipe model. if self.args.fp8_recipe: self.model = self._apply_fp8_training(self.model) + if self.args.qat_recipe: + self.model = self._apply_qat_training(self.model) if self.args.gradient_checkpointing: if self.enable_activation_checkpoint_fn is None: if self.dist.rank == 0: @@ -950,6 +958,41 @@ def _filter_fn(mod: torch.nn.Module, fqn: str) -> bool: return model + def _apply_qat_training(self, model: torch.nn.Module) -> torch.nn.Module: + """Install torchao FakeQuantizedLinear modules for quantization-aware training. + + The forward pass simulates the target low-bit precision via fake + quantizers while the backward pass stays in full precision, letting + the model learn to be robust to quantization noise. The convert phase + (real low-bit ops) is run post-training by ``forgather finalize + --qat-convert ``. + """ + from torchao.quantization import quantize_ + from torchao.quantization.qat import FakeQuantizedLinear, QATConfig + + from forgather.ml.qat_recipes import recipe_to_base_config + + assert self.args.qat_recipe is not None + base_config = recipe_to_base_config(self.args.qat_recipe) + quantize_(model, QATConfig(base_config, step="prepare")) + + converted = sum( + 1 for m in model.modules() if isinstance(m, FakeQuantizedLinear) + ) + total_linear = sum( + 1 + for m in model.modules() + if isinstance(m, (torch.nn.Linear, FakeQuantizedLinear)) + ) + logger.info( + f"QAT training ({self.args.qat_recipe}): " + f"converted {converted}/{total_linear} Linear layers to FakeQuantizedLinear. " + f"Run `forgather finalize --qat-convert {self.args.qat_recipe}` " + f"after training to produce a deployable quantized artifact." + ) + + return model + def _wrap_loss_fn(self): # Rescale loss by gradient accumulation steps. self.loss_fn = RescaleLoss( diff --git a/src/forgather/preprocess.py b/src/forgather/preprocess.py index 21b7e875..45b48c0a 100644 --- a/src/forgather/preprocess.py +++ b/src/forgather/preprocess.py @@ -84,6 +84,18 @@ def _detect_gpu_flops() -> tuple[str | None, float | None]: return None, None +def _qat_recipes_global() -> list[str]: + """Return the QAT recipe list for use as a Jinja global. + + Renders to YAML via the ``toyaml`` filter so the template's + ``--qat-recipe`` ``choices:`` list stays in sync with the Python source + of truth (``forgather.ml.qat_recipes.QAT_RECIPES``). + """ + from forgather.ml.qat_recipes import QAT_RECIPES + + return list(QAT_RECIPES) + + def get_peak_hardware_flops() -> float: """ Return the peak BF16 FLOP/s (FP32 accumulation) for a single GPU. @@ -577,6 +589,7 @@ class PPEnvironment(SandboxedEnvironment): "getcwd": os.getcwd, "forgather_config_dir": forgather_config_dir, "get_peak_hardware_flops": get_peak_hardware_flops, + "qat_recipes": _qat_recipes_global, # https://pypi.org/project/platformdirs/ "user_data_dir": user_data_dir, "user_cache_dir": user_cache_dir, diff --git a/templatelib/examples/projects/lm_training_project.yaml b/templatelib/examples/projects/lm_training_project.yaml index 1eff10a3..98af9eb8 100644 --- a/templatelib/examples/projects/lm_training_project.yaml +++ b/templatelib/examples/projects/lm_training_project.yaml @@ -93,6 +93,7 @@ ## (str) Approximate float32 matmul with bf16. Choices: highest, high, medium ## mixed_precision (str) AMP dtype. Choices: bf16, fp16, no ## fp8_recipe (str) FP8 recipe for linear layers. Choices: tensorwise, rowwise, rowwise_with_gw_hp +## qat_recipe (str) QAT recipe for linear layers. Choices: see forgather.ml.qat_recipes.QAT_RECIPES (rendered into --qat-recipe --help) ## compile (bool) Enable torch.compile. Default: False ## torch_compile_mode (str) torch.compile mode. Default: default ## Choices: default, max-autotune, max-autotune-no-cudagraphs. @@ -469,6 +470,7 @@ model: &model !call:getitem [ *model_dict, 'model' ] default_dtype: {{ default_dtype | toyaml(None) }} mixed_precision: {{ mixed_precision | toyaml(None) }} fp8_recipe: {{ fp8_recipe | toyaml(None) }} + qat_recipe: {{ qat_recipe | toyaml(None) }} gradient_checkpointing: {{ gradient_checkpointing | default(False) }} fuse_optim_with_backward: {{ fuse_optim_with_backward | default(False) }} enable_activation_offloading: {{ activation_offloading | default(False) }} @@ -695,6 +697,11 @@ optimizer_groups: &optimizer_groups choices: [ "tensorwise", "rowwise", "rowwise_with_gw_hp" ] group: "Precision" help: "FP8 training recipe (torchao Float8Linear)" + qat_recipe: + names: "--qat-recipe" + choices: {{ qat_recipes() | toyaml }} + group: "Precision" + help: "Quantization-aware training recipe (torchao QAT prepare; run `forgather finalize --qat-convert ` after training)" compile: names: "--compile" type: bool diff --git a/tools/finalize_model/finalize_model.py b/tools/finalize_model/finalize_model.py index 35bc37a7..e1c88b50 100644 --- a/tools/finalize_model/finalize_model.py +++ b/tools/finalize_model/finalize_model.py @@ -174,6 +174,19 @@ def parse_args(argv=None): default="cpu", help="Device to load the model onto during finalize (default: cpu)", ) + parser.add_argument( + "--qat-convert", + type=str, + default=None, + help=( + "Run the torchao QAT convert step before saving: swaps " + "FakeQuantizedLinear modules for real low-bit quantized linear " + "ops. Pass the same recipe string that was used at training " + "time (e.g. 'int8-dynamic-act-int4-weight'). On models that " + "were not QAT-trained this is a no-op with a warning. See " + "docs/trainers/qat-training.md." + ), + ) parser.add_argument( "--dry-run", action="store_true", @@ -198,6 +211,66 @@ def _resolve_dtype(dtype_str: Optional[str]): return torch_dtype(dtype_str) +def _apply_qat_convert(model, recipe: str) -> None: + """Run torchao's QAT prepare-then-convert pipeline in-place on a loaded model. + + The loaded model's state_dict contains plain float weights (Forgather's + sharded saver doesn't persist FakeQuantizedLinear inner state), so we + re-install fake quantizers on top of the trained weights via + ``step="prepare"``, then swap them for the real low-bit quantized linear + ops via ``step="convert"``. The scales/zero-points the convert step + computes are derived from the (QAT-trained) weight statistics. + + Running this on a model that was *not* trained with ``--qat-recipe`` is + functionally a post-training-quantization (PTQ) pass: the recipe will + still be applied, but the result lacks the QAT training-time accuracy + benefit. A future ``--ptq-quantize`` flag (tracked separately) will make + that PTQ-on-plain-model intent explicit. + """ + import torch + from torchao.quantization import quantize_ + from torchao.quantization.qat import FakeQuantizedLinear, QATConfig + + from forgather.ml.qat_recipes import QAT_RECIPES, recipe_to_base_config + + if recipe not in QAT_RECIPES: + raise ValueError( + f"--qat-convert must be one of {QAT_RECIPES}, got {recipe!r}" + ) + + base_config = recipe_to_base_config(recipe) + + # If the loaded model already has FakeQuantizedLinear modules (e.g. a + # future Forgather saver that preserves them), skip prepare and go + # straight to convert. + fq_count = sum(1 for m in model.modules() if isinstance(m, FakeQuantizedLinear)) + if fq_count == 0: + linear_count = sum( + 1 for m in model.modules() if isinstance(m, torch.nn.Linear) + ) + if linear_count == 0: + logger.warning( + "--qat-convert %r requested but model has no nn.Linear " + "modules to quantize; skipping convert step.", + recipe, + ) + return + logger.info( + f"QAT convert ({recipe}): re-installing fake quantizers on " + f"{linear_count} nn.Linear modules before convert" + ) + quantize_(model, QATConfig(base_config, step="prepare")) + fq_count = sum( + 1 for m in model.modules() if isinstance(m, FakeQuantizedLinear) + ) + + quantize_(model, QATConfig(base_config, step="convert")) + logger.info( + f"QAT convert ({recipe}): converted {fq_count} FakeQuantizedLinear " + f"modules to real quantized linear ops" + ) + + def main(argv=None): args = parse_args(argv) @@ -351,9 +424,28 @@ def main(argv=None): logger.info( f"Would write generation_config.json (mode={args.generation_config})" ) + if args.qat_convert: + logger.info( + f"Would run QAT convert step with recipe '{args.qat_convert}'" + ) return 0 - # ---- 6. Materialize destination ------------------------------------ + # ---- 6. QAT convert (optional) ------------------------------------- + if args.qat_convert: + _apply_qat_convert(model, args.qat_convert) + if args.safetensors: + # torchao's quantized tensor subclasses wrap multiple inner + # tensors and do not expose a single .storage().data_ptr(), + # so safetensors saves fail with "Attempted to access the + # data pointer on an invalid python storage". Force .bin. + logger.warning( + "--safetensors is incompatible with QAT-converted models " + "(torchao subclass tensors lack a single storage pointer). " + "Saving as PyTorch (.bin) instead." + ) + args.safetensors = False + + # ---- 7. Materialize destination ------------------------------------ os.makedirs(dest, exist_ok=False) copy_model_source(source, dest) diff --git a/tools/forgather_server/finalize_ops.py b/tools/forgather_server/finalize_ops.py index 865bb54d..8dc61199 100644 --- a/tools/forgather_server/finalize_ops.py +++ b/tools/forgather_server/finalize_ops.py @@ -35,6 +35,7 @@ def build_finalize_command( device: Optional[str] = None, dry_run: bool = False, log_level: str = "INFO", + qat_convert: Optional[str] = None, ) -> List[str]: """Build the argv for ``tools/finalize_model/finalize_model.py``. @@ -70,6 +71,8 @@ def build_finalize_command( cmd.extend(["--dtype", dtype]) if device: cmd.extend(["--device", device]) + if qat_convert: + cmd.extend(["--qat-convert", qat_convert]) if dry_run: cmd.append("--dry-run") cmd.extend(["--log-level", log_level]) diff --git a/tools/forgather_server/launcher.py b/tools/forgather_server/launcher.py index e3b6e991..f8a9cee0 100644 --- a/tools/forgather_server/launcher.py +++ b/tools/forgather_server/launcher.py @@ -461,6 +461,7 @@ def spawn_finalize_process( device: Optional[str] = None, dry_run: bool = False, log_level: str = "INFO", + qat_convert: Optional[str] = None, extra_env: Optional[Dict[str, str]] = None, ) -> LaunchResult: """Spawn a ``forgather finalize`` run. @@ -486,6 +487,7 @@ def spawn_finalize_process( device=device, dry_run=dry_run, log_level=log_level, + qat_convert=qat_convert, ) return _spawn_subprocess(cmd, gpu_indices, tty_log_path, extra_env) diff --git a/tools/forgather_server/scheduler.py b/tools/forgather_server/scheduler.py index c7abaee0..c858a412 100644 --- a/tools/forgather_server/scheduler.py +++ b/tools/forgather_server/scheduler.py @@ -454,6 +454,7 @@ def _build_finalize(item, gpu_indices, tty_path): device=p.get("device"), dry_run=bool(p.get("dry_run", False)), log_level=p.get("log_level", "INFO"), + qat_convert=p.get("qat_convert"), gpu_indices=gpu_indices, tty_log_path=tty_path, ) diff --git a/tools/forgather_server/webui/src/components/FinalizeModal.tsx b/tools/forgather_server/webui/src/components/FinalizeModal.tsx index a1c2890b..a09fa834 100644 --- a/tools/forgather_server/webui/src/components/FinalizeModal.tsx +++ b/tools/forgather_server/webui/src/components/FinalizeModal.tsx @@ -31,8 +31,17 @@ interface PersistedFinalize { dryRun: boolean; logLevel: string; requestedGpus: number; + /** Torchao QAT convert recipe. Empty string means "skip QAT convert". */ + qatConvert: string; } +// Keep in sync with QAT_RECIPES in src/forgather/ml/qat_recipes.py. +const QAT_CONVERT_RECIPES = [ + "int8-dynamic-act-int4-weight", + "int4-weight-only", + "float8-dynamic-act-float8-weight", +] as const; + const STORAGE_KEY = "forgather-global-finalize-v1"; const DEFAULTS: PersistedFinalize = { @@ -56,6 +65,7 @@ const DEFAULTS: PersistedFinalize = { dryRun: false, logLevel: "INFO", requestedGpus: 0, + qatConvert: "", }; function loadPersisted(): Partial { @@ -200,6 +210,7 @@ export function FinalizeModal({ initialSource, onClose, onSubmitted }: Props) { const [requestedGpus, setRequestedGpus] = useState( initial.requestedGpus ?? 0, ); + const [qatConvert, setQatConvert] = useState(initial.qatConvert ?? ""); const [priority, setPriority] = useState(0); // Backfill tokenizer defaults once quick-paths resolves, but only @@ -248,6 +259,7 @@ export function FinalizeModal({ initialSource, onClose, onSubmitted }: Props) { setDryRun(DEFAULTS.dryRun); setLogLevel(DEFAULTS.logLevel); setRequestedGpus(DEFAULTS.requestedGpus); + setQatConvert(DEFAULTS.qatConvert); }; const enqueue = useMutation({ @@ -289,6 +301,7 @@ export function FinalizeModal({ initialSource, onClose, onSubmitted }: Props) { dryRun, logLevel, requestedGpus, + qatConvert, }); const job_params: Record = { @@ -317,6 +330,7 @@ export function FinalizeModal({ initialSource, onClose, onSubmitted }: Props) { if (dtype && dtype !== "keep") job_params.dtype = dtype; const dev = device.trim(); if (dev) job_params.device = dev; + if (qatConvert) job_params.qat_convert = qatConvert; enqueue.mutate({ // project_dir isn't meaningful for finalize; use the dest path so @@ -572,6 +586,26 @@ export function FinalizeModal({ initialSource, onClose, onSubmitted }: Props) { resolve only; don't write +
+ +

Scheduling