diff --git a/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp b/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp index 2a4f749e1..3de8ef1c9 100644 --- a/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp +++ b/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp @@ -199,6 +199,32 @@ static inline void get_wgroup_size_spatial( GroupRow = std::min(GroupRow, int(dim_size)); } +static bool inline get_spatial_vec_choice( + int bs, + int dim_size, + int inner_size, + int scalar_size, + int max_vec_size) { + if (dim_size % max_vec_size != 0) + return false; + int total_resource = syclMaxWorkItemsPerTile(); + int maxWGSize = syclDeviceMaxWorkGroupSize(); + int group_col_size = std::min(inner_size, SIMD32); + auto local_group_num = (inner_size + group_col_size - 1) / group_col_size; + int group_row_size = 1; + while (bs *group_row_size * local_group_num * group_col_size < + total_resource ) { + group_row_size = group_row_size << 1; + if (group_row_size * SIMD32 == maxWGSize) + break; + } + group_row_size = std::min(group_row_size, int(dim_size)); + if(bs *group_row_size * local_group_num * group_col_size <= + total_resource) + return false; + return true; +} + template < int INNER_LOOP, int vec_size, @@ -1544,6 +1570,7 @@ void spatial_softmax_forward( constexpr int float4_size = sizeof(float) * 4; constexpr int max_vec_size = float4_size / sizeof(inscalar_t); + constexpr int scalar_size = sizeof(inscalar_t); constexpr int INNER_LOOP = max_vec_size * 2; // decide vec_size: max_vec_size or 1 @@ -1695,14 +1722,16 @@ void spatial_softmax_forward( } } else { if (can_use_32bit_index) { - if (input_start == output_start && inner_size % max_vec_size == 0) { + if (input_start == output_start && + get_spatial_vec_choice(outer_size, dim_size, inner_size, scalar_size, max_vec_size)) { SPATIAL_SOFTMAX_FORWARD_IMPL( /*vec_size*/ max_vec_size, /*IndexType*/ uint32_t); } else { SPATIAL_SOFTMAX_FORWARD_IMPL(/*vec_size*/ 1, /*IndexType*/ uint32_t); } } else { - if (input_start == output_start && inner_size % max_vec_size == 0) { + if (input_start == output_start && + get_spatial_vec_choice(outer_size, dim_size, inner_size, scalar_size, max_vec_size)) { SPATIAL_SOFTMAX_FORWARD_IMPL( /*vec_size*/ max_vec_size, /*IndexType*/ uint64_t); } else {