Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -554,11 +554,13 @@ def _real_quantize(self, inputs):
if self._num_bits == (4, 3):
# FP8 quantization
# For per-tensor/per-channel quantization, we might need amax which is synced across all ranks
# For blockwise quantization, amax will be recomputed in the kernel
use_amax = self.amax is not None and not (self._block_sizes and self.amax.numel() == 1)
outputs, _scale = FP8QTensor.quantize(
inputs,
axis=self._axis,
block_sizes=self._block_sizes,
scales=self.amax / 448.0 if self.amax is not None else None,
scales=self.amax / 448.0 if use_amax else None,
)
buffer_to_register["_scale"] = _scale
elif self._num_bits == 8:
Expand Down
33 changes: 33 additions & 0 deletions tests/gpu/torch/quantization/test_qtensor_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,3 +569,36 @@ def test_nvfp4_dequantize_fast(self, shape, input_dtype):
f"Fast and standard dequantization differ: "
f"max diff = {(dequant_fast - dequant_standard).abs().max()}"
)

@pytest.mark.parametrize("device", ["cuda"])
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize(
("input_shape", "block_sizes"),
[
((128, 1152), {-1: 128}),
((256, 256), {-1: 64, -2: 64}), # 2D block sizes
],
)
def test_fp8_with_amax_and_block_sizes(self, device, input_dtype, input_shape, block_sizes):
"""Test FP8 quantization with both amax and block_sizes specified."""
quant_cfg = QuantizerAttributeConfig(
num_bits=(4, 3),
block_sizes=block_sizes,
fake_quant=False,
)
quantizer = TensorQuantizer(quant_cfg).to(device)

# Set a mock amax (scalar) - this was causing the bug
mock_amax = torch.tensor(1.5, device=device)
quantizer.amax = mock_amax

# Create input tensor
x = torch.randn(input_shape, dtype=input_dtype, device=device)

# QDQ
q_x = quantizer(x)
deq_x = quantizer(q_x)

assert torch.allclose(deq_x, x, rtol=1e-1, atol=1e-1)
assert hasattr(quantizer, "_scale")
assert quantizer._scale.numel() > 1