Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions olive/olive_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@
"dataset": "dataset_optional",
"module_dependencies": [ "autoawq" ]
},
"AutoClip": {
"module_path": "olive.passes.pytorch.autoclip.AutoClip",
"supported_providers": [ "*" ],
"supported_accelerators": [ "*" ],
"supported_precisions": [ ],
"supported_algorithms": [ ],
"supported_quantization_encodings": [ ],
"dataset": "dataset_optional"
},
"CaptureSplitInfo": {
"module_path": "olive.passes.pytorch.capture_split_info.CaptureSplitInfo",
"supported_providers": [ "*" ],
Expand Down
229 changes: 229 additions & 0 deletions olive/passes/pytorch/autoclip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# Based on original implementation at
# https://github.com/OpenBitSys/BitDistiller/blob/main/quantization/autoclip.py
# --------------------------------------------------------------------------
from __future__ import annotations

import logging
from functools import partial
from typing import TYPE_CHECKING, Union

import torch

from olive.data.config import DataConfig
from olive.passes import Pass
from olive.passes.pass_config import BasePassConfig, PassConfigParam
from olive.passes.pytorch.common import inherit_hf_from_hf
from olive.passes.pytorch.quant_utils import (
get_quantizer_config,
prepare_model,
run_layerwise_quantization,
)
from olive.passes.pytorch.train_utils import get_calibration_data_config

if TYPE_CHECKING:
from olive.hardware.accelerator import AcceleratorSpec
from olive.model import HfModelHandler


logger = logging.getLogger(__name__)


class AutoClip(Pass):
"""AutoClip quantization-aware clipping for weights."""

@classmethod
def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassConfigParam]:
return {
**get_quantizer_config(),
"n_grid": PassConfigParam(
type_=int,
default_value=20,
description="Number of grid steps to search for clipping.",
),
"max_shrink": PassConfigParam(
type_=float,
default_value=0.5,
description="Maximum shrink ratio for clip bounds.",
),
"n_sample_token": PassConfigParam(
type_=int,
default_value=512,
description="Number of token samples for input feature selection.",
),
"data_config": PassConfigParam(
type_=Union[DataConfig, dict],
default_value=None,
description=(
"Data config for clipping calibration. If not provided, wikitest train data will be used for"
" HfModels. Required for PyTorch models."
),
),
}

@classmethod
def validate_config(
cls,
config: type[BasePassConfig],
accelerator_spec: AcceleratorSpec,
) -> bool:
if not super().validate_config(config, accelerator_spec):
return False

if config.group_size <= 0 and config.group_size != -1:
logger.info("group_size must be -1 or greater than 0")
return False

bits = config.bits.value if hasattr(config.bits, "value") else config.bits
if bits not in [2, 4, 8]:
logger.info("bits must be 2, 4, or 8")
return False

return True

@torch.no_grad()
def _run_for_config(
self, model: HfModelHandler, config: type[BasePassConfig], output_model_path: str
) -> HfModelHandler:
wrapper, _, _ = prepare_model(model, config, exclude_attn_inputs=True)

data_config = config.data_config or get_calibration_data_config(
model.model_name_or_path,
trust_remote_code=model.get_load_kwargs().get("trust_remote_code", None),
data_name="mit-han-lab/pile-val-backup",
subset=None,
split="validation[:1000]",
max_seq_len=1024,
max_samples=128,
)
process_module = partial(
self.process_module,
n_grid=config.n_grid,
max_shrink=config.max_shrink,
n_sample_token=config.n_sample_token,
)
run_layerwise_quantization(
model,
wrapper,
data_config,
input_hook=self.accumulate_inputs,
process_module=process_module,
update_before_process=True,
include_lm_head=config.lm_head,
)

# TODO(jambayk): explore whether we should tie the embedding with lm_head after lm_head is clipped

wrapper.model.save_pretrained(output_model_path)
model.save_metadata(output_model_path)

return inherit_hf_from_hf(model, output_model_path, adapter_path=model.adapter_path)

@staticmethod
def _get_oc_batch_size(out_features: int) -> int:
for candidate in [256, 128, 64, 32, 16]:
if out_features % candidate == 0:
return candidate
return out_features

@staticmethod
def accumulate_inputs(module: torch.nn.Module, inputs: tuple, _: torch.Tensor) -> None:
if module.quant_info.data is None:
module.quant_info.data = {"inputs": []}
module.quant_info.data["inputs"].append(inputs[0].detach().cpu())

@classmethod
def process_module(
cls,
module: torch.nn.Module,
device: str,
n_grid: int,
max_shrink: float,
n_sample_token: int,
) -> None:
if module.quant_info.data is None or not module.quant_info.data.get("inputs"):
raise ValueError(f"Module {module} does not have cached inputs initialized!")

input_feat = torch.cat(module.quant_info.data["inputs"], dim=0)
module.quant_info.data = None

module.to(device)
cls._auto_clip_layer(
module,
input_feat,
n_grid,
max_shrink,
n_sample_token,
)
module.to("cpu")

@classmethod
def _auto_clip_layer(
cls,
module: torch.nn.Module,
input_feat: torch.Tensor,
n_grid: int,
max_shrink: float,
n_sample_token: int,
) -> None:
weight = module.weight.data
if weight.dim() != 2:
raise ValueError("AutoClip expects a 2D linear weight tensor.")

quantizer = module.quant_info.quantizer
effective_group_size = weight.shape[1] if quantizer.group_size <= 0 else quantizer.group_size
if weight.shape[1] % effective_group_size != 0:
raise ValueError("Weight in_features must be divisible by group_size.")

input_feat = input_feat.view(-1, input_feat.shape[-1])
if input_feat.shape[0] > n_sample_token:
step = max(1, input_feat.shape[0] // n_sample_token)
input_feat = input_feat[::step]
input_feat = input_feat.reshape(1, input_feat.shape[0], -1, effective_group_size)

weight = weight.reshape(weight.shape[0], 1, -1, effective_group_size)

oc_batch_size = cls._get_oc_batch_size(weight.shape[0])
best_max_val_all = []
best_min_val_all = []

input_feat = input_feat.to(weight.device)
for i_b in range(0, weight.shape[0], oc_batch_size):
w_block = weight[i_b : i_b + oc_batch_size]

org_max_val = w_block.amax(dim=-1, keepdim=True)
org_min_val = w_block.amin(dim=-1, keepdim=True)

best_max_val = org_max_val.clone()
best_min_val = org_min_val.clone()
min_errs = torch.full_like(org_max_val, 1e9)

org_out = (input_feat * w_block).sum(dim=-1)

for i_s_p in range(int(max_shrink * n_grid)):
max_val = org_max_val * (1 - i_s_p / n_grid)
for i_s_n in range(int(max_shrink * n_grid)):
min_val = org_min_val * (1 - i_s_n / n_grid)
cur_w = torch.clamp(w_block, min_val, max_val)
q_w = quantizer.fake_quantize(cur_w.reshape(cur_w.shape[0], -1)).reshape(cur_w.shape)

cur_out = (input_feat * q_w).sum(dim=-1)
err = (cur_out - org_out).pow(2).mean(dim=1).reshape(min_errs.shape)

cur_best = err < min_errs
min_errs[cur_best] = err[cur_best]
best_max_val[cur_best] = max_val[cur_best]
best_min_val[cur_best] = min_val[cur_best]

best_max_val_all.append(best_max_val)
best_min_val_all.append(best_min_val)

best_max_val = torch.cat(best_max_val_all, dim=0).squeeze(1)
best_min_val = torch.cat(best_min_val_all, dim=0).squeeze(1)
original_shape = module.weight.data.shape
clipped = module.weight.data.reshape(best_max_val.shape[0], best_max_val.shape[1], -1)
clipped = torch.clamp(clipped, best_min_val, best_max_val)
module.weight.data = clipped.reshape(original_shape)
Loading
Loading