Skip to content

Add NVFP4 QAT #2666

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: gh/andrewor14/16/base
Choose a base branch
from
Open

Conversation

andrewor14
Copy link
Contributor

@andrewor14 andrewor14 commented Aug 1, 2025

Stack from ghstack (oldest at bottom):

Summary: This commit adds a QAT flow for NVFP4, following the
numerics in NVFP4Tensor closely but without the dtyping casting,
swizzling, and the packing/unpacking. Users can call this flow as follows:

from torchao.quantization import quantize_
from torchao.quantization.qat import NVFP4FakeQuantizeConfig, QATConfig

qat_config = QATConfig(
    activation_config=NVFP4FakeQuantizeConfig(),
    weight_config=NVFP4FakeQuantizeConfig(),
    step="prepare",
)
quantize_(model, qat_config)

Test Plan:

python test/quantization/test_qat.py -k test_qat_nvfp4

**Summary:** This commit adds a QAT flow for NVFP4, following the
numerics in `NVFP4Tensor` closely but without the dtyping casting,
swizzling, and the packing/unpacking. Users can call this flow as follows:

```
from torchao.quantization import quantize_
from torchao.quantization.qat import NVFP4FakeQuantizeConfig, QATConfig

qat_config = QATConfig(
    activation_config=NVFP4FakeQuantizeConfig(),
    weight_config=NVFP4FakeQuantizeConfig(),
    step="prepare",
)
quantize_(model, qat_config)
```

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
```

[ghstack-poisoned]
This was referenced Aug 1, 2025
Copy link

pytorch-bot bot commented Aug 1, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2666

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 87175e9 with merge base 97b090d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

andrewor14 added a commit that referenced this pull request Aug 1, 2025
**Summary:** This commit adds a QAT flow for NVFP4, following the
numerics in `NVFP4Tensor` closely but without the dtyping casting,
swizzling, and the packing/unpacking. Users can call this flow as follows:

```
from torchao.quantization import quantize_
from torchao.quantization.qat import NVFP4FakeQuantizeConfig, QATConfig

qat_config = QATConfig(
    activation_config=NVFP4FakeQuantizeConfig(),
    weight_config=NVFP4FakeQuantizeConfig(),
    step="prepare",
)
quantize_(model, qat_config)
```

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
```

ghstack-source-id: fe592ca
Pull Request resolved: #2666
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 1, 2025
@andrewor14 andrewor14 added the topic: new feature Use this tag if this PR adds a new feature label Aug 1, 2025
baseline_out = baseline_model(*x)
sqnr = compute_error(out, baseline_out).item()
# Use same SQNR threshold as `test_nvfp4_reconstruction`
# TODO: why is this 0.0 when `use_per_tensor_scale=True`?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a seems to be a bug it use_per_tensor_scale should be higher, probably supposed to be 10.0

after the initial fp8 (e4m3) block-wise scaling.
"""

use_per_tensor_scale: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should default to true

@drisspg
Copy link
Contributor

drisspg commented Aug 4, 2025

Any numeric studies on how / if this improves quant error? Even if its pretty trivial setup

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants