Skip to content

Commit 7c75f20

Browse files
committed
Add NVFP4 QAT
**Summary:** This commit adds a QAT flow for NVFP4, following the numerics in `NVFP4Tensor` closely but without the dtyping casting, swizzling, and the packing/unpacking. Users can call this flow as follows: ``` from torchao.quantization import quantize_ from torchao.quantization.qat import NVFP4FakeQuantizeConfig, QATConfig qat_config = QATConfig( activation_config=NVFP4FakeQuantizeConfig(), weight_config=NVFP4FakeQuantizeConfig(), step="prepare", ) quantize_(model, qat_config) ``` **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 ``` ghstack-source-id: fe592ca Pull Request resolved: #2666
1 parent 255c0b2 commit 7c75f20

File tree

6 files changed

+126
-26
lines changed

6 files changed

+126
-26
lines changed

docs/source/api_ref_qat.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ Custom QAT APIs
2626

2727
FakeQuantizeConfigBase
2828
IntxFakeQuantizeConfig
29+
NVFP4FakeQuantizeConfig
2930
FakeQuantizedLinear
3031
FakeQuantizedEmbedding
3132
FakeQuantizer

test/quantization/test_qat.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import warnings
1313
from typing import List
1414

15+
import pytest
1516
import torch
1617
import torch.nn.functional as F
1718
from parameterized import parameterized
@@ -44,6 +45,7 @@
4445
)
4546
from torchao.quantization.qat.fake_quantize_config import (
4647
IntxFakeQuantizeConfig,
48+
NVFP4FakeQuantizeConfig,
4749
)
4850
from torchao.quantization.qat.fake_quantizer import (
4951
FakeQuantizer,
@@ -112,8 +114,8 @@ def __init__(self):
112114
self.sub = Sub()
113115
self.linear2 = torch.nn.Linear(256, 512, bias=False).to(torch.float)
114116

115-
def example_inputs(self):
116-
return (torch.randn(1, 512).to(torch.float),)
117+
def example_inputs(self, device: torch.device = None):
118+
return (torch.randn((1, 512), device=device).to(torch.float),)
117119

118120
def _get_all_weight_qparams(self) -> List[torch.Tensor]:
119121
return [
@@ -1884,6 +1886,32 @@ def test_qat_api_deprecation(self):
18841886
str(w.message),
18851887
)
18861888

1889+
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
1890+
@pytest.mark.parametrize("use_per_tensor_scale", [True, False])
1891+
def test_qat_nvfp4(self, use_per_tensor_scale: bool = False):
1892+
"""
1893+
Test QAT with `NVFP4FakeQuantizeConfig`.
1894+
"""
1895+
torch.manual_seed(self.SEED)
1896+
m = M().cuda()
1897+
baseline_model = copy.deepcopy(m)
1898+
qat_config = QATConfig(
1899+
activation_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale),
1900+
weight_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale),
1901+
step="prepare",
1902+
)
1903+
quantize_(m, qat_config)
1904+
1905+
# Compare prepared values
1906+
torch.manual_seed(self.SEED)
1907+
x = m.example_inputs("cuda")
1908+
out = m(*x)
1909+
baseline_out = baseline_model(*x)
1910+
sqnr = compute_error(out, baseline_out).item()
1911+
# Use same SQNR threshold as `test_nvfp4_reconstruction`
1912+
# TODO: why is this 0.0 when `use_per_tensor_scale=True`?
1913+
self.assertGreater(sqnr, 8.0)
1914+
18871915

18881916
if __name__ == "__main__":
18891917
unittest.main()

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,15 @@ def nvfp4_quantize(
764764
AssertionError: If input dtype is not supported, tensor size is not
765765
divisible by block_size, tensor is not contiguous, or block_size != 16
766766
"""
767+
return _nvfp4_quantize(data_hp, block_size, per_tensor_scale)
768+
769+
770+
def _nvfp4_quantize(
771+
data_hp: torch.Tensor,
772+
block_size: int = 16,
773+
per_tensor_scale: Optional[torch.Tensor] = None,
774+
skip_dtype_cast_and_packing: bool = False,
775+
) -> tuple[torch.Tensor, torch.Tensor]:
767776
assert data_hp.dtype in (torch.bfloat16, torch.float), (
768777
f"{data_hp.dtype} not supported"
769778
)
@@ -782,9 +791,9 @@ def nvfp4_quantize(
782791
out_scales = None
783792
if per_tensor_scale is None:
784793
# We are doing single level scaling
785-
block_scale_fp8 = torch.clamp(block_scale, min=E4M3_EPS, max=F8E4M3_MAX).to(
786-
torch.float8_e4m3fn
787-
)
794+
block_scale_fp8 = torch.clamp(block_scale, min=E4M3_EPS, max=F8E4M3_MAX)
795+
if not skip_dtype_cast_and_packing:
796+
block_scale_fp8 = block_scale_fp8.to(torch.float8_e4m3fn)
788797
block_scale_fp32 = block_scale_fp8.to(torch.float32)
789798
data_scaled = data_hp / block_scale_fp32.unsqueeze(-1)
790799
out_scales = block_scale_fp8
@@ -797,7 +806,9 @@ def nvfp4_quantize(
797806
scaled_block_scales = block_scale_fp32 / per_tensor_scale
798807
scaled_block_scales_fp8 = torch.clamp(
799808
scaled_block_scales, min=E4M3_EPS, max=F8E4M3_MAX
800-
).to(torch.float8_e4m3fn)
809+
)
810+
if not skip_dtype_cast_and_packing:
811+
scaled_block_scales_fp8 = scaled_block_scales_fp8.to(torch.float8_e4m3fn)
801812
scaled_block_scales_fp32 = scaled_block_scales_fp8.to(torch.float32)
802813
# We "temporarily" dequant the scaled_block_scales_fp32 to get the per_tensor_scale
803814
# To apply to data
@@ -807,8 +818,11 @@ def nvfp4_quantize(
807818

808819
data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX)
809820
data_scaled = data_scaled.view(orig_shape)
810-
data_lp = f32_to_f4_unpacked(data_scaled)
811-
# TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'
812-
# data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2)
813-
data_lp = pack_uint4(data_lp)
814-
return out_scales, data_lp
821+
if skip_dtype_cast_and_packing:
822+
return out_scales, data_scaled
823+
else:
824+
data_lp = f32_to_f4_unpacked(data_scaled)
825+
# TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'
826+
# data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2)
827+
data_lp = pack_uint4(data_lp)
828+
return out_scales, data_lp

torchao/quantization/qat/fake_quantize_config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,22 @@ class FakeQuantizeConfigBase(abc.ABC):
3636
pass
3737

3838

39+
@dataclass
40+
class NVFP4FakeQuantizeConfig(FakeQuantizeConfigBase):
41+
"""
42+
Config for fake quantizing weights or activations to NVIDIA's NVFP4 format
43+
according to https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/.
44+
45+
Fake quantization numerics follow `NVFP4Tensor` closely: https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats/nvfp4_tensor.py.
46+
47+
Args:
48+
use_per_tensor_scale (bool): Whether to use two-level per-tensor fp32 scaling
49+
after the initial fp8 (e4m3) block-wise scaling.
50+
"""
51+
52+
use_per_tensor_scale: bool = False
53+
54+
3955
@dataclass
4056
class IntxFakeQuantizeConfig(FakeQuantizeConfigBase):
4157
"""

torchao/quantization/qat/fake_quantizer.py

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from .fake_quantize_config import (
3030
FakeQuantizeConfigBase,
3131
IntxFakeQuantizeConfig,
32+
NVFP4FakeQuantizeConfig,
3233
)
3334
from .utils import (
3435
_fake_quantize_per_channel_group,
@@ -46,13 +47,14 @@ def __init__(self, config: FakeQuantizeConfigBase):
4647
super().__init__()
4748
self.config = config
4849
self.enabled = True
49-
self.scale: Optional[torch.Tensor] = None
50-
self.zero_point: Optional[torch.Tensor] = None
5150

52-
# For range learning only
53-
# TODO: make this configurable?
54-
self._scale_eps = 1e-9
55-
self._initialized = False
51+
if isinstance(self.config, IntxFakeQuantizeConfig):
52+
self.scale: Optional[torch.Tensor] = None
53+
self.zero_point: Optional[torch.Tensor] = None
54+
# For range learning only
55+
# TODO: make this configurable?
56+
self._scale_eps = 1e-9
57+
self._initialized = False
5658

5759
def forward(self, x: torch.Tensor) -> torch.Tensor:
5860
"""
@@ -62,9 +64,46 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
6264
if not self.enabled:
6365
return x
6466

65-
if not isinstance(self.config, IntxFakeQuantizeConfig):
66-
raise ValueError("Only IntxFakeQuantizeConfig is supported currently")
67+
if isinstance(self.config, NVFP4FakeQuantizeConfig):
68+
return self._nvfp4_forward(x)
69+
elif isinstance(self.config, IntxFakeQuantizeConfig):
70+
return self._intx_forward(x)
71+
else:
72+
raise ValueError(f"Unexpected config type {self.config}")
73+
74+
def _nvfp4_forward(self, x: torch.Tensor):
75+
"""
76+
Apply NVFP4 fake quantization to the tensor following `NVFP4Tensor`.
77+
"""
78+
from torchao.prototype.mx_formats.nvfp4_tensor import (
79+
_nvfp4_quantize,
80+
per_tensor_amax_to_scale,
81+
)
6782

83+
block_size = 16
84+
if self.config.use_per_tensor_scale:
85+
tensor_amax = torch.max(torch.abs(x))
86+
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
87+
else:
88+
per_tensor_scale = None
89+
scale, q = _nvfp4_quantize(
90+
x,
91+
block_size=block_size,
92+
per_tensor_scale=per_tensor_scale,
93+
skip_dtype_cast_and_packing=True,
94+
)
95+
assert q.dtype == x.dtype
96+
assert scale.dtype == torch.float32
97+
M, K = q.shape[0], q.shape[1]
98+
q = q.view(M, K // block_size, block_size)
99+
scale = scale.view(M, K // block_size, 1)
100+
dq = q * scale
101+
return dq.view(x.shape)
102+
103+
def _intx_forward(self, x: torch.Tensor) -> torch.Tensor:
104+
"""
105+
Apply intx fake quantization to the tensor.
106+
"""
68107
if (
69108
self.config.range_learning
70109
and not self._initialized
@@ -77,15 +116,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
77116
)
78117

79118
if isinstance(self.config.granularity, PerToken):
80-
return self._per_token_forward(x)
119+
return self._intx_per_token_forward(x)
81120
elif isinstance(self.config.granularity, (PerAxis, PerGroup)):
82-
return self._per_channel_or_group_forward(x)
121+
return self._intx_per_channel_or_group_forward(x)
83122
else:
84123
raise ValueError("Unknown granularity '%s'" % self.config.granularity)
85124

86-
def _per_token_forward(self, x: torch.Tensor) -> torch.Tensor:
125+
def _intx_per_token_forward(self, x: torch.Tensor) -> torch.Tensor:
87126
"""
88-
Perform per token fake quantization on the tensor.
127+
Perform intx per token fake quantization on the tensor.
89128
"""
90129
if self.config.is_symmetric:
91130
raise NotImplementedError("Symmetric per token is not supported yet")
@@ -105,9 +144,9 @@ def _per_token_forward(self, x: torch.Tensor) -> torch.Tensor:
105144
self._maybe_update_qparams_for_range_learning()
106145
return _fake_quantize_per_token(x, self.scale, self.zero_point, qmin, qmax)
107146

108-
def _per_channel_or_group_forward(self, x: torch.Tensor) -> torch.Tensor:
147+
def _intx_per_channel_or_group_forward(self, x: torch.Tensor) -> torch.Tensor:
109148
"""
110-
Perform per channel or per group fake quantization on the tensor.
149+
Perform intx per channel or per group fake quantization on the tensor.
111150
We express per channel using per group where the group size is the size
112151
of the last dimension of the tensor.
113152
"""

torchao/quantization/qat/linear.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ def __init__(
9090

9191
# initialize weight fake quantizer
9292
if weight_config is not None:
93-
if isinstance(weight_config.granularity, PerGroup):
93+
if isinstance(weight_config, IntxFakeQuantizeConfig) and isinstance(
94+
weight_config.granularity, PerGroup
95+
):
9496
group_size = weight_config.group_size
9597
if group_size is not None and in_features % group_size != 0:
9698
raise ValueError(

0 commit comments

Comments
 (0)