From ddb9267a9f60fba1abe4b7eb6cc62b96985e45be Mon Sep 17 00:00:00 2001 From: "Zhang, Jianyi" Date: Thu, 26 Jun 2025 09:05:50 +0800 Subject: [PATCH 1/2] softmax: adjust vectorization length according to shape --- src/ATen/native/xpu/sycl/SoftMaxKernels.cpp | 24 +++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp b/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp index 2a4f749e15..ac939ea9e9 100644 --- a/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp +++ b/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp @@ -199,6 +199,23 @@ static inline void get_wgroup_size_spatial( GroupRow = std::min(GroupRow, int(dim_size)); } +static bool inline can_use_vec( + int dim_size, + int scalar_size, + int max_vec_size, + bool is_same_dtype) { + if (dim_size % max_vec_size != 0) + return false; + if (is_same_dtype) { + if (dim_size <= 2048 && dim_size * scalar_size <= 8192) + return false; + } else { + if (dim_size <= 1024 && dim_size * scalar_size <= 4096) + return false; + } + return true; +} + template < int INNER_LOOP, int vec_size, @@ -1544,6 +1561,7 @@ void spatial_softmax_forward( constexpr int float4_size = sizeof(float) * 4; constexpr int max_vec_size = float4_size / sizeof(inscalar_t); + constexpr int scalar_size = sizeof(inscalar_t); constexpr int INNER_LOOP = max_vec_size * 2; // decide vec_size: max_vec_size or 1 @@ -1695,14 +1713,16 @@ void spatial_softmax_forward( } } else { if (can_use_32bit_index) { - if (input_start == output_start && inner_size % max_vec_size == 0) { + if (input_start == output_start && + can_use_vec(inner_size, scalar_size, max_vec_size, is_same_dtype)) { SPATIAL_SOFTMAX_FORWARD_IMPL( /*vec_size*/ max_vec_size, /*IndexType*/ uint32_t); } else { SPATIAL_SOFTMAX_FORWARD_IMPL(/*vec_size*/ 1, /*IndexType*/ uint32_t); } } else { - if (input_start == output_start && inner_size % max_vec_size == 0) { + if (input_start == output_start && + can_use_vec(inner_size, scalar_size, max_vec_size, is_same_dtype)) { SPATIAL_SOFTMAX_FORWARD_IMPL( /*vec_size*/ max_vec_size, /*IndexType*/ uint64_t); } else { From 0d4a4f92ec0bbcd649f259c1f2dfae24f0cd5058 Mon Sep 17 00:00:00 2001 From: "Deng, Weishi" Date: Wed, 9 Jul 2025 11:06:55 +0800 Subject: [PATCH 2/2] update heristic for spatial softmax --- src/ATen/native/xpu/sycl/SoftMaxKernels.cpp | 31 +++++++++++++-------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp b/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp index ac939ea9e9..3de8ef1c99 100644 --- a/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp +++ b/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp @@ -199,20 +199,29 @@ static inline void get_wgroup_size_spatial( GroupRow = std::min(GroupRow, int(dim_size)); } -static bool inline can_use_vec( +static bool inline get_spatial_vec_choice( + int bs, int dim_size, + int inner_size, int scalar_size, - int max_vec_size, - bool is_same_dtype) { + int max_vec_size) { if (dim_size % max_vec_size != 0) return false; - if (is_same_dtype) { - if (dim_size <= 2048 && dim_size * scalar_size <= 8192) - return false; - } else { - if (dim_size <= 1024 && dim_size * scalar_size <= 4096) - return false; + int total_resource = syclMaxWorkItemsPerTile(); + int maxWGSize = syclDeviceMaxWorkGroupSize(); + int group_col_size = std::min(inner_size, SIMD32); + auto local_group_num = (inner_size + group_col_size - 1) / group_col_size; + int group_row_size = 1; + while (bs *group_row_size * local_group_num * group_col_size < + total_resource ) { + group_row_size = group_row_size << 1; + if (group_row_size * SIMD32 == maxWGSize) + break; } + group_row_size = std::min(group_row_size, int(dim_size)); + if(bs *group_row_size * local_group_num * group_col_size <= + total_resource) + return false; return true; } @@ -1714,7 +1723,7 @@ void spatial_softmax_forward( } else { if (can_use_32bit_index) { if (input_start == output_start && - can_use_vec(inner_size, scalar_size, max_vec_size, is_same_dtype)) { + get_spatial_vec_choice(outer_size, dim_size, inner_size, scalar_size, max_vec_size)) { SPATIAL_SOFTMAX_FORWARD_IMPL( /*vec_size*/ max_vec_size, /*IndexType*/ uint32_t); } else { @@ -1722,7 +1731,7 @@ void spatial_softmax_forward( } } else { if (input_start == output_start && - can_use_vec(inner_size, scalar_size, max_vec_size, is_same_dtype)) { + get_spatial_vec_choice(outer_size, dim_size, inner_size, scalar_size, max_vec_size)) { SPATIAL_SOFTMAX_FORWARD_IMPL( /*vec_size*/ max_vec_size, /*IndexType*/ uint64_t); } else {