From 07feecbadec63871dd7a3a6f510a18f54a9b84e3 Mon Sep 17 00:00:00 2001 From: Jason dinAlt Date: Fri, 15 May 2026 02:51:27 +0000 Subject: [PATCH 1/2] eval/finalize: load torchao-quantized models without a flag MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Finalize now writes a `quantization_config` block into the saved config.json so HF `from_pretrained()` runs the `TorchAoHfQuantizer` pre-process path on reload — the quantized linear modules are installed before `load_state_dict`, so the saved tensor subclasses land in slots that know how to hold them. Eval autodetects the block at load time and forces the `from_pretrained()` path regardless of `--checkpoint` / `--no-checkpoint`. The default checkpoint-resume path uses `from_config()` + `load_state_dict()` which has no quantizer hook and fails with "'Parameter' object has no attribute 'tensor_data_names'" on torchao quantized tensor subclasses. No new CLI flag and no marker file — issue #41 originally proposed both, but HF's existing config-based autodetection makes them unnecessary. Same mechanism benefits the inference server (issue #42) for free. Three-way comparison verified on 4.43M Tiny Llama (RTX 3090, full Tiny Stories eval, recipe `int8-dynamic-act-int4-weight`): - bf16 baseline: eval_loss=1.366 - PTQ on bf16: eval_loss=1.392 (+0.026) - QAT converted: eval_loss=1.391 (+0.025; -0.001 vs PTQ) Numbers and the loader autodetection mechanism documented in docs/trainers/qat-training.md, with cross-links from docs/guides/finalize-model.md and docs/guides/evaluating-models.md. Closes #41. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/guides/evaluating-models.md | 10 +++++ docs/guides/finalize-model.md | 8 ++++ docs/trainers/qat-training.md | 53 ++++++++++++++++++++++++-- scripts/eval_script.py | 18 +++++++++ tools/finalize_model/finalize_model.py | 12 ++++++ 5 files changed, 97 insertions(+), 4 deletions(-) diff --git a/docs/guides/evaluating-models.md b/docs/guides/evaluating-models.md index e278aef6d..cfd564c7a 100644 --- a/docs/guides/evaluating-models.md +++ b/docs/guides/evaluating-models.md @@ -119,6 +119,16 @@ same mechanism `forgather train` uses). Pass an explicit path with `--checkpoint PATH` to pin a specific one, or `--no-checkpoint` to load via `AutoModelForCausalLM.from_pretrained` on the model directory. +**Quantized models** (artifacts produced by `forgather finalize --quantize`) +are autodetected: if `config.json` has a `quantization_config` block, eval +forces the `from_pretrained()` path regardless of `--checkpoint` / +`--no-checkpoint` flags. The checkpoint-resume path uses `from_config()` + +`load_state_dict()` which has no quantizer hook and fails on quantized +tensor subclasses; `from_pretrained()` runs HF's `TorchAoHfQuantizer` +pre-process so the right linear modules are in place before the weights +load. See [QAT Training § Evaluating Quantized +Models](../trainers/qat-training.md#evaluating-quantized-models). + The tokenizer is always loaded directly from `--model` via `AutoTokenizer.from_pretrained`. diff --git a/docs/guides/finalize-model.md b/docs/guides/finalize-model.md index 395955efb..47dd73d46 100644 --- a/docs/guides/finalize-model.md +++ b/docs/guides/finalize-model.md @@ -122,6 +122,14 @@ quantized tensor subclasses don't expose a single `.storage().data_ptr()`, which the safetensors writer requires. If `--safetensors` is passed alongside `--quantize`, it is silently disabled with a warning. +Finalize also writes a `quantization_config` block into `config.json` +with the recipe. This makes HF `AutoModelForCausalLM.from_pretrained()` +auto-detect the quantization on reload and run the +`TorchAoHfQuantizer` pre-process path — so `forgather eval` and the +inference server load the artifact correctly without any caller-side +flag. See [Evaluating Quantized +Models](../trainers/qat-training.md#evaluating-quantized-models). + ### Misc | Option | Description | diff --git a/docs/trainers/qat-training.md b/docs/trainers/qat-training.md index 59a9b7217..1b8022b40 100644 --- a/docs/trainers/qat-training.md +++ b/docs/trainers/qat-training.md @@ -261,10 +261,55 @@ forgather finalize output_models/my_bf16_run out/my_bf16_run_int8_int4 \ This is the path used for the AMP-baseline / PTQ / QAT three-way comparison: train the same model once in plain bf16, then run finalize twice (one source bf16, one source QAT) with the same `--quantize` -recipe and compare the eval results. PTQ on a model that was not -QAT-trained typically pays a larger accuracy penalty than the -`~+0.017` eval-loss premium QAT does (see [Loss Trajectory](#loss-trajectory-1-chinchilla-tiny-llama) -above for the QAT numbers); how much depends on the model and recipe. +recipe and compare the eval results. See [Three-Way +Comparison](#three-way-comparison-bf16--ptq--qat) below for measured +numbers on Tiny Llama. + +## Evaluating Quantized Models + +`forgather eval` loads `--quantize`-finalized models with no extra flag: + +```bash +# Same invocation as for a bf16 model — eval autodetects the quantized +# artifact and routes through HF `from_pretrained()` so the +# TorchAoHfQuantizer pre-process path installs the right linear modules. +forgather -p examples/tutorials/tiny_llama eval test tinystories \ + -M /path/to/quantized_model +``` + +How it works: at finalize time, `--quantize` writes a +`quantization_config` block into `config.json` with the recipe. At eval +time, `scripts/eval_script.py:resolve_checkpoint()` reads that field +and forces the `from_pretrained()` load path (the normal +checkpoint-resume path uses `from_config()` + `load_state_dict()`, which +does not run any quantizer hook and fails when the saved tensors are +torchao quantized subclasses). Same mechanism applies to the inference +server — no caller-side changes needed. + +The check is purely additive: bf16 models keep using the +checkpoint-resume path. Pass `--no-checkpoint` to opt into +`from_pretrained()` manually for non-quantized models. + +### Three-Way Comparison: bf16 / PTQ / QAT + +Full eval against the Tiny Stories test split for the 4.43M Tiny Llama +trained at 1-Chinchilla on a single RTX 3090 (matching the [Loss +Trajectory](#loss-trajectory-1-chinchilla-tiny-llama) setup above), all +three using `int8-dynamic-act-int4-weight`: + +| Model | eval_loss | perplexity | Δ vs bf16 | +|-------|-----------|------------|-----------| +| bf16 baseline | 1.3656 | 3.918 | — | +| PTQ on bf16 baseline | 1.3917 | 4.022 | +0.0262 | +| QAT-trained + converted | 1.3905 | 4.017 | +0.0249 | + +At this scale (4.43M params, recipe = int4 weights / int8 dynamic +activations), QAT shaves about **0.0013 eval-loss** off PTQ — a real +but small gain. Whether QAT's ~1.67× training-time overhead pays for +that depends on your tolerance for the +0.025 perplexity premium that +quantization itself imposes (PTQ buys you almost all of the win at zero +training cost). QAT is expected to scale better at larger models and +more aggressive recipes — measure your own setup before committing. ## Programmatic Usage diff --git a/scripts/eval_script.py b/scripts/eval_script.py index 6edfd0fad..4e4f211fb 100644 --- a/scripts/eval_script.py +++ b/scripts/eval_script.py @@ -141,6 +141,24 @@ def resolve_checkpoint(args): return False, False if args.checkpoint: return args.checkpoint, True + # Quantized models can only load through HF `from_pretrained()` — that + # path installs TorchAoHfQuantizer pre-process, which swaps in the + # quantized linear modules so the saved tensor subclasses land in + # slots that know how to hold them. The checkpoint-resume path uses + # `from_config()` + `load_state_dict()`, which has no quantizer hook + # and fails with `'Parameter' object has no attribute 'tensor_data_names'`. + cfg_path = os.path.join(args.model, "config.json") + if os.path.isfile(cfg_path): + try: + with open(cfg_path) as f: + if "quantization_config" in json.load(f): + logger.info( + "Detected quantization_config in model config; " + "loading via from_pretrained() instead of checkpoint resume." + ) + return False, False + except (OSError, ValueError): + pass return True, True diff --git a/tools/finalize_model/finalize_model.py b/tools/finalize_model/finalize_model.py index 79817573b..e6a89f471 100644 --- a/tools/finalize_model/finalize_model.py +++ b/tools/finalize_model/finalize_model.py @@ -437,6 +437,18 @@ def main(argv=None): # ---- 6. Quantize (optional) ---------------------------------------- if args.quantize_recipe: _apply_quantize(model, args.quantize_recipe) + # Record the recipe on the config so HF `from_pretrained()` runs + # the TorchAoHfQuantizer pre-process path on reload — it installs + # the right quantized linear modules before `load_state_dict`, so + # the quantized tensor subclasses land in slots that know how to + # hold them. Without this block, reload via `from_pretrained()` + # fails with `'Parameter' object has no attribute 'tensor_data_names'`. + from transformers import TorchAoConfig + from forgather.ml.qat_recipes import recipe_to_base_config + + config.quantization_config = TorchAoConfig( + quant_type=recipe_to_base_config(args.quantize_recipe) + ) if args.safetensors: # torchao's quantized tensor subclasses wrap multiple inner # tensors and do not expose a single .storage().data_ptr(), From be638c516ab13158a96aa8ff358fb1efdc496207 Mon Sep 17 00:00:00 2001 From: Jason dinAlt Date: Fri, 15 May 2026 03:01:19 +0000 Subject: [PATCH 2/2] review fixes: autodetect cache, --checkpoint warning, dry-run log, doc polish MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Cache `resolve_checkpoint()` result on `args` so the autodetect log fires once per run instead of twice. - Move the quantization probe ahead of the `--checkpoint` short-circuit: explicit `--checkpoint PATH` on a quantized model previously slid through to the failing `from_config()` path and crashed with the tensor_data_names error. Now it logs a WARNING naming the overridden path and routes through `from_pretrained()`. - Factor out `_model_is_quantized()` helper. - Dry-run logs that it would write a `quantization_config` block. - Docs: clarify that autodetect overrides both checkpoint flags (not just "checkpoint resume"); s/perplexity premium/eval-loss premium/ in the three-way comparison; forward-reference the measured Δ from the Choosing-a-Recipe table to the Three-Way Comparison section. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/guides/evaluating-models.md | 5 +- docs/trainers/qat-training.md | 16 +++--- scripts/eval_script.py | 75 ++++++++++++++++++-------- tools/finalize_model/finalize_model.py | 3 ++ 4 files changed, 67 insertions(+), 32 deletions(-) diff --git a/docs/guides/evaluating-models.md b/docs/guides/evaluating-models.md index cfd564c7a..7a95850ee 100644 --- a/docs/guides/evaluating-models.md +++ b/docs/guides/evaluating-models.md @@ -121,8 +121,9 @@ same mechanism `forgather train` uses). Pass an explicit path with **Quantized models** (artifacts produced by `forgather finalize --quantize`) are autodetected: if `config.json` has a `quantization_config` block, eval -forces the `from_pretrained()` path regardless of `--checkpoint` / -`--no-checkpoint` flags. The checkpoint-resume path uses `from_config()` + +forces the `from_pretrained()` path and ignores both `--checkpoint` and +`--no-checkpoint`. An explicit `--checkpoint PATH` logs a warning before +being overridden. The checkpoint-resume path uses `from_config()` + `load_state_dict()` which has no quantizer hook and fails on quantized tensor subclasses; `from_pretrained()` runs HF's `TorchAoHfQuantizer` pre-process so the right linear modules are in place before the weights diff --git a/docs/trainers/qat-training.md b/docs/trainers/qat-training.md index 1b8022b40..61befc1f3 100644 --- a/docs/trainers/qat-training.md +++ b/docs/trainers/qat-training.md @@ -117,7 +117,7 @@ Whether that's worth it depends heavily on the recipe: | Recipe | QAT vs PTQ delta (expected) | Recommendation | |--------|---------------------------|----------------| -| `int8-dynamic-act-int4-weight` | Largest QAT benefit — int4 weights are aggressive enough that plain PTQ can drift noticeably. | QAT if you care about the last point of eval loss; PTQ fine for prototyping. | +| `int8-dynamic-act-int4-weight` | Largest QAT benefit — int4 weights are aggressive enough that plain PTQ can drift noticeably. Measured Δ at 4.43M params is small (0.0013, see [Three-Way Comparison](#three-way-comparison-bf16--ptq--qat)); the gap is expected to widen at larger scale. | QAT if you care about the last point of eval loss; PTQ fine for prototyping. | | `int4-weight-only` | Moderate QAT benefit. Per-group int4 + bf16 matmul is already quite robust. | PTQ first; reach for QAT only if eval drops more than you can absorb. | | `float8-dynamic-act-float8-weight` | Minimal — fp8 is already near-lossless. | PTQ. QAT is rarely justified for fp8. | @@ -280,11 +280,13 @@ forgather -p examples/tutorials/tiny_llama eval test tinystories \ How it works: at finalize time, `--quantize` writes a `quantization_config` block into `config.json` with the recipe. At eval time, `scripts/eval_script.py:resolve_checkpoint()` reads that field -and forces the `from_pretrained()` load path (the normal -checkpoint-resume path uses `from_config()` + `load_state_dict()`, which -does not run any quantizer hook and fails when the saved tensors are -torchao quantized subclasses). Same mechanism applies to the inference -server — no caller-side changes needed. +and forces the `from_pretrained()` load path, **ignoring** any +`--checkpoint` / `--no-checkpoint` flag (an explicit `--checkpoint +PATH` logs a warning before being overridden). The normal +checkpoint-resume path uses `from_config()` + `load_state_dict()`, +which does not run any quantizer hook and fails when the saved tensors +are torchao quantized subclasses. Same mechanism applies to the +inference server — no caller-side changes needed. The check is purely additive: bf16 models keep using the checkpoint-resume path. Pass `--no-checkpoint` to opt into @@ -306,7 +308,7 @@ three using `int8-dynamic-act-int4-weight`: At this scale (4.43M params, recipe = int4 weights / int8 dynamic activations), QAT shaves about **0.0013 eval-loss** off PTQ — a real but small gain. Whether QAT's ~1.67× training-time overhead pays for -that depends on your tolerance for the +0.025 perplexity premium that +that depends on your tolerance for the +0.025 eval-loss premium that quantization itself imposes (PTQ buys you almost all of the win at zero training cost). QAT is expected to scale better at larger models and more aggressive recipes — measure your own setup before committing. diff --git a/scripts/eval_script.py b/scripts/eval_script.py index 4e4f211fb..61d8ff0b1 100644 --- a/scripts/eval_script.py +++ b/scripts/eval_script.py @@ -129,6 +129,23 @@ def init_model(): return init_model +def _model_is_quantized(model_path: str) -> bool: + """True iff ``/config.json`` carries a ``quantization_config`` block. + + Written by ``forgather finalize --quantize``; HF's + ``from_pretrained()`` consumes it to install the TorchAoHfQuantizer + pre-process path. Unreadable or malformed config.json -> False. + """ + cfg_path = os.path.join(model_path, "config.json") + if not os.path.isfile(cfg_path): + return False + try: + with open(cfg_path) as f: + return "quantization_config" in json.load(f) + except (OSError, ValueError): + return False + + def resolve_checkpoint(args): """Return (checkpoint_arg, use_checkpoint) for the trainer. @@ -136,30 +153,42 @@ def resolve_checkpoint(args): - True: auto-find latest - str path: explicit - False: do not resume + + Quantized models force ``(False, False)`` regardless of flags: only HF + `from_pretrained()` runs the TorchAoHfQuantizer pre-process that + installs the right quantized linear modules before weight load. The + checkpoint-resume path uses `from_config()` + `load_state_dict()`, + which has no quantizer hook and crashes with `'Parameter' object has + no attribute 'tensor_data_names'` on torchao tensor subclasses. + + Cache the result on ``args`` so repeated calls don't re-log. """ - if args.no_checkpoint: - return False, False - if args.checkpoint: - return args.checkpoint, True - # Quantized models can only load through HF `from_pretrained()` — that - # path installs TorchAoHfQuantizer pre-process, which swaps in the - # quantized linear modules so the saved tensor subclasses land in - # slots that know how to hold them. The checkpoint-resume path uses - # `from_config()` + `load_state_dict()`, which has no quantizer hook - # and fails with `'Parameter' object has no attribute 'tensor_data_names'`. - cfg_path = os.path.join(args.model, "config.json") - if os.path.isfile(cfg_path): - try: - with open(cfg_path) as f: - if "quantization_config" in json.load(f): - logger.info( - "Detected quantization_config in model config; " - "loading via from_pretrained() instead of checkpoint resume." - ) - return False, False - except (OSError, ValueError): - pass - return True, True + cached = getattr(args, "_resolved_checkpoint", None) + if cached is not None: + return cached + + if _model_is_quantized(args.model): + if args.checkpoint: + logger.warning( + "Detected quantization_config in model config; ignoring " + "--checkpoint %r and loading via from_pretrained().", + args.checkpoint, + ) + else: + logger.info( + "Detected quantization_config in model config; " + "loading via from_pretrained() (overrides default checkpoint resume)." + ) + result = (False, False) + elif args.no_checkpoint: + result = (False, False) + elif args.checkpoint: + result = (args.checkpoint, True) + else: + result = (True, True) + + args._resolved_checkpoint = result + return result def build_trainer(args, model_init, eval_dataset, data_collator, tokenizer, device): diff --git a/tools/finalize_model/finalize_model.py b/tools/finalize_model/finalize_model.py index e6a89f471..e889885ce 100644 --- a/tools/finalize_model/finalize_model.py +++ b/tools/finalize_model/finalize_model.py @@ -432,6 +432,9 @@ def main(argv=None): logger.info( f"Would run quantize step with recipe '{args.quantize_recipe}'" ) + logger.info( + "Would write 'quantization_config' block to config.json" + ) return 0 # ---- 6. Quantize (optional) ----------------------------------------