Skip to content

[float8] Add fnuz fp8 dtypes to Float8Layout #2351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 12, 2025
Merged

[float8] Add fnuz fp8 dtypes to Float8Layout #2351

merged 2 commits into from
Jun 12, 2025

Conversation

jcaip
Copy link
Contributor

@jcaip jcaip commented Jun 10, 2025

This should give us AMD perf on vLLM. With Phi-4-mini-instruct on MI300x with TorchAO FP8 rowwise quant on the MLP I see the following, which is about a 5% speedup:

Avg latency: 1.080369415456274 seconds
10% percentile latency: 1.075335633114446 seconds
25% percentile latency: 1.0811904482543468 seconds
50% percentile latency: 1.082176529977005 seconds
75% percentile latency: 1.0826280051842332 seconds
90% percentile latency: 1.0831242799758911 seconds
99% percentile latency: 1.0836151059856638 seconds

For comparison, here is the baseline Phi-4-mini-instruct on MI300x:

Avg latency: 1.148340248184589 seconds
10% percentile latency: 1.1391733552212826 seconds
25% percentile latency: 1.14905939399614 seconds
50% percentile latency: 1.150204271019902 seconds
75% percentile latency: 1.1523984443047084 seconds
90% percentile latency: 1.1536207939614542 seconds
99% percentile latency: 1.1548575214319863 seconds

Previously, these checks were failing on the unsigned zero ROCm fp8 dtypes, causing us to call .dequantize() and then do a bfloat16 mm, which was slower than the bf16 baseline (~2s).

Copy link

pytorch-bot bot commented Jun 10, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2351

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 2 Unrelated Failures

As of commit 98eb0dc with merge base 16e2d0a (image):

NEW FAILURE - The following job has failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 10, 2025
@jcaip jcaip added module: rocm topic: bug fix Use this tag for PRs that fix bugs labels Jun 10, 2025
@jcaip jcaip requested review from jerryzh168 and drisspg June 10, 2025 21:38
@jcaip
Copy link
Contributor Author

jcaip commented Jun 10, 2025

cc @drisspg @jerryzh168

Trying to think what's the best way to test this but I don't think it's that simple since we try and dequantize -> do dense matmul by default, which means that testing correctness is not enough here - Calling DynamicFloat8... on a model would fail on perf but would output the correct result.

Any opinions on maybe turning off (or putting it behind a flag) the dequantize -> dense op fallback by default or will that break a lot of things?

@drisspg
Copy link
Contributor

drisspg commented Jun 10, 2025

dense op fallback by default or will that break a lot of things? I think there is flag for specifying this @jerryzh168

@@ -442,7 +442,7 @@ def _linear_fp_act_fp8_weight_check(
# weight is float8 quantized affine quantized tensor
isinstance(weight_tensor, AffineQuantizedTensor)
and isinstance(weight_tensor._layout, Float8Layout)
and weight_tensor.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
and _is_float8_type(weight_tensor.tensor_impl.dtype)
Copy link
Contributor

@jerryzh168 jerryzh168 Jun 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so previously it's using fallback? we should probably have a way to check the kernel is called, or just remove the fallback

@jerryzh168
Copy link
Contributor

we can try removing the fallback in a PR I think, it might be OK

@jcaip jcaip merged commit aec0821 into main Jun 12, 2025
17 of 20 checks passed
@jerryzh168
Copy link
Contributor

dense op fallback by default or will that break a lot of things? I think there is flag for specifying this @jerryzh168

fallback is still the default bahavior, there is a flag for specific kernel choice as well if people want to make sure they are testing a specific kernel path

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/rocm CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: rocm topic: bug fix Use this tag for PRs that fix bugs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants