From eda66077755faad7b85901b5e26680d834cfb429 Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Wed, 18 Jun 2025 21:54:43 +0800 Subject: [PATCH 1/7] add nan check for xccl --- src/xccl/NanCheck.cpp | 233 ++++++++++++++++++++++++++++++++++++++++++ src/xccl/NanCheck.hpp | 16 +++ 2 files changed, 249 insertions(+) create mode 100644 src/xccl/NanCheck.cpp create mode 100644 src/xccl/NanCheck.hpp diff --git a/src/xccl/NanCheck.cpp b/src/xccl/NanCheck.cpp new file mode 100644 index 000000000..de027d8ac --- /dev/null +++ b/src/xccl/NanCheck.cpp @@ -0,0 +1,233 @@ +#ifdef USE_C10D_XCCL + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10d { + +using BytePack = memory::aligned_vector; + +template +struct CheckBytePack { + static void check(BytePack* tmp) { + T* data = (T*)tmp; + #pragma unroll 8 + for (int i = 0; i < EltPerPack; i++) { + if (isnan(data[i])) assert(0); + } + } +}; + +template +struct CheckBytePack { + static void check(BytePack* tmp) { + T* data = (T*)tmp; + if (isnan(data[0]) || isnan(data[1])) assert(0); + } +}; + +template +struct CheckBytePack { + static void check(BytePack* tmp) { + T* data = (T*)tmp; + if (isnan(data[0]) || isnan(data[1]) || isnan(data[2]) || isnan(data[3])) assert(0); + } +}; + + +template +struct CheckBytePack { + static void check(BytePack* tmp) { + T* data = (T*)tmp; + if (isnan(data[0]) || isnan(data[1]) || isnan(data[2]) || isnan(data[3]) || + isnan(data[4]) || isnan(data[5]) || isnan(data[6]) || isnan(data[7])) { + assert(0); + } + } +}; + +template +struct HasNanFP8x8 { + static bool check(uint64_t fp8x8) = delete; + /* + { + // `static_assert` in template definition requires c++23 onwards. + // But the error message still applies if you find yourself here. + static_assert( + false, + "You should never call this template definition because it is empty. You " + "can follow the example of Float8_e4m3fn below to implement the check for " + "your new datatype." + ); + } + */ +}; + +template<> +struct HasNanFP8x8 { + static bool check(uint64_t fp8x8) { + auto t = fp8x8 & 0x7F7F7F7F7F7F7F7FULL; + auto incremented = t + 0x0101010101010101ULL; + auto overflow = incremented & 0x8080808080808080ULL; + return overflow != 0; + } +}; + +template<> +struct HasNanFP8x8 { + static bool check(uint64_t fp8x8) { + auto t = fp8x8 & 0x7F7F7F7F7F7F7F7FULL; + auto incremented = t + 0x0303030303030303ULL; + auto overflow = incremented & 0x8080808080808080ULL; + return overflow != 0; + } +}; + +template +struct CheckBytePack { + static void check(BytePack* tmp) { + if (HasNanFP8x8::check(tmp->val[0]) || HasNanFP8x8::check(tmp->val[1])) + assert(0); + } +}; + + +#define UNROLL 8 + +template +void checkChunk(BytePack* ptr, int nWorkers) { + BytePack tmp[UNROLL]; + // int nWorkers = num_group_ * max_group_size_; + #pragma unroll 8 + for (int j = 0; j < UNROLL; j++) { + tmp[j] = ptr[nWorkers * j]; + } + // Then check each BytePack in the tmp buffer + #pragma unroll 8 + for (int j = 0; j < UNROLL; j++) { + CheckBytePack::check(tmp + j); + } + // Note: we separate the check from the load for efficient loading +} + +// Align address of `ptr` up, to the alignment of `T` +#define ALIGN_UP(ptr, T) (((uintptr_t)ptr + sizeof(T) - 1) / sizeof(T) * sizeof(T)) + + +template +struct checkForNaN { + void operator()(sycl::nd_item<1> item) const { + constexpr int EltPerPack = sizeof(BytePack) / sizeof(T); + + size_t offset = item.get_global_id()[2]; + + // Align input address up to BytePack in case it is not + T* ptrAlign = (T*)ALIGN_UP(data, BytePack); + // Pre-process the data before alignment + size_t preProcElts = min(ptrAlign - data, size); + // Read memory by T (slow). One iter is enough bc the number of threads would + // be bigger than `preProcElts` + if (offset < preProcElts) { + if (isnan(data[offset])) assert(0); + } + // We have processes this amount of data + size -= preProcElts; + + // Start BytePack processing + BytePack* ptr = (BytePack*)ptrAlign; + // Size of input data in unit of BytePack + size_t sizeInBP = size * sizeof(T) / sizeof(BytePack); + // Number of BytePacks processed in one fast-path iteration + size_t loopSize = num_group_ * max_group_size_ * UNROLL; + + // Fast path + // The condition below makes sure there is enough data to process (`loopSize`) + for (; offset + loopSize <= sizeInBP; offset += loopSize) { + checkChunk(ptr + offset, num_group_ * max_group_size_); + } + + // The rest data goes on slow path + // We just do regular load and check + for (; offset < sizeInBP; offset += num_group_ * max_group_size_) { + BytePack tmp = ptr[offset]; + CheckBytePack::check(&tmp); + } + + // We can still have a tail smaller than 1 BytePack + // TODO: merge this tail check with head check to make them concurrent + if (item.get_local_id(1) < size % EltPerPack) { + T* tailPtr = (T*)(ptr + sizeInBP); + if (isnan(tailPtr[item.get_local_id(1)])) assert(0); + } + } + checkForNaN(T* data, size_t size, int64_t num_group, int64_t max_group_size) + : data(data), size(size), num_group_(num_group), max_group_size_(max_group_size) {} + private: + T* data; + size_t size; + int64_t num_group_; + int64_t max_group_size_; +} + +template +void checkfornan_impl_xpu( + const at::Tensor& tensor, + at::xpu::XPUStream& stream) { + // skip check for non float types + if (!torch::is_floating_point(tensor)) { + return; + } + + int64_t maxNumThreadsPerBlock = syclMaxWorkGroupSize>(); + + const size_t numThreadsPerBlock = + std::min(maxNumThreadsPerBlock, tensor.numel()); + + if (!(numThreadsPerBlock > 0)) { + return; + } + + int64_t numBlocks = (tensor.numel() + maxNumThreadsPerBlock - 1) / maxNumThreadsPerBlock; + auto global_range{numBlocks * maxNumThreadsPerBlock}; + auto local_range{maxNumThreadsPerBlock}; + + using Kernel = checkForNaN; + auto knf = Kernel( + tensor.data_ptr(), + tensor.numel(), + numBlocks, + maxNumThreadsPerBlock); + + sycl_kernel_submit(global_range, local_range, stream.queue(), knf); +} + + + +// CHECK if a Tensor contains NAN in any of its element +void checkForNan(const at::Tensor& tensor, at::xpu::XPUStream& stream) { + + AT_DISPATCH_FLOATING_TYPES_AND4( + at::ScalarType::Half, + at::ScalarType::BFloat16, + at::ScalarType::Float8_e4m3fn, + at::ScalarType::Float8_e5m2, + tensor.scalar_type(), + "checkForNaN_XPU", + [&] () { + checkfornan_impl_xpu( + tensor, + stream); + }); +} + +} // namespace c10d + +#endif // USE_C10D_XCCL + + diff --git a/src/xccl/NanCheck.hpp b/src/xccl/NanCheck.hpp new file mode 100644 index 000000000..0cfa9a2af --- /dev/null +++ b/src/xccl/NanCheck.hpp @@ -0,0 +1,16 @@ +#pragma once + +#ifdef USE_C10D_XCCL + +#include +#include + +namespace c10d { + +void checkForNan(const at::Tensor& tensor, c10::xpu::XPUStream& stream); + +} // namespace c10d + +#endif // USE_C10D_XCCL + + From 595a1dc0d4bf09e5c4e00cf4483d8218be1e5fcb Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Wed, 18 Jun 2025 23:07:02 +0800 Subject: [PATCH 2/7] cmake and format --- src/xccl/CMakeLists.txt | 3 + src/xccl/NanCheck.cpp | 131 ++++++++++++++++++---------------------- src/xccl/NanCheck.hpp | 2 - 3 files changed, 62 insertions(+), 74 deletions(-) diff --git a/src/xccl/CMakeLists.txt b/src/xccl/CMakeLists.txt index f147b55ca..df1210a06 100644 --- a/src/xccl/CMakeLists.txt +++ b/src/xccl/CMakeLists.txt @@ -2,10 +2,13 @@ file(GLOB xccl_h "*.hpp") file(GLOB xccl_cpp "*.cpp") +list(REMOVE_ITEM xccl_cpp "${CMAKE_CURRENT_SOURCE_DIR}/NanCheck.cpp") list(APPEND ATen_XPU_XCCL_SRCS ${xccl_cpp}) +list(APPEND ATen_XPU_SYCL_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/NanCheck.cpp") set(ATen_XPU_XCCL_SRCS ${ATen_XPU_XCCL_SRCS} PARENT_SCOPE) +set(ATen_XPU_SYCL_SRCS ${ATen_XPU_SYCL_SRCS} PARENT_SCOPE) # Why copy the header file to the build directory? # We want register XCCL backend to PyTorch c10d in torch/csrc/distributed/c10d/init.cpp#L27-L29. diff --git a/src/xccl/NanCheck.cpp b/src/xccl/NanCheck.cpp index de027d8ac..1969e3d3d 100644 --- a/src/xccl/NanCheck.cpp +++ b/src/xccl/NanCheck.cpp @@ -1,58 +1,60 @@ #ifdef USE_C10D_XCCL #include +#include #include -#include +#include +#include #include -#include #include -#include -#include +#include namespace c10d { -using BytePack = memory::aligned_vector; +using BytePack = at::native::memory::aligned_vector; template struct CheckBytePack { static void check(BytePack* tmp) { T* data = (T*)tmp; - #pragma unroll 8 +#pragma unroll 8 for (int i = 0; i < EltPerPack; i++) { - if (isnan(data[i])) assert(0); + if (isnan(data[i])) + assert(0); } } }; template -struct CheckBytePack { +struct CheckBytePack { static void check(BytePack* tmp) { T* data = (T*)tmp; - if (isnan(data[0]) || isnan(data[1])) assert(0); + if (isnan(data[0]) || isnan(data[1])) + assert(0); } }; template -struct CheckBytePack { +struct CheckBytePack { static void check(BytePack* tmp) { T* data = (T*)tmp; - if (isnan(data[0]) || isnan(data[1]) || isnan(data[2]) || isnan(data[3])) assert(0); + if (isnan(data[0]) || isnan(data[1]) || isnan(data[2]) || isnan(data[3])) + assert(0); } }; - template -struct CheckBytePack { +struct CheckBytePack { static void check(BytePack* tmp) { T* data = (T*)tmp; if (isnan(data[0]) || isnan(data[1]) || isnan(data[2]) || isnan(data[3]) || isnan(data[4]) || isnan(data[5]) || isnan(data[6]) || isnan(data[7])) { - assert(0); + assert(0); } } }; -template +template struct HasNanFP8x8 { static bool check(uint64_t fp8x8) = delete; /* @@ -62,14 +64,14 @@ struct HasNanFP8x8 { static_assert( false, "You should never call this template definition because it is empty. You " - "can follow the example of Float8_e4m3fn below to implement the check for " - "your new datatype." + "can follow the example of Float8_e4m3fn below to implement the check for + " "your new datatype." ); } */ }; -template<> +template <> struct HasNanFP8x8 { static bool check(uint64_t fp8x8) { auto t = fp8x8 & 0x7F7F7F7F7F7F7F7FULL; @@ -79,7 +81,7 @@ struct HasNanFP8x8 { } }; -template<> +template <> struct HasNanFP8x8 { static bool check(uint64_t fp8x8) { auto t = fp8x8 & 0x7F7F7F7F7F7F7F7FULL; @@ -89,36 +91,36 @@ struct HasNanFP8x8 { } }; -template -struct CheckBytePack { +template +struct CheckBytePack { static void check(BytePack* tmp) { - if (HasNanFP8x8::check(tmp->val[0]) || HasNanFP8x8::check(tmp->val[1])) - assert(0); + if (HasNanFP8x8::check(tmp->val[0]) || + HasNanFP8x8::check(tmp->val[1])) + assert(0); } }; - #define UNROLL 8 template void checkChunk(BytePack* ptr, int nWorkers) { BytePack tmp[UNROLL]; - // int nWorkers = num_group_ * max_group_size_; - #pragma unroll 8 + +#pragma unroll 8 for (int j = 0; j < UNROLL; j++) { tmp[j] = ptr[nWorkers * j]; } - // Then check each BytePack in the tmp buffer - #pragma unroll 8 +// Then check each BytePack in the tmp buffer +#pragma unroll 8 for (int j = 0; j < UNROLL; j++) { - CheckBytePack::check(tmp + j); + CheckBytePack::check(tmp + j); } // Note: we separate the check from the load for efficient loading } // Align address of `ptr` up, to the alignment of `T` -#define ALIGN_UP(ptr, T) (((uintptr_t)ptr + sizeof(T) - 1) / sizeof(T) * sizeof(T)) - +#define ALIGN_UP(ptr, T) \ + (((uintptr_t)ptr + sizeof(T) - 1) / sizeof(T) * sizeof(T)) template struct checkForNaN { @@ -129,51 +131,47 @@ struct checkForNaN { // Align input address up to BytePack in case it is not T* ptrAlign = (T*)ALIGN_UP(data, BytePack); - // Pre-process the data before alignment size_t preProcElts = min(ptrAlign - data, size); - // Read memory by T (slow). One iter is enough bc the number of threads would - // be bigger than `preProcElts` + + size_t size_left = size; + if (offset < preProcElts) { - if (isnan(data[offset])) assert(0); + if (isnan(data[offset])) + assert(0); } - // We have processes this amount of data - size -= preProcElts; + size_left -= preProcElts; - // Start BytePack processing BytePack* ptr = (BytePack*)ptrAlign; - // Size of input data in unit of BytePack - size_t sizeInBP = size * sizeof(T) / sizeof(BytePack); - // Number of BytePacks processed in one fast-path iteration + size_t sizeInBP = size_left * sizeof(T) / sizeof(BytePack); size_t loopSize = num_group_ * max_group_size_ * UNROLL; - // Fast path - // The condition below makes sure there is enough data to process (`loopSize`) for (; offset + loopSize <= sizeInBP; offset += loopSize) { checkChunk(ptr + offset, num_group_ * max_group_size_); } - // The rest data goes on slow path - // We just do regular load and check for (; offset < sizeInBP; offset += num_group_ * max_group_size_) { BytePack tmp = ptr[offset]; CheckBytePack::check(&tmp); } - // We can still have a tail smaller than 1 BytePack - // TODO: merge this tail check with head check to make them concurrent - if (item.get_local_id(1) < size % EltPerPack) { + if (item.get_local_id(1) < size_left % EltPerPack) { T* tailPtr = (T*)(ptr + sizeInBP); - if (isnan(tailPtr[item.get_local_id(1)])) assert(0); + if (isnan(tailPtr[item.get_local_id(1)])) + assert(0); } } checkForNaN(T* data, size_t size, int64_t num_group, int64_t max_group_size) - : data(data), size(size), num_group_(num_group), max_group_size_(max_group_size) {} - private: - T* data; - size_t size; - int64_t num_group_; - int64_t max_group_size_; -} + : data(data), + size(size), + num_group_(num_group), + max_group_size_(max_group_size) {} + + private: + T* data; + size_t size; + int64_t num_group_; + int64_t max_group_size_; +}; template void checkfornan_impl_xpu( @@ -187,31 +185,26 @@ void checkfornan_impl_xpu( int64_t maxNumThreadsPerBlock = syclMaxWorkGroupSize>(); const size_t numThreadsPerBlock = - std::min(maxNumThreadsPerBlock, tensor.numel()); + std::min(maxNumThreadsPerBlock, tensor.numel()); if (!(numThreadsPerBlock > 0)) { return; } - int64_t numBlocks = (tensor.numel() + maxNumThreadsPerBlock - 1) / maxNumThreadsPerBlock; + int64_t numBlocks = + (tensor.numel() + maxNumThreadsPerBlock - 1) / maxNumThreadsPerBlock; auto global_range{numBlocks * maxNumThreadsPerBlock}; auto local_range{maxNumThreadsPerBlock}; using Kernel = checkForNaN; auto knf = Kernel( - tensor.data_ptr(), - tensor.numel(), - numBlocks, - maxNumThreadsPerBlock); + tensor.data_ptr(), tensor.numel(), numBlocks, maxNumThreadsPerBlock); sycl_kernel_submit(global_range, local_range, stream.queue(), knf); } - - // CHECK if a Tensor contains NAN in any of its element void checkForNan(const at::Tensor& tensor, at::xpu::XPUStream& stream) { - AT_DISPATCH_FLOATING_TYPES_AND4( at::ScalarType::Half, at::ScalarType::BFloat16, @@ -219,15 +212,9 @@ void checkForNan(const at::Tensor& tensor, at::xpu::XPUStream& stream) { at::ScalarType::Float8_e5m2, tensor.scalar_type(), "checkForNaN_XPU", - [&] () { - checkfornan_impl_xpu( - tensor, - stream); - }); + [&]() { checkfornan_impl_xpu(tensor, stream); }); } } // namespace c10d #endif // USE_C10D_XCCL - - diff --git a/src/xccl/NanCheck.hpp b/src/xccl/NanCheck.hpp index 0cfa9a2af..3bf4cd96d 100644 --- a/src/xccl/NanCheck.hpp +++ b/src/xccl/NanCheck.hpp @@ -12,5 +12,3 @@ void checkForNan(const at::Tensor& tensor, c10::xpu::XPUStream& stream); } // namespace c10d #endif // USE_C10D_XCCL - - From aef4be31d6ecf51535f22cf78300692c8d64f808 Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Thu, 19 Jun 2025 18:28:00 +0800 Subject: [PATCH 3/7] add nan check --- src/xccl/CMakeLists.txt | 4 +-- src/xccl/{NanCheck.cpp => NanCheck_XPU.cpp} | 2 +- src/xccl/{NanCheck.hpp => NanCheck_XPU.hpp} | 2 +- src/xccl/ProcessGroupXCCL.cpp | 35 ++++++++++++++++--- src/xccl/ProcessGroupXCCL.hpp | 37 ++++++++++++++++----- 5 files changed, 63 insertions(+), 17 deletions(-) rename src/xccl/{NanCheck.cpp => NanCheck_XPU.cpp} (99%) rename src/xccl/{NanCheck.hpp => NanCheck_XPU.hpp} (67%) diff --git a/src/xccl/CMakeLists.txt b/src/xccl/CMakeLists.txt index df1210a06..74ece226c 100644 --- a/src/xccl/CMakeLists.txt +++ b/src/xccl/CMakeLists.txt @@ -2,10 +2,10 @@ file(GLOB xccl_h "*.hpp") file(GLOB xccl_cpp "*.cpp") -list(REMOVE_ITEM xccl_cpp "${CMAKE_CURRENT_SOURCE_DIR}/NanCheck.cpp") +list(REMOVE_ITEM xccl_cpp "${CMAKE_CURRENT_SOURCE_DIR}/NanCheck_XPU.cpp") list(APPEND ATen_XPU_XCCL_SRCS ${xccl_cpp}) -list(APPEND ATen_XPU_SYCL_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/NanCheck.cpp") +list(APPEND ATen_XPU_SYCL_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/NanCheck_XPU.cpp") set(ATen_XPU_XCCL_SRCS ${ATen_XPU_XCCL_SRCS} PARENT_SCOPE) set(ATen_XPU_SYCL_SRCS ${ATen_XPU_SYCL_SRCS} PARENT_SCOPE) diff --git a/src/xccl/NanCheck.cpp b/src/xccl/NanCheck_XPU.cpp similarity index 99% rename from src/xccl/NanCheck.cpp rename to src/xccl/NanCheck_XPU.cpp index 1969e3d3d..2a982101d 100644 --- a/src/xccl/NanCheck.cpp +++ b/src/xccl/NanCheck_XPU.cpp @@ -6,7 +6,7 @@ #include #include #include -#include +#include #include namespace c10d { diff --git a/src/xccl/NanCheck.hpp b/src/xccl/NanCheck_XPU.hpp similarity index 67% rename from src/xccl/NanCheck.hpp rename to src/xccl/NanCheck_XPU.hpp index 3bf4cd96d..8c8222f36 100644 --- a/src/xccl/NanCheck.hpp +++ b/src/xccl/NanCheck_XPU.hpp @@ -7,7 +7,7 @@ namespace c10d { -void checkForNan(const at::Tensor& tensor, c10::xpu::XPUStream& stream); +void checkForNan(const at::Tensor& tensor, at::xpu::XPUStream& stream); } // namespace c10d diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index 0b16c348c..0488a6b88 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -1,6 +1,7 @@ #ifdef USE_C10D_XCCL #include +#include #include namespace c10d { @@ -338,6 +339,7 @@ ProcessGroupXCCL::ProcessGroupXCCL( local_id_(process_group_id++) { logPrefix_ = createLogPrefix(); blockingWait_ = getCvarBool(TORCH_XCCL_BLOCKING_WAIT, false); + enableNanCheck_ = getCvarBool(TORCH_XCCL_NAN_CHECK, false); init(); const std::string OFF = "OFF"; std::string torch_distributed_debug = @@ -349,7 +351,8 @@ ProcessGroupXCCL::ProcessGroupXCCL( LOG(INFO) << logPrefix() << "ProcessGroupXCCL environments: " << "XCCL version: " << XcclVersion << ", TORCH_XCCL_BLOCKING_WAIT: " << blockingWait_ - << ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug; + << ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug + << ", TORCH_XCCL_NAN_CHECK: " << enableNanCheck_; } ProcessGroupXCCL::~ProcessGroupXCCL() = default; @@ -360,6 +363,10 @@ uint64_t ProcessGroupXCCL::getSequenceNumberForGroup() { return seqCollective_; } +void ProcessGroupXCCL::setEnableNanCheck(bool enableNanCheck) { + enableNanCheck_ = enableNanCheck; +} + c10::intrusive_ptr ProcessGroupXCCL::initWork( at::Device& device, int rank, @@ -553,7 +560,9 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( PostProcess post, OpType opType, bool asyncOp, - const char* profilingTitle) { + const char* profilingTitle, + bool nanCheck) { + nanCheck &= enableNanCheck_; seqCollective_++; auto device = inputs[0].device(); const auto key = std::to_string(device.index()); @@ -620,6 +629,12 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( c10::OptionalDeviceGuard gpuGuard(device); + if (nanCheck) { + for (const auto& input : inputs) { + checkForNan(input, stream); + } + } + pre(stream, work); for (const auto i : c10::irange(inputs.size())) { @@ -697,6 +712,10 @@ c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( auto cclstream = xcclStreamsMap_.at(key).second; syncStream(device, xcclEventsMap_[key], stream); + if (enableNanCheck_ && opType == OpType::SEND) { + checkForNan(tensor, stream); + } + if (!coalescing_state_) { auto work = initWork(device, rank_, opType, true, profilingTitle, {tensor}, {}); @@ -1006,6 +1025,7 @@ c10::intrusive_ptr ProcessGroupXCCL::scatter( this->getSize()); // worldSize const auto root = opts.rootRank; + bool nanCheck = (rank_ == root); auto outputs = std::vector{outputTensor}; return collective( @@ -1059,7 +1079,8 @@ c10::intrusive_ptr ProcessGroupXCCL::scatter( }, OpType::SCATTER, opts.asyncOp, - "xccl:scatter"); + "xccl:scatter", + nanCheck); } c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( @@ -1222,6 +1243,7 @@ c10::intrusive_ptr ProcessGroupXCCL::broadcast( this->getSize()); // worldSize const auto root = opts.rootRank + opts.rootTensor; + bool nanCheck = (root == rank_); return collective( tensor, @@ -1243,7 +1265,8 @@ c10::intrusive_ptr ProcessGroupXCCL::broadcast( }, OpType::BROADCAST, opts.asyncOp, - "xccl:broadcast"); + "xccl:broadcast", + nanCheck); } c10::intrusive_ptr ProcessGroupXCCL::_broadcast_oop( @@ -1256,6 +1279,7 @@ c10::intrusive_ptr ProcessGroupXCCL::_broadcast_oop( "Tensor input and output of _broadcast_oop must have the same number of elements "); } const auto root = opts.rootRank + opts.rootTensor; + bool nanCheck = (root == rank_); return collective( inputTensor, outputTensor, @@ -1277,7 +1301,8 @@ c10::intrusive_ptr ProcessGroupXCCL::_broadcast_oop( }, OpType::BROADCAST, opts.asyncOp, - "xccl:_broadcast_oop"); + "xccl:_broadcast_oop", + nanCheck); } c10::intrusive_ptr ProcessGroupXCCL::reduce( diff --git a/src/xccl/ProcessGroupXCCL.hpp b/src/xccl/ProcessGroupXCCL.hpp index 8fc765f69..ebf8f837a 100644 --- a/src/xccl/ProcessGroupXCCL.hpp +++ b/src/xccl/ProcessGroupXCCL.hpp @@ -27,6 +27,9 @@ static std::vector TORCH_XCCL_BLOCKING_WAIT = { "XCCL_BLOCKING_WAIT"}; using xcclComm_t = ccl::communicator; + +static std::vector TORCH_XCCL_NAN_CHECK = {"TORCH_XCCL_NAN_CHECK"}; + constexpr const char* XCCL_BACKEND_NAME = "xccl"; class TensorShelf { @@ -153,7 +156,8 @@ class TORCH_API ProcessGroupXCCL : public Backend { Fn fn, OpType opType, bool asyncOp, - const char* profilingTitle = nullptr) { + const char* profilingTitle = nullptr, + bool nanCheck = true) { return collective( input, output, @@ -164,7 +168,8 @@ class TORCH_API ProcessGroupXCCL : public Backend { c10::intrusive_ptr&) {}, opType, asyncOp, - profilingTitle); + profilingTitle, + nanCheck); } template @@ -176,11 +181,20 @@ class TORCH_API ProcessGroupXCCL : public Backend { PostProcess post, OpType opType, bool asyncOp, - const char* profilingTitle = nullptr) { + const char* profilingTitle = nullptr, + bool nanCheck = true) { auto inputs = std::vector{input}; auto outputs = std::vector{output}; return collective( - inputs, outputs, fn, pre, post, opType, asyncOp, profilingTitle); + inputs, + outputs, + fn, + pre, + post, + opType, + asyncOp, + profilingTitle, + nanCheck); } template @@ -190,7 +204,8 @@ class TORCH_API ProcessGroupXCCL : public Backend { Fn fn, OpType opType, bool asyncOp, - const char* profilingTitle = nullptr) { + const char* profilingTitle = nullptr, + bool nanCheck = true) { return collective( inputs, outputs, @@ -201,7 +216,8 @@ class TORCH_API ProcessGroupXCCL : public Backend { c10::intrusive_ptr&) {}, opType, asyncOp, - profilingTitle); + profilingTitle, + nanCheck); } template @@ -213,7 +229,8 @@ class TORCH_API ProcessGroupXCCL : public Backend { PostProcess post, OpType opType, bool asyncOp, - const char* profilingTitle = nullptr); + const char* profilingTitle = nullptr, + bool nanCheck = true); template c10::intrusive_ptr collectiveCoalesced( @@ -247,7 +264,8 @@ class TORCH_API ProcessGroupXCCL : public Backend { }, opType, asyncOp, - profilingTitle); + profilingTitle, + /*nanCheck =*/false); } template @@ -367,6 +385,8 @@ class TORCH_API ProcessGroupXCCL : public Backend { const std::string& logPrefix() const; + void setEnableNanCheck(bool enableNanCheck); + protected: std::unordered_map> xcclStreamsMap_; @@ -387,6 +407,7 @@ class TORCH_API ProcessGroupXCCL : public Backend { uint64_t seqP2P_{0}; size_t local_id_; std::string logPrefix_; + bool enableNanCheck_; private: std::mutex kvs_mutex; From 38c1c107f5cb793cf306a65d146c12f49a559005 Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Fri, 20 Jun 2025 16:25:38 +0800 Subject: [PATCH 4/7] update --- src/xccl/NanCheck_XPU.cpp | 42 +++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/src/xccl/NanCheck_XPU.cpp b/src/xccl/NanCheck_XPU.cpp index 2a982101d..3f6f88f98 100644 --- a/src/xccl/NanCheck_XPU.cpp +++ b/src/xccl/NanCheck_XPU.cpp @@ -1,9 +1,8 @@ -#ifdef USE_C10D_XCCL - #include +#include #include #include -#include +#include #include #include #include @@ -19,7 +18,7 @@ struct CheckBytePack { T* data = (T*)tmp; #pragma unroll 8 for (int i = 0; i < EltPerPack; i++) { - if (isnan(data[i])) + if (at::_isnan(data[i])) assert(0); } } @@ -29,7 +28,7 @@ template struct CheckBytePack { static void check(BytePack* tmp) { T* data = (T*)tmp; - if (isnan(data[0]) || isnan(data[1])) + if (at::_isnan(data[0]) || at::_isnan(data[1])) assert(0); } }; @@ -38,7 +37,8 @@ template struct CheckBytePack { static void check(BytePack* tmp) { T* data = (T*)tmp; - if (isnan(data[0]) || isnan(data[1]) || isnan(data[2]) || isnan(data[3])) + if (at::_isnan(data[0]) || at::_isnan(data[1]) || at::_isnan(data[2]) || + at::_isnan(data[3])) assert(0); } }; @@ -47,8 +47,9 @@ template struct CheckBytePack { static void check(BytePack* tmp) { T* data = (T*)tmp; - if (isnan(data[0]) || isnan(data[1]) || isnan(data[2]) || isnan(data[3]) || - isnan(data[4]) || isnan(data[5]) || isnan(data[6]) || isnan(data[7])) { + if (at::_isnan(data[0]) || at::_isnan(data[1]) || at::_isnan(data[2]) || + at::_isnan(data[3]) || at::_isnan(data[4]) || at::_isnan(data[5]) || + at::_isnan(data[6]) || at::_isnan(data[7])) { assert(0); } } @@ -127,36 +128,37 @@ struct checkForNaN { void operator()(sycl::nd_item<1> item) const { constexpr int EltPerPack = sizeof(BytePack) / sizeof(T); - size_t offset = item.get_global_id()[2]; + size_t offset = item.get_global_id(0); // Align input address up to BytePack in case it is not T* ptrAlign = (T*)ALIGN_UP(data, BytePack); - size_t preProcElts = min(ptrAlign - data, size); + size_t preProcElts = + std::min(static_cast(ptrAlign - data), size); size_t size_left = size; if (offset < preProcElts) { - if (isnan(data[offset])) + if (at::_isnan(data[offset])) assert(0); } size_left -= preProcElts; BytePack* ptr = (BytePack*)ptrAlign; size_t sizeInBP = size_left * sizeof(T) / sizeof(BytePack); - size_t loopSize = num_group_ * max_group_size_ * UNROLL; + size_t loopSize = item.get_global_range(0) * UNROLL; for (; offset + loopSize <= sizeInBP; offset += loopSize) { - checkChunk(ptr + offset, num_group_ * max_group_size_); + checkChunk(ptr + offset, item.get_global_range(0)); } - for (; offset < sizeInBP; offset += num_group_ * max_group_size_) { + for (; offset < sizeInBP; offset += item.get_global_range(0)) { BytePack tmp = ptr[offset]; CheckBytePack::check(&tmp); } - if (item.get_local_id(1) < size_left % EltPerPack) { + if (item.get_local_id(0) < size_left % EltPerPack) { T* tailPtr = (T*)(ptr + sizeInBP); - if (isnan(tailPtr[item.get_local_id(1)])) + if (at::_isnan(tailPtr[item.get_local_id(0)])) assert(0); } } @@ -182,7 +184,7 @@ void checkfornan_impl_xpu( return; } - int64_t maxNumThreadsPerBlock = syclMaxWorkGroupSize>(); + int64_t maxNumThreadsPerBlock = syclMaxWorkGroupSize>(); const size_t numThreadsPerBlock = std::min(maxNumThreadsPerBlock, tensor.numel()); @@ -197,10 +199,10 @@ void checkfornan_impl_xpu( auto local_range{maxNumThreadsPerBlock}; using Kernel = checkForNaN; - auto knf = Kernel( + auto kfn = Kernel( tensor.data_ptr(), tensor.numel(), numBlocks, maxNumThreadsPerBlock); - sycl_kernel_submit(global_range, local_range, stream.queue(), knf); + sycl_kernel_submit(global_range, local_range, stream.queue(), kfn); } // CHECK if a Tensor contains NAN in any of its element @@ -216,5 +218,3 @@ void checkForNan(const at::Tensor& tensor, at::xpu::XPUStream& stream) { } } // namespace c10d - -#endif // USE_C10D_XCCL From 1a7febb3847bc49898b5727b62766f3c144383ee Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Mon, 23 Jun 2025 15:18:20 +0800 Subject: [PATCH 5/7] add test case --- test/xpu/distributed/test_c10d_xccl.py | 79 ++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/test/xpu/distributed/test_c10d_xccl.py b/test/xpu/distributed/test_c10d_xccl.py index 0625a6993..916524073 100644 --- a/test/xpu/distributed/test_c10d_xccl.py +++ b/test/xpu/distributed/test_c10d_xccl.py @@ -2,6 +2,8 @@ import math import os +import random +import signal import sys import time from datetime import timedelta @@ -19,6 +21,9 @@ import torch.testing._internal.common_utils as common from torch.testing._internal.common_distributed import MultiProcessTestCase from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + IS_SANDCASTLE, + parametrize, retry_on_connect_failures, run_tests, skip_but_pass_in_sandcastle_if, @@ -210,6 +215,15 @@ def _create_process_group_xccl( def setUp(self): super().setUp() + TEST_NAN_ASSERT_RETURN = 0 if IS_SANDCASTLE else -signal.SIGABRT + self.special_return_code_checks = { + self.test_nan_assert_float16.__wrapped__: TEST_NAN_ASSERT_RETURN, + self.test_nan_assert_float32.__wrapped__: TEST_NAN_ASSERT_RETURN, + self.test_nan_assert_float64.__wrapped__: TEST_NAN_ASSERT_RETURN, + self.test_nan_assert_bfloat16.__wrapped__: TEST_NAN_ASSERT_RETURN, + self.test_nan_assert_float8_e4m3fn.__wrapped__: TEST_NAN_ASSERT_RETURN, + self.test_nan_assert_float8_e5m2.__wrapped__: TEST_NAN_ASSERT_RETURN, + } self._spawn_processes() def tearDown(self): @@ -288,6 +302,68 @@ def test_set_process_group_desc(self): pg_2 = c10d.new_group([0, 1]) self.assertEqual(pg_2.group_desc, "undefined") + @requires_xccl() + @parametrize( + "type", + [ + torch.float16, + torch.float32, + torch.float64, + torch.bfloat16, + torch.float8_e4m3fn, + torch.float8_e5m2, + ], + ) + def test_nan_assert(self, type): + # Expecting a device-side error when NaN is detected + os.environ["TORCH_XCCL_NAN_CHECK"] = "1" + pg = self._create_process_group_xccl() + device = self.rank_to_GPU[self.rank][0] + # Cover different buffer sizes + if type == torch.float64: + size = (1024,) # 1K elements + elif type == torch.float32: + size = (1024, 1024) # 1M elements + elif type == torch.float16: + size = (1024, 1024, 1024) # 1G elements + else: + size = (1,) # 1 element + + # Note: currently we cannot fill values into a FP8 tensor, thus we + # create the NaN tensor in float32 type and cast it to FP8 + if type == torch.float8_e4m3fn or type == torch.float8_e5m2: + init_type = torch.float32 + else: + init_type = type + + nan_tensor = torch.zeros(*size, dtype=init_type, device=device) + # randomly pick an nan element + index = tuple([random.randrange(size[i]) for i in range(len(size))]) + nan_tensor[index] = float("nan") + if init_type != type: + # Now cast to the targeted dtype + nan_tensor = nan_tensor.to(type) + + output = torch.empty(self.world_size, *size, dtype=type, device=device) + + # # confirm enable/disable flag works + # backend._set_enable_nan_check(False) + # # Note: using all-gather here bc some NCCL/SM version does not support + # # FP8 reduction + # # temporarily skip due to https://github.com/pytorch/pytorch/issues/153479 + # # pg._allgather_base(output, nan_tensor) + + # backend._set_enable_nan_check(True) + try: + pg._allgather_base(output, nan_tensor) + except Exception: + sys.exit(signal.SIGABRT) + + dist.destroy_process_group() + + # reset env + os.environ["TORCH_XCCL_NAN_CHECK"] = "0" + class CommTest(MultiProcessTestCase): @property @@ -551,6 +627,9 @@ def test_all_gather_into_tensor(self): ) +instantiate_parametrized_tests(ProcessGroupXCCLTest) + + class SetDeviceMethod(Enum): TORCH_XPU_SET = auto() # torch.xpu.set_device COLLECTIVE_ARGUMENT = auto() # broadcast_object_list(device=) From 1b5d5be8852040832a3e18fb7dc6fca2fd584dc2 Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Wed, 25 Jun 2025 13:49:39 +0800 Subject: [PATCH 6/7] clean code --- src/xccl/NanCheck_XPU.cpp | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/xccl/NanCheck_XPU.cpp b/src/xccl/NanCheck_XPU.cpp index 3f6f88f98..1f4f81df1 100644 --- a/src/xccl/NanCheck_XPU.cpp +++ b/src/xccl/NanCheck_XPU.cpp @@ -162,17 +162,11 @@ struct checkForNaN { assert(0); } } - checkForNaN(T* data, size_t size, int64_t num_group, int64_t max_group_size) - : data(data), - size(size), - num_group_(num_group), - max_group_size_(max_group_size) {} + checkForNaN(T* data, size_t size) : data(data), size(size) {} private: T* data; size_t size; - int64_t num_group_; - int64_t max_group_size_; }; template @@ -199,8 +193,7 @@ void checkfornan_impl_xpu( auto local_range{maxNumThreadsPerBlock}; using Kernel = checkForNaN; - auto kfn = Kernel( - tensor.data_ptr(), tensor.numel(), numBlocks, maxNumThreadsPerBlock); + auto kfn = Kernel(tensor.data_ptr(), tensor.numel()); sycl_kernel_submit(global_range, local_range, stream.queue(), kfn); } From 62263aa99d497ebe20b42f24585f6f92068599ab Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Mon, 30 Jun 2025 14:25:42 +0800 Subject: [PATCH 7/7] update --- src/xccl/NanCheck_XPU.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/xccl/NanCheck_XPU.cpp b/src/xccl/NanCheck_XPU.cpp index 1f4f81df1..341e41610 100644 --- a/src/xccl/NanCheck_XPU.cpp +++ b/src/xccl/NanCheck_XPU.cpp @@ -188,9 +188,9 @@ void checkfornan_impl_xpu( } int64_t numBlocks = - (tensor.numel() + maxNumThreadsPerBlock - 1) / maxNumThreadsPerBlock; - auto global_range{numBlocks * maxNumThreadsPerBlock}; - auto local_range{maxNumThreadsPerBlock}; + (tensor.numel() + numThreadsPerBlock - 1) / numThreadsPerBlock; + auto global_range{numBlocks * numThreadsPerBlock}; + auto local_range{numThreadsPerBlock}; using Kernel = checkForNaN; auto kfn = Kernel(tensor.data_ptr(), tensor.numel());