From c0380cea20132e88d2364208f46088e5566269b8 Mon Sep 17 00:00:00 2001 From: GuoxiaWang Date: Mon, 23 Sep 2024 16:27:50 +0800 Subject: [PATCH] skip right mask block each row first --- csrc/flash_attn/src/flash_fwd_kernel.h | 117 +++++++++++++++++-------- 1 file changed, 80 insertions(+), 37 deletions(-) diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index ff2297040df..5db3a986e3f 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -1080,9 +1080,86 @@ __forceinline__ __device__ void compute_attn_1rowblock_flashmask(const Params &p int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); if (Is_causal) { n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM, kBlockN)); - // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { - // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); - // } + } + + const index_t row_offset_sparse_mask = (bidb * params.h_sparsemask + bidh / params.h_h_sparsemask_ratio) * params.seqlen_k + (n_block_max - 1) * kBlockN; + const index_t row_offset_sparsemask_nblock = + (bidb * params.h_sparsemask + bidh / params.h_h_sparsemask_ratio) * cute::ceil_div(params.seqlen_k, kBlockN); + Tensor gFlashMaskLTStart = make_tensor(make_gmem_ptr(reinterpret_cast(params.flashmask_downstart_ptr) + row_offset_sparse_mask), + Shape>{}); + Tensor gFlashMaskLTEnd = make_tensor(make_gmem_ptr(reinterpret_cast(params.flashmask_downend_ptr) + row_offset_sparse_mask), + Shape>{}); + Tensor gFlashMaskUTStart = make_tensor(make_gmem_ptr(reinterpret_cast(params.flashmask_upstart_ptr) + row_offset_sparse_mask), + Shape>{}); + Tensor gFlashMaskUTEnd = make_tensor(make_gmem_ptr(reinterpret_cast(params.flashmask_upend_ptr) + row_offset_sparse_mask), + Shape>{}); + const int* gFlashMaskLTStartMax = reinterpret_cast(params.flashmask_downstart_nblockmax) + row_offset_sparsemask_nblock; + const int* gFlashMaskLTStartMin = reinterpret_cast(params.flashmask_downstart_nblockmin) + row_offset_sparsemask_nblock; + const int* gFlashMaskLTEndMax = reinterpret_cast(params.flashmask_downend_nblockmax) + row_offset_sparsemask_nblock; + const int* gFlashMaskLTEndMin = reinterpret_cast(params.flashmask_downend_nblockmin) + row_offset_sparsemask_nblock; + const int* gFlashMaskUTStartMax = reinterpret_cast(params.flashmask_upstart_nblockmax) + row_offset_sparsemask_nblock; + const int* gFlashMaskUTStartMin = reinterpret_cast(params.flashmask_upstart_nblockmin) + row_offset_sparsemask_nblock; + const int* gFlashMaskUTEndMax = reinterpret_cast(params.flashmask_upend_nblockmax) + row_offset_sparsemask_nblock; + const int* gFlashMaskUTEndMin = reinterpret_cast(params.flashmask_upend_nblockmin) + row_offset_sparsemask_nblock; + + const bool enable_mask_bypass = params.enable_mask_bypass; + const bool flashmask_lt_has_end = params.flashmask_downend_ptr != nullptr; + const bool flashmask_ut_has_start = params.flashmask_upstart_ptr != nullptr; + +#define SPARSE_MASKED_DOWN(N_BLOCK) \ + (((m_block * kBlockM) >= gFlashMaskLTStartMax[(N_BLOCK)]) && (!flashmask_lt_has_end || (m_block + 1) * kBlockM <= gFlashMaskLTEndMin[(N_BLOCK)])) + +#define SPARSE_MASKED_UP(N_BLOCK) \ + (!Is_causal && (m_block + 1) * kBlockM <= gFlashMaskUTEndMin[(N_BLOCK)] && (!flashmask_ut_has_start || m_block * kBlockM >= gFlashMaskUTStartMax[(N_BLOCK)])) + +#define SPARSE_MASKED(N_BLOCK) \ + (SPARSE_MASKED_DOWN(N_BLOCK) || SPARSE_MASKED_UP(N_BLOCK)) + + for (--n_block_max; n_block_max >= 0; --n_block_max) { + if (true/*Is_flashmask*/ && n_block_max >= 0 && enable_mask_bypass && SPARSE_MASKED(n_block_max)) { + gFlashMaskLTStart.data() = gFlashMaskLTStart.data() + (-kBlockN); + gFlashMaskLTEnd.data() = gFlashMaskLTEnd.data() + (-kBlockN); + if (!Is_causal) { + gFlashMaskUTEnd.data() = gFlashMaskUTEnd.data() + (-kBlockN); + gFlashMaskUTStart.data() = gFlashMaskUTStart.data() + (-kBlockN); + } + continue; + } else { + n_block_max++; + break; + } + } + + if (n_block_max <= 0) { + // need clear O block if we skip the whole row, otherwise elements in corresponding block will be uninitialized + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_tensor(shape(tOgO)); + cute::clear(tOrO); + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(kBlockM, kHeadDim)); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); + + return; } // We iterate over the blocks in reverse order. This is because the last block is the only one @@ -1104,9 +1181,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_flashmask(const Params &p + (m_block * kBlockM % params.mask_seq_q_mod_size)) * params.seqlen_k + (n_block_max - 1) * kBlockN; - const index_t row_offset_sparse_mask = (bidb * params.h_sparsemask + bidh / params.h_h_sparsemask_ratio) * params.seqlen_k + (n_block_max - 1) * kBlockN; - const index_t row_offset_sparsemask_nblock = - (bidb * params.h_sparsemask + bidh / params.h_h_sparsemask_ratio) * cute::ceil_div(params.seqlen_k, kBlockN); Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, @@ -1121,28 +1195,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_flashmask(const Params &p Shape, Int>{}, make_stride(params.seqlen_k_rounded, _1{})); - Tensor gFlashMaskLTStart = make_tensor(make_gmem_ptr(reinterpret_cast(params.flashmask_downstart_ptr) + row_offset_sparse_mask), - Shape>{}); - Tensor gFlashMaskLTEnd = make_tensor(make_gmem_ptr(reinterpret_cast(params.flashmask_downend_ptr) + row_offset_sparse_mask), - Shape>{}); - Tensor gFlashMaskUTStart = make_tensor(make_gmem_ptr(reinterpret_cast(params.flashmask_upstart_ptr) + row_offset_sparse_mask), - Shape>{}); - Tensor gFlashMaskUTEnd = make_tensor(make_gmem_ptr(reinterpret_cast(params.flashmask_upend_ptr) + row_offset_sparse_mask), - Shape>{}); - const int* gFlashMaskLTStartMax = reinterpret_cast(params.flashmask_downstart_nblockmax) + row_offset_sparsemask_nblock; - const int* gFlashMaskLTStartMin = reinterpret_cast(params.flashmask_downstart_nblockmin) + row_offset_sparsemask_nblock; - const int* gFlashMaskLTEndMax = reinterpret_cast(params.flashmask_downend_nblockmax) + row_offset_sparsemask_nblock; - const int* gFlashMaskLTEndMin = reinterpret_cast(params.flashmask_downend_nblockmin) + row_offset_sparsemask_nblock; - const int* gFlashMaskUTStartMax = reinterpret_cast(params.flashmask_upstart_nblockmax) + row_offset_sparsemask_nblock; - const int* gFlashMaskUTStartMin = reinterpret_cast(params.flashmask_upstart_nblockmin) + row_offset_sparsemask_nblock; - const int* gFlashMaskUTEndMax = reinterpret_cast(params.flashmask_upend_nblockmax) + row_offset_sparsemask_nblock; - const int* gFlashMaskUTEndMin = reinterpret_cast(params.flashmask_upend_nblockmin) + row_offset_sparsemask_nblock; - - - const bool enable_mask_bypass = params.enable_mask_bypass; - const bool flashmask_lt_has_end = params.flashmask_downend_ptr != nullptr; - const bool flashmask_ut_has_start = params.flashmask_upstart_ptr != nullptr; - Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQ{}); // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; @@ -1299,15 +1351,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_flashmask(const Params &p // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. // We will have at least 1 "masking" iteration. -#define SPARSE_MASKED_DOWN(N_BLOCK) \ - (((m_block * kBlockM) >= gFlashMaskLTStartMax[(N_BLOCK)]) && (!flashmask_lt_has_end || (m_block + 1) * kBlockM <= gFlashMaskLTEndMin[(N_BLOCK)])) - -#define SPARSE_MASKED_UP(N_BLOCK) \ - (!Is_causal && (m_block + 1) * kBlockM <= gFlashMaskUTEndMin[(N_BLOCK)] && (!flashmask_ut_has_start || m_block * kBlockM >= gFlashMaskUTStartMax[(N_BLOCK)])) - -#define SPARSE_MASKED(N_BLOCK) \ - (SPARSE_MASKED_DOWN(N_BLOCK) || SPARSE_MASKED_UP(N_BLOCK)) - constexpr int n_masking_steps = Is_causal ? cute::ceil_div(kBlockM, kBlockN) : 1; #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {