-
Notifications
You must be signed in to change notification settings - Fork 0
Added MXFP6 packing and fused unpack-dequantise kernel #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
balancap
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@alex-titterton Just a few small comments, I am keeping it at a high level as anyway you'll a more detailed code review on the upstream repo (and they will know better what they prefer as a style for integration).
A few additional things not directly in the code:
- The
ruffformat linter is failing, worth fixing before opening the PR upstream; - The
cutlasssubmodule hash seems to be different to themainbranch one. They'll ask you to revert back to the one onmain - It is not a small/simple PR, so I think it worth motivating a bit more in the main comment why this additional complexity should be added to the repo. The main motivation for me is
FP6is as good as FP8 for accuracy, but saving memory. We should support FP6 packing to save memory.
Additionally, it would help to document in the PR what is the FP6 packing we are using here (as we know there are multiple options). They may ask you if it aligns with Blackwell hardware specs.
| @pack_uint6.register_fake | ||
| def _(uint8_data): | ||
| out_shape = (*uint8_data.shape[:-1], 3 * uint8_data.shape[-1] // 4) | ||
| return torch.empty(*out_shape, device=uint8_data.device, dtype=torch.uint8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You'll need some test coverage of the function added to custom_cast.py to be accepted upstream. i.e. for pack_uint6, triton_f6_e3m2_to_scaled_bf16, triton_f6_e2m3_to_scaled_bf16
| y_ref = m(x) | ||
| y_mx = m_mx(x) | ||
| sqnr = compute_error(y_ref, y_mx) | ||
| print(sqnr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To remove
…sts to suit appropriate tensor dimensions
7b59a94 to
555c845
Compare
555c845 to
2cd2104
Compare
Initial PR to support MXFP6 packing, whereby the bit representations of 4 x
fp6are packed into 3 xuint8containers via custom triton kernel.mx_formatspytests have been amended to ensure the trailing tensor dimension and M block size are both multiples of 4 wheneverfp6packing is performed, since this is required in order to pack4Nvalues into3Nelements. This shouldn't cause any issues since any FP8 or lower HW implementation (e.g. tensor core) typically expects a minimum trailing dim size of16/32/...Main changes:
uint8containingfp6bits into 3 xuint8fp6-->bfloat16unpack-dequantise fused Triton kerneltorchcustom ops to call these kernels andFakeTensorshapes in order to supporttorch.compile