From 9913295b0c32da0fff4da29dd3e318dc75fd999c Mon Sep 17 00:00:00 2001 From: yucai Date: Tue, 20 May 2025 22:28:56 -0700 Subject: [PATCH 1/6] optimize work group reduce --- src/ATen/native/xpu/sycl/GroupNormKernels.cpp | 6 ++++-- src/ATen/native/xpu/sycl/GroupReduceUtils.h | 16 ++++++++-------- src/ATen/native/xpu/sycl/LayerNormKernels.cpp | 3 ++- src/ATen/native/xpu/sycl/TensorModeKernel.cpp | 3 ++- 4 files changed, 16 insertions(+), 12 deletions(-) diff --git a/src/ATen/native/xpu/sycl/GroupNormKernels.cpp b/src/ATen/native/xpu/sycl/GroupNormKernels.cpp index 518a4cbefc..e14f829ed5 100644 --- a/src/ATen/native/xpu/sycl/GroupNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/GroupNormKernels.cpp @@ -71,6 +71,7 @@ struct GNRowwiseMomentsFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { const int64_t i = item.get_group(0); WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false, item}; WelfordType val(0, 0, 0, 0); + WelfordType identity_element(0, 0, 0, 0); for (int64_t j = item.get_local_id(0); j < N_; j += item.get_local_range(0)) { const int64_t index = i * N_ + j; @@ -78,7 +79,7 @@ struct GNRowwiseMomentsFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { } val = GroupReduceWithoutBroadcast( - item, val, welford_op, shared_); + item, val, welford_op, identity_element, shared_); if (item.get_local_id(0) == 0) { T_ACC m1; @@ -117,6 +118,7 @@ struct GNRowwiseMomentsVectorizedFunctor [[intel::reqd_sub_group_size(SIMD)]] void operator()( sycl::nd_item<1> item) const { WelfordType val[VEC_SIZE]; + WelfordType identity_element(0, 0, 0, 0); WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false, item}; auto g_start = item.get_group(0) * VEC_SIZE; @@ -139,7 +141,7 @@ struct GNRowwiseMomentsVectorizedFunctor #pragma unroll for (int v = 0; v < VEC_SIZE; ++v) { val[v] = GroupReduceWithoutBroadcast( - item, val[v], welford_op, shared_); + item, val[v], welford_op, identity_element, shared_); } if (item.get_local_id(0) == 0) { diff --git a/src/ATen/native/xpu/sycl/GroupReduceUtils.h b/src/ATen/native/xpu/sycl/GroupReduceUtils.h index 07aecd092e..6d9ba24385 100644 --- a/src/ATen/native/xpu/sycl/GroupReduceUtils.h +++ b/src/ATen/native/xpu/sycl/GroupReduceUtils.h @@ -120,9 +120,8 @@ inline T& SubgroupReduceWithoutBroadcast( auto sg_tid = sg.get_local_linear_id(); #pragma unroll for (int offset = 1; offset < SIMD; offset <<= 1) { - T temp = sycl::shift_group_left(sg, val, offset); if (sg_tid < SIMD - offset) { - val = op.combine(val, temp); + val = op.combine(val, sycl::shift_group_left(sg, val, offset)); } } return val; @@ -133,10 +132,12 @@ inline T& GroupReduceWithoutBroadcast( sycl::nd_item& item, T& val, const ReduceOp& op, + const T& identity_element, shared_t shared) { auto sg = item.get_sub_group(); int sg_tid = sg.get_local_linear_id(); - int sg_id = sg.get_group_linear_id(); + int sg_lid = sg_tid % SIMD; + int sg_wid = sg.get_group_linear_id(); int n_sg = get_local_linear_range(item) / SIMD; val = SubgroupReduceWithoutBroadcast(item, val, op); item.barrier(sycl_local_fence); // prevent races when GroupReduce @@ -145,13 +146,12 @@ inline T& GroupReduceWithoutBroadcast( return val; } if (sg_tid == 0) { - shared[sg_id] = val; + shared[sg_wid] = val; } item.barrier(sycl_local_fence); - if (sg_id == 0) { - for (int i = 1; i < n_sg; i++) { - val = op.combine(val, shared[i]); - } + val = (sg_tid < n_sg) ? shared[sg_lid] : identity_element; + if (sg_wid == 0) { + val = SubgroupReduceWithoutBroadcast(item, val, op); } return val; } diff --git a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp index 448c1b6770..7bd0ee1d59 100644 --- a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp @@ -191,6 +191,7 @@ struct RowwiseMomentsFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { const int64_t i = item_id.get_group(0); WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false}; WelfordType val(0, 0, 0, 0); + WelfordType identity_element(0, 0, 0, 0); for (int64_t j = item_id.get_local_id(0); j < N_; j += item_id.get_local_range(0)) { const int64_t index = i * N_ + j; @@ -198,7 +199,7 @@ struct RowwiseMomentsFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { } val = GroupReduceWithoutBroadcast( - item_id, val, welford_op, shared_); + item_id, val, welford_op, identity_element, shared_); if (item_id.get_local_id(0) == 0) { T_ACC m1; diff --git a/src/ATen/native/xpu/sycl/TensorModeKernel.cpp b/src/ATen/native/xpu/sycl/TensorModeKernel.cpp index 7ae95e36b8..1e5a2fc6c4 100644 --- a/src/ATen/native/xpu/sycl/TensorModeKernel.cpp +++ b/src/ATen/native/xpu/sycl/TensorModeKernel.cpp @@ -217,6 +217,7 @@ inline T reduceGroupWithNThreadLocalReductions( T init) { int offset = item.get_local_id(2) * N; T local = offset < numVals ? threadVals[0] : init; + T identity_element = init; #pragma unroll for (int i = 1; i < N; ++i) { @@ -226,7 +227,7 @@ inline T reduceGroupWithNThreadLocalReductions( } return GroupReduceWithoutBroadcast( - item, local, reduceOp, smem); + item, local, reduceOp, identity_element, smem); } template From a4e50c8615db7d103c4514b0d6cb8fdbcf6b0b0d Mon Sep 17 00:00:00 2001 From: yucai-intel <108388355+yucai-intel@users.noreply.github.com> Date: Thu, 22 May 2025 11:10:16 +0800 Subject: [PATCH 2/6] Update GroupReduceUtils.h --- src/ATen/native/xpu/sycl/GroupReduceUtils.h | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/ATen/native/xpu/sycl/GroupReduceUtils.h b/src/ATen/native/xpu/sycl/GroupReduceUtils.h index 6d9ba24385..84e54682c5 100644 --- a/src/ATen/native/xpu/sycl/GroupReduceUtils.h +++ b/src/ATen/native/xpu/sycl/GroupReduceUtils.h @@ -55,9 +55,6 @@ inline T& GroupReduceSumWithoutBroadcast( val = SubgroupReduceSumWithoutBroadcast(item, val); item.barrier(sycl_local_fence); // prevent races when GroupReduceSum are // called in a row. - if (n_sg == 1) { - return val; - } if (sg_tid == 0) { shared[sg_id] = val; } From 6a2e7feed31b14bc1b8f89dc78ee7d371a9061e3 Mon Sep 17 00:00:00 2001 From: yucai-intel <108388355+yucai-intel@users.noreply.github.com> Date: Fri, 23 May 2025 10:41:49 +0800 Subject: [PATCH 3/6] Update GroupReduceUtils.h --- src/ATen/native/xpu/sycl/GroupReduceUtils.h | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/ATen/native/xpu/sycl/GroupReduceUtils.h b/src/ATen/native/xpu/sycl/GroupReduceUtils.h index 84e54682c5..44d804dbcc 100644 --- a/src/ATen/native/xpu/sycl/GroupReduceUtils.h +++ b/src/ATen/native/xpu/sycl/GroupReduceUtils.h @@ -133,8 +133,7 @@ inline T& GroupReduceWithoutBroadcast( shared_t shared) { auto sg = item.get_sub_group(); int sg_tid = sg.get_local_linear_id(); - int sg_lid = sg_tid % SIMD; - int sg_wid = sg.get_group_linear_id(); + int sg_id = sg.get_group_linear_id(); int n_sg = get_local_linear_range(item) / SIMD; val = SubgroupReduceWithoutBroadcast(item, val, op); item.barrier(sycl_local_fence); // prevent races when GroupReduce @@ -143,11 +142,11 @@ inline T& GroupReduceWithoutBroadcast( return val; } if (sg_tid == 0) { - shared[sg_wid] = val; + shared[sg_id] = val; } item.barrier(sycl_local_fence); - val = (sg_tid < n_sg) ? shared[sg_lid] : identity_element; - if (sg_wid == 0) { + val = (sg_id < n_sg) ? shared[sg_id] : identity_element; + if (sg_id == 0) { val = SubgroupReduceWithoutBroadcast(item, val, op); } return val; From 0f29a6e6cb2345dcfd678c3c53e1c813a22138b6 Mon Sep 17 00:00:00 2001 From: yucai-intel <108388355+yucai-intel@users.noreply.github.com> Date: Fri, 23 May 2025 10:44:09 +0800 Subject: [PATCH 4/6] Update GroupReduceUtils.h --- src/ATen/native/xpu/sycl/GroupReduceUtils.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/ATen/native/xpu/sycl/GroupReduceUtils.h b/src/ATen/native/xpu/sycl/GroupReduceUtils.h index 44d804dbcc..b81479dbc7 100644 --- a/src/ATen/native/xpu/sycl/GroupReduceUtils.h +++ b/src/ATen/native/xpu/sycl/GroupReduceUtils.h @@ -55,6 +55,9 @@ inline T& GroupReduceSumWithoutBroadcast( val = SubgroupReduceSumWithoutBroadcast(item, val); item.barrier(sycl_local_fence); // prevent races when GroupReduceSum are // called in a row. + if (n_sg == 1) { + return val; + } if (sg_tid == 0) { shared[sg_id] = val; } From 33770273b8b89de6caa2761a897e79b90e35236f Mon Sep 17 00:00:00 2001 From: yucai Date: Mon, 26 May 2025 01:20:51 -0700 Subject: [PATCH 5/6] update --- src/ATen/native/xpu/sycl/GroupNormKernels.cpp | 30 ++++++++++++------- src/ATen/native/xpu/sycl/GroupReduceUtils.h | 21 +++++++++---- 2 files changed, 36 insertions(+), 15 deletions(-) diff --git a/src/ATen/native/xpu/sycl/GroupNormKernels.cpp b/src/ATen/native/xpu/sycl/GroupNormKernels.cpp index e14f829ed5..4b624dc60b 100644 --- a/src/ATen/native/xpu/sycl/GroupNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/GroupNormKernels.cpp @@ -63,13 +63,12 @@ template struct GNRowwiseMomentsFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { using T_ACC = acc_type_device; using WelfordType = WelfordData; - using WelfordOp = - WelfordOpsXPU>; + using WelfordOp = WelfordOps>; [[intel::reqd_sub_group_size(SIMD)]] void operator()( sycl::nd_item<1> item) const { const int64_t i = item.get_group(0); - WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false, item}; + WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false}; WelfordType val(0, 0, 0, 0); WelfordType identity_element(0, 0, 0, 0); for (int64_t j = item.get_local_id(0); j < N_; @@ -78,8 +77,13 @@ struct GNRowwiseMomentsFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { val = welford_op.reduce(val, static_cast(X_[index]), index); } - val = GroupReduceWithoutBroadcast( - item, val, welford_op, identity_element, shared_); + if (item.get_local_range(0) <= SIMD) { + val = SubgroupReduceWithoutBroadcast( + item, val, welford_op); + } else { + val = GroupReduceWithoutBroadcast( + item, val, welford_op, identity_element, shared_); + } if (item.get_local_id(0) == 0) { T_ACC m1; @@ -111,15 +115,14 @@ struct GNRowwiseMomentsVectorizedFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { using T_ACC = acc_type_device; using WelfordType = WelfordData; - using WelfordOp = - WelfordOpsXPU>; + using WelfordOp = WelfordOps>; using vec_t = memory::aligned_vector; [[intel::reqd_sub_group_size(SIMD)]] void operator()( sycl::nd_item<1> item) const { WelfordType val[VEC_SIZE]; WelfordType identity_element(0, 0, 0, 0); - WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false, item}; + WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false}; auto g_start = item.get_group(0) * VEC_SIZE; #pragma unroll @@ -140,8 +143,15 @@ struct GNRowwiseMomentsVectorizedFunctor #pragma unroll for (int v = 0; v < VEC_SIZE; ++v) { - val[v] = GroupReduceWithoutBroadcast( - item, val[v], welford_op, identity_element, shared_); + // val[v] = GroupReduceWithoutBroadcast( + // item, val[v], welford_op, identity_element, shared_); + if (item.get_local_range(0) <= SIMD) { + val[v] = SubgroupReduceWithoutBroadcast( + item, val[v], welford_op); + } else { + val[v] = GroupReduceWithoutBroadcast( + item, val[v], welford_op, identity_element, shared_); + } } if (item.get_local_id(0) == 0) { diff --git a/src/ATen/native/xpu/sycl/GroupReduceUtils.h b/src/ATen/native/xpu/sycl/GroupReduceUtils.h index b81479dbc7..9577468f78 100644 --- a/src/ATen/native/xpu/sycl/GroupReduceUtils.h +++ b/src/ATen/native/xpu/sycl/GroupReduceUtils.h @@ -119,10 +119,12 @@ inline T& SubgroupReduceWithoutBroadcast( auto sg = item.get_sub_group(); auto sg_tid = sg.get_local_linear_id(); #pragma unroll - for (int offset = 1; offset < SIMD; offset <<= 1) { - if (sg_tid < SIMD - offset) { - val = op.combine(val, sycl::shift_group_left(sg, val, offset)); - } + for (int offset = (SIMD >> 1); offset > 0; offset >>= 1) { + // for (int offset = 1; offset < SIMD; offset <<= 1) { + T temp = sycl::shift_group_left(sg, val, offset); + // if (sg_tid < SIMD - offset) { + val = op.combine(val, temp); + // } } return val; } @@ -135,6 +137,7 @@ inline T& GroupReduceWithoutBroadcast( const T& identity_element, shared_t shared) { auto sg = item.get_sub_group(); + int g_tid = item.get_local_linear_id(); int sg_tid = sg.get_local_linear_id(); int sg_id = sg.get_group_linear_id(); int n_sg = get_local_linear_range(item) / SIMD; @@ -148,9 +151,17 @@ inline T& GroupReduceWithoutBroadcast( shared[sg_id] = val; } item.barrier(sycl_local_fence); - val = (sg_id < n_sg) ? shared[sg_id] : identity_element; + // val = (g_tid < n_sg) ? shared[sg_id] : identity_element; + val = identity_element; + if (sg_id == 0) { + for (int i = sg_tid; i < n_sg; i += SIMD) { + val = op.combine(val, shared[i]); + } val = SubgroupReduceWithoutBroadcast(item, val, op); + // for (int i = 1; i < n_sg; i++) { + // val = op.combine(val, shared[i]); + // } } return val; } From b9681686a4ba917227860f554c5c5f78254e3357 Mon Sep 17 00:00:00 2001 From: yucai Date: Mon, 26 May 2025 23:38:06 -0700 Subject: [PATCH 6/6] fix err --- src/ATen/native/xpu/sycl/GroupNormKernels.cpp | 2 -- src/ATen/native/xpu/sycl/GroupReduceUtils.h | 13 ++++--------- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/src/ATen/native/xpu/sycl/GroupNormKernels.cpp b/src/ATen/native/xpu/sycl/GroupNormKernels.cpp index 4b624dc60b..a7f6c942db 100644 --- a/src/ATen/native/xpu/sycl/GroupNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/GroupNormKernels.cpp @@ -143,8 +143,6 @@ struct GNRowwiseMomentsVectorizedFunctor #pragma unroll for (int v = 0; v < VEC_SIZE; ++v) { - // val[v] = GroupReduceWithoutBroadcast( - // item, val[v], welford_op, identity_element, shared_); if (item.get_local_range(0) <= SIMD) { val[v] = SubgroupReduceWithoutBroadcast( item, val[v], welford_op); diff --git a/src/ATen/native/xpu/sycl/GroupReduceUtils.h b/src/ATen/native/xpu/sycl/GroupReduceUtils.h index 9577468f78..1e1bb1f487 100644 --- a/src/ATen/native/xpu/sycl/GroupReduceUtils.h +++ b/src/ATen/native/xpu/sycl/GroupReduceUtils.h @@ -119,12 +119,11 @@ inline T& SubgroupReduceWithoutBroadcast( auto sg = item.get_sub_group(); auto sg_tid = sg.get_local_linear_id(); #pragma unroll - for (int offset = (SIMD >> 1); offset > 0; offset >>= 1) { - // for (int offset = 1; offset < SIMD; offset <<= 1) { + for (int offset = 1; offset < SIMD; offset <<= 1) { T temp = sycl::shift_group_left(sg, val, offset); - // if (sg_tid < SIMD - offset) { - val = op.combine(val, temp); - // } + if (sg_tid < SIMD - offset) { + val = op.combine(val, temp); + } } return val; } @@ -151,7 +150,6 @@ inline T& GroupReduceWithoutBroadcast( shared[sg_id] = val; } item.barrier(sycl_local_fence); - // val = (g_tid < n_sg) ? shared[sg_id] : identity_element; val = identity_element; if (sg_id == 0) { @@ -159,9 +157,6 @@ inline T& GroupReduceWithoutBroadcast( val = op.combine(val, shared[i]); } val = SubgroupReduceWithoutBroadcast(item, val, op); - // for (int i = 1; i < n_sg; i++) { - // val = op.combine(val, shared[i]); - // } } return val; }