-
Notifications
You must be signed in to change notification settings - Fork 61
remove scaled_mm fallback #1746
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR removes the "_scaled_mm" fallback registration from the XPU backend, addressing duplicate registration issues as detailed in the PR description.
- Remove the "_scaled_mm" fallback entry to prevent duplicate registrations
- Align the XPU fallback registration with the updates in PyTorch core
|
Hi, let's file a separate PR to a temp branch such as We do these things to avoid breaking CI both at PyTorch and torch-xpu-ops. |
guangyey
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Let's merge this PR after pytorch/pytorch#140972 landed in case block internal CI.
This PR implements `scaled_mm` for XPU. It enables the following data types: 1. TensorWise Scaling: `fp8_e4m3` and `fp8_e5m2` 2. RowWise Scaling: `fp8_e4m3` and `fp8_e5m2` It leaves the BlockWise Scaling to next PR, so that it will have less reviewing efforts. This is the first PR that only adds `scaled_mm_xpu` but does not registered. We separate this out for less reviewing efforts. Secondly, there is a `scaled_mm_v2` API in #164141 . We will align with it once the v1 is cleaned up. **Co-author:** @yuchengliu1, @carsonwang ## PR stack: - -> #165978 : implementation of XPU scaled_mm and oneDNN kernel - #167518 : implementation of XPU scaled_mm_v2 - #166056 : Op registration ## Test Status: 1. Relies on the changes in intel/torch-xpu-ops#1746, Otherwise the op will fallback to CPU. 2. This PR does not include tests, the tests are enabled in #166056. ## Credit: This work is based on @yuchengliu1's work at #140972 . The purpose that we created a new PR is to align with the API / checks with CUDA, so there will be less porting efforts. ## FP8 Task tracker: We will track all the scaled_mm related tasks in: #167170 Pull Request resolved: #165978 Approved by: https://github.com/liangan1, https://github.com/EikanWang Co-authored-by: Eikan Wang <[email protected]>
|
Are we ready to merge this one as pytorch/pytorch#165978 was merged? @guangyey @Stonepia |
|
Hi @carsonwang , I think this could be merged, but since the op is not registered pytorch/pytorch#166056 , so I would like to delete the torch-xpu-ops fallback after that PR is merged. ETA would be within this week or no later than next week |
…h#165978) This PR implements `scaled_mm` for XPU. It enables the following data types: 1. TensorWise Scaling: `fp8_e4m3` and `fp8_e5m2` 2. RowWise Scaling: `fp8_e4m3` and `fp8_e5m2` It leaves the BlockWise Scaling to next PR, so that it will have less reviewing efforts. This is the first PR that only adds `scaled_mm_xpu` but does not registered. We separate this out for less reviewing efforts. Secondly, there is a `scaled_mm_v2` API in pytorch#164141 . We will align with it once the v1 is cleaned up. **Co-author:** @yuchengliu1, @carsonwang ## PR stack: - -> pytorch#165978 : implementation of XPU scaled_mm and oneDNN kernel - pytorch#167518 : implementation of XPU scaled_mm_v2 - pytorch#166056 : Op registration ## Test Status: 1. Relies on the changes in intel/torch-xpu-ops#1746, Otherwise the op will fallback to CPU. 2. This PR does not include tests, the tests are enabled in pytorch#166056. ## Credit: This work is based on @yuchengliu1's work at pytorch#140972 . The purpose that we created a new PR is to align with the API / checks with CUDA, so there will be less porting efforts. ## FP8 Task tracker: We will track all the scaled_mm related tasks in: pytorch#167170 Pull Request resolved: pytorch#165978 Approved by: https://github.com/liangan1, https://github.com/EikanWang Co-authored-by: Eikan Wang <[email protected]>
scaled_mm has supported in pytorch pytorch/pytorch#140972
This fallback will cause duplicate registration.
remove this fallback after pytorch/pytorch#140972 merged