Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions hopper/debug.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#pragma once

#define DEBUG_PRINT 0

#if DEBUG_PRINT
#define DPRINTF(fmt, ...) printf("%s:%d " fmt, __FILE__, __LINE__, ## __VA_ARGS__);
#define DPRINTF0(fmt, ...) if (threadIdx.x == 0) printf("%s:%d " fmt, __FILE__, __LINE__, ## __VA_ARGS__);
#define PRODUCER_DPRINTF0(fmt, ...) if (threadIdx.x == 0 && block0()) printf("%s:%d " fmt, __FILE__, __LINE__, ## __VA_ARGS__);
#define CONSUMER_DPRINTF0(fmt, ...) if (threadIdx.x == 128 && block0()) printf("%s:%d " fmt, __FILE__, __LINE__, ## __VA_ARGS__);
#else
#define DPRINTF(fmt, ...)
#define DPRINTF0(fmt, ...)
#define PRODUCER_DPRINTF0(fmt, ...)
#define CONSUMER_DPRINTF0(fmt, ...)
#endif
6 changes: 6 additions & 0 deletions hopper/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,12 @@ struct Flash_fwd_params : public Qkv_params {
int window_size_left, window_size_right;
int attention_chunk;

uint32_t const* sparse_masks;
int sparse_block_q;
int sparse_block_k;
int max_seqlen_q;
int max_seqlen_k;

// Pointer to the RNG seed (idx 0) and offset (idx 1).
uint64_t * rng_state;

Expand Down
37 changes: 37 additions & 0 deletions hopper/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,8 @@ mha_fwd_get_scheduler_metadata(
int64_t window_size_right,
int64_t attention_chunk,
bool has_softcap,
int64_t sparse_block_q,
int64_t sparse_block_k,
int64_t num_splits,
std::optional<bool> pack_gqa_,
int64_t sm_margin
Expand Down Expand Up @@ -616,6 +618,15 @@ mha_fwd_get_scheduler_metadata(

if (params.num_splits_dynamic_ptr) {
auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f);
if (sparse_block_q != 0) {
if (sparse_block_q == 128 && sparse_block_k == 128) {
kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90_blocksparse<128, 128>();
} else if (sparse_block_q == 64 && sparse_block_k == 64) {
kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90_blocksparse<64, 64>();
} else {
TORCH_CHECK(false, "");
}
}
auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr);
int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x);
int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x);
Expand Down Expand Up @@ -659,13 +670,16 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql
std::optional<at::Tensor> q_descale_, // (b, h_k), not (b, h)
std::optional<at::Tensor> k_descale_, // (b, h_k)
std::optional<at::Tensor> v_descale_, // (b, h_k)
std::optional<at::Tensor> sparse_masks_, // (b, h, ceil_div(max_s, q_blk), ceil(max_t, k_blk*8)) with row compression, 0 means the block is masked out
std::optional<double> softmax_scale_,
bool is_causal,
int64_t window_size_left,
int64_t window_size_right,
int64_t attention_chunk,
double softcap,
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int64_t sparse_block_q,
int64_t sparse_block_k,
std::optional<at::Tensor> scheduler_metadata_, // (b + 1)
int64_t num_splits,
std::optional<bool> pack_gqa_,
Expand Down Expand Up @@ -1094,6 +1108,17 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql
}
}

if (sparse_masks_.has_value()) {
TORCH_CHECK(max_seqlen_q_.has_value());
TORCH_CHECK(max_seqlen_k_.has_value());
auto sparse_masks = sparse_masks_.value();
params.sparse_masks = reinterpret_cast<uint32_t *>(sparse_masks.data_ptr());
params.sparse_block_q = sparse_block_q;
params.sparse_block_k = sparse_block_k;
params.max_seqlen_q = max_seqlen_q_.value();
params.max_seqlen_k = max_seqlen_k_.value();
}

#ifdef FLASHATTENTION_DISABLE_LOCAL
TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
#endif
Expand All @@ -1113,6 +1138,13 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql
TORCH_CHECK(!k_new_.has_value(), "This flash attention build does not support appending KV.");
#endif

#ifndef FLASHATTENTION_DISABLE_PAGEDKV
TORCH_CHECK(!params.page_table || !params.sparse_masks, "block sparse flash attention does not support paged KV.");
#endif
#ifndef FLASHATTENTION_DISABLE_APPENDKV
TORCH_CHECK(!params.page_table || !params.sparse_masks, "block sparse flash attention does not support appending KV.");
#endif

if (total_q > 0 && (total_k + params.total_knew) > 0 && num_heads_k > 0) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd(params, stream);
Expand Down Expand Up @@ -1642,13 +1674,16 @@ TORCH_LIBRARY(flash_attn_3, m) {
"Tensor? q_descale = None,"
"Tensor? k_descale = None,"
"Tensor? v_descale = None,"
"Tensor? sparse_masks = None,"
"float? softmax_scale = None,"
"bool is_causal = False,"
"int window_size_left = -1,"
"int window_size_right = -1,"
"int attention_chunk = 0,"
"float softcap = 0.0,"
"bool is_rotary_interleaved = False,"
"int sparse_block_q = 0,"
"int sparse_block_k = 0,"
"Tensor? scheduler_metadata = None,"
"int num_splits = 0,"
"bool? pack_gqa = None,"
Expand Down Expand Up @@ -1703,6 +1738,8 @@ TORCH_LIBRARY(flash_attn_3, m) {
"int window_size_right = -1,"
"int attention_chunk = 0,"
"bool has_softcap = False,"
"int sparse_block_q = 0,"
"int sparse_block_k = 0,"
"int num_splits = 0,"
"bool? pack_gqa = None,"
"int sm_margin = 0) -> Tensor");
Expand Down
42 changes: 42 additions & 0 deletions hopper/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,15 @@ def _flash_attn_forward(
q_descale,
k_descale,
v_descale,
sparse_masks,
softmax_scale,
causal,
window_size=(-1, -1),
attention_chunk=0,
softcap=0.0,
rotary_interleaved=True,
sparse_block_q=0,
sparse_block_k=0,
scheduler_metadata=None,
num_splits=1,
pack_gqa=None,
Expand Down Expand Up @@ -86,13 +89,16 @@ def _flash_attn_forward(
q_descale,
k_descale,
v_descale,
sparse_masks,
softmax_scale,
causal,
window_size[0],
window_size[1],
attention_chunk,
softcap,
rotary_interleaved,
sparse_block_q,
sparse_block_k,
scheduler_metadata,
num_splits,
pack_gqa,
Expand Down Expand Up @@ -192,6 +198,7 @@ def forward(
None, None, None, # page_table, kv_batch_idx, leftpad_k,
None, None, None, # rotary_cos/sin, seqlens_rotary
q_descale, k_descale, v_descale,
None, # sparse_masks
softmax_scale,
causal=causal,
window_size=window_size,
Expand Down Expand Up @@ -262,9 +269,12 @@ def forward(
causal,
qv=None,
q_descale=None, k_descale=None, v_descale=None,
sparse_masks=None,
window_size=(-1, -1),
attention_chunk=0,
softcap=0.0,
sparse_block_q=0,
sparse_block_k=0,
num_splits=1,
pack_gqa=None,
deterministic=False,
Expand All @@ -286,11 +296,14 @@ def forward(
None, None, None, # page_table, kv_batch_idx, leftpad_k,
None, None, None, # rotary_cos/sin, seqlens_rotary
q_descale, k_descale, v_descale,
sparse_masks,
softmax_scale,
causal=causal,
window_size=window_size,
attention_chunk=attention_chunk,
softcap=softcap,
sparse_block_q=sparse_block_q,
sparse_block_k=sparse_block_k,
num_splits=num_splits,
pack_gqa=pack_gqa,
sm_margin=sm_margin,
Expand Down Expand Up @@ -355,9 +368,12 @@ def forward(
causal,
qv=None,
q_descale=None, k_descale=None, v_descale=None,
sparse_masks=None,
window_size=(-1, -1),
attention_chunk=0,
softcap=0.0,
sparse_block_q=0,
sparse_block_k=0,
num_splits=1,
pack_gqa=None,
deterministic=False,
Expand All @@ -383,11 +399,14 @@ def forward(
None, None, None, # page_table, kv_batch_idx, leftpad_k,
None, None, None, # rotary_cos/sin, seqlens_rotary
q_descale, k_descale, v_descale,
sparse_masks,
softmax_scale,
causal=causal,
window_size=window_size,
attention_chunk=attention_chunk,
softcap=softcap,
sparse_block_q=sparse_block_q,
sparse_block_k=sparse_block_k,
num_splits=num_splits,
pack_gqa=pack_gqa,
sm_margin=sm_margin,
Expand Down Expand Up @@ -444,9 +463,12 @@ def flash_attn_qkvpacked_func(
softmax_scale=None,
causal=False,
q_descale=None, k_descale=None, v_descale=None,
sparse_masks=None,
window_size=(-1, -1),
attention_chunk=0,
softcap=0.0,
sparse_block_q=0,
sparse_block_k=0,
deterministic=False,
num_heads_q=None,
sm_margin=0,
Expand Down Expand Up @@ -490,9 +512,12 @@ def flash_attn_qkvpacked_func(
softmax_scale,
causal,
q_descale, k_descale, v_descale,
sparse_masks,
window_size,
attention_chunk,
softcap,
sparse_block_q,
sparse_block_k,
deterministic,
num_heads_q,
sm_margin,
Expand All @@ -507,9 +532,12 @@ def flash_attn_func(
causal=False,
qv=None,
q_descale=None, k_descale=None, v_descale=None,
sparse_masks=None,
window_size=(-1, -1),
attention_chunk=0,
softcap=0.0,
sparse_block_q=0,
sparse_block_k=0,
num_splits=1,
pack_gqa=None,
deterministic=False,
Expand Down Expand Up @@ -568,9 +596,12 @@ def flash_attn_func(
causal,
qv,
q_descale, k_descale, v_descale,
sparse_masks,
window_size,
attention_chunk,
softcap,
sparse_block_q,
sparse_block_k,
num_splits,
pack_gqa,
deterministic,
Expand All @@ -592,9 +623,12 @@ def flash_attn_varlen_func(
causal=False,
qv=None,
q_descale=None, k_descale=None, v_descale=None,
sparse_masks=None,
window_size=(-1, -1),
attention_chunk=0,
softcap=0.0,
sparse_block_q=0,
sparse_block_k=0,
num_splits=1,
pack_gqa=None,
deterministic=False,
Expand All @@ -614,9 +648,12 @@ def flash_attn_varlen_func(
causal,
qv,
q_descale, k_descale, v_descale,
sparse_masks,
window_size,
attention_chunk,
softcap,
sparse_block_q,
sparse_block_k,
num_splits,
pack_gqa,
deterministic,
Expand Down Expand Up @@ -776,6 +813,7 @@ def flash_attn_with_kvcache(
rotary_sin,
rotary_seqlens,
q_descale, k_descale, v_descale,
None, # sparse_masks
softmax_scale,
causal=causal,
window_size=window_size,
Expand Down Expand Up @@ -805,6 +843,8 @@ def get_scheduler_metadata(
window_size=(-1, -1), # -1 means infinite context window
attention_chunk=0,
has_softcap=False,
sparse_block_q=0,
sparse_block_k=0,
num_splits=0, # Can be tuned for speed
pack_gqa=None, # Can be tuned for speed
sm_margin=0, # Can be tuned if some SMs are used for communication
Expand All @@ -827,6 +867,8 @@ def get_scheduler_metadata(
window_size[0], window_size[1],
attention_chunk,
has_softcap,
sparse_block_q,
sparse_block_k,
num_splits,
pack_gqa,
sm_margin,
Expand Down
10 changes: 10 additions & 0 deletions hopper/flash_fwd_kernel_sm90.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include "utils.h"
#include "softmax.h"

#include "debug.hpp"

namespace flash {

using namespace cute;
Expand Down Expand Up @@ -101,6 +103,7 @@ class FlashAttnFwdSm90 {
// We want smem_o to line up with the start of smem_v
typename CollectiveEpilogue::TensorStorage epilogue;
};
cute::array_aligned<uint32_t, 512> sparse_masks; // supports up to 1M context with blocksize 64
} tensors;
struct PipelineStorage : cute::aligned_struct<16, _1> {
alignas(16) BarrierQ barrier_Q;
Expand Down Expand Up @@ -306,7 +309,9 @@ class FlashAttnFwdSm90 {
TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.pipelines.smem_scheduler));

if (warp_group_idx == 0) { // Producer
#if !defined(DEBUG_PRINT) || !DEBUG_PRINT
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
#endif

// The pipelines for AppendKV and main attention are different, since e.g. main attention
// might use cp.async to load KV (if PagedKVNonTMA) while AppendKV always uses TMA to load
Expand All @@ -330,6 +335,7 @@ class FlashAttnFwdSm90 {
work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {

auto block_coord = work_tile_info.get_block_coord(params.scheduler);
PRODUCER_DPRINTF0("Producer: work on q_block_idx=%d head_idx=%d batch_idx=%d split_idx=%d 0x%p\n", get<0>(block_coord), get<1>(block_coord), get<2>(block_coord), get<3>(block_coord), params.mainloop.sparse_masks);
SeqlenInfo_t seqlen_info{
get<2>(block_coord) /*bidb*/,
get<0>(params.mainloop.shape_Q),
Expand Down Expand Up @@ -358,7 +364,9 @@ class FlashAttnFwdSm90 {
}
mainloop.load_tail(pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, shared_storage, work_idx);
} else { // Consumer
#if !defined(DEBUG_PRINT) || !DEBUG_PRINT
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
#endif

// Initialize matmul objects.
TiledMmaPV tiled_mma_pv;
Expand All @@ -378,6 +386,8 @@ class FlashAttnFwdSm90 {
// get_next_work will be called before the epilogue
) {
auto block_coord = work_tile_info.get_block_coord(params.scheduler);
CONSUMER_DPRINTF0("Consumer: work on q_block_idx=%d head_idx=%d batch_idx=%d split_idx=%d 0x%p\n", get<0>(block_coord), get<1>(block_coord), get<2>(block_coord), get<3>(block_coord), params.mainloop.sparse_masks);

int const bidb = get<2>(block_coord);
SeqlenInfo_t seqlen_info{
bidb,
Expand Down
Loading