-
Notifications
You must be signed in to change notification settings - Fork 52
[wip] Support premul_sum #1948
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[wip] Support premul_sum #1948
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for premultiplied sum operations (premul_sum
) in the XCCL distributed communication backend. The implementation includes version checking to ensure compatibility with oneCCL >= 2021.17, which is required for this feature.
- Adds
XCCLPreMulSumSupplement
struct to handle premul sum factors (both scalar and tensor) - Implements
makeXCCLPreMulSum
template function to create reduce operations with factors - Updates
getXcclReduceOp
function signature and adds PREMUL_SUM case handling - Adds comprehensive test coverage for premul sum operations in reduce and reduce_scatter scenarios
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
File | Description |
---|---|
test/xpu/distributed/test_c10d_ops_xccl.py | Fixes typos in class names and adds test cases for premul_sum operations |
src/xccl/ProcessGroupXCCL.hpp | Adds XCCLPreMulSumSupplement struct and makeXCCLPreMulSum template function |
src/xccl/ProcessGroupXCCL.cpp | Implements premul_sum support with version checking and updates function signatures |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
case ccl::datatype::float32: | ||
return unpackPreMulSum<float, ccl::datatype::float32>(reduceOp, comm); | ||
case ccl::datatype::bfloat16: | ||
return unpackPreMulSum<float, ccl::datatype::bfloat16>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For bfloat16 data type, the template should use at::BFloat16
instead of float
. Using float
for bfloat16 data will cause type mismatch issues when accessing the tensor data.
return unpackPreMulSum<float, ccl::datatype::bfloat16>( | |
return unpackPreMulSum<at::BFloat16, ccl::datatype::bfloat16>( |
Copilot uses AI. Check for mistakes.
ccl::reduction_destroy(op_, *comm_); | ||
} | ||
} | ||
#endif // ENABLE_XCCL_PREMUL_SUM_SUPPORT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The destructor is only defined when ENABLE_XCCL_PREMUL_SUM_SUPPORT
is defined, but the class is used regardless of this macro. This will cause linking errors when the macro is not defined. The destructor should be defined unconditionally with appropriate conditional logic inside.
#endif // ENABLE_XCCL_PREMUL_SUM_SUPPORT | |
~xcclRedOpRAII() { | |
#if defined(ENABLE_XCCL_PREMUL_SUM_SUPPORT) | |
if (premul_sum_) { | |
ccl::reduction_destroy(op_, *comm_); | |
} | |
#endif // ENABLE_XCCL_PREMUL_SUM_SUPPORT | |
} |
Copilot uses AI. Check for mistakes.
refer https://github.com/intel-innersource/libraries.performance.communication.oneccl/pull/3449
integrate premul_sum to xccl