diff --git a/ggml/src/ggml-sycl/softmax.cpp b/ggml/src/ggml-sycl/softmax.cpp index 7563d9ceda654..24986f169fcf2 100644 --- a/ggml/src/ggml-sycl/softmax.cpp +++ b/ggml/src/ggml-sycl/softmax.cpp @@ -1,16 +1,16 @@ #include "softmax.hpp" -template -static void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, +template +static void soft_max_f32(const float * x, const T * mask, float * dst, int ncols, const int nrows_y, const float scale, const float max_bias, const float m0, - const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) { - const int ncols = ncols_template == 0 ? ncols_par : ncols_template; + const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf, + const bool vals_smem, const bool check_columns_count) { const int tid = item_ct1.get_local_id(2); const int rowx = item_ct1.get_group(2); const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension - const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template; + const int block_size = check_columns_count ? item_ct1.get_local_range(2) : std::min(ncols, 1024); const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; @@ -35,7 +35,7 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int for (int col0 = 0; col0 < ncols; col0 += block_size) { const int col = col0 + tid; - if (ncols_template == 0 && col >= ncols) { + if (check_columns_count && col >= ncols) { break; } @@ -74,7 +74,7 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int #pragma unroll for (int col0 = 0; col0 < ncols; col0 += block_size) { const int col = col0 + tid; - if (ncols_template == 0 && col >= ncols) { + if (check_columns_count && col >= ncols) { break; } @@ -113,7 +113,7 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int for (int col0 = 0; col0 < ncols; col0 += block_size) { const int col = col0 + tid; - if (ncols_template == 0 && col >= ncols) { + if (check_columns_count && col >= ncols) { return; } @@ -122,25 +122,6 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int } } -template -static void soft_max_f32_submitter(const float * x, const T * mask, float * dst, const int ncols_par, - const int nrows_y, const float scale, const float max_bias, const float m0, - const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims, - const size_t n_local_scratch, queue_ptr stream) { - stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor local_buf_acc(n_local_scratch, cgh); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - soft_max_f32(x, mask, dst, ncols_par, - nrows_y, scale, max_bias, m0, - m1, n_head_log2, item_ct1, - get_pointer(local_buf_acc)); - }); - }); -} - template static void soft_max_f32_sycl(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, @@ -163,64 +144,28 @@ static void soft_max_f32_sycl(const float * x, const T * mask, const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); const size_t local_mem_size = stream->get_device().get_info(); - if (n_local_scratch*sizeof(float) < local_mem_size) { - if (ncols_x > max_block_size) { - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - return; - } - switch (ncols_x) { - case 32: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 64: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 128: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 256: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 512: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 1024: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 2048: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 4096: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - default: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - } + + auto soft_max_f32_submit = [=](size_t scratch_size, bool vals_smem, bool check_columns_count) { + stream->submit([=](sycl::handler &cgh) { + sycl::local_accessor local_buf_acc(scratch_size, cgh); + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + soft_max_f32(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, + item_ct1, get_pointer(local_buf_acc), vals_smem, check_columns_count); + }); + }); + }; + + if (n_local_scratch*sizeof(float) >= local_mem_size) { + soft_max_f32_submit(WARP_SIZE, false, true); + } else if (ncols_x > max_block_size) { + soft_max_f32_submit(n_local_scratch, true, true); + } else if (ncols_x == 32 || ncols_x == 64 || ncols_x == 128 || ncols_x == 256 + || ncols_x == 512 || ncols_x == 1024 || ncols_x == 2048 || ncols_x == 4096) { + soft_max_f32_submit(n_local_scratch, true, false); } else { - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, WARP_SIZE, stream); + soft_max_f32_submit(n_local_scratch, true, true); } }