Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions modelopt/torch/export/plugins/vllm_fakequant_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch.nn as nn

import modelopt.torch.opt as mto
from modelopt.torch.quantization.config import RotateConfig
from modelopt.torch.quantization.conversion import quantizer_state
from modelopt.torch.quantization.nn import QuantModule, TensorQuantizer
from modelopt.torch.quantization.utils import get_quantizer_state_dict
Expand All @@ -28,6 +29,15 @@
__all__ = ["export_hf_vllm_fq_checkpoint"]


def disable_rotate(quantizer: TensorQuantizer):
"""Return a disabled copy of the quantizer's ``_rotate`` field, preserving its type."""
if isinstance(quantizer._rotate, RotateConfig):
return RotateConfig(enable=False)
if isinstance(quantizer._rotate, dict): # backward compat: old checkpoints stored a dict
return dict(quantizer._rotate, enable=False)
return False


def export_hf_vllm_fq_checkpoint(
model: nn.Module,
export_dir: Path | str,
Expand Down Expand Up @@ -104,6 +114,8 @@ def export_hf_vllm_fq_checkpoint(
# dict, then re-enable. The _disabled=True flag is captured in modelopt_state
# so that on vLLM reload weight quantizers stay off while input/output/
# attention quantizers remain active.
# Rotation is also cleared: the weight was already folded with rotation applied,
# so if fold_weight is called on reload it must not re-rotate the exported weight.
wqs_to_restore = []
for _, module in model.named_modules():
if isinstance(module, QuantModule):
Expand All @@ -114,7 +126,10 @@ def export_hf_vllm_fq_checkpoint(
and quantizer.is_enabled
):
quantizer.disable()
wqs_to_restore.append(quantizer)
orig_rotate = quantizer._rotate
if quantizer.rotate_is_enabled:
quantizer._rotate = disable_rotate(quantizer)
wqs_to_restore.append((quantizer, orig_rotate))

quantizer_state_dict = get_quantizer_state_dict(model)
for key in list(quantizer_state_dict):
Expand Down Expand Up @@ -149,5 +164,6 @@ def export_hf_vllm_fq_checkpoint(
# Step 3: Save HF weights using the pre-built folded state dict.
model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False)

for wq in wqs_to_restore:
for wq, orig_rotate in wqs_to_restore:
wq.enable()
wq._rotate = orig_rotate
Loading