Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions modelopt/torch/quantization/calib/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,17 @@ def collect(self, x: torch.Tensor):
x = x.detach().to(dtype=torch.float32)

device = x.device
multipliers = torch.linspace(
self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device
# Split steps between _start_multiplier to 1.0 and 1.0 to _stop_multiplier
# to ensure balanced exploration on both sides of the original amax (1.0)
steps_first_half = self._num_steps // 2 + 1 # Include 1.0
steps_second_half = self._num_steps - self._num_steps // 2 # For second range
multipliers = torch.cat(
[
torch.linspace(self._start_multiplier, 1.0, steps=steps_first_half, device=device),
torch.linspace(1.0, self._stop_multiplier, steps=steps_second_half, device=device)[
1:
], # Skip duplicate 1.0
]
)

# Get reduce axis for per-channel quantization
Expand Down
69 changes: 61 additions & 8 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .calib import MseCalibrator
from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context
from .nn import QuantModule, SequentialQuantizer, TensorQuantizer
from .tensor_quant import scaled_e4m3_impl
from .utils import (
disable_calib,
enable_fake_quant,
Expand All @@ -41,6 +42,7 @@
is_quantized_linear,
is_quantized_row_parallel_linear,
quantizer_attr_names,
reduce_amax,
weight_attr_names,
)

Expand Down Expand Up @@ -216,14 +218,18 @@ def mse_calibrate(
max_calibrate(model, forward_loop, distributed_sync)

# Step 2: Replace calibrators with MseCalibrator for enabled quantizers
# and identify weight quantizers
weight_quantizers = []
seen_modules = set()

for name, module in model.named_modules():
if isinstance(module, TensorQuantizer) and not module._disabled:
# Static block quantization is not supported by MseCalibrator
if module.is_static_block_quant:
raise ValueError(
f"MSE calibration does not support static block quantization. "
f"Found static block quantization at {name}."
)
# if module.is_static_block_quant:
# raise ValueError(
# f"MSE calibration does not support static block quantization. "
# f"Found static block quantization at {name}."
# )
if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"):
# Get the initial amax from max calibration
initial_amax = module._amax.clone().detach()
Expand All @@ -237,7 +243,20 @@ def quant_func(x, amax, quantizer=module):
disable_calib(quantizer),
enable_fake_quant(quantizer),
):
quantizer._keep_shape = True
Copy link
Contributor

@realAsma realAsma Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line 233:

if is_nvfp4_static: # static per-block
scale = amax/6.0
global_scale = weight.amax()/6.0
scale_fp8 = scaled_e4m3_impl(amax/6.0, weight.amax()/6.0) # fP8 quantization
amax_equivalent = scale_fp8 * 6.0

xq = quantizer(x)
quantizer._keep_shape = False

# FP8 quantization of NVFP4 static per-block scales
if (
quantizer.is_static_block_quant
and quantizer._num_bits == (2, 1)
and quantizer._block_sizes.get("scale_bits") == (4, 3)
):
weight_amax = reduce_amax(
x, axis=None, keepdims=False, squeeze_scalar=True
)
quantizer._amax = scaled_e4m3_impl(amax, weight_amax)

if original_amax is not None:
quantizer._amax = original_amax
Expand All @@ -256,14 +275,48 @@ def quant_func(x, amax, quantizer=module):
quant_func=quant_func,
)

# Step 3: Collect data with MSE calibrators
# Identify weight quantizers by checking if they have corresponding weight parameters
for name, parent_module in model.named_modules():
if parent_module in seen_modules:
continue
for weight_name in weight_attr_names(parent_module):
weight_quantizer_name = quantizer_attr_names(weight_name).weight_quantizer
weight_quantizer = getattr(parent_module, weight_quantizer_name, None)
if isinstance(weight_quantizer, TensorQuantizer) and not weight_quantizer._disabled:
if weight_quantizer._calibrator is not None:
weight_quantizers.append((parent_module, weight_name, weight_quantizer))
seen_modules.add(parent_module)

# Step 3: Calibrate weight quantizers once with MSE calibration
# This ensures weights are only calibrated once, not during every forward pass
for parent_module, weight_name, weight_quantizer in weight_quantizers:
# Enable calibration mode for the weight quantizer
weight_quantizer.disable_quant()
weight_quantizer.enable_calib()

with enable_weight_access_and_writeback(parent_module, model):
weight = getattr(parent_module, weight_name)
weight_quantizer(weight)

# Step 4: Disable weight quantizers during forward loop
for _, _, weight_quantizer in weight_quantizers:
weight_quantizer.disable()

# Step 5: Collect data with MSE calibrators for activation quantizers only
enable_stats_collection(model)
if forward_loop is None:
weight_only_quantize(model)
# If no forward loop, nothing else to do since weights are already calibrated
pass
else:
# Run forward loop - only activation quantizers will collect data
forward_loop(model)

# Step 4: Compute optimal amax and load it
# Step 6: Re-enable weight quantizers before finalizing calibration
# This ensures finish_stats_collection processes them correctly
for _, _, weight_quantizer in weight_quantizers:
weight_quantizer.enable()

# Step 7: Compute optimal amax and load it for all quantizers (weights + activations)
finish_stats_collection(model, method="mse")

# TODO: Sync amax across distributed processes
Expand Down
21 changes: 15 additions & 6 deletions modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(
self._enable_pre_quant_scale = True
self._dequantize = False
self._input_dtype = None
self._keep_shape = False

# Lazy initialize the bias calibrator for KV cache quantization
self._bias_calibrator = None
Expand Down Expand Up @@ -653,6 +654,14 @@ def _fake_quantize(self, inputs):
getattr(self, "_onnx_quantizer_type", None),
self._pass_through_bwd,
)
elif self._num_bits == (2, 1) and self.is_static_block_quant:
from modelopt.torch.quantization.triton.fp4_kernel import (
launch_static_blockwise_fp4_fake_quant,
)

outputs = launch_static_blockwise_fp4_fake_quant(
inputs, amax / 6.0, out_dtype=inputs.dtype
)
elif isinstance(self._num_bits, tuple):
# Float-point quantization, e.g., FP8
E, M = self._num_bits # noqa: N806
Expand Down Expand Up @@ -783,11 +792,11 @@ def _process_for_blockquant(self, inputs: torch.Tensor):
if hasattr(self, "_padding"):
inputs = F.pad(inputs, self._padding, "constant", 0)

if inputs.shape != self._original_shape:
raise ValueError(
f"Input shape has changed from {self._original_shape} to {inputs.shape}."
" Block-quantization requires a fixed input shape."
)
# if inputs.shape != self._original_shape:
# print(
# f"Input shape has changed from {self._original_shape} to {inputs.shape}."
# " Block-quantization requires a fixed input shape."
# )
inputs = inputs.reshape(self._block_reshape_size)
return inputs

Expand Down Expand Up @@ -941,7 +950,7 @@ def forward(self, inputs):
"This case should have been handled."
)

if self.is_static_block_quant:
if self.is_static_block_quant and not self._keep_shape:
outputs = self._reset_to_original_shape(outputs)

return outputs
Expand Down
154 changes: 154 additions & 0 deletions modelopt/torch/quantization/triton/fp4_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,157 @@ def fp4_dequantize(
)

return output


@triton.jit
def static_blockwise_fp4_fake_quant_kernel(
x_ptr, # [NUM_FP4_BLOCKS * BLOCK_SIZE]
y_ptr, # [NUM_FP4_BLOCKS * BLOCK_SIZE]
scale_ptr, # [NUM_FP4_BLOCKS]
NUM_FP4_BLOCKS,
BLOCK_SIZE: tl.constexpr,
OUT_DTYPE: tl.constexpr,
):
pid = tl.program_id(axis=0)
if pid >= NUM_FP4_BLOCKS:
return

block_offset = pid * BLOCK_SIZE
idx = block_offset + tl.arange(0, BLOCK_SIZE)

scale = tl.load(scale_ptr + pid).to(tl.float32)

x = tl.load(x_ptr + idx).to(tl.float32)

x_abs = tl.abs(x)
scale_safe = tl.where(scale >= 1e-5, scale, 1.0)
abs_scaled = x_abs / scale_safe

# FP4 values: 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0
q_val = tl.where(
abs_scaled <= 0.25,
0.0,
tl.where(
abs_scaled < 0.75,
0.5,
tl.where(
abs_scaled <= 1.25,
1.0,
tl.where(
abs_scaled < 1.75,
1.5,
tl.where(
abs_scaled <= 2.5,
2.0,
tl.where(
abs_scaled < 3.5,
3.0,
tl.where(abs_scaled <= 5.0, 4.0, 6.0),
),
),
),
),
),
)

x_rescaled = q_val * scale_safe
x_dequant = tl.where(x >= 0, x_rescaled, -x_rescaled)

tl.store(y_ptr + idx, x_dequant.to(OUT_DTYPE))


def launch_static_blockwise_fp4_fake_quant(
x: torch.Tensor,
scale: torch.Tensor,
out_dtype: torch.dtype = torch.float16,
):
"""Launch Triton kernel for blockwise FP4 fake quantization.

x: [NUM_FP4_BLOCKS, BLOCK_SIZE] on CUDA.
scale: [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1] on CUDA.
"""
assert x.ndim == 2
NUM_FP4_BLOCKS, BLOCK_SIZE = x.shape

x_flat = x.contiguous().view(-1)
y_flat = torch.empty_like(x_flat, dtype=out_dtype)
scale_flat = scale.view(NUM_FP4_BLOCKS).contiguous()

tl_out_dtype = _torch_dtype_to_tl(out_dtype)

grid = (NUM_FP4_BLOCKS,)

# Ensure we're running on the correct CUDA device
with torch.cuda.device(x.device):
static_blockwise_fp4_fake_quant_kernel[grid](
x_flat,
y_flat,
scale_flat,
NUM_FP4_BLOCKS,
BLOCK_SIZE,
OUT_DTYPE=tl_out_dtype,
)

return y_flat.view_as(x)


def blockwise_fp4_fake_quant_reference(
x: torch.Tensor,
scale: torch.Tensor,
out_dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
"""Reference implementation of blockwise FP4 fake quantization.

x: [NUM_FP4_BLOCKS, BLOCK_SIZE].
scale: [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1].

Uses FP4 quantization levels: 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0.
"""
assert x.ndim == 2
num_blocks, block_size = x.shape

if scale.ndim == 1:
scale = scale.view(num_blocks, 1)
assert scale.shape == (num_blocks, 1)

x_f = x.to(torch.float32)
s_f = scale.to(torch.float32)

s_f = torch.where(s_f >= 1e-5, s_f, torch.ones_like(s_f))

x_abs = torch.abs(x_f)
abs_scaled = x_abs / s_f

q_val = torch.where(
abs_scaled <= 0.25,
torch.zeros_like(abs_scaled),
torch.where(
abs_scaled < 0.75,
torch.full_like(abs_scaled, 0.5),
torch.where(
abs_scaled <= 1.25,
torch.ones_like(abs_scaled),
torch.where(
abs_scaled < 1.75,
torch.full_like(abs_scaled, 1.5),
torch.where(
abs_scaled <= 2.5,
torch.full_like(abs_scaled, 2.0),
torch.where(
abs_scaled < 3.5,
torch.full_like(abs_scaled, 3.0),
torch.where(
abs_scaled <= 5.0,
torch.full_like(abs_scaled, 4.0),
torch.full_like(abs_scaled, 6.0),
),
),
),
),
),
),
)

x_rescaled = q_val * s_f
x_dequant = torch.where(x_f >= 0, x_rescaled, -x_rescaled)
return x_dequant.to(out_dtype)