From c923e03396010a6c6a32eb163729cdaf8c097e7d Mon Sep 17 00:00:00 2001 From: Jason dinAlt Date: Fri, 15 May 2026 06:03:52 +0000 Subject: [PATCH 1/2] native loader: detect and handle torchao quantization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit forgather.ml.sharded_checkpoint.load_checkpoint() now detects torchao-quantized checkpoints (via config.json's quantization_config block, falling back to a state_dict scan of the first shard) and installs the matching quantized linear modules on the constructed model before load_state_dict runs. Every tool that loads through the native checkpoint loader inherits the support: - forgather eval test ... -M - forgather inf server -m --from-checkpoint - Trainer resume_from_checkpoint No CLI flag, no marker file: the recipe is recovered from config.json's TorchAoConfig.quant_type (written by finalize) or from the saved tensor subclass attributes (covers checkpoints that lack the config hint). New module src/forgather/ml/quantization_detect.py centralises: - detect_torchao_quantization() — two-tier detection - install_torchao_quantization() — prepare→convert + safe-globals registration so torch.load(weights_only=True) accepts the subclasses Mechanics on the loader side: - when quantization is detected, force assign=True (Tensor.copy_ does not handle quantized-to-quantized copies cleanly) - override the load device to the module's existing device so the assigned tensors don't migrate the model to the caller-passed staging device (e.g. cpu when running on cuda:0) Reverts PR #46's eval-side shim that detected the same config and forced HF's from_pretrained() path. The shim violated the project's "-c CHECKPOINT_PATH means native loader" contract; fix is on the native loader instead. Test coverage: - tests/test_quantization_detect.py (11 unit tests, all recipes) - tests/integration/specs/tiny_llama_inference_quantized.yaml + FinalizeSpec hook in tests/integration/test_inference.py Verified end-to-end on the chinchilla baseline: - bf16 eval_loss=1.367 (unchanged regression) - quantized eval_loss=1.392 via native loader (matches PR #46's HF loader result within noise: 1.392) - inference server -c serves coherent completions at 61.9 tok/s vs bf16 379.9 tok/s on a 4.43M model (slowdown documented in inference README — quantization is memory-bound; throughput wins appear at larger scale) Closes #41, #42. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/guides/evaluating-models.md | 14 +- docs/guides/finalize-model.md | 12 +- docs/trainers/qat-training.md | 33 ++- scripts/eval_script.py | 46 +--- src/forgather/ml/quantization_detect.py | 233 ++++++++++++++++++ src/forgather/ml/sharded_checkpoint.py | 116 +++++++++ tests/integration/spec.py | 19 +- .../specs/tiny_llama_inference_quantized.yaml | 35 +++ tests/integration/test_inference.py | 24 ++ tests/test_quantization_detect.py | 155 ++++++++++++ tools/finalize_model/finalize_model.py | 49 ++-- tools/inference_server/README.md | 42 ++++ 12 files changed, 679 insertions(+), 99 deletions(-) create mode 100644 src/forgather/ml/quantization_detect.py create mode 100644 tests/integration/specs/tiny_llama_inference_quantized.yaml create mode 100644 tests/test_quantization_detect.py diff --git a/docs/guides/evaluating-models.md b/docs/guides/evaluating-models.md index 7a95850ee..f4c784724 100644 --- a/docs/guides/evaluating-models.md +++ b/docs/guides/evaluating-models.md @@ -120,14 +120,12 @@ same mechanism `forgather train` uses). Pass an explicit path with `AutoModelForCausalLM.from_pretrained` on the model directory. **Quantized models** (artifacts produced by `forgather finalize --quantize`) -are autodetected: if `config.json` has a `quantization_config` block, eval -forces the `from_pretrained()` path and ignores both `--checkpoint` and -`--no-checkpoint`. An explicit `--checkpoint PATH` logs a warning before -being overridden. The checkpoint-resume path uses `from_config()` + -`load_state_dict()` which has no quantizer hook and fails on quantized -tensor subclasses; `from_pretrained()` runs HF's `TorchAoHfQuantizer` -pre-process so the right linear modules are in place before the weights -load. See [QAT Training § Evaluating Quantized +load transparently through the standard checkpoint-resume path. Forgather's +native loader (`forgather.ml.sharded_checkpoint.load_checkpoint`) detects +torchao quantization from `config.json`'s `quantization_config` block (or +falls back to scanning the saved state_dict) and installs the matching +quantized linear modules before `load_state_dict` runs. No extra flag, +no caller-side recipe argument. See [QAT Training § Evaluating Quantized Models](../trainers/qat-training.md#evaluating-quantized-models). The tokenizer is always loaded directly from `--model` via diff --git a/docs/guides/finalize-model.md b/docs/guides/finalize-model.md index 47dd73d46..403e29bf0 100644 --- a/docs/guides/finalize-model.md +++ b/docs/guides/finalize-model.md @@ -123,11 +123,13 @@ which the safetensors writer requires. If `--safetensors` is passed alongside `--quantize`, it is silently disabled with a warning. Finalize also writes a `quantization_config` block into `config.json` -with the recipe. This makes HF `AutoModelForCausalLM.from_pretrained()` -auto-detect the quantization on reload and run the -`TorchAoHfQuantizer` pre-process path — so `forgather eval` and the -inference server load the artifact correctly without any caller-side -flag. See [Evaluating Quantized +with the recipe. Forgather's native checkpoint loader consumes this +hint (with a state_dict scan as fallback) and installs the matching +quantized linear modules before weights load — so `forgather eval`, +the inference server, and any other tool using the native loader +handle the artifact transparently with no caller-side flag. The same +block also enables HF `AutoModelForCausalLM.from_pretrained()` +auto-detection for non-Forgather consumers. See [Evaluating Quantized Models](../trainers/qat-training.md#evaluating-quantized-models). ### Misc diff --git a/docs/trainers/qat-training.md b/docs/trainers/qat-training.md index 61befc1f3..049688b4c 100644 --- a/docs/trainers/qat-training.md +++ b/docs/trainers/qat-training.md @@ -278,19 +278,26 @@ forgather -p examples/tutorials/tiny_llama eval test tinystories \ ``` How it works: at finalize time, `--quantize` writes a -`quantization_config` block into `config.json` with the recipe. At eval -time, `scripts/eval_script.py:resolve_checkpoint()` reads that field -and forces the `from_pretrained()` load path, **ignoring** any -`--checkpoint` / `--no-checkpoint` flag (an explicit `--checkpoint -PATH` logs a warning before being overridden). The normal -checkpoint-resume path uses `from_config()` + `load_state_dict()`, -which does not run any quantizer hook and fails when the saved tensors -are torchao quantized subclasses. Same mechanism applies to the -inference server — no caller-side changes needed. - -The check is purely additive: bf16 models keep using the -checkpoint-resume path. Pass `--no-checkpoint` to opt into -`from_pretrained()` manually for non-quantized models. +`quantization_config` block into `config.json` with the recipe. +Forgather's native checkpoint loader +(`forgather.ml.sharded_checkpoint.load_checkpoint`) reads that block, +or — as a fallback when the block is absent — scans the first shard for +torchao tensor subclasses. When quantization is detected, the loader +installs the matching quantized linear modules (`quantize_(model, +QATConfig(base_config, step="convert"))`) on the constructed model +*before* `load_state_dict` runs, so the saved tensor subclasses land in +slots that know how to hold them. + +This is built into the native loader, so it applies uniformly to every +tool that loads via Forgather checkpoints (`-c`): + +- `forgather eval test ... -M ` (and its `--checkpoint PATH` variant) +- `forgather inf server -m --from-checkpoint` +- Trainer resume (`resume_from_checkpoint`) + +No caller-side recipe flag, no marker file. The check is purely +additive — bf16 models load through the exact same path with no +quantization step. ### Three-Way Comparison: bf16 / PTQ / QAT diff --git a/scripts/eval_script.py b/scripts/eval_script.py index 61d8ff0b1..1276c2348 100644 --- a/scripts/eval_script.py +++ b/scripts/eval_script.py @@ -129,23 +129,6 @@ def init_model(): return init_model -def _model_is_quantized(model_path: str) -> bool: - """True iff ``/config.json`` carries a ``quantization_config`` block. - - Written by ``forgather finalize --quantize``; HF's - ``from_pretrained()`` consumes it to install the TorchAoHfQuantizer - pre-process path. Unreadable or malformed config.json -> False. - """ - cfg_path = os.path.join(model_path, "config.json") - if not os.path.isfile(cfg_path): - return False - try: - with open(cfg_path) as f: - return "quantization_config" in json.load(f) - except (OSError, ValueError): - return False - - def resolve_checkpoint(args): """Return (checkpoint_arg, use_checkpoint) for the trainer. @@ -154,33 +137,20 @@ def resolve_checkpoint(args): - str path: explicit - False: do not resume - Quantized models force ``(False, False)`` regardless of flags: only HF - `from_pretrained()` runs the TorchAoHfQuantizer pre-process that - installs the right quantized linear modules before weight load. The - checkpoint-resume path uses `from_config()` + `load_state_dict()`, - which has no quantizer hook and crashes with `'Parameter' object has - no attribute 'tensor_data_names'` on torchao tensor subclasses. - Cache the result on ``args`` so repeated calls don't re-log. + + Quantization is handled transparently downstream: when the native + loader at ``forgather.ml.sharded_checkpoint.load_checkpoint`` detects + torchao quantization (via ``config.json`` or a state_dict scan), it + installs the matching quantized linear modules before + ``load_state_dict``. Eval doesn't need to special-case quantized + models here. """ cached = getattr(args, "_resolved_checkpoint", None) if cached is not None: return cached - if _model_is_quantized(args.model): - if args.checkpoint: - logger.warning( - "Detected quantization_config in model config; ignoring " - "--checkpoint %r and loading via from_pretrained().", - args.checkpoint, - ) - else: - logger.info( - "Detected quantization_config in model config; " - "loading via from_pretrained() (overrides default checkpoint resume)." - ) - result = (False, False) - elif args.no_checkpoint: + if args.no_checkpoint: result = (False, False) elif args.checkpoint: result = (args.checkpoint, True) diff --git a/src/forgather/ml/quantization_detect.py b/src/forgather/ml/quantization_detect.py new file mode 100644 index 000000000..cf397bc1f --- /dev/null +++ b/src/forgather/ml/quantization_detect.py @@ -0,0 +1,233 @@ +"""Detect and install torchao quantization on a model at load time. + +The native checkpoint loader (:func:`forgather.ml.sharded_checkpoint.load_checkpoint`) +uses these helpers to make ``-c CHECKPOINT_PATH`` work transparently on +torchao-quantized artifacts produced by ``forgather finalize --quantize``. +Detection is two-tier: + +1. ``/config.json`` carries a ``quantization_config`` block + (written by finalize). HF's ``TorchAoConfig`` deserializes it directly + into an ``AOBaseConfig`` — the fast path; all Forgather-finalized + artifacts hit it. +2. The saved state_dict contains torchao tensor subclasses. We reconstruct + the base config from the subclass type and its attributes. v1 supports + :data:`IntxUnpackedToInt8Tensor` (``int8-dynamic-act-int4-weight``) and + :data:`Int4Tensor` (``int4-weight-only``); the ``float8`` recipe is + only available via the config.json hint (saving its tensor subclass + for state_dict-only reload would need SM ≥8.9 hardware at finalize + time, which we can't test on common GPUs). + +Once a base config is in hand, :func:`install_torchao_quantization` runs +``quantize_(module, QATConfig(base_config, step="prepare"))`` followed by +``step="convert"``. The convert step swaps each ``nn.Linear`` for the +corresponding torchao quantized linear class, so that a subsequent +``load_state_dict`` lands the saved quantized tensor subclasses in slots +that know how to hold them. +""" + +from __future__ import annotations + +import json +import logging +import os +from typing import Any + +logger = logging.getLogger(__name__) + + +def install_torchao_quantization(module, base_config) -> None: + """Install torchao quantized linear modules in-place via prepare→convert. + + Mirrors what ``tools/finalize_model/finalize_model.py:_apply_quantize`` + does for QAT/PTQ at finalize time, but called at *load* time on an + already-constructed model right before ``load_state_dict``. The scales + computed during convert get overwritten by the loaded tensors; only + the module type swap matters. + + Also registers torchao tensor subclasses with PyTorch's + ``add_safe_globals`` so ``torch.load(weights_only=True)`` accepts them + in the subsequent state_dict load. Without this, the load fails with + ``UnpicklingError``: torchao classes aren't on PyTorch's default + allowlist, and HF's ``TorchAoHfQuantizer`` does the same registration + on its own load path. + """ + from torchao.quantization import quantize_ + from torchao.quantization.qat import QATConfig + + _register_torchao_safe_globals() + quantize_(module, QATConfig(base_config, step="prepare")) + quantize_(module, QATConfig(base_config, step="convert")) + + +_TORCHAO_SAFE_GLOBALS_REGISTERED = False + + +def _register_torchao_safe_globals() -> None: + """Make torchao tensor subclasses safe to ``torch.load(weights_only=True)``. + + Idempotent. We scan ``torchao.quantization``, ``torchao.dtypes``, and a + handful of enum modules for tensor-subclass / enum types and pass them + to ``torch.serialization.add_safe_globals``. PyTorch's allowlist applies + process-wide; subsequent ``torch.load`` calls with ``weights_only=True`` + will accept these classes. + """ + global _TORCHAO_SAFE_GLOBALS_REGISTERED + if _TORCHAO_SAFE_GLOBALS_REGISTERED: + return + + import torch + + safe = [] + + def _collect_from(module_name: str) -> None: + try: + mod = __import__(module_name, fromlist=["*"]) + except ImportError: + return + for name in dir(mod): + if name.startswith("_"): + continue + obj = getattr(mod, name, None) + if not isinstance(obj, type): + continue + # tensor subclasses + enums used by their __setstate__ + mod_path = getattr(obj, "__module__", "") + if mod_path.startswith("torchao"): + safe.append(obj) + + _collect_from("torchao.quantization") + _collect_from("torchao.dtypes") + _collect_from("torchao.quantization.granularity") + _collect_from("torchao.quantization.quant_primitives") + + if safe: + # Deduplicate; add_safe_globals warns on duplicates. + torch.serialization.add_safe_globals(list(set(safe))) + _TORCHAO_SAFE_GLOBALS_REGISTERED = True + + +def _base_config_from_config_json(model_dir: str): + """Parse ``/config.json`` for a torchao ``quantization_config`` block. + + Returns the underlying ``AOBaseConfig`` (``TorchAoConfig.quant_type``) or + None when the block is absent / malformed / not torchao. + + ``model_dir`` may point at the model root (which has ``config.json``) *or* + at a ``checkpoints/checkpoint-N`` subdir (which doesn't). For the latter + we walk up two parents to find the root. + """ + candidates = [model_dir] + parent = os.path.dirname(model_dir) + if os.path.basename(parent) == "checkpoints": + candidates.append(os.path.dirname(parent)) + + for d in candidates: + cfg_path = os.path.join(d, "config.json") + if not os.path.isfile(cfg_path): + continue + try: + with open(cfg_path) as f: + cfg = json.load(f) + except (OSError, ValueError): + continue + block = cfg.get("quantization_config") + if not isinstance(block, dict): + continue + if block.get("quant_method") != "torchao": + continue + try: + from transformers import TorchAoConfig + + return TorchAoConfig.from_dict(block).quant_type + except Exception as e: + logger.warning( + "quantization_config present in %s but TorchAoConfig.from_dict failed: %s", + cfg_path, + e, + ) + continue + return None + + +def _base_config_from_tensor(t): + """Reconstruct an ``AOBaseConfig`` from a saved torchao tensor subclass. + + Recognises the v1 recipes in :data:`QAT_RECIPES`. Returns None for + unknown subclasses so callers can produce their own error message. + """ + import torch + + cls_name = type(t).__name__ + + if cls_name == "IntxUnpackedToInt8Tensor": + # int8-dynamic-act-int4-weight: weight is int4 per-group; group_size + # comes from block_size = (1, group_size). + from torchao.quantization import Int8DynamicActivationIntxWeightConfig + from torchao.quantization.granularity import PerGroup + + block = getattr(t, "block_size", None) + group_size = block[1] if block is not None and len(block) == 2 else 32 + return Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=PerGroup(group_size=int(group_size)), + ) + + if cls_name == "Int4Tensor": + # int4-weight-only: group_size is dim-1 of block_size. + from torchao.quantization import Int4WeightOnlyConfig + + block = getattr(t, "block_size", None) + group_size = block[1] if block is not None and len(block) == 2 else 128 + return Int4WeightOnlyConfig(group_size=int(group_size)) + + return None + + +def _detect_quantized_tensor(state_dict): + """Return the first torchao tensor subclass instance in ``state_dict``, or None.""" + for v in state_dict.values(): + # nn.Parameter wraps regular tensors; quantized weights show up as + # bare tensor subclasses whose module path starts with "torchao". + if type(v).__module__.startswith("torchao"): + return v + return None + + +def detect_torchao_quantization( + *, + model_dir: str | None = None, + state_dict: dict[str, Any] | None = None, +): + """Return a torchao ``AOBaseConfig`` if the model is quantized, else None. + + Both signals are optional; pass whichever you have. ``model_dir`` is + the fast path (no shard load needed). ``state_dict`` enables detection + for artifacts whose config.json lacks the ``quantization_config`` block. + + Raises: + ValueError: If ``state_dict`` contains a torchao tensor subclass + the v1 reverse-lookup doesn't recognise. The error names the + class and points the user at ``forgather finalize --quantize`` + to restore the metadata. + """ + if model_dir is not None: + cfg = _base_config_from_config_json(model_dir) + if cfg is not None: + return cfg + + if state_dict is not None: + sample = _detect_quantized_tensor(state_dict) + if sample is not None: + cfg = _base_config_from_tensor(sample) + if cfg is not None: + return cfg + raise ValueError( + f"State dict contains a torchao quantized tensor subclass " + f"{type(sample).__name__!r} that this version of Forgather " + f"doesn't know how to reverse-engineer into a base config. " + f"Re-finalize the model with `forgather finalize --quantize " + f"` to write a `quantization_config` block into " + f"`config.json`, then reload." + ) + + return None diff --git a/src/forgather/ml/sharded_checkpoint.py b/src/forgather/ml/sharded_checkpoint.py index 167ffad06..b0454841b 100644 --- a/src/forgather/ml/sharded_checkpoint.py +++ b/src/forgather/ml/sharded_checkpoint.py @@ -11,6 +11,8 @@ from pprint import pp from typing import Dict, List, Optional, Set, TypeAlias, Union, overload +import pickle + import torch from safetensors.torch import load_file as safetensors_load from safetensors.torch import save_file as safetensors_save @@ -612,6 +614,20 @@ def load_checkpoint( if checkpoint_meta.is_index: shard_index = load_shard_index(model_dir, checkpoint_meta.file_name) + if module is not None and _maybe_install_torchao_quantization( + model_dir, module, shard_index=shard_index, + safetensors=checkpoint_meta.safetensors, device=device, + ): + # Quantized weights are tensor subclasses; ``Tensor.copy_`` + # between two quantized subclasses fails with a metadata + # mismatch. ``assign=True`` rebinds the Parameter directly, + # bypassing copy_. The flip-side: assigned tensors keep + # their map_location device, so we must load to wherever + # the (already-constructed) module lives — not to the + # caller-passed ``device`` (which may be a staging area + # like CPU). + assign = True + device = _module_device(module, fallback=device) return load_sharded_checkpoint( model_dir, shard_index, @@ -638,12 +654,112 @@ def load_checkpoint( return {k: v for k, v in state_dict.items() if k in keys} return state_dict + if _maybe_install_torchao_quantization(model_dir, module, state_dict=state_dict): + assign = True + # Move the loaded tensors onto the module's existing device before + # assigning them in (see the sharded branch above for the reason). + target = _module_device(module, fallback=device) + if str(target) != str(device): + state_dict = {k: v.to(target) for k, v in state_dict.items()} + # TODO: Properly handle strict, in this case? # We wish to ensure that all model weights were loaded, but ignore any other weights, like we do in load_sharded_checkpoint() module.load_state_dict(state_dict, strict=strict, assign=assign) return None +def _module_device(module: Module, *, fallback) -> "torch.device | str": + """Return the device of the module's first parameter, or ``fallback``.""" + for p in module.parameters(): + return p.device + for b in module.buffers(): + return b.device + return fallback + + +def _maybe_install_torchao_quantization( + model_dir: str, + module: Module, + *, + shard_index: "ShardIndex | None" = None, + state_dict: "Dict[str, Tensor] | None" = None, + safetensors: bool = False, + device: str = "cpu", +) -> bool: + """Detect torchao quantization on a checkpoint and install matching linear modules. + + Detection prefers ``/config.json``'s ``quantization_config`` + block (cheap, no shard load). Falls back to scanning the first shard + (or the supplied single-file state_dict) for torchao tensor subclasses. + Without this step, ``load_state_dict`` would try to copy quantized + tensor subclasses into plain ``nn.Linear.weight`` slots and fail with + ``'Parameter' object has no attribute 'tensor_data_names'``. + + Returns True if quantization was detected and installed; the caller + should force ``assign=True`` on the subsequent ``load_state_dict`` + call to bypass ``Tensor.copy_`` (which doesn't handle + quantized-to-quantized copies cleanly). + """ + from forgather.ml.quantization_detect import ( + detect_torchao_quantization, + install_torchao_quantization, + ) + + base_config = detect_torchao_quantization(model_dir=model_dir) + if base_config is None: + if state_dict is None and shard_index is not None: + state_dict = _peek_first_shard( + model_dir, shard_index, safetensors=safetensors, device=device, + ) + if state_dict is not None: + base_config = detect_torchao_quantization(state_dict=state_dict) + if base_config is None: + return False + + logger.info( + "load_checkpoint: detected torchao quantization (%s); " + "installing quantized linear modules before load_state_dict", + type(base_config).__name__, + ) + install_torchao_quantization(module, base_config) + return True + + +def _peek_first_shard( + model_dir: str, + shard_index: "ShardIndex", + *, + safetensors: bool = False, + device: str = "cpu", +) -> "Dict[str, Tensor]": + """Load just one shard's state_dict for quantization detection. + + Try ``weights_only=True`` first; fall back to ``weights_only=False`` on + ``UnpicklingError`` (raised when the shard contains torchao tensor + subclasses that aren't yet on PyTorch's safe-globals allowlist). + Once :func:`install_torchao_quantization` runs in the caller, the + safe-globals get registered process-wide and subsequent loads (the + main shard loop) succeed with the default ``weights_only=True``. + """ + weight_map = shard_index["weight_map"] + # Unique shard files; pick the first deterministically (sorted) so + # tests are reproducible. + shard_files = sorted(set(weight_map.values())) + if not shard_files: + return {} + shard_file_path = os.path.join(model_dir, shard_files[0]) + if safetensors: + return safetensors_load(shard_file_path, device=device) + try: + return torch.load( + shard_file_path, map_location=device, weights_only=True, mmap=True + ) + except pickle.UnpicklingError: + return torch.load( + shard_file_path, map_location=device, weights_only=False, mmap=True + ) + + @overload def load_sharded_checkpoint( model_dir: str, diff --git a/tests/integration/spec.py b/tests/integration/spec.py index f968e592a..6deb2c474 100644 --- a/tests/integration/spec.py +++ b/tests/integration/spec.py @@ -37,6 +37,17 @@ class InferenceSpec: server_timeout: int = 60 +@dataclass +class FinalizeSpec: + """Run ``forgather finalize`` between train and inference. + + Lets a test exercise the inference / eval load paths against finalized + artifacts (e.g. ``--quantize`` for torchao quantized models). + """ + + quantize: str | None = None + + @dataclass class IntegrationSpec: """Complete specification for an integration test.""" @@ -50,6 +61,7 @@ class IntegrationSpec: expected_files: list[str] = field(default_factory=lambda: ["trainer_logs.json"]) min_steps_logged: int = 1 inference: InferenceSpec | None = None + finalize: FinalizeSpec | None = None gpu_requirement: int = 1 timeout: int = 300 markers: list[str] = field(default_factory=lambda: ["integration"]) @@ -69,7 +81,12 @@ def _load_spec(path: Path) -> IntegrationSpec: inference_raw = raw.pop("inference", None) inference = InferenceSpec(**inference_raw) if inference_raw else None - return IntegrationSpec(loss=loss, stderr=stderr, inference=inference, **raw) + finalize_raw = raw.pop("finalize", None) + finalize = FinalizeSpec(**finalize_raw) if finalize_raw else None + + return IntegrationSpec( + loss=loss, stderr=stderr, inference=inference, finalize=finalize, **raw + ) def load_all_specs(specs_dir: Path) -> list[IntegrationSpec]: diff --git a/tests/integration/specs/tiny_llama_inference_quantized.yaml b/tests/integration/specs/tiny_llama_inference_quantized.yaml new file mode 100644 index 000000000..bb1aaf2c3 --- /dev/null +++ b/tests/integration/specs/tiny_llama_inference_quantized.yaml @@ -0,0 +1,35 @@ +test_id: tiny_llama_inference_quantized +project_dir: examples/tutorials/tiny_llama +config: train_tiny_llama.yaml +dynamic_args: + max_steps: 500 + save_strategy: "steps" +loss: + final_max: 5.0 + no_nan: true +stderr: + forbidden_patterns: + - "RuntimeError" + - "CUDA error" + - "Traceback" + - "tensor_data_names" + warn_patterns: [] +expected_files: + - trainer_logs.json +min_steps_logged: 5 +finalize: + quantize: int8-dynamic-act-int4-weight +inference: + prompt: "Once upon a time" + max_tokens: 50 + temperature: 0.7 + # int4-weight quantized output is noisier than bf16 baseline; + # relax the perplexity ceiling slightly while still gating against + # gibberish (the bf16 sibling test uses 500.0). + perplexity_max: 750.0 + server_timeout: 90 +timeout: 1200 +gpu_requirement: 1 +markers: + - integration + - slow diff --git a/tests/integration/test_inference.py b/tests/integration/test_inference.py index 3b81543ba..8a4718bac 100644 --- a/tests/integration/test_inference.py +++ b/tests/integration/test_inference.py @@ -85,6 +85,30 @@ def test_inference_with_perplexity(spec, output_dir): # 2. Find model directory model_dir = _find_model_dir(output_dir) + # 2b. Optional finalize step (e.g. --quantize for torchao quantized + # artifacts). Produces a sibling directory under output_dir; the + # server then loads from there instead of the raw training output. + if spec.finalize and spec.finalize.quantize: + finalize_dest = output_dir / "finalized" + finalize_proc = subprocess.run( + [ + "forgather", + "finalize", + "--quantize", + spec.finalize.quantize, + str(model_dir), + str(finalize_dest), + ], + capture_output=True, + text=True, + ) + assert finalize_proc.returncode == 0, ( + f"finalize --quantize failed: rc={finalize_proc.returncode}\n" + f"stdout: {finalize_proc.stdout[-2000:]}\n" + f"stderr: {finalize_proc.stderr[-2000:]}" + ) + model_dir = finalize_dest + # 3. Start inference server (load from checkpoint since training # saves checkpoints, not standalone model weights) port = _find_free_port() diff --git a/tests/test_quantization_detect.py b/tests/test_quantization_detect.py new file mode 100644 index 000000000..9dc465256 --- /dev/null +++ b/tests/test_quantization_detect.py @@ -0,0 +1,155 @@ +"""Unit tests for forgather.ml.quantization_detect. + +Covers both detection paths (config.json hint, state_dict scan) and the +install helper. Tests use real torchao on tiny synthetic models so we +don't have to ship pickle fixtures; per-recipe cost is well under a +second on CPU. +""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest +import torch + +from forgather.ml.qat_recipes import recipe_to_base_config +from forgather.ml.quantization_detect import ( + detect_torchao_quantization, + install_torchao_quantization, +) + + +# Build a model -> install one recipe -> harvest its state_dict. +# Cached at module scope to avoid running quantize_ once per test. +_FIXTURES: dict[str, dict] = {} + + +def _state_dict_for(recipe: str) -> dict: + if recipe in _FIXTURES: + return _FIXTURES[recipe] + model = torch.nn.Sequential( + torch.nn.Linear(128, 256, bias=False), + torch.nn.Linear(256, 128, bias=False), + ) + install_torchao_quantization(model, recipe_to_base_config(recipe)) + _FIXTURES[recipe] = model.state_dict() + return _FIXTURES[recipe] + + +@pytest.mark.parametrize( + "recipe", + [ + "int8-dynamic-act-int4-weight", + "int4-weight-only", + # float8-* skipped: requires SM >= 8.9 to actually convert. + ], +) +def test_detect_from_state_dict(recipe: str): + """State_dict scan recovers a base config matching the recipe.""" + sd = _state_dict_for(recipe) + base = detect_torchao_quantization(state_dict=sd) + assert base is not None, f"failed to detect {recipe}" + expected_type = type(recipe_to_base_config(recipe)) + assert type(base) is expected_type, ( + f"expected {expected_type.__name__}, got {type(base).__name__}" + ) + + +@pytest.mark.parametrize( + "recipe", + [ + "int8-dynamic-act-int4-weight", + "int4-weight-only", + "float8-dynamic-act-float8-weight", + ], +) +def test_detect_from_config_json(tmp_path: Path, recipe: str): + """config.json with a torchao quantization_config block resolves to the right base config. + + Includes float8: TorchAoConfig.from_dict doesn't need hardware to + deserialize — only the quantize_ install step requires SM >= 8.9. + """ + from transformers import TorchAoConfig + + block = TorchAoConfig(quant_type=recipe_to_base_config(recipe)).to_dict() + (tmp_path / "config.json").write_text( + json.dumps({"architectures": ["dummy"], "quantization_config": block}) + ) + base = detect_torchao_quantization(model_dir=str(tmp_path)) + assert base is not None + assert type(base) is type(recipe_to_base_config(recipe)) + + +def test_detect_returns_none_on_bf16(tmp_path: Path): + """Plain config.json without a quantization_config block returns None.""" + (tmp_path / "config.json").write_text(json.dumps({"architectures": ["dummy"]})) + assert detect_torchao_quantization(model_dir=str(tmp_path)) is None + + +def test_detect_returns_none_on_missing_config(tmp_path: Path): + """No config.json anywhere → None (don't crash).""" + assert detect_torchao_quantization(model_dir=str(tmp_path)) is None + + +def test_detect_walks_up_from_checkpoint_subdir(tmp_path: Path): + """When passed /checkpoints/checkpoint-N, look for config.json at .""" + from transformers import TorchAoConfig + + recipe = "int4-weight-only" + block = TorchAoConfig(quant_type=recipe_to_base_config(recipe)).to_dict() + (tmp_path / "config.json").write_text( + json.dumps({"architectures": ["dummy"], "quantization_config": block}) + ) + ckpt = tmp_path / "checkpoints" / "checkpoint-100" + ckpt.mkdir(parents=True) + base = detect_torchao_quantization(model_dir=str(ckpt)) + assert base is not None + assert type(base) is type(recipe_to_base_config(recipe)) + + +def test_detect_from_plain_state_dict_returns_none(): + """state_dict of plain bf16 tensors → no detection.""" + sd = {"a.weight": torch.randn(4, 4), "b.weight": torch.randn(8, 8)} + assert detect_torchao_quantization(state_dict=sd) is None + + +def test_detect_unknown_subclass_raises(monkeypatch): + """A torchao-namespaced tensor subclass we don't know about → ValueError.""" + import torch + + # Synthesize a subclass that lives in the torchao namespace but isn't on + # our reverse-lookup table. + class _UnknownTorchaoTensor(torch.Tensor): + pass + + _UnknownTorchaoTensor.__module__ = "torchao.fake.module" + fake = _UnknownTorchaoTensor() + with pytest.raises(ValueError, match="Re-finalize the model"): + detect_torchao_quantization(state_dict={"x.weight": fake}) + + +def test_install_swaps_linear_for_quantized(): + """After install, the model's Linear modules carry torchao quantized weights.""" + model = torch.nn.Sequential( + torch.nn.Linear(64, 128, bias=False), + torch.nn.Linear(128, 64, bias=False), + ) + install_torchao_quantization( + model, recipe_to_base_config("int8-dynamic-act-int4-weight") + ) + # Either the module class or the weight class becomes a torchao subclass — + # depending on the recipe, torchao may swap the module or just wrap the + # weight tensor. Verify *something* downstream is torchao-flavored. + has_torchao_artifact = False + for m in model.modules(): + if type(m).__module__.startswith("torchao"): + has_torchao_artifact = True + break + if isinstance(m, torch.nn.Linear) and type(m.weight).__module__.startswith( + "torchao" + ): + has_torchao_artifact = True + break + assert has_torchao_artifact, "install did not produce a torchao tensor or module" diff --git a/tools/finalize_model/finalize_model.py b/tools/finalize_model/finalize_model.py index e889885ce..be47adac9 100644 --- a/tools/finalize_model/finalize_model.py +++ b/tools/finalize_model/finalize_model.py @@ -217,11 +217,10 @@ def _apply_quantize(model, recipe: str) -> None: The loaded model's state_dict contains plain float weights (Forgather's sharded saver doesn't persist FakeQuantizedLinear inner state, even for - QAT-trained models), so we run ``step="prepare"`` to install fake - quantizers on top of the trained weights, then swap them for the real - low-bit quantized linear ops via ``step="convert"``. The - scales/zero-points the convert step computes are derived from the - loaded weight statistics. + QAT-trained models), so we run prepare to install fake quantizers on + top of the trained weights, then convert to swap them for real low-bit + quantized linear ops. Scales/zero-points come from the loaded weight + statistics. Works on any input: @@ -232,47 +231,29 @@ def _apply_quantize(model, recipe: str) -> None: for the AMP-baseline vs PTQ vs QAT comparison. """ import torch - from torchao.quantization import quantize_ - from torchao.quantization.qat import FakeQuantizedLinear, QATConfig from forgather.ml.qat_recipes import QAT_RECIPES, recipe_to_base_config + from forgather.ml.quantization_detect import install_torchao_quantization if recipe not in QAT_RECIPES: raise ValueError( f"--quantize must be one of {QAT_RECIPES}, got {recipe!r}" ) - base_config = recipe_to_base_config(recipe) - - # If the loaded model already has FakeQuantizedLinear modules (e.g. a - # future Forgather saver that preserves them), skip prepare and go - # straight to convert. - fq_count = sum(1 for m in model.modules() if isinstance(m, FakeQuantizedLinear)) - if fq_count == 0: - linear_count = sum( - 1 for m in model.modules() if isinstance(m, torch.nn.Linear) - ) - if linear_count == 0: - logger.warning( - "--quantize %r requested but model has no nn.Linear " - "modules to quantize; skipping quantize step.", - recipe, - ) - return - logger.info( - f"Quantize ({recipe}): installing fake quantizers on " - f"{linear_count} nn.Linear modules before convert" - ) - quantize_(model, QATConfig(base_config, step="prepare")) - fq_count = sum( - 1 for m in model.modules() if isinstance(m, FakeQuantizedLinear) + linear_count = sum(1 for m in model.modules() if isinstance(m, torch.nn.Linear)) + if linear_count == 0: + logger.warning( + "--quantize %r requested but model has no nn.Linear " + "modules to quantize; skipping quantize step.", + recipe, ) + return - quantize_(model, QATConfig(base_config, step="convert")) logger.info( - f"Quantize ({recipe}): converted {fq_count} FakeQuantizedLinear " - f"modules to real quantized linear ops" + f"Quantize ({recipe}): running torchao prepare→convert on " + f"{linear_count} nn.Linear modules" ) + install_torchao_quantization(model, recipe_to_base_config(recipe)) def main(argv=None): diff --git a/tools/inference_server/README.md b/tools/inference_server/README.md index a3ac28049..e48b339dd 100644 --- a/tools/inference_server/README.md +++ b/tools/inference_server/README.md @@ -321,6 +321,48 @@ forgather inf server --model /path/to/model --dtype float32 forgather inf server --model /path/to/model --stop-sequences "<|im_end|>" "" ``` +### Loading Quantized Models + +The server transparently loads torchao-quantized artifacts produced by +`forgather finalize --quantize`. No extra flag — the native checkpoint +loader detects quantization (from `config.json`'s `quantization_config` +block, or by scanning the saved state_dict) and installs the matching +quantized linear modules before weights load. See [QAT Training § +Evaluating Quantized Models](../../docs/trainers/qat-training.md#evaluating-quantized-models) +for the underlying mechanism. + +```bash +# Quantize a finalized model (one-time, after training) +forgather finalize --quantize int8-dynamic-act-int4-weight \ + output_models/my_run /serve/my_run_int4 + +# Serve it — same invocation shape as for bf16 models +forgather inf server -m /serve/my_run_int4 --from-checkpoint +``` + +**Throughput.** Quantized serving is currently slower than bf16 on +small/medium models at batch size 1. Measured on the 4.43M Tiny Llama, +single RTX 3090, greedy 64-token completion: + +| Variant | tok/s | +|---------|-------| +| bf16 baseline | 379.9 | +| `int8-dynamic-act-int4-weight` | 61.9 | + +The slowdown is the per-matmul dequant overhead being a large fraction +of the work in a small model. Quantization wins (memory footprint, +longer context, larger batch) appear at scale; benchmark your own +setup before deploying. + +**Dtype interaction.** `--dtype` controls the unquantized layers (norms, +embeddings, residuals) and the dequant target. The recipe controls +activation/weight precision for the quantized linears. The default +`bfloat16` is the right choice unless you have a specific reason to +override. + +**Device placement.** Quantized linears are CUDA-bound for the v1 +recipes. CPU serving of quantized models is not currently supported. + ### Stop Sequences The server supports flexible stop sequence configuration to control when generation should halt: From 8c6ca1442a3472f9166e94d10a08c91ddf959e46 Mon Sep 17 00:00:00 2001 From: Jason dinAlt Date: Fri, 15 May 2026 06:15:22 +0000 Subject: [PATCH 2/2] review fixes: eager safe-globals, doc clarifications, forward/reverse test - Drop the `weights_only=False` pickle fallback in `_peek_first_shard`. Register torchao safe-globals eagerly before the load so `weights_only=True` keeps PyTorch's safe-default posture on .bin shards. The previous fallback allowed arbitrary pickled code on any shard whose contents happened to not load under the strict allowlist. - Document the silent device override in `load_checkpoint`'s docstring and note that tied weights are restored post-load via the trainer's retie step (eval/inference paths don't grad-update tied weights, so the missing retie is behaviourally invisible). - Guard `_module_device` against meta-device modules: skip meta and defer to the caller's fallback, otherwise `assign=True` would silently produce a fully-meta model. - Document the "canonical Forgather recipe" contract on `_base_config_from_tensor` so future readers know the state_dict reverse-lookup coerces non-default packing formats / mapping types to defaults. Config.json path preserves them; state_dict path is best-effort for hand-crafted torchao configs. - Widen the unknown-subclass ValueError with a float8-specific hint and point at restoring the config.json block as the cheaper alternative to re-finalizing. - Add `test_forward_reverse_roundtrip` (parametrized over runnable recipes) asserting `_base_config_from_tensor` recovers the same fields `recipe_to_base_config` produced. Locks down drift between the forward and reverse maps. - Fix the stale "routes through HF `from_pretrained()`" comment in the Quick Start code block of `qat-training.md` left over from PR #46. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/trainers/qat-training.md | 3 +- src/forgather/ml/quantization_detect.py | 28 ++++++++++--- src/forgather/ml/sharded_checkpoint.py | 54 ++++++++++++++++--------- tests/test_quantization_detect.py | 44 +++++++++++++++++++- 4 files changed, 102 insertions(+), 27 deletions(-) diff --git a/docs/trainers/qat-training.md b/docs/trainers/qat-training.md index 049688b4c..0777e98e1 100644 --- a/docs/trainers/qat-training.md +++ b/docs/trainers/qat-training.md @@ -271,8 +271,7 @@ numbers on Tiny Llama. ```bash # Same invocation as for a bf16 model — eval autodetects the quantized -# artifact and routes through HF `from_pretrained()` so the -# TorchAoHfQuantizer pre-process path installs the right linear modules. +# artifact and installs the right linear modules via the native loader. forgather -p examples/tutorials/tiny_llama eval test tinystories \ -M /path/to/quantized_model ``` diff --git a/src/forgather/ml/quantization_detect.py b/src/forgather/ml/quantization_detect.py index cf397bc1f..e935dcf7a 100644 --- a/src/forgather/ml/quantization_detect.py +++ b/src/forgather/ml/quantization_detect.py @@ -154,6 +154,16 @@ def _base_config_from_tensor(t): Recognises the v1 recipes in :data:`QAT_RECIPES`. Returns None for unknown subclasses so callers can produce their own error message. + + Assumes the **canonical Forgather recipe** was used at finalize time + (defaults in :func:`recipe_to_base_config`). Non-default packing + formats, mapping types, or qparams algorithms saved into the tensor + will be silently coerced to defaults here; the config.json path + preserves them faithfully. For artifacts produced by Forgather's + ``--quantize`` flag, the two recover the same config. For + hand-crafted torchao configs without a ``quantization_config`` block + in ``config.json``, prefer to restore the block rather than rely on + this reverse-lookup. """ import torch @@ -221,13 +231,21 @@ class and points the user at ``forgather finalize --quantize`` cfg = _base_config_from_tensor(sample) if cfg is not None: return cfg + cls_name = type(sample).__name__ + float8_hint = "" + if "Float8" in cls_name: + float8_hint = ( + " (the v1 reverse-lookup does not cover float8 — float8 " + "checkpoints must carry a `quantization_config` block.)" + ) raise ValueError( f"State dict contains a torchao quantized tensor subclass " - f"{type(sample).__name__!r} that this version of Forgather " - f"doesn't know how to reverse-engineer into a base config. " - f"Re-finalize the model with `forgather finalize --quantize " - f"` to write a `quantization_config` block into " - f"`config.json`, then reload." + f"{cls_name!r} that this version of Forgather doesn't know " + f"how to reverse-engineer into a base config.{float8_hint} " + f"Restore the `quantization_config` block in " + f"`/config.json` (written by `forgather finalize " + f"--quantize `), or re-finalize the source model " + f"with `--quantize`." ) return None diff --git a/src/forgather/ml/sharded_checkpoint.py b/src/forgather/ml/sharded_checkpoint.py index b0454841b..2c767eddd 100644 --- a/src/forgather/ml/sharded_checkpoint.py +++ b/src/forgather/ml/sharded_checkpoint.py @@ -11,8 +11,6 @@ from pprint import pp from typing import Dict, List, Optional, Set, TypeAlias, Union, overload -import pickle - import torch from safetensors.torch import load_file as safetensors_load from safetensors.torch import save_file as safetensors_save @@ -606,6 +604,17 @@ def load_checkpoint( See `torch.nn.Module.load_state_dict `_ for the semantics of the ``strict`` and ``assign`` flags. + + When the checkpoint is torchao-quantized, this function installs the + matching quantized linear modules on ``module`` before + ``load_state_dict`` runs and forces ``assign=True`` (``Tensor.copy_`` + does not handle quantized-to-quantized copies). In that branch the + ``device`` argument is silently overridden to the module's existing + device, so the ``assign``-rebound tensors don't migrate the model + off the caller's compute device. Tied weights are restored + post-load by the trainer's ``retie_parameters()`` step; eval / + inference paths that don't re-tie still produce correct outputs + because quantized inference doesn't grad-update tied tensors. """ checkpoint_meta = get_checkpoint_metadata(model_dir) @@ -669,11 +678,19 @@ def load_checkpoint( def _module_device(module: Module, *, fallback) -> "torch.device | str": - """Return the device of the module's first parameter, or ``fallback``.""" + """Return the device of the module's first parameter, or ``fallback``. + + Skips ``meta`` devices: a module constructed under ``torch.device("meta")`` + has no storage, and inheriting that device would silently produce a + fully-meta model after ``load_state_dict(assign=True)``. In that + case, defer to the caller-passed ``fallback`` instead. + """ for p in module.parameters(): - return p.device + if p.device.type != "meta": + return p.device for b in module.buffers(): - return b.device + if b.device.type != "meta": + return b.device return fallback @@ -734,13 +751,16 @@ def _peek_first_shard( ) -> "Dict[str, Tensor]": """Load just one shard's state_dict for quantization detection. - Try ``weights_only=True`` first; fall back to ``weights_only=False`` on - ``UnpicklingError`` (raised when the shard contains torchao tensor - subclasses that aren't yet on PyTorch's safe-globals allowlist). - Once :func:`install_torchao_quantization` runs in the caller, the - safe-globals get registered process-wide and subsequent loads (the - main shard loop) succeed with the default ``weights_only=True``. + Eagerly register torchao tensor subclasses with PyTorch's + ``add_safe_globals`` before loading, so ``weights_only=True`` accepts + quantized shards. Registration is idempotent and cheap (a single + import + dedup against PyTorch's allowlist), and importantly keeps + the loader's safe-default posture: we never fall back to + ``weights_only=False`` on a `.bin` shard, which would permit + arbitrary pickled code execution. """ + from forgather.ml.quantization_detect import _register_torchao_safe_globals + weight_map = shard_index["weight_map"] # Unique shard files; pick the first deterministically (sorted) so # tests are reproducible. @@ -750,14 +770,10 @@ def _peek_first_shard( shard_file_path = os.path.join(model_dir, shard_files[0]) if safetensors: return safetensors_load(shard_file_path, device=device) - try: - return torch.load( - shard_file_path, map_location=device, weights_only=True, mmap=True - ) - except pickle.UnpicklingError: - return torch.load( - shard_file_path, map_location=device, weights_only=False, mmap=True - ) + _register_torchao_safe_globals() + return torch.load( + shard_file_path, map_location=device, weights_only=True, mmap=True + ) @overload diff --git a/tests/test_quantization_detect.py b/tests/test_quantization_detect.py index 9dc465256..28f990b85 100644 --- a/tests/test_quantization_detect.py +++ b/tests/test_quantization_detect.py @@ -126,10 +126,52 @@ class _UnknownTorchaoTensor(torch.Tensor): _UnknownTorchaoTensor.__module__ = "torchao.fake.module" fake = _UnknownTorchaoTensor() - with pytest.raises(ValueError, match="Re-finalize the model"): + with pytest.raises(ValueError, match="re-finalize the source model"): detect_torchao_quantization(state_dict={"x.weight": fake}) +@pytest.mark.parametrize( + "recipe", + [ + "int8-dynamic-act-int4-weight", + "int4-weight-only", + # float8 omitted: requires SM ≥8.9 to run quantize_, and its + # state_dict reverse-lookup is intentionally not covered in v1. + ], +) +def test_forward_reverse_roundtrip(recipe: str): + """Reverse-lookup recovers the same base config that produced the tensor. + + Locks down drift between :func:`recipe_to_base_config` (forward map, + used by trainer + finalize) and :func:`_base_config_from_tensor` + (reverse map, used by the native loader's state_dict scan). If + either side adds a parameter the other doesn't read, this test + flags it. + """ + from forgather.ml.quantization_detect import _base_config_from_tensor + + forward = recipe_to_base_config(recipe) + sd = _state_dict_for(recipe) + sample = next( + v for v in sd.values() if type(v).__module__.startswith("torchao") + ) + reverse = _base_config_from_tensor(sample) + + assert reverse is not None + assert type(reverse) is type(forward) + # Spot-check the user-facing fields the reverse lookup actually reads. + # We don't assert every attribute equals — torchao adds derived / + # internal state — but the constructor-visible ones must match. + if recipe == "int8-dynamic-act-int4-weight": + assert reverse.weight_dtype == forward.weight_dtype + assert ( + reverse.weight_granularity.group_size + == forward.weight_granularity.group_size + ) + elif recipe == "int4-weight-only": + assert reverse.group_size == forward.group_size + + def test_install_swaps_linear_for_quantized(): """After install, the model's Linear modules carry torchao quantized weights.""" model = torch.nn.Sequential(