diff --git a/modelopt/torch/quantization/calib/mse.py b/modelopt/torch/quantization/calib/mse.py index 9710d3a4b..37f4540af 100644 --- a/modelopt/torch/quantization/calib/mse.py +++ b/modelopt/torch/quantization/calib/mse.py @@ -79,8 +79,17 @@ def collect(self, x: torch.Tensor): x = x.detach().to(dtype=torch.float32) device = x.device - multipliers = torch.linspace( - self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device + # Split steps between _start_multiplier to 1.0 and 1.0 to _stop_multiplier + # to ensure balanced exploration on both sides of the original amax (1.0) + steps_first_half = self._num_steps // 2 + 1 # Include 1.0 + steps_second_half = self._num_steps - self._num_steps // 2 # For second range + multipliers = torch.cat( + [ + torch.linspace(self._start_multiplier, 1.0, steps=steps_first_half, device=device), + torch.linspace(1.0, self._stop_multiplier, steps=steps_second_half, device=device)[ + 1: + ], # Skip duplicate 1.0 + ] ) # Get reduce axis for per-channel quantization diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index c8e2b044c..5ea7b9599 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -32,6 +32,7 @@ from .calib import MseCalibrator from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context from .nn import QuantModule, SequentialQuantizer, TensorQuantizer +from .tensor_quant import scaled_e4m3_impl from .utils import ( disable_calib, enable_fake_quant, @@ -41,6 +42,7 @@ is_quantized_linear, is_quantized_row_parallel_linear, quantizer_attr_names, + reduce_amax, weight_attr_names, ) @@ -216,14 +218,18 @@ def mse_calibrate( max_calibrate(model, forward_loop, distributed_sync) # Step 2: Replace calibrators with MseCalibrator for enabled quantizers + # and identify weight quantizers + weight_quantizers = [] + seen_modules = set() + for name, module in model.named_modules(): if isinstance(module, TensorQuantizer) and not module._disabled: # Static block quantization is not supported by MseCalibrator - if module.is_static_block_quant: - raise ValueError( - f"MSE calibration does not support static block quantization. " - f"Found static block quantization at {name}." - ) + # if module.is_static_block_quant: + # raise ValueError( + # f"MSE calibration does not support static block quantization. " + # f"Found static block quantization at {name}." + # ) if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): # Get the initial amax from max calibration initial_amax = module._amax.clone().detach() @@ -237,7 +243,20 @@ def quant_func(x, amax, quantizer=module): disable_calib(quantizer), enable_fake_quant(quantizer), ): + quantizer._keep_shape = True xq = quantizer(x) + quantizer._keep_shape = False + + # FP8 quantization of NVFP4 static per-block scales + if ( + quantizer.is_static_block_quant + and quantizer._num_bits == (2, 1) + and quantizer._block_sizes.get("scale_bits") == (4, 3) + ): + weight_amax = reduce_amax( + x, axis=None, keepdims=False, squeeze_scalar=True + ) + quantizer._amax = scaled_e4m3_impl(amax, weight_amax) if original_amax is not None: quantizer._amax = original_amax @@ -256,14 +275,48 @@ def quant_func(x, amax, quantizer=module): quant_func=quant_func, ) - # Step 3: Collect data with MSE calibrators + # Identify weight quantizers by checking if they have corresponding weight parameters + for name, parent_module in model.named_modules(): + if parent_module in seen_modules: + continue + for weight_name in weight_attr_names(parent_module): + weight_quantizer_name = quantizer_attr_names(weight_name).weight_quantizer + weight_quantizer = getattr(parent_module, weight_quantizer_name, None) + if isinstance(weight_quantizer, TensorQuantizer) and not weight_quantizer._disabled: + if weight_quantizer._calibrator is not None: + weight_quantizers.append((parent_module, weight_name, weight_quantizer)) + seen_modules.add(parent_module) + + # Step 3: Calibrate weight quantizers once with MSE calibration + # This ensures weights are only calibrated once, not during every forward pass + for parent_module, weight_name, weight_quantizer in weight_quantizers: + # Enable calibration mode for the weight quantizer + weight_quantizer.disable_quant() + weight_quantizer.enable_calib() + + with enable_weight_access_and_writeback(parent_module, model): + weight = getattr(parent_module, weight_name) + weight_quantizer(weight) + + # Step 4: Disable weight quantizers during forward loop + for _, _, weight_quantizer in weight_quantizers: + weight_quantizer.disable() + + # Step 5: Collect data with MSE calibrators for activation quantizers only enable_stats_collection(model) if forward_loop is None: - weight_only_quantize(model) + # If no forward loop, nothing else to do since weights are already calibrated + pass else: + # Run forward loop - only activation quantizers will collect data forward_loop(model) - # Step 4: Compute optimal amax and load it + # Step 6: Re-enable weight quantizers before finalizing calibration + # This ensures finish_stats_collection processes them correctly + for _, _, weight_quantizer in weight_quantizers: + weight_quantizer.enable() + + # Step 7: Compute optimal amax and load it for all quantizers (weights + activations) finish_stats_collection(model, method="mse") # TODO: Sync amax across distributed processes diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 1688c7fa7..71695a634 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -128,6 +128,7 @@ def __init__( self._enable_pre_quant_scale = True self._dequantize = False self._input_dtype = None + self._keep_shape = False # Lazy initialize the bias calibrator for KV cache quantization self._bias_calibrator = None @@ -653,6 +654,14 @@ def _fake_quantize(self, inputs): getattr(self, "_onnx_quantizer_type", None), self._pass_through_bwd, ) + elif self._num_bits == (2, 1) and self.is_static_block_quant: + from modelopt.torch.quantization.triton.fp4_kernel import ( + launch_static_blockwise_fp4_fake_quant, + ) + + outputs = launch_static_blockwise_fp4_fake_quant( + inputs, amax / 6.0, out_dtype=inputs.dtype + ) elif isinstance(self._num_bits, tuple): # Float-point quantization, e.g., FP8 E, M = self._num_bits # noqa: N806 @@ -783,11 +792,11 @@ def _process_for_blockquant(self, inputs: torch.Tensor): if hasattr(self, "_padding"): inputs = F.pad(inputs, self._padding, "constant", 0) - if inputs.shape != self._original_shape: - raise ValueError( - f"Input shape has changed from {self._original_shape} to {inputs.shape}." - " Block-quantization requires a fixed input shape." - ) + # if inputs.shape != self._original_shape: + # print( + # f"Input shape has changed from {self._original_shape} to {inputs.shape}." + # " Block-quantization requires a fixed input shape." + # ) inputs = inputs.reshape(self._block_reshape_size) return inputs @@ -941,7 +950,7 @@ def forward(self, inputs): "This case should have been handled." ) - if self.is_static_block_quant: + if self.is_static_block_quant and not self._keep_shape: outputs = self._reset_to_original_shape(outputs) return outputs diff --git a/modelopt/torch/quantization/triton/fp4_kernel.py b/modelopt/torch/quantization/triton/fp4_kernel.py index f2f9bd077..6049882e4 100644 --- a/modelopt/torch/quantization/triton/fp4_kernel.py +++ b/modelopt/torch/quantization/triton/fp4_kernel.py @@ -345,3 +345,157 @@ def fp4_dequantize( ) return output + + +@triton.jit +def static_blockwise_fp4_fake_quant_kernel( + x_ptr, # [NUM_FP4_BLOCKS * BLOCK_SIZE] + y_ptr, # [NUM_FP4_BLOCKS * BLOCK_SIZE] + scale_ptr, # [NUM_FP4_BLOCKS] + NUM_FP4_BLOCKS, + BLOCK_SIZE: tl.constexpr, + OUT_DTYPE: tl.constexpr, +): + pid = tl.program_id(axis=0) + if pid >= NUM_FP4_BLOCKS: + return + + block_offset = pid * BLOCK_SIZE + idx = block_offset + tl.arange(0, BLOCK_SIZE) + + scale = tl.load(scale_ptr + pid).to(tl.float32) + + x = tl.load(x_ptr + idx).to(tl.float32) + + x_abs = tl.abs(x) + scale_safe = tl.where(scale >= 1e-5, scale, 1.0) + abs_scaled = x_abs / scale_safe + + # FP4 values: 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0 + q_val = tl.where( + abs_scaled <= 0.25, + 0.0, + tl.where( + abs_scaled < 0.75, + 0.5, + tl.where( + abs_scaled <= 1.25, + 1.0, + tl.where( + abs_scaled < 1.75, + 1.5, + tl.where( + abs_scaled <= 2.5, + 2.0, + tl.where( + abs_scaled < 3.5, + 3.0, + tl.where(abs_scaled <= 5.0, 4.0, 6.0), + ), + ), + ), + ), + ), + ) + + x_rescaled = q_val * scale_safe + x_dequant = tl.where(x >= 0, x_rescaled, -x_rescaled) + + tl.store(y_ptr + idx, x_dequant.to(OUT_DTYPE)) + + +def launch_static_blockwise_fp4_fake_quant( + x: torch.Tensor, + scale: torch.Tensor, + out_dtype: torch.dtype = torch.float16, +): + """Launch Triton kernel for blockwise FP4 fake quantization. + + x: [NUM_FP4_BLOCKS, BLOCK_SIZE] on CUDA. + scale: [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1] on CUDA. + """ + assert x.ndim == 2 + NUM_FP4_BLOCKS, BLOCK_SIZE = x.shape + + x_flat = x.contiguous().view(-1) + y_flat = torch.empty_like(x_flat, dtype=out_dtype) + scale_flat = scale.view(NUM_FP4_BLOCKS).contiguous() + + tl_out_dtype = _torch_dtype_to_tl(out_dtype) + + grid = (NUM_FP4_BLOCKS,) + + # Ensure we're running on the correct CUDA device + with torch.cuda.device(x.device): + static_blockwise_fp4_fake_quant_kernel[grid]( + x_flat, + y_flat, + scale_flat, + NUM_FP4_BLOCKS, + BLOCK_SIZE, + OUT_DTYPE=tl_out_dtype, + ) + + return y_flat.view_as(x) + + +def blockwise_fp4_fake_quant_reference( + x: torch.Tensor, + scale: torch.Tensor, + out_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """Reference implementation of blockwise FP4 fake quantization. + + x: [NUM_FP4_BLOCKS, BLOCK_SIZE]. + scale: [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1]. + + Uses FP4 quantization levels: 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0. + """ + assert x.ndim == 2 + num_blocks, block_size = x.shape + + if scale.ndim == 1: + scale = scale.view(num_blocks, 1) + assert scale.shape == (num_blocks, 1) + + x_f = x.to(torch.float32) + s_f = scale.to(torch.float32) + + s_f = torch.where(s_f >= 1e-5, s_f, torch.ones_like(s_f)) + + x_abs = torch.abs(x_f) + abs_scaled = x_abs / s_f + + q_val = torch.where( + abs_scaled <= 0.25, + torch.zeros_like(abs_scaled), + torch.where( + abs_scaled < 0.75, + torch.full_like(abs_scaled, 0.5), + torch.where( + abs_scaled <= 1.25, + torch.ones_like(abs_scaled), + torch.where( + abs_scaled < 1.75, + torch.full_like(abs_scaled, 1.5), + torch.where( + abs_scaled <= 2.5, + torch.full_like(abs_scaled, 2.0), + torch.where( + abs_scaled < 3.5, + torch.full_like(abs_scaled, 3.0), + torch.where( + abs_scaled <= 5.0, + torch.full_like(abs_scaled, 4.0), + torch.full_like(abs_scaled, 6.0), + ), + ), + ), + ), + ), + ), + ) + + x_rescaled = q_val * s_f + x_dequant = torch.where(x_f >= 0, x_rescaled, -x_rescaled) + return x_dequant.to(out_dtype)