Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/guides/evaluating-models.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,17 @@ same mechanism `forgather train` uses). Pass an explicit path with
`--checkpoint PATH` to pin a specific one, or `--no-checkpoint` to load via
`AutoModelForCausalLM.from_pretrained` on the model directory.

**Quantized models** (artifacts produced by `forgather finalize --quantize`)
are autodetected: if `config.json` has a `quantization_config` block, eval
forces the `from_pretrained()` path and ignores both `--checkpoint` and
`--no-checkpoint`. An explicit `--checkpoint PATH` logs a warning before
being overridden. The checkpoint-resume path uses `from_config()` +
`load_state_dict()` which has no quantizer hook and fails on quantized
tensor subclasses; `from_pretrained()` runs HF's `TorchAoHfQuantizer`
pre-process so the right linear modules are in place before the weights
load. See [QAT Training § Evaluating Quantized
Models](../trainers/qat-training.md#evaluating-quantized-models).

The tokenizer is always loaded directly from `--model` via
`AutoTokenizer.from_pretrained`.

Expand Down
8 changes: 8 additions & 0 deletions docs/guides/finalize-model.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,14 @@ quantized tensor subclasses don't expose a single `.storage().data_ptr()`,
which the safetensors writer requires. If `--safetensors` is passed
alongside `--quantize`, it is silently disabled with a warning.

Finalize also writes a `quantization_config` block into `config.json`
with the recipe. This makes HF `AutoModelForCausalLM.from_pretrained()`
auto-detect the quantization on reload and run the
`TorchAoHfQuantizer` pre-process path — so `forgather eval` and the
inference server load the artifact correctly without any caller-side
flag. See [Evaluating Quantized
Models](../trainers/qat-training.md#evaluating-quantized-models).

### Misc

| Option | Description |
Expand Down
57 changes: 52 additions & 5 deletions docs/trainers/qat-training.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ Whether that's worth it depends heavily on the recipe:

| Recipe | QAT vs PTQ delta (expected) | Recommendation |
|--------|---------------------------|----------------|
| `int8-dynamic-act-int4-weight` | Largest QAT benefit — int4 weights are aggressive enough that plain PTQ can drift noticeably. | QAT if you care about the last point of eval loss; PTQ fine for prototyping. |
| `int8-dynamic-act-int4-weight` | Largest QAT benefit — int4 weights are aggressive enough that plain PTQ can drift noticeably. Measured Δ at 4.43M params is small (0.0013, see [Three-Way Comparison](#three-way-comparison-bf16--ptq--qat)); the gap is expected to widen at larger scale. | QAT if you care about the last point of eval loss; PTQ fine for prototyping. |
| `int4-weight-only` | Moderate QAT benefit. Per-group int4 + bf16 matmul is already quite robust. | PTQ first; reach for QAT only if eval drops more than you can absorb. |
| `float8-dynamic-act-float8-weight` | Minimal — fp8 is already near-lossless. | PTQ. QAT is rarely justified for fp8. |

Expand Down Expand Up @@ -261,10 +261,57 @@ forgather finalize output_models/my_bf16_run out/my_bf16_run_int8_int4 \
This is the path used for the AMP-baseline / PTQ / QAT three-way
comparison: train the same model once in plain bf16, then run finalize
twice (one source bf16, one source QAT) with the same `--quantize`
recipe and compare the eval results. PTQ on a model that was not
QAT-trained typically pays a larger accuracy penalty than the
`~+0.017` eval-loss premium QAT does (see [Loss Trajectory](#loss-trajectory-1-chinchilla-tiny-llama)
above for the QAT numbers); how much depends on the model and recipe.
recipe and compare the eval results. See [Three-Way
Comparison](#three-way-comparison-bf16--ptq--qat) below for measured
numbers on Tiny Llama.

## Evaluating Quantized Models

`forgather eval` loads `--quantize`-finalized models with no extra flag:

```bash
# Same invocation as for a bf16 model — eval autodetects the quantized
# artifact and routes through HF `from_pretrained()` so the
# TorchAoHfQuantizer pre-process path installs the right linear modules.
forgather -p examples/tutorials/tiny_llama eval test tinystories \
-M /path/to/quantized_model
```

How it works: at finalize time, `--quantize` writes a
`quantization_config` block into `config.json` with the recipe. At eval
time, `scripts/eval_script.py:resolve_checkpoint()` reads that field
and forces the `from_pretrained()` load path, **ignoring** any
`--checkpoint` / `--no-checkpoint` flag (an explicit `--checkpoint
PATH` logs a warning before being overridden). The normal
checkpoint-resume path uses `from_config()` + `load_state_dict()`,
which does not run any quantizer hook and fails when the saved tensors
are torchao quantized subclasses. Same mechanism applies to the
inference server — no caller-side changes needed.

The check is purely additive: bf16 models keep using the
checkpoint-resume path. Pass `--no-checkpoint` to opt into
`from_pretrained()` manually for non-quantized models.

### Three-Way Comparison: bf16 / PTQ / QAT

Full eval against the Tiny Stories test split for the 4.43M Tiny Llama
trained at 1-Chinchilla on a single RTX 3090 (matching the [Loss
Trajectory](#loss-trajectory-1-chinchilla-tiny-llama) setup above), all
three using `int8-dynamic-act-int4-weight`:

| Model | eval_loss | perplexity | Δ vs bf16 |
|-------|-----------|------------|-----------|
| bf16 baseline | 1.3656 | 3.918 | — |
| PTQ on bf16 baseline | 1.3917 | 4.022 | +0.0262 |
| QAT-trained + converted | 1.3905 | 4.017 | +0.0249 |

At this scale (4.43M params, recipe = int4 weights / int8 dynamic
activations), QAT shaves about **0.0013 eval-loss** off PTQ — a real
but small gain. Whether QAT's ~1.67× training-time overhead pays for
that depends on your tolerance for the +0.025 eval-loss premium that
quantization itself imposes (PTQ buys you almost all of the win at zero
training cost). QAT is expected to scale better at larger models and
more aggressive recipes — measure your own setup before committing.

## Programmatic Usage

Expand Down
57 changes: 52 additions & 5 deletions scripts/eval_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,19 +129,66 @@ def init_model():
return init_model


def _model_is_quantized(model_path: str) -> bool:
"""True iff ``<model_path>/config.json`` carries a ``quantization_config`` block.

Written by ``forgather finalize --quantize``; HF's
``from_pretrained()`` consumes it to install the TorchAoHfQuantizer
pre-process path. Unreadable or malformed config.json -> False.
"""
cfg_path = os.path.join(model_path, "config.json")
if not os.path.isfile(cfg_path):
return False
try:
with open(cfg_path) as f:
return "quantization_config" in json.load(f)
except (OSError, ValueError):
return False


def resolve_checkpoint(args):
"""Return (checkpoint_arg, use_checkpoint) for the trainer.

``checkpoint_arg`` is what we pass as ``resume_from_checkpoint``:
- True: auto-find latest
- str path: explicit
- False: do not resume

Quantized models force ``(False, False)`` regardless of flags: only HF
`from_pretrained()` runs the TorchAoHfQuantizer pre-process that
installs the right quantized linear modules before weight load. The
checkpoint-resume path uses `from_config()` + `load_state_dict()`,
which has no quantizer hook and crashes with `'Parameter' object has
no attribute 'tensor_data_names'` on torchao tensor subclasses.

Cache the result on ``args`` so repeated calls don't re-log.
"""
if args.no_checkpoint:
return False, False
if args.checkpoint:
return args.checkpoint, True
return True, True
cached = getattr(args, "_resolved_checkpoint", None)
if cached is not None:
return cached

if _model_is_quantized(args.model):
if args.checkpoint:
logger.warning(
"Detected quantization_config in model config; ignoring "
"--checkpoint %r and loading via from_pretrained().",
args.checkpoint,
)
else:
logger.info(
"Detected quantization_config in model config; "
"loading via from_pretrained() (overrides default checkpoint resume)."
)
result = (False, False)
elif args.no_checkpoint:
result = (False, False)
elif args.checkpoint:
result = (args.checkpoint, True)
else:
result = (True, True)

args._resolved_checkpoint = result
return result


def build_trainer(args, model_init, eval_dataset, data_collator, tokenizer, device):
Expand Down
15 changes: 15 additions & 0 deletions tools/finalize_model/finalize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,11 +432,26 @@ def main(argv=None):
logger.info(
f"Would run quantize step with recipe '{args.quantize_recipe}'"
)
logger.info(
"Would write 'quantization_config' block to config.json"
)
return 0

# ---- 6. Quantize (optional) ----------------------------------------
if args.quantize_recipe:
_apply_quantize(model, args.quantize_recipe)
# Record the recipe on the config so HF `from_pretrained()` runs
# the TorchAoHfQuantizer pre-process path on reload — it installs
# the right quantized linear modules before `load_state_dict`, so
# the quantized tensor subclasses land in slots that know how to
# hold them. Without this block, reload via `from_pretrained()`
# fails with `'Parameter' object has no attribute 'tensor_data_names'`.
from transformers import TorchAoConfig
from forgather.ml.qat_recipes import recipe_to_base_config

config.quantization_config = TorchAoConfig(
quant_type=recipe_to_base_config(args.quantize_recipe)
)
if args.safetensors:
# torchao's quantized tensor subclasses wrap multiple inner
# tensors and do not expose a single .storage().data_ptr(),
Expand Down
Loading