From f94e2e01f2ae4dc88ca7fd79e7c6a799c202097e Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Mon, 16 Jun 2025 03:46:30 -0400 Subject: [PATCH] [cpu int8 sdpa] use manual tranpose and pack --- .../codegen/cpp_int8_sdpa_template.py | 326 ++++++++---------- 1 file changed, 148 insertions(+), 178 deletions(-) diff --git a/torchao/prototype/inductor/codegen/cpp_int8_sdpa_template.py b/torchao/prototype/inductor/codegen/cpp_int8_sdpa_template.py index 1f8865356a..8a79067e79 100644 --- a/torchao/prototype/inductor/codegen/cpp_int8_sdpa_template.py +++ b/torchao/prototype/inductor/codegen/cpp_int8_sdpa_template.py @@ -13,7 +13,7 @@ from .utils import expand USEFUL_FUNCTIONS = r""" -inline float {{kernel_name}}_calculate_scale( +inline float calculate_scale( int64_t headSize, std::optional scale) { return scale.has_value() @@ -22,7 +22,7 @@ } template -inline void {{kernel_name}}_fill_stub(scalar_t* data, scalar_t val, int64_t size) { +inline void fill_stub(scalar_t* data, scalar_t val, int64_t size) { const int32_t vec_size = at::vec::Vectorized::size(); auto data_vec = at::vec::Vectorized(val); int64_t d = 0; @@ -35,13 +35,13 @@ } template -inline void {{kernel_name}}_store(scalar_t* dst, at::vec::Vectorized src, int size=at::vec::Vectorized::size()) { +inline void store(scalar_t* dst, at::vec::Vectorized src, int size=at::vec::Vectorized::size()) { src.store(dst, size); } template inline typename std::enable_if_t || std::is_same_v, void> -{{kernel_name}}_store(scalar_t* dst, at::vec::Vectorized src, int size=at::vec::Vectorized::size()) { +store(scalar_t* dst, at::vec::Vectorized src, int size=at::vec::Vectorized::size()) { auto res = at::vec::convert(src); res.store(dst, size); } @@ -52,7 +52,7 @@ 3. max reduce for softmax */ template -inline void {{kernel_name}}_dequant_mask_max_fusion_kernel( +inline void dequant_mask_max_fusion_kernel( const int32_t* in, const mask_t* mask_ptr, const int32_t* sum_a_ptr, @@ -90,7 +90,7 @@ auto tmp7 = at::vec::convert(tmp6); auto tmp8 = tmp5 + tmp7; vec_tmp_max = at::vec::clamp_min(vec_tmp_max, tmp8); - {{kernel_name}}_store(tmp_out + col, tmp8); + store(tmp_out + col, tmp8); } if (col < N) { auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col, N - col); @@ -103,7 +103,7 @@ auto tmp6 = at::vec::Vectorized::loadu(mask_data_ptr + col, N - col); auto tmp7 = at::vec::convert(tmp6); auto tmp8 = tmp5 + tmp7; - {{kernel_name}}_store(tmp_out + col, tmp8, N - col); + store(tmp_out + col, tmp8, N - col); vec_tmp_max = at::vec::Vectorized::set(vec_tmp_max, at::vec::clamp_min(vec_tmp_max, tmp8), N - col); } sfm_max_ptr[row] = std::max(sfm_max_ptr[row], vec_tmp_max.reduce_max()); @@ -114,7 +114,7 @@ 1. dequant 2. max reduce for softmax */ -inline void {{kernel_name}}_dequant_max_fusion_kernel( +inline void dequant_max_fusion_kernel( const int32_t* in, const int32_t* sum_a_ptr, const int32_t* sum_b_ptr, @@ -146,7 +146,7 @@ auto tmp4 = at::vec::convert(tmp3); auto tmp5 = tmp4 * vec_alpha; vec_tmp_max = at::vec::clamp_min(vec_tmp_max, tmp5); - {{kernel_name}}_store(tmp_out + col, tmp5); + store(tmp_out + col, tmp5); } if (col < N) { auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col, N - col); @@ -156,7 +156,7 @@ auto tmp3 = tmp2 + vec_beta; auto tmp4 = at::vec::convert(tmp3); auto tmp5 = tmp4 * vec_alpha; - {{kernel_name}}_store(tmp_out + col, tmp5, N - col); + store(tmp_out + col, tmp5, N - col); vec_tmp_max = at::vec::Vectorized::set(vec_tmp_max, at::vec::clamp_min(vec_tmp_max, tmp5), N - col); } sfm_max_ptr[row] = std::max(sfm_max_ptr[row], vec_tmp_max.reduce_max()); @@ -169,7 +169,7 @@ 3. sum for attention */ template -inline void {{kernel_name}}_sub_exp_sum_div_quant_sum_fusion_kernel( +inline void sub_exp_sum_div_quant_sum_fusion_kernel( const float* in, const int64_t& M, const int64_t& N_step, @@ -214,13 +214,13 @@ auto tmp1 = tmp0 - vec_max; auto tmp2 = tmp1.exp_u20(); vec_tmp_sum += tmp2; - {{kernel_name}}_store(tmp_out + col, tmp2); + store(tmp_out + col, tmp2); } if (col < kvBlockSize) { auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col, kvBlockSize - col); auto tmp1 = tmp0 - vec_max; auto tmp2 = tmp1.exp_u20(); - {{kernel_name}}_store(tmp_out + col, tmp2, kvBlockSize - col); + store(tmp_out + col, tmp2, kvBlockSize - col); vec_tmp_sum = at::vec::Vectorized::set(vec_tmp_sum, vec_tmp_sum + tmp2, kvBlockSize - col); } sfm_sum_ptr[row] += vec_tmp_sum.reduce_add(); @@ -243,7 +243,7 @@ auto tmp2 = tmp1.round(); auto tmp3 = tmp2 + vec_beta1; auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); - {{kernel_name}}_store(tmp_out + col, tmp4); + store(tmp_out + col, tmp4); auto tmp6 = at::vec::convert(tmp4); vec_tmp_sum += tmp6; } @@ -253,7 +253,7 @@ auto tmp2 = tmp1.round(); auto tmp3 = tmp2 + vec_beta1; auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); - {{kernel_name}}_store(tmp_out + col, tmp4, kvBlockSize - col); + store(tmp_out + col, tmp4, kvBlockSize - col); auto tmp6 = at::vec::convert(tmp4); vec_tmp_sum = at::vec::Vectorized::set(vec_tmp_sum, vec_tmp_sum + tmp6, kvBlockSize - col); } @@ -261,10 +261,10 @@ // set zero col = kvBlockSize; for (; col < vec_size * (av_gemm_K / vec_size); col += vec_size) { - {{kernel_name}}_store(tmp_out + col, vec_zero); + store(tmp_out + col, vec_zero); } if (col < av_gemm_K) { - {{kernel_name}}_store(tmp_out + col, vec_zero, av_gemm_K - col); + store(tmp_out + col, vec_zero, av_gemm_K - col); } } } @@ -275,7 +275,7 @@ 2. quant */ template -inline void {{kernel_name}}_sub_exp_sum_div_quant_fusion_kernel( +inline void sub_exp_sum_div_quant_fusion_kernel( const float* in, const int64_t& M, const int64_t& N_step, @@ -318,14 +318,14 @@ auto tmp1 = tmp0 - vec_max; auto tmp2 = tmp1.exp_u20(); vec_tmp_sum += tmp2; - {{kernel_name}}_store(tmp_out + col, tmp2); + store(tmp_out + col, tmp2); } if (col < kvBlockSize) { auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col, kvBlockSize - col); auto tmp1 = tmp0 - vec_max; auto tmp2 = tmp1.exp_u20(); vec_tmp_sum = at::vec::Vectorized::set(vec_tmp_sum, vec_tmp_sum + tmp2, kvBlockSize - col); - {{kernel_name}}_store(tmp_out + col, tmp2, kvBlockSize - col); + store(tmp_out + col, tmp2, kvBlockSize - col); } sfm_sum_ptr[row] += vec_tmp_sum.reduce_add(); } @@ -345,7 +345,7 @@ auto tmp2 = tmp1.round(); auto tmp3 = tmp2 + vec_beta1; auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); - {{kernel_name}}_store(tmp_out + col, tmp4); + store(tmp_out + col, tmp4); } if (col < kvBlockSize) { auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col, kvBlockSize - col); @@ -353,15 +353,15 @@ auto tmp2 = tmp1.round(); auto tmp3 = tmp2 + vec_beta1; auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); - {{kernel_name}}_store(tmp_out + col, tmp4, kvBlockSize - col); + store(tmp_out + col, tmp4, kvBlockSize - col); } // set zero col = kvBlockSize; for (; col < vec_size * (av_gemm_K / vec_size); col += vec_size) { - {{kernel_name}}_store(tmp_out + col, vec_zero); + store(tmp_out + col, vec_zero); } if (col < av_gemm_K) { - {{kernel_name}}_store(tmp_out + col, vec_zero, av_gemm_K - col); + store(tmp_out + col, vec_zero, av_gemm_K - col); } } } @@ -372,7 +372,7 @@ 2. quant */ template -inline void {{kernel_name}}_dequant_quant_fusion_kernel( +inline void dequant_quant_fusion_kernel( const int32_t* in, const int32_t* sum_a_ptr, const int32_t* sum_b_ptr, @@ -410,7 +410,7 @@ auto tmp6 = tmp5.round(); auto tmp7 = tmp6 + vec_beta2; auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val); - {{kernel_name}}_store(tmp_out + col, tmp8); + store(tmp_out + col, tmp8); } if (col < N) { auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col, N - col); @@ -423,7 +423,7 @@ auto tmp6 = tmp5.round(); auto tmp7 = tmp6 + vec_beta2; auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val); - {{kernel_name}}_store(tmp_out + col, tmp8, N - col); + store(tmp_out + col, tmp8, N - col); } } } @@ -433,7 +433,7 @@ 2. quant */ template -inline void {{kernel_name}}_dequant_quant_fusion_kernel( +inline void dequant_quant_fusion_kernel( const int32_t* in, const int32_t* sum_a_ptr, const int& M, @@ -467,7 +467,7 @@ auto tmp6 = tmp5.round(); auto tmp7 = tmp6 + vec_beta2; auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val); - {{kernel_name}}_store(tmp_out + col, tmp8); + store(tmp_out + col, tmp8); } if (col < N) { auto tmp1 = at::vec::Vectorized::loadu(tmp_in + col, N - col); @@ -477,13 +477,13 @@ auto tmp6 = tmp5.round(); auto tmp7 = tmp6 + vec_beta2; auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val); - {{kernel_name}}_store(tmp_out + col, tmp8, N - col); + store(tmp_out + col, tmp8, N - col); } } } template -inline void {{kernel_name}}_int_sum_b_contiguous_kernel_helper( +inline void int_sum_b_contiguous_kernel_helper( const scalar_t* in, int32_t* out, const int& N, @@ -507,7 +507,7 @@ // reduce along dim b for shape [a, b], with sum shape [a] template -inline void {{kernel_name}}_int_sum_b_contiguous_kernel( +inline void int_sum_b_contiguous_kernel( const scalar_t* in, int32_t* out, const int& M, @@ -515,13 +515,13 @@ const int& ld, const int32_t& scale) { for (long r = 0; r < M; r += 1) { - {{kernel_name}}_int_sum_b_contiguous_kernel_helper(in + r * ld, out + r, N, scale); + int_sum_b_contiguous_kernel_helper(in + r * ld, out + r, N, scale); } } // reduce along dim a for shape [a, b], with sum shape [b] template -inline void {{kernel_name}}_int_sum_a_contiguous_kernel( +inline void int_sum_a_contiguous_kernel( const scalar_t* in, int32_t* out, const int& M, @@ -535,10 +535,10 @@ auto vec_zero = at::vec::Vectorized(zero); long i = 0; for (; i < vec_size * (M / vec_size); i += vec_size) { - {{kernel_name}}_store(out + i, vec_zero); + store(out + i, vec_zero); } if (i < M) { - {{kernel_name}}_store(out + i, vec_zero, M - i); + store(out + i, vec_zero, M - i); } // sum for (long j = 0; j < N; j++) { @@ -549,14 +549,14 @@ auto tmp1 = at::vec::Vectorized::loadu(out + k); auto tmp2 = at::vec::convert(tmp0); auto tmp3 = tmp1 + tmp2; - {{kernel_name}}_store(out + k, tmp3); + store(out + k, tmp3); } if (k < M) { auto tmp0 = at::vec::Vectorized::loadu(tmp_in + k, M - k); auto tmp1 = at::vec::Vectorized::loadu(out + k, M - k); auto tmp2 = at::vec::convert(tmp0); auto tmp3 = tmp1 + tmp2; - {{kernel_name}}_store(out + k, tmp3, M - k); + store(out + k, tmp3, M - k); } } // scale @@ -564,18 +564,18 @@ for (; i < vec_size * (M / vec_size); i += vec_size) { auto tmp0 = at::vec::Vectorized::loadu(out + i); auto tmp1 = tmp0 * vec_scale; - {{kernel_name}}_store(out + i, tmp1); + store(out + i, tmp1); } if (i < M) { auto tmp0 = at::vec::Vectorized::loadu(out + i, M - i); auto tmp1 = tmp0 * vec_scale; - {{kernel_name}}_store(out + i, tmp1, M - i); + store(out + i, tmp1, M - i); } } // do the transpose: [in_rows, in_cols] -> [in_cols, in_rows] template -inline void {{kernel_name}}_do_transpose( +inline void do_transpose( const scalar_t* src, scalar_t* dst, int64_t in_rows, @@ -591,7 +591,7 @@ // padding with pad_val: [rows, cols] -> [prows, pcols] template -inline void {{kernel_name}}_pad_remain_row_col( +inline void pad_remain_row_col( scalar_t* value_ptr, int rows, int cols, @@ -630,7 +630,7 @@ // copy value_ptr to dst_ptr with padding: [rows, cols] -> [prows, pcols] template -inline void {{kernel_name}}_copy_value_with_pad( +inline void copy_value_with_pad( const scalar_t* value_ptr, scalar_t* dst_ptr, int rows, @@ -694,6 +694,9 @@ INT8_SDPA_ONE_LOOP_TEMPLATE = r""" +#ifndef HEADER_DEFINED +#define HEADER_DEFINED + {{template.header().getvalue()}} #include #include @@ -701,6 +704,7 @@ #include #include #include +#include #include #include #include @@ -721,6 +725,8 @@ {{template.codegen_useful_function(kernel.kernel_name)}} +#endif + {%- if has_attention_mask %} {%- set kernel_args = {"query": query, "key": key, "value": value, "attention_mask": attention_mask} %} @@ -746,7 +752,7 @@ int64_t num_head = {{kernel.size(query, 2)}}; int64_t headSize = {{kernel.size(query, 3)}}; float scaling_factor = - {{kernel.kernel_name}}_calculate_scale(headSize, {{scale}}); + calculate_scale(headSize, {{scale}}); // Strides int64_t qStrideB = {{kernel.stride(query, 0)}}; @@ -873,16 +879,16 @@ // sum k and v {%- if q_zp == 0 %} - {{kernel.kernel_name}}_fill_stub(k_sum_ptr, static_cast(0), kvSize); + fill_stub(k_sum_ptr, static_cast(0), kvSize); {%- else %} - {{kernel.kernel_name}}_int_sum_b_contiguous_kernel(k_data + i * kStrideB + j * kStrideH, + int_sum_b_contiguous_kernel(k_data + i * kStrideB + j * kStrideH, k_sum_ptr, kvSize, headSize, kStrideN, {{q_zp}}); {%- endif %} {%- if a_zp == 0 %} - {{kernel.kernel_name}}_fill_stub(v_sum_ptr, static_cast(0), headSize); + fill_stub(v_sum_ptr, static_cast(0), headSize); {%- else %} - {{kernel.kernel_name}}_int_sum_a_contiguous_kernel(v_data + i * vStrideB + j * vStrideH, + int_sum_a_contiguous_kernel(v_data + i * vStrideB + j * vStrideH, v_sum_ptr, headSize, kvSize, vStrideN, {{a_zp}}); {%- endif %} @@ -893,15 +899,15 @@ for (int64_t b = 0; b < kvBlockSize; b += block_64) { bool istail = kvBlockSize - b < block_64; int64_t trans_rows = istail ? kvBlockSize - b : block_64; - {{kernel.kernel_name}}_do_transpose( - k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, - B_blocked_xform_u8, - trans_rows, + at::native::utils::transpose( headSize, + trans_rows, + k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, kStrideN, + B_blocked_xform_u8, block_64); if (!headSize_mul64 || istail) { - {{kernel.kernel_name}}_pad_remain_row_col( + pad_remain_row_col( B_blocked_xform_u8, headSize, trans_rows, @@ -910,30 +916,24 @@ block_64 ); } - at::native::cpublas::pack( - qk_gemm_K, // K - block_64, // N - block_64, // ld_in - block_64, // ld_out - u8_dt, // dt_in - u8_dt, // dt_out - B_blocked_xform_u8, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K); + at::vec::pack_vnni4( + /* src */ B_blocked_xform_u8, + /* dst */ key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K, + /* ld_src */ block_64, + /* K */ qk_gemm_K, + /* N */ block_64); } // split headSize to block_64, block_64, block_64 ... // [av_gemm_K, headSize] -> [av_gemm_K, block_64 ...] for (int64_t b = 0; b < rndHeadSize; b += block_64) { - at::native::cpublas::pack( - av_gemm_K, - block_64, - vStrideN, // block_64, - block_64, - u8_dt, - u8_dt, - v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, - value_reorder_ptr + n * rndHeadSize + - av_gemm_K * b); + at::vec::pack_vnni4( + /* src */ v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, + /* dst */ value_reorder_ptr + n * rndHeadSize + + av_gemm_K * b, + /* ld_src */ vStrideN, + /* K */ av_gemm_K, + /* N */ block_64); } } @@ -942,14 +942,14 @@ int64_t m = k * qSplitSize; int64_t qBlockSize = std::min(qSplitSize, qSize - m); // Initialize sum and max - {{kernel.kernel_name}}_fill_stub( + fill_stub( sfm_sum_ptr, static_cast(0), qSplitSize); - {{kernel.kernel_name}}_fill_stub( + fill_stub( a_sum_ptr, static_cast(0), qSplitSize); - {{kernel.kernel_name}}_fill_stub( + fill_stub( sfm_max_ptr, static_cast(-std::numeric_limits::infinity()), qSplitSize); int64_t num_keys = kvSize; - {{kernel.kernel_name}}_copy_value_with_pad( + copy_value_with_pad( q_data + i * qStrideB + j * qStrideH + m * qStrideM, query_t_padding_ptr, qBlockSize, @@ -959,10 +959,10 @@ qStrideM); // sum q {%- if k_zp != 0 %} - {{kernel.kernel_name}}_int_sum_b_contiguous_kernel(q_data + i * qStrideB + j * qStrideH + m * qStrideM, + int_sum_b_contiguous_kernel(q_data + i * qStrideB + j * qStrideH + m * qStrideM, q_sum_ptr, qBlockSize, headSize, qStrideM, {{k_zp}}); {%- else %} - {{kernel.kernel_name}}_fill_stub( + fill_stub( q_sum_ptr, static_cast(0), qSplitSize); {%- endif %} const int64_t rkvSlice = (num_keys - 1) / kvSplitSize + 1; @@ -986,7 +986,7 @@ accum_t* qk_block_data = qk_data + l * qSplitSize * rndkvSplitSize; {%- if has_attention_mask %} const mask_t* mask_data_offset = mask_data + i * mStrideB + j * mStrideH + m * mStrideM + (mStrideN == 0 ? 0 : n); - {{kernel.kernel_name}}_dequant_mask_max_fusion_kernel( + dequant_mask_max_fusion_kernel( qk_s32_data, //in mask_data_offset, //mask_ptr q_sum_ptr, //sum_a_ptr @@ -1002,7 +1002,7 @@ sfm_max_ptr //sfm_max_ptr ); {%- else %} - {{kernel.kernel_name}}_dequant_max_fusion_kernel( + dequant_max_fusion_kernel( qk_s32_data, //in q_sum_ptr, //sum_a_ptr k_sum_ptr + n, //sum_b_ptr @@ -1021,7 +1021,7 @@ // and quant // and sum for attention {%- if v_zp == 0 %} - {{kernel.kernel_name}}_sub_exp_sum_div_quant_fusion_kernel( + sub_exp_sum_div_quant_fusion_kernel( qk_data, //in qBlockSize, //M kvSplitSize, //N_step @@ -1039,7 +1039,7 @@ sfm_sum_ptr //sfm_sum_ptr ); {%- else %} - {{kernel.kernel_name}}_sub_exp_sum_div_quant_sum_fusion_kernel( + sub_exp_sum_div_quant_sum_fusion_kernel( qk_data, //in qBlockSize, //M kvSplitSize, //N_step @@ -1079,7 +1079,7 @@ // After the last gemm, // do dequant compensation, quant and convert from s32 to int8 {%- if a_zp == 0 %} - {{kernel.kernel_name}}_dequant_quant_fusion_kernel( + dequant_quant_fusion_kernel( dst_s32_data, //in a_sum_ptr, //sum_a_ptr qBlockSize, //M @@ -1091,7 +1091,7 @@ out_data + i * oStrideB + j * oStrideH + m * oStrideM //out ); {%- else %} - {{kernel.kernel_name}}_dequant_quant_fusion_kernel( + dequant_quant_fusion_kernel( dst_s32_data, //in a_sum_ptr, //sum_a_ptr v_sum_ptr, //sum_b_ptr @@ -1118,6 +1118,9 @@ INT8_SDPA_SEVERAL_LOOPS_TEMPLATE = r""" +#ifndef HEADER_DEFINED +#define HEADER_DEFINED + {{template.header().getvalue()}} #include #include @@ -1125,6 +1128,7 @@ #include #include #include +#include #include #include #include @@ -1145,6 +1149,8 @@ {{template.codegen_useful_function(kernel.kernel_name)}} +#endif + {%- if has_attention_mask %} {%- set kernel_args = {"query": query, "key": key, "value": value, "attention_mask": attention_mask} %} @@ -1170,7 +1176,7 @@ int64_t num_head = {{kernel.size(query, 2)}}; int64_t headSize = {{kernel.size(query, 3)}}; float scaling_factor = - {{kernel.kernel_name}}_calculate_scale(headSize, {{scale}}); + calculate_scale(headSize, {{scale}}); // Strides int64_t qStrideB = {{kernel.stride(query, 0)}}; @@ -1275,16 +1281,16 @@ int32_t* k_sum_ptr = kv_sum_ptr; int32_t* v_sum_ptr = kv_sum_ptr + kvSize; {%- if q_zp == 0 %} - {{kernel.kernel_name}}_fill_stub(k_sum_ptr, static_cast(0), kvSize); + fill_stub(k_sum_ptr, static_cast(0), kvSize); {%- else %} - {{kernel.kernel_name}}_int_sum_b_contiguous_kernel(k_data + i * kStrideB + j * kStrideH, + int_sum_b_contiguous_kernel(k_data + i * kStrideB + j * kStrideH, k_sum_ptr, kvSize, headSize, kStrideN, {{q_zp}}); {%- endif %} {%- if a_zp == 0 %} - {{kernel.kernel_name}}_fill_stub(v_sum_ptr, static_cast(0), headSize); + fill_stub(v_sum_ptr, static_cast(0), headSize); {%- else %} - {{kernel.kernel_name}}_int_sum_a_contiguous_kernel(v_data + i * vStrideB + j * vStrideH, + int_sum_a_contiguous_kernel(v_data + i * vStrideB + j * vStrideH, v_sum_ptr, headSize, kvSize, vStrideN, {{a_zp}}); {%- endif %} @@ -1299,7 +1305,7 @@ int64_t i = 0, j = 0, l = 0, n = 0; at::native::data_index_init( begin, i, batchSize, j, num_head, l, kvSlice); - uint8_t* B_blocked_xform_u8 = new uint8_t[qk_gemm_K * block_64]; + uint8_t* B_blocked_xform_u8 = new uint8_t[qk_gemm_K * kvSplitSize]; for (const auto z : c10::irange(begin, end)) { (void)z; // Suppress unused variable n = l * kvSplitSize; @@ -1309,49 +1315,25 @@ i * num_head * kvSlice * v_reorder_strideL + j * kvSlice * v_reorder_strideL + n * rndHeadSize; int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); - for (int64_t b = 0; b < kvBlockSize; b += block_64) { - bool istail = kvBlockSize - b < block_64; - int64_t trans_rows = istail ? kvBlockSize - b : block_64; - {{kernel.kernel_name}}_do_transpose( - k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, - B_blocked_xform_u8, - trans_rows, + at::native::utils::transpose( + kvBlockSize, headSize, + k_data + i * kStrideB + j * kStrideH + n * kStrideN, kStrideN, - block_64); - if (!headSize_mul64 || istail) { - {{kernel.kernel_name}}_pad_remain_row_col( - B_blocked_xform_u8, - headSize, - trans_rows, - qk_gemm_K, - block_64, - block_64 - ); - } - at::native::cpublas::pack( - qk_gemm_K, // K - block_64, // N - block_64, // ld_in - block_64, // ld_out - u8_dt, // dt_in - u8_dt, // dt_out - B_blocked_xform_u8, - k_reorder + b * qk_gemm_K); - } - // split headSize to block_64, block_64, block_64 ... - // [av_gemm_K, headSize] -> [av_gemm_K, block_64 ...] - for (int64_t b = 0; b < rndHeadSize; b += block_64) { - at::native::cpublas::pack( - av_gemm_K, - block_64, - vStrideN, // block_64, - block_64, - u8_dt, - u8_dt, - v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, - v_reorder + av_gemm_K * b); - } + B_blocked_xform_u8, + kvBlockSize); + at::vec::pack_vnni4( + /* src */ B_blocked_xform_u8, + /* dst */ k_reorder, + /* ld_src */ kvBlockSize, + /* K */ qk_gemm_K, + /* N */ kvBlockSize); + at::vec::pack_vnni4( + /* src */ v_data + i * vStrideB + j * vStrideH + n * vStrideN, + /* dst */ v_reorder, + /* ld_src */ vStrideN, + /* K */ av_gemm_K, + /* N */ rndHeadSize); // Move to the next query at::native::data_index_step(i, batchSize, j, num_head, l, kvSlice); } @@ -1382,8 +1364,8 @@ int32_t* a_sum_ptr = reinterpret_cast(total_buf_ptr + offset); offset += qSplitSize * 4; accum_t* sfm_max_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * 4; - scalar_t* query_t_padding_ptr = reinterpret_cast(total_buf_ptr + offset); + //offset += qSplitSize * 4; + //scalar_t* query_t_padding_ptr = reinterpret_cast(total_buf_ptr + offset); for (const auto z : c10::irange(begin, end)) { (void)z; // Suppress unused variable @@ -1398,53 +1380,45 @@ int64_t m = k * qSplitSize; int64_t qBlockSize = std::min(qSplitSize, qSize - m); // Initialize sum and max - {{kernel.kernel_name}}_fill_stub( + fill_stub( sfm_sum_ptr, static_cast(0), qSplitSize); - {{kernel.kernel_name}}_fill_stub( + fill_stub( a_sum_ptr, static_cast(0), qSplitSize); - {{kernel.kernel_name}}_fill_stub( + fill_stub( sfm_max_ptr, static_cast(-std::numeric_limits::infinity()), qSplitSize); int64_t num_keys = kvSize; - {{kernel.kernel_name}}_copy_value_with_pad( - q_data + i * qStrideB + j * qStrideH + m * qStrideM, - query_t_padding_ptr, - qBlockSize, - headSize, - qBlockSize, - qk_gemm_K, - qStrideM); // sum q + const scalar_t* q_tmp = q_data + i * qStrideB + j * qStrideH + m * qStrideM; {%- if k_zp != 0 %} - {{kernel.kernel_name}}_int_sum_b_contiguous_kernel(q_data + i * qStrideB + j * qStrideH + m * qStrideM, + int_sum_b_contiguous_kernel(q_tmp, q_sum_ptr, qBlockSize, headSize, qStrideM, {{k_zp}}); {%- else %} - {{kernel.kernel_name}}_fill_stub( + fill_stub( q_sum_ptr, static_cast(0), qSplitSize); {%- endif %} const int64_t rkvSlice = (num_keys - 1) / kvSplitSize + 1; + for (int64_t l = 0; l < rkvSlice; l++) { int64_t n = l * kvSplitSize; int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); auto k_reorder = key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + j * qk_gemm_K * rndkvSize + n * qk_gemm_K; // Calculate q @ k.T - for (int64_t b = 0; b < kvBlockSize; b += block_64) { - at::native::cpublas::brgemm( - qSplitSize, block_64, qk_gemm_K, - qk_gemm_K, // lda - block_64, //ldb + at::native::cpublas::brgemm( + qSplitSize, kvBlockSize, headSize, + qStrideM, // lda + kvBlockSize, //ldb rndkvSplitSize, //ldc, false, - query_t_padding_ptr, - k_reorder + b * qk_gemm_K, - qk_s32_data + b); - } + q_tmp, + k_reorder, + qk_s32_data); // do dequant compensation, add mask, max reduce for softmax, and convert qk from s32 to fp32 accum_t* qk_block_data = qk_data + l * qSplitSize * rndkvSplitSize; {%- if has_attention_mask %} const mask_t* mask_data_offset = mask_data + i * mStrideB + j * mStrideH + m * mStrideM + (mStrideN == 0 ? 0 : n); - {{kernel.kernel_name}}_dequant_mask_max_fusion_kernel( + dequant_mask_max_fusion_kernel( qk_s32_data, //in mask_data_offset, //mask_ptr q_sum_ptr, //sum_a_ptr @@ -1460,7 +1434,7 @@ sfm_max_ptr //sfm_max_ptr ); {%- else %} - {{kernel.kernel_name}}_dequant_max_fusion_kernel( + dequant_max_fusion_kernel( qk_s32_data, //in q_sum_ptr, //sum_a_ptr k_sum_ptr + n, //sum_b_ptr @@ -1479,7 +1453,7 @@ // and quant // and sum for attention {%- if v_zp == 0 %} - {{kernel.kernel_name}}_sub_exp_sum_div_quant_fusion_kernel( + sub_exp_sum_div_quant_fusion_kernel( qk_data, //in qBlockSize, //M kvSplitSize, //N_step @@ -1497,7 +1471,7 @@ sfm_sum_ptr //sfm_sum_ptr ); {%- else %} - {{kernel.kernel_name}}_sub_exp_sum_div_quant_sum_fusion_kernel( + sub_exp_sum_div_quant_sum_fusion_kernel( qk_data, //in qBlockSize, //M kvSplitSize, //N_step @@ -1521,26 +1495,22 @@ auto v_reorder = value_reorder_ptr + i * num_head * kvSlice * v_reorder_strideL + j * kvSlice * v_reorder_strideL; - for (int64_t b = 0; b < headSize; b += block_64) { - auto value_reorder_b = v_reorder + b * av_gemm_K; - auto dst_s32_b = dst_s32_data + b; - for (int64_t s = 0; s < kvSlice; s++) { - at::native::cpublas::brgemm( - qSplitSize, block_64, av_gemm_K, - av_gemm_K, // lda - rndHeadSize, //ldb - rndHeadSize, //ldc - s != 0, - qk_reduced_data + s * qk_reduce_strideL, - value_reorder_b + s * v_reorder_strideL, - dst_s32_b); - } + for (int64_t s = 0; s < kvSlice; s++) { + at::native::cpublas::brgemm( + qSplitSize, headSize, av_gemm_K, + av_gemm_K, // lda + rndHeadSize, //ldb + rndHeadSize, //ldc + s != 0, + qk_reduced_data + s * qk_reduce_strideL, + v_reorder + s * v_reorder_strideL, + dst_s32_data); } // After the last gemm, // do dequant compensation, quant and convert from s32 to int8 {%- if a_zp == 0 %} - {{kernel.kernel_name}}_dequant_quant_fusion_kernel( + dequant_quant_fusion_kernel( dst_s32_data, //in a_sum_ptr, //sum_a_ptr qBlockSize, //M @@ -1552,7 +1522,7 @@ out_data + i * oStrideB + j * oStrideH + m * oStrideM //out ); {%- else %} - {{kernel.kernel_name}}_dequant_quant_fusion_kernel( + dequant_quant_fusion_kernel( dst_s32_data, //in a_sum_ptr, //sum_a_ptr v_sum_ptr, //sum_b_ptr @@ -1704,8 +1674,8 @@ def get_options( if qSize >= 768: q_split_size = 256 elif qSize >= 192: - q_split_size = 64 - kv_split_size = 64 + q_split_size = 128 + kv_split_size = 512 qSplitSize = min(qSize, q_split_size) l2_cache_size = torch._C._cpu._L2_cache_size()