Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,28 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig):
),
)

sequential_checkpoint_dir: str | None = ModeloptField(
default=None,
title="Directory for sequential calibration checkpoints.",
description=(
"If set (together with sequential_checkpoint_interval), sequential calibration "
"will save intermediate checkpoints to this directory. On resume, if a checkpoint "
"with seq_calib_progress metadata is found, calibration resumes from the last "
"completed layer. Uses a rolling checkpoint (overwrites on each save)."
),
)

sequential_checkpoint_interval: int | None = ModeloptField(
default=None,
gt=0,
title="Checkpoint interval for sequential calibration (in layers).",
description=(
"Save a checkpoint every N layers during sequential calibration. "
"Requires sequential_checkpoint_dir to also be set. "
"If None, no checkpoints are saved."
),
)


class MaxCalibConfig(QuantizeAlgorithmConfig):
"""The config for max calibration algorithm.
Expand Down
12 changes: 12 additions & 0 deletions modelopt/torch/quantization/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
TensorQuantizer,
)
from .utils import is_quantized, is_quantized_linear
from .utils.checkpoint import SEQ_CALIB_PROGRESS_ATTR

__all__ = [
"register",
Expand Down Expand Up @@ -108,6 +109,12 @@ def restore_quantizer_state(model: nn.Module, config: QuantizeConfig, metadata:
details regarding how MCore sharded checkpoint is restored,
see modelopt.torch.opt.plugins.mcore_dist_checkpointing.restore_sharded_modelopt_state.
"""
# Propagate sequential calibration progress to the model for resume.
# This is global metadata (not per-module), so it must run before the
# MCore early return — it applies to both HF and MCore checkpoint paths.
if "seq_calib_progress" in metadata:
setattr(model, SEQ_CALIB_PROGRESS_ATTR, metadata["seq_calib_progress"])

if "quantizer_state" not in metadata:
# MCore sharded checkpoint (`torch-dist`) has its quantizer_state stored as the
# extra_state of `QuantModule`. The quantizer_state is resumed with
Expand Down Expand Up @@ -170,6 +177,11 @@ def update_quantize_metadata(
"""Update the quantizer state in the metadata dict."""
metadata["quantizer_state"] = quantizer_state(model)

# Propagate sequential calibration progress if present (for checkpoint save)
progress = getattr(model, SEQ_CALIB_PROGRESS_ATTR, None)
if progress is not None:
metadata["seq_calib_progress"] = progress


def quantizer_state(model: nn.Module) -> dict[str, Any]:
"""Returns the quantizer state dict describing the quantizer states in the model."""
Expand Down
12 changes: 12 additions & 0 deletions modelopt/torch/quantization/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""This module contains the mode descriptor for the quantization mode."""

import warnings
from abc import abstractmethod
from collections.abc import Callable

Expand Down Expand Up @@ -228,6 +229,15 @@ def wrapped_calib_func(
kwargs["algorithm"] = method

moe_calib_experts_ratio = kwargs.pop("moe_calib_experts_ratio", None)
checkpoint_dir = kwargs.pop("sequential_checkpoint_dir", None)
checkpoint_interval = kwargs.pop("sequential_checkpoint_interval", None)

if not sequential and (checkpoint_dir is not None or checkpoint_interval is not None):
warnings.warn(
"sequential_checkpoint_dir/sequential_checkpoint_interval are set but "
"use_sequential is False. Checkpoint settings will be ignored."
)

if moe_calib_experts_ratio is not None:
assert (
isinstance(moe_calib_experts_ratio, (int, float)) and 0 < moe_calib_experts_ratio <= 1
Expand All @@ -248,6 +258,8 @@ def wrapped_calib_func(
model,
forward_loop=forward_loop,
calib_func=func,
checkpoint_dir=checkpoint_dir,
checkpoint_interval=checkpoint_interval,
**kwargs,
)
else:
Expand Down
51 changes: 47 additions & 4 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@
reduce_amax,
weight_attr_names,
)
from .utils.checkpoint import (
SEQ_CALIB_PROGRESS_ATTR,
detect_sequential_resume_layer,
save_sequential_checkpoint,
should_save_seq_calib_checkpoint,
)

__all__ = [
"awq",
Expand Down Expand Up @@ -1870,13 +1876,27 @@ def sequential_calibrate(
model: nn.Module,
forward_loop: ForwardLoop,
calib_func: Callable,
checkpoint_dir: str | None = None,
checkpoint_interval: int | None = None,
**calib_kwargs,
):
"""Sequential calibration - a sequential layer-by-layer calibration algorithm.

Runs the full model forward per layer but patches decoder layers with a
skip / run / capture strategy so that inter-layer logic in parent modules
(e.g. mask construction) executes naturally without model-specific hooks.

Args:
model: The model to calibrate.
forward_loop: Callable that runs calibration data through the model.
calib_func: Per-layer calibration function (e.g. ``max_calibrate``).
checkpoint_dir: If set (with *checkpoint_interval*), save a rolling
checkpoint every *checkpoint_interval* layers. On re-run with a
model restored from such a checkpoint, calibration resumes
automatically from the last completed layer.
checkpoint_interval: Save a checkpoint every N layers. Requires
*checkpoint_dir* to also be set.
**calib_kwargs: Extra arguments forwarded to *calib_func*.
"""
if forward_loop is None:
raise ValueError(
Expand All @@ -1891,14 +1911,23 @@ def sequential_calibrate(
"Sequential calibration requires a model with identifiable transformer layers."
)

print_rank_0(f"Sequential calibration: Found {len(transformer_layers)} transformer layers")
num_layers = len(transformer_layers)
print_rank_0(f"Sequential calibration: Found {num_layers} transformer layers")

resume_from_layer, layer_output_metas = detect_sequential_resume_layer(model, num_layers)

input_getter = LayerActivationCollector(model)
input_getter._patch_all_layers(decoder_layers=transformer_layers)
input_getter._patch_all_layers(
decoder_layers=transformer_layers, layer_output_metas=layer_output_metas
)

try:
for layer_idx, layer in enumerate(transformer_layers):
print_rank_0(f"Calibrating layer {layer_idx + 1}/{len(transformer_layers)}")
if resume_from_layer > 0:
input_getter.prepare_for_resume(resume_from_layer, forward_loop)

for layer_idx in range(resume_from_layer, num_layers):
layer = transformer_layers[layer_idx]
print_rank_0(f"Calibrating layer {layer_idx + 1}/{num_layers}")
layer_inputs = input_getter.get_input_activations(layer, forward_loop)

def _layer_forward_loop(m, _inputs=layer_inputs):
Expand All @@ -1909,5 +1938,19 @@ def _layer_forward_loop(m, _inputs=layer_inputs):

del layer_inputs
torch.cuda.empty_cache()

if should_save_seq_calib_checkpoint(
layer_idx, num_layers, checkpoint_dir, checkpoint_interval
):
assert checkpoint_dir is not None # narrowed by should_save_seq_calib_checkpoint
layer_output_metas = input_getter.get_layer_output_metas(layer_idx)
save_sequential_checkpoint(
model, layer_idx, num_layers, checkpoint_dir, layer_output_metas
)
finally:
# Sole owner of _seq_calib_progress cleanup. The attribute may be set
# by save_sequential_checkpoint (save path) or restore_quantizer_state
# (resume path); neither deletes it — this is the single cleanup point.
if hasattr(model, SEQ_CALIB_PROGRESS_ATTR):
delattr(model, SEQ_CALIB_PROGRESS_ATTR)
input_getter._unpatch_all_layers()
9 changes: 9 additions & 0 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from ..triton import IS_AVAILABLE as IS_TRITON_AVAILABLE
from ..utils import replace_function, sync_moe_expert_amax
from ..utils.activation_collector import LayerActivationCollector
from ..utils.checkpoint import register_seq_calib_checkpoint_saver
from .attention import register_attention_for_kv_quant
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear, _QuantFunctionalMixin

Expand Down Expand Up @@ -1472,6 +1473,14 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model):
)


def _save_hf_checkpoint(model: nn.Module, checkpoint_dir: str) -> None:
"""Save a HuggingFace model checkpoint using ``save_pretrained``."""
model.save_pretrained(checkpoint_dir)


register_seq_calib_checkpoint_saver(_is_supported_hf_model, _save_hf_checkpoint)


class _QuantMoELinear(QuantModule):
"""Quantization wrapper for Step3p5 MoELinear modules (fused expert weights).

Expand Down
Loading
Loading