diff --git a/flash-attn2/build.toml b/flash-attn2/build.toml index c9a63c4f..6bd4e8a5 100644 --- a/flash-attn2/build.toml +++ b/flash-attn2/build.toml @@ -230,6 +230,22 @@ src = [ "flash_attn_xpu/src/fmha_bwd.hpp", "flash_attn_xpu/src/fmha_bwd_impl.hpp", "flash_attn_xpu/src/fmha_bwd.cpp", + "flash_attn_xpu/src/flash_bwd_hdim32_varlen_fp16.cpp", + "flash_attn_xpu/src/flash_bwd_hdim32_varlen_bf16.cpp", + "flash_attn_xpu/src/flash_bwd_hdim64_varlen_fp16.cpp", + "flash_attn_xpu/src/flash_bwd_hdim64_varlen_bf16.cpp", + "flash_attn_xpu/src/flash_bwd_hdim96_varlen_fp16.cpp", + "flash_attn_xpu/src/flash_bwd_hdim96_varlen_bf16.cpp", + "flash_attn_xpu/src/flash_bwd_hdim128_varlen_fp16.cpp", + "flash_attn_xpu/src/flash_bwd_hdim128_varlen_bf16.cpp", + "flash_attn_xpu/src/flash_bwd_hdim160_varlen_fp16.cpp", + "flash_attn_xpu/src/flash_bwd_hdim160_varlen_bf16.cpp", + "flash_attn_xpu/src/flash_bwd_hdim192_varlen_fp16.cpp", + "flash_attn_xpu/src/flash_bwd_hdim192_varlen_bf16.cpp", + "flash_attn_xpu/src/flash_bwd_hdim256_varlen_fp16.cpp", + "flash_attn_xpu/src/flash_bwd_hdim256_varlen_bf16.cpp", + "flash_attn_xpu/src/flash_bwd_hdim512_varlen_fp16.cpp", + "flash_attn_xpu/src/flash_bwd_hdim512_varlen_bf16.cpp", "flash_attn_xpu/src/flash_bwd_hdim32_fix_fp16.cpp", "flash_attn_xpu/src/flash_bwd_hdim32_fix_bf16.cpp", "flash_attn_xpu/src/flash_bwd_hdim64_fix_fp16.cpp", diff --git a/flash-attn2/flash_attn_xpu/flash_api.cpp b/flash-attn2/flash_attn_xpu/flash_api.cpp index 5923a7dc..8701dcc5 100644 --- a/flash-attn2/flash_attn_xpu/flash_api.cpp +++ b/flash-attn2/flash_attn_xpu/flash_api.cpp @@ -392,6 +392,8 @@ mha_varlen_fwd( compat::select_device(device_idx); // check inputs + TORCH_CHECK(p_dropout < 1.0f, "FlashAttention dropout probability must be less than 1.0"); + q = ensure_contiguous(q); const auto sizes = q.sizes(); const int total_q = sizes[0]; @@ -446,6 +448,31 @@ mha_varlen_fwd( auto queue = c10::xpu::getCurrentXPUStream(device_idx).queue(); + uint64_t philox_seed = 0; + uint64_t philox_offset = 0; + at::Tensor rng_state = at::empty({2}, q.options().dtype(at::kLong)); + + if (p_dropout > 0.0f) { + int64_t counter_offset = batch_size * num_heads * 32; + auto [seed, offset] = get_philox_state(gen_, counter_offset); + philox_seed = seed; + philox_offset = offset; + rng_state[0] = static_cast(philox_seed); + rng_state[1] = static_cast(philox_offset); + } + + const int q_tile_size = (head_size_og <= 32) ? 64 : (head_size_og <= 128) ? 128 : 256; + const int k_tile_size = (max_seqlen_q == 1 && !is_paged) ? 512 : 128; + const int seqlen_q_rounded = round_multiple(max_seqlen_q, q_tile_size); + const int seqlen_k_rounded = round_multiple(max_seqlen_k, k_tile_size); + at::Tensor S_dmask; + if (return_softmax) { + TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); + S_dmask = torch::empty({batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded}, q.options()); + } else { + S_dmask = torch::empty({0}, q.options()); + } + cutlass_fmha_fwd_varlen_impl( queue, q_padded, k_padded, v_padded, out_padded, @@ -455,7 +482,10 @@ mha_varlen_fwd( max_seqlen_q, max_seqlen_k, softmax_scale, window_size_left, window_size_right, - true, is_paged, is_causal, is_local); + true, is_paged, is_causal, is_local, + p_dropout, philox_seed, philox_offset, nullptr, + return_softmax ? S_dmask.data_ptr() : nullptr, + seqlen_q_rounded, seqlen_k_rounded); // Strip padding from output back to original head_size at::Tensor out = needs_padding @@ -464,11 +494,214 @@ mha_varlen_fwd( .contiguous() : ensure_contiguous(out_padded); - at::Tensor S_dmask; - at::Tensor rng_state; return {out, softmax_lse, S_dmask, rng_state}; } +std::vector +mha_varlen_bwd( + const at::Tensor &dout, + const at::Tensor &q, + const at::Tensor &k, + const at::Tensor &v, + const at::Tensor &out, + const at::Tensor &softmax_lse, + std::optional &dq_, + std::optional &dk_, + std::optional &dv_, + const at::Tensor &cu_seqlens_q, + const at::Tensor &cu_seqlens_k, + std::optional &alibi_slopes_, + const int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool deterministic, + std::optional gen_, + std::optional &rng_state) { + + auto device_idx = q.device().index(); + compat::select_device(device_idx); + + TORCH_CHECK(p_dropout < 1.0f, "FlashAttention dropout probability must be less than 1.0"); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention varlen backward only supports fp16 and bf16 data type"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); + TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); + + CHECK_DEVICE(q); + CHECK_DEVICE(k); + CHECK_DEVICE(v); + CHECK_DEVICE(out); + CHECK_DEVICE(dout); + CHECK_DEVICE(softmax_lse); + CHECK_DEVICE(cu_seqlens_q); + CHECK_DEVICE(cu_seqlens_k); + + TORCH_CHECK(q.dim() == 3, "q must have shape (total_q, num_heads, head_size)"); + TORCH_CHECK(k.dim() == 3, "k must have shape (total_k, num_heads_k, head_size)"); + TORCH_CHECK(v.dim() == 3, "v must have shape (total_k, num_heads_k, head_size)"); + TORCH_CHECK(out.sizes() == q.sizes(), "out must have the same shape as q"); + TORCH_CHECK(dout.sizes() == q.sizes(), "dout must have the same shape as q"); + TORCH_CHECK(k.sizes() == v.sizes(), "k and v must have the same shape"); + TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_k.dim() == 1, + "cu_seqlens_q and cu_seqlens_k must be 1D tensors"); + TORCH_CHECK(cu_seqlens_q.scalar_type() == at::ScalarType::Int && + cu_seqlens_k.scalar_type() == at::ScalarType::Int, + "cu_seqlens_q and cu_seqlens_k must be int32 tensors"); + TORCH_CHECK(cu_seqlens_q.numel() == cu_seqlens_k.numel(), + "cu_seqlens_q and cu_seqlens_k must have the same length"); + TORCH_CHECK(max_seqlen_q > 0 && max_seqlen_k > 0, + "max_seqlen_q and max_seqlen_k must be positive"); + + at::Tensor q_contig = ensure_contiguous(q); + at::Tensor k_contig = ensure_contiguous(k); + at::Tensor v_contig = ensure_contiguous(v); + at::Tensor out_contig = ensure_contiguous(out); + at::Tensor dout_contig = ensure_contiguous(dout); + at::Tensor softmax_lse_contig = ensure_contiguous(softmax_lse); + at::Tensor cu_q_contig = ensure_contiguous(cu_seqlens_q); + at::Tensor cu_k_contig = ensure_contiguous(cu_seqlens_k); + + TORCH_CHECK(q_contig.stride(-1) == 1, "q must have contiguous last dimension"); + TORCH_CHECK(k_contig.stride(-1) == 1, "k must have contiguous last dimension"); + TORCH_CHECK(v_contig.stride(-1) == 1, "v must have contiguous last dimension"); + TORCH_CHECK(out_contig.stride(-1) == 1, "out must have contiguous last dimension"); + TORCH_CHECK(dout_contig.stride(-1) == 1, "dout must have contiguous last dimension"); + + const int total_q = q_contig.size(0); + const int total_k = k_contig.size(0); + const int num_heads = q_contig.size(1); + const int head_size_og = q_contig.size(2); + const int num_heads_k = k_contig.size(1); + const int batch_size = cu_q_contig.numel() - 1; + + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size_og % 8 == 0, "head_size should be a multiple of 8"); + TORCH_CHECK(head_size_og <= 256 || head_size_og == 512, + "FlashAttention XPU varlen backward only supports head dimension up to 256 or exactly 512. Got: " + std::to_string(head_size_og)); + TORCH_CHECK(k_contig.size(2) == head_size_og && v_contig.size(2) == head_size_og, + "k and v head dimensions must match q"); + TORCH_CHECK(num_heads % num_heads_k == 0, + "Number of heads in key/value must divide number of heads in query"); + TORCH_CHECK(softmax_lse_contig.dim() == 3 && softmax_lse_contig.size(0) == batch_size && + softmax_lse_contig.size(1) == num_heads && softmax_lse_contig.size(2) == max_seqlen_q, + "softmax_lse must have shape (batch_size, num_heads, max_seqlen_q)"); + const int head_size_padded = round_multiple(head_size_og, 32); + const bool needs_padding = (head_size_og != head_size_padded); + const int pad_size = head_size_padded - head_size_og; + + auto maybe_pad = [&](const at::Tensor &t) -> at::Tensor { + return needs_padding + ? torch::nn::functional::pad(t, torch::nn::functional::PadFuncOptions({0, pad_size})) + : t; + }; + + at::Tensor q_padded = maybe_pad(q_contig); + at::Tensor k_padded = maybe_pad(k_contig); + at::Tensor v_padded = maybe_pad(v_contig); + at::Tensor out_padded = maybe_pad(out_contig); + at::Tensor dout_padded = maybe_pad(dout_contig); + + auto opts = q_contig.options(); + at::Tensor dq, dk, dv; + at::Tensor dq_work, dk_work, dv_work; + bool dq_needs_copy = false, dk_needs_copy = false, dv_needs_copy = false; + + auto get_or_alloc = [&](const c10::optional &opt, + const std::vector &sizes, + bool &needs_copy) -> at::Tensor { + if (opt.has_value() && opt.value().is_contiguous()) { + return opt.value(); + } + needs_copy = opt.has_value(); + return torch::empty(sizes, opts); + }; + + if (!needs_padding && num_heads_k == num_heads) { + dq_work = get_or_alloc(dq_, {total_q, num_heads, head_size_og}, dq_needs_copy); + dk_work = get_or_alloc(dk_, {total_k, num_heads_k, head_size_og}, dk_needs_copy); + dv_work = get_or_alloc(dv_, {total_k, num_heads_k, head_size_og}, dv_needs_copy); + } else { + dq_work = torch::empty({total_q, num_heads, head_size_padded}, opts); + dk_work = torch::empty({total_k, num_heads_k, head_size_padded}, opts); + dv_work = torch::empty({total_k, num_heads_k, head_size_padded}, opts); + } + + at::Tensor softmax_d = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + + at::Tensor dk_expanded, dv_expanded; + if (num_heads_k != num_heads) { + dk_expanded = torch::empty({total_k, num_heads, head_size_padded}, opts); + dv_expanded = torch::empty({total_k, num_heads, head_size_padded}, opts); + } else { + dk_expanded = dk_work; + dv_expanded = dv_work; + } + + const bool is_local = (window_size_left != -1) || (window_size_right != -1); + auto queue = c10::xpu::getCurrentXPUStream(device_idx).queue(); + + uint64_t philox_seed = 0; + uint64_t philox_offset = 0; + if (p_dropout > 0.0f && rng_state.has_value()) { + auto rng_state_val = rng_state.value(); + philox_seed = static_cast(rng_state_val[0].item()); + philox_offset = static_cast(rng_state_val[1].item()); + } + + cutlass_fmha_bwd_varlen_impl( + queue, + dout_padded, q_padded, k_padded, v_padded, out_padded, softmax_lse_contig, + cu_q_contig, cu_k_contig, + dq_work, dk_expanded, dv_expanded, softmax_d, + softmax_scale, + max_seqlen_q, max_seqlen_k, + window_size_left, window_size_right, is_causal, is_local, + p_dropout, philox_seed, philox_offset, deterministic); + + if (num_heads_k != num_heads) { + at::sum_out(dk_work, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size_padded}), {2}); + at::sum_out(dv_work, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size_padded}), {2}); + } + + auto slice_head = [&](const at::Tensor &t) -> at::Tensor { + return t.index({torch::indexing::Slice(), torch::indexing::Slice(), + torch::indexing::Slice(0, head_size_og)}).contiguous(); + }; + auto copy_if_provided = [&](const c10::optional &opt, const at::Tensor &src) { + if (opt.has_value()) { + opt.value().copy_(src); + } + }; + + if (needs_padding || num_heads_k != num_heads) { + dq = slice_head(dq_work); + dk = slice_head(dk_work); + dv = slice_head(dv_work); + copy_if_provided(dq_, dq); + copy_if_provided(dk_, dk); + copy_if_provided(dv_, dv); + } else { + dq = dq_work; + dk = dk_work; + dv = dv_work; + if (dq_needs_copy) dq_.value().copy_(dq); + if (dk_needs_copy) dk_.value().copy_(dk); + if (dv_needs_copy) dv_.value().copy_(dv); + } + + return {dq, dk, dv, softmax_d}; +} + std::vector mha_fwd_kvcache( at::Tensor &q, @@ -749,8 +982,41 @@ mha_varlen_bwd( const bool deterministic, c10::optional gen_, const c10::optional &rng_state) { - TORCH_CHECK(false, "FlashAttention varlen backward is not supported on XPU yet."); - return {}; + auto to_std_opt = [](const c10::optional& opt) -> std::optional { + return opt.has_value() ? std::optional(opt.value()) : std::nullopt; + }; + + auto dq_opt = to_std_opt(dq_); + auto dk_opt = to_std_opt(dk_); + auto dv_opt = to_std_opt(dv_); + auto alibi_opt = to_std_opt(alibi_slopes_); + auto rng_opt = to_std_opt(rng_state); + + return FLASH_NAMESPACE::mha_varlen_bwd( + dout, + q, + k, + v, + out, + softmax_lse, + dq_opt, + dk_opt, + dv_opt, + cu_seqlens_q, + cu_seqlens_k, + alibi_opt, + static_cast(max_seqlen_q), + static_cast(max_seqlen_k), + static_cast(p_dropout), + static_cast(softmax_scale), + zero_tensors, + is_causal, + static_cast(window_size_left), + static_cast(window_size_right), + static_cast(softcap), + deterministic, + gen_, + rng_opt); } std::vector diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim128_fix_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim128_fix_bf16.cpp index 3263f32d..c1506333 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim128_fix_bf16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim128_fix_bf16.cpp @@ -2,14 +2,14 @@ // Fixed mode backward for head_dim=128, dtype=bf16 -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim128_fix_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim128_fix_fp16.cpp index 05cc7c0d..4b9f0195 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim128_fix_fp16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim128_fix_fp16.cpp @@ -2,14 +2,14 @@ // Fixed mode backward for head_dim=128, dtype=fp16 -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim128_varlen_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim128_varlen_bf16.cpp new file mode 100644 index 00000000..f7cfab57 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim128_varlen_bf16.cpp @@ -0,0 +1,15 @@ +#include "fmha_bwd_impl.hpp" + +// Varlen mode backward for head_dim=128, dtype=bf16 + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim128_varlen_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim128_varlen_fp16.cpp new file mode 100644 index 00000000..a6489eba --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim128_varlen_fp16.cpp @@ -0,0 +1,15 @@ +#include "fmha_bwd_impl.hpp" + +// Varlen mode backward for head_dim=128, dtype=fp16 + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim160_fix_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim160_fix_bf16.cpp index b8f5019b..b686fef0 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim160_fix_bf16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim160_fix_bf16.cpp @@ -2,14 +2,14 @@ // Fixed mode backward for head_dim=160, dtype=bf16 -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim160_fix_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim160_fix_fp16.cpp index 24f50970..e5d24855 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim160_fix_fp16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim160_fix_fp16.cpp @@ -2,14 +2,14 @@ // Fixed mode backward for head_dim=160, dtype=fp16 -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim160_varlen_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim160_varlen_bf16.cpp new file mode 100644 index 00000000..23d355be --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim160_varlen_bf16.cpp @@ -0,0 +1,15 @@ +#include "fmha_bwd_impl.hpp" + +// Varlen mode backward for head_dim=160, dtype=bf16 + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim160_varlen_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim160_varlen_fp16.cpp new file mode 100644 index 00000000..f3ae5931 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim160_varlen_fp16.cpp @@ -0,0 +1,15 @@ +#include "fmha_bwd_impl.hpp" + +// Varlen mode backward for head_dim=160, dtype=fp16 + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim192_fix_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim192_fix_bf16.cpp index 982eaec5..6a7980a1 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim192_fix_bf16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim192_fix_bf16.cpp @@ -2,14 +2,14 @@ // Fixed mode backward for head_dim=192, dtype=bf16 -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim192_fix_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim192_fix_fp16.cpp index 2eda0f08..beab6f3d 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim192_fix_fp16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim192_fix_fp16.cpp @@ -2,14 +2,14 @@ // Fixed mode backward for head_dim=192, dtype=fp16 -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim192_varlen_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim192_varlen_bf16.cpp new file mode 100644 index 00000000..890587e8 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim192_varlen_bf16.cpp @@ -0,0 +1,15 @@ +#include "fmha_bwd_impl.hpp" + +// Varlen mode backward for head_dim=192, dtype=bf16 + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim192_varlen_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim192_varlen_fp16.cpp new file mode 100644 index 00000000..1c609800 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim192_varlen_fp16.cpp @@ -0,0 +1,15 @@ +#include "fmha_bwd_impl.hpp" + +// Varlen mode backward for head_dim=192, dtype=fp16 + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim256_fix_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim256_fix_bf16.cpp index 14c13316..de9798f6 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim256_fix_bf16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim256_fix_bf16.cpp @@ -2,14 +2,14 @@ // Fixed mode backward for head_dim=256, dtype=bf16 -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim256_fix_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim256_fix_fp16.cpp index 64ba3151..0f9c5693 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim256_fix_fp16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim256_fix_fp16.cpp @@ -2,14 +2,14 @@ // Fixed mode backward for head_dim=256, dtype=fp16 -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim256_varlen_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim256_varlen_bf16.cpp new file mode 100644 index 00000000..1abe4926 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim256_varlen_bf16.cpp @@ -0,0 +1,15 @@ +#include "fmha_bwd_impl.hpp" + +// Varlen mode backward for head_dim=256, dtype=bf16 + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim256_varlen_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim256_varlen_fp16.cpp new file mode 100644 index 00000000..36ae3d4c --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim256_varlen_fp16.cpp @@ -0,0 +1,15 @@ +#include "fmha_bwd_impl.hpp" + +// Varlen mode backward for head_dim=256, dtype=fp16 + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim32_fix_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim32_fix_bf16.cpp index b4806d08..6aca0ebd 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim32_fix_bf16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim32_fix_bf16.cpp @@ -2,14 +2,14 @@ // Fixed mode backward for head_dim=32, dtype=bf16 -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim32_fix_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim32_fix_fp16.cpp index 09a40031..2a1b776f 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim32_fix_fp16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim32_fix_fp16.cpp @@ -2,14 +2,14 @@ // Fixed mode backward for head_dim=32, dtype=fp16 -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim32_varlen_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim32_varlen_bf16.cpp new file mode 100644 index 00000000..e7befecb --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim32_varlen_bf16.cpp @@ -0,0 +1,15 @@ +#include "fmha_bwd_impl.hpp" + +// Varlen mode backward for head_dim=32, dtype=bf16 + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim32_varlen_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim32_varlen_fp16.cpp new file mode 100644 index 00000000..8a518402 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim32_varlen_fp16.cpp @@ -0,0 +1,15 @@ +#include "fmha_bwd_impl.hpp" + +// Varlen mode backward for head_dim=32, dtype=fp16 + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim512_fix_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim512_fix_bf16.cpp index 195bbdb6..7fdb24fe 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim512_fix_bf16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim512_fix_bf16.cpp @@ -2,14 +2,14 @@ // Fixed mode backward for head_dim=512, dtype=bf16 -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim512_fix_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim512_fix_fp16.cpp index 671b0aba..0c38875d 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim512_fix_fp16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim512_fix_fp16.cpp @@ -2,14 +2,14 @@ // Fixed mode backward for head_dim=512, dtype=fp16 -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim512_varlen_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim512_varlen_bf16.cpp new file mode 100644 index 00000000..caf75ffe --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim512_varlen_bf16.cpp @@ -0,0 +1,15 @@ +#include "fmha_bwd_impl.hpp" + +// Varlen mode backward for head_dim=512, dtype=bf16 + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim512_varlen_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim512_varlen_fp16.cpp new file mode 100644 index 00000000..d4a37a02 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim512_varlen_fp16.cpp @@ -0,0 +1,15 @@ +#include "fmha_bwd_impl.hpp" + +// Varlen mode backward for head_dim=512, dtype=fp16 + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim64_fix_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim64_fix_bf16.cpp index b17abd51..6f9ec100 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim64_fix_bf16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim64_fix_bf16.cpp @@ -2,14 +2,14 @@ // Fixed mode backward for head_dim=64, dtype=bf16 -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim64_fix_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim64_fix_fp16.cpp index cc1b81ca..d7c6c82b 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim64_fix_fp16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim64_fix_fp16.cpp @@ -2,14 +2,14 @@ // Fixed mode backward for head_dim=64, dtype=fp16 -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim64_varlen_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim64_varlen_bf16.cpp new file mode 100644 index 00000000..2ed4038d --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim64_varlen_bf16.cpp @@ -0,0 +1,15 @@ +#include "fmha_bwd_impl.hpp" + +// Varlen mode backward for head_dim=64, dtype=bf16 + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim64_varlen_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim64_varlen_fp16.cpp new file mode 100644 index 00000000..18f65cd3 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim64_varlen_fp16.cpp @@ -0,0 +1,15 @@ +#include "fmha_bwd_impl.hpp" + +// Varlen mode backward for head_dim=64, dtype=fp16 + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim96_fix_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim96_fix_bf16.cpp index 8a7a6ec4..5bcb3b46 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim96_fix_bf16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim96_fix_bf16.cpp @@ -2,14 +2,14 @@ // Fixed mode backward for head_dim=96, dtype=bf16 -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_bf16( +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim96_fix_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim96_fix_fp16.cpp index e40fa670..7eb1c7eb 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim96_fix_fp16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim96_fix_fp16.cpp @@ -2,14 +2,14 @@ // Fixed mode backward for head_dim=96, dtype=fp16 -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); -template void bwd_policy_dispatch_fp16( +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim96_varlen_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim96_varlen_bf16.cpp new file mode 100644 index 00000000..7b7d43c7 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim96_varlen_bf16.cpp @@ -0,0 +1,15 @@ +#include "fmha_bwd_impl.hpp" + +// Varlen mode backward for head_dim=96, dtype=bf16 + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_bf16( + sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim96_varlen_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim96_varlen_fp16.cpp new file mode 100644 index 00000000..536a03e1 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_bwd_hdim96_varlen_fp16.cpp @@ -0,0 +1,15 @@ +#include "fmha_bwd_impl.hpp" + +// Varlen mode backward for head_dim=96, dtype=fp16 + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); + +template void bwd_policy_dispatch_fp16( + sycl::queue& queue, const fmha_bwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/fmha_bwd.cpp b/flash-attn2/flash_attn_xpu/src/fmha_bwd.cpp index 9eea4d0b..1644b93f 100644 --- a/flash-attn2/flash_attn_xpu/src/fmha_bwd.cpp +++ b/flash-attn2/flash_attn_xpu/src/fmha_bwd.cpp @@ -20,34 +20,62 @@ constexpr int round_up(int x, int m) { return (x + m - 1) / m * m; } +[[nodiscard]] constexpr int bwd_block_m_for_head(int head_size) { + return (head_size > 96 && head_size <= 128) ? 128 : 64; +} + +[[nodiscard]] constexpr int bwd_block_n_for_head(int head_size) { + if (head_size <= 64) { + return 32; + } + if (head_size <= 96) { + return 64; + } + if (head_size <= 128) { + return 128; + } + return 32; +} + /// Dispatch backward kernel by head_size with pre-resolved causal/local flags. -template +template void dispatch_bwd_causal_local(sycl::queue& queue, BwdCutlassType cuType, const fmha_bwd_args_t& args, bool is_causal, bool is_local) { if (is_causal && is_local) { - bwd_policy_dispatch(queue, cuType, args); + bwd_policy_dispatch(queue, cuType, args); } else if (is_causal) { - bwd_policy_dispatch(queue, cuType, args); + bwd_policy_dispatch(queue, cuType, args); } else if (is_local) { - bwd_policy_dispatch(queue, cuType, args); + bwd_policy_dispatch(queue, cuType, args); + } else { + bwd_policy_dispatch(queue, cuType, args); + } +} + +template +void dispatch_bwd_causal_local(sycl::queue& queue, BwdCutlassType cuType, + const fmha_bwd_args_t& args, + bool is_causal, bool is_local, bool is_varlen) { + if (is_varlen) { + dispatch_bwd_causal_local(queue, cuType, args, is_causal, is_local); } else { - bwd_policy_dispatch(queue, cuType, args); + dispatch_bwd_causal_local(queue, cuType, args, is_causal, is_local); } } /// Dispatch backward kernel by head dimension. void dispatch_bwd_by_head(sycl::queue& queue, BwdCutlassType cuType, const fmha_bwd_args_t& args, int head_size, - bool is_causal, bool is_local) { - if (head_size <= 32) dispatch_bwd_causal_local (queue, cuType, args, is_causal, is_local); - else if (head_size <= 64) dispatch_bwd_causal_local (queue, cuType, args, is_causal, is_local); - else if (head_size <= 96) dispatch_bwd_causal_local (queue, cuType, args, is_causal, is_local); - else if (head_size <= 128) dispatch_bwd_causal_local(queue, cuType, args, is_causal, is_local); - else if (head_size <= 160) dispatch_bwd_causal_local(queue, cuType, args, is_causal, is_local); - else if (head_size <= 192) dispatch_bwd_causal_local(queue, cuType, args, is_causal, is_local); - else if (head_size <= 256) dispatch_bwd_causal_local(queue, cuType, args, is_causal, is_local); - else if (head_size == 512) dispatch_bwd_causal_local(queue, cuType, args, is_causal, is_local); + bool is_causal, bool is_local, bool is_varlen) { + if (head_size <= 32) dispatch_bwd_causal_local (queue, cuType, args, is_causal, is_local, is_varlen); + else if (head_size <= 64) dispatch_bwd_causal_local (queue, cuType, args, is_causal, is_local, is_varlen); + else if (head_size <= 96) dispatch_bwd_causal_local (queue, cuType, args, is_causal, is_local, is_varlen); + else if (head_size <= 128) dispatch_bwd_causal_local(queue, cuType, args, is_causal, is_local, is_varlen); + else if (head_size <= 160) dispatch_bwd_causal_local(queue, cuType, args, is_causal, is_local, is_varlen); + else if (head_size <= 192) dispatch_bwd_causal_local(queue, cuType, args, is_causal, is_local, is_varlen); + else if (head_size <= 256) dispatch_bwd_causal_local(queue, cuType, args, is_causal, is_local, is_varlen); + else if (head_size == 512) dispatch_bwd_causal_local(queue, cuType, args, is_causal, is_local, is_varlen); else throw std::runtime_error("Unsupported head_size: " + std::to_string(head_size) + ". Only <= 256 or exactly 512 is supported"); } @@ -84,9 +112,12 @@ void cutlass_fmha_bwd_fix_impl( const int seqlen_k = k.size(1); const int num_heads_k = k.size(2); - // Round up sequence lengths for internal buffers - const int seqlen_q_rounded = round_up(seqlen_q, 64); - const int seqlen_k_rounded = round_up(seqlen_k, 64); + const int kBlockM_bwd = bwd_block_m_for_head(head_size); + const int kBlockN_bwd = bwd_block_n_for_head(head_size); + + // Round up sequence lengths for internal buffers using the selected policy. + const int seqlen_q_rounded = round_up(seqlen_q, kBlockM_bwd); + const int seqlen_k_rounded = round_up(seqlen_k, kBlockN_bwd); // Allocate dq_accum buffer (float) // For deterministic mode, allocate separate splits to avoid atomicAdd races @@ -94,7 +125,7 @@ void cutlass_fmha_bwd_fix_impl( int dq_accum_split_stride = 0; at::Tensor dq_accum; if (!deterministic) { - dq_accum = at::zeros({batch_size, seqlen_q_rounded, num_heads_q, head_size}, + dq_accum = at::zeros({batch_size, num_heads_q, seqlen_q_rounded, head_size}, q.options().dtype(at::kFloat)); } else { // Each work-group gets its own split to write dQ accumulator @@ -107,15 +138,13 @@ void cutlass_fmha_bwd_fix_impl( // The minimum kBlockN across all head size policies is 32. const int max_n_blocks = std::max((seqlen_k + 31) / 32, 1); nsplits = std::min(nsplits, max_n_blocks); - dq_accum = at::zeros({nsplits, batch_size, seqlen_q_rounded, num_heads_q, head_size}, + dq_accum = at::zeros({nsplits, batch_size, num_heads_q, seqlen_q_rounded, head_size}, q.options().dtype(at::kFloat)); dq_accum_split_stride = batch_size * seqlen_q_rounded * num_heads_q * head_size; } // Pre-allocate intermediate buffer using PyTorch caching allocator // (avoids expensive compat::malloc/free per backward call). - // All backward policies use kBlockM=64. TODO hard code - constexpr int kBlockM_bwd = 64; at::Tensor pbuff_tensor = at::empty( {batch_size * num_heads_q * seqlen_k_rounded * 2 * kBlockM_bwd}, q.options()); @@ -128,6 +157,8 @@ void cutlass_fmha_bwd_fix_impl( k.data_ptr(), v.data_ptr(), softmax_lse.data_ptr(), + nullptr, + nullptr, dq.data_ptr(), dk.data_ptr(), dv.data_ptr(), @@ -142,9 +173,6 @@ void cutlass_fmha_bwd_fix_impl( seqlen_q_rounded, seqlen_k_rounded, sm_scale, - is_causal, - is_local, - q.scalar_type() == at::ScalarType::BFloat16, deterministic, nsplits, dq_accum_split_stride, @@ -157,5 +185,99 @@ void cutlass_fmha_bwd_fix_impl( }; const BwdCutlassType cuType = aten_to_Bwd_Cutlass_dtype(q); - dispatch_bwd_by_head(queue, cuType, args, args.head_size, is_causal, is_local); + dispatch_bwd_by_head(queue, cuType, args, args.head_size, is_causal, is_local, false); +} + +void cutlass_fmha_bwd_varlen_impl( + sycl::queue& queue, + const at::Tensor& dout, + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const at::Tensor& out, + const at::Tensor& softmax_lse, + const at::Tensor& cu_seqlens_q, + const at::Tensor& cu_seqlens_k, + at::Tensor& dq, + at::Tensor& dk, + at::Tensor& dv, + at::Tensor& softmax_d, + float sm_scale, + int max_seqlen_q, + int max_seqlen_k, + int window_size_left, + int window_size_right, + bool is_causal, + bool is_local, + float p_dropout, + uint64_t philox_seed, + uint64_t philox_offset, + bool deterministic) { + + const int batch_size = cu_seqlens_q.numel() - 1; + const int num_heads_q = q.size(1); + const int num_heads_k = k.size(1); + const int head_size = q.size(2); + + const int kBlockM_bwd = bwd_block_m_for_head(head_size); + const int kBlockN_bwd = bwd_block_n_for_head(head_size); + const int seqlen_q_rounded = round_up(max_seqlen_q, kBlockM_bwd); + const int seqlen_k_rounded = round_up(max_seqlen_k, kBlockN_bwd); + + int nsplits = 1; + int dq_accum_split_stride = 0; + at::Tensor dq_accum; + if (!deterministic) { + dq_accum = at::zeros({batch_size, num_heads_q, seqlen_q_rounded, head_size}, + q.options().dtype(at::kFloat)); + } else { + const int num_compute_units = static_cast(queue.get_device().get_info()); + nsplits = std::max((num_compute_units + batch_size * num_heads_q - 1) / (batch_size * num_heads_q), 1); + const int max_n_blocks = std::max((max_seqlen_k + 31) / 32, 1); + nsplits = std::min(nsplits, max_n_blocks); + dq_accum = at::zeros({nsplits, batch_size, num_heads_q, seqlen_q_rounded, head_size}, + q.options().dtype(at::kFloat)); + dq_accum_split_stride = batch_size * seqlen_q_rounded * num_heads_q * head_size; + } + + at::Tensor pbuff_tensor = at::empty( + {batch_size * num_heads_q * seqlen_k_rounded * 2 * kBlockM_bwd}, + q.options()); + + fmha_bwd_args_t args = { + dout.data_ptr(), + out.data_ptr(), + q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + softmax_lse.data_ptr(), + cu_seqlens_q.data_ptr(), + cu_seqlens_k.data_ptr(), + dq.data_ptr(), + dk.data_ptr(), + dv.data_ptr(), + softmax_d.data_ptr(), + dq_accum.data_ptr(), + batch_size, + num_heads_q, + num_heads_k, + max_seqlen_q, + max_seqlen_k, + head_size, + seqlen_q_rounded, + seqlen_k_rounded, + sm_scale, + deterministic, + nsplits, + dq_accum_split_stride, + window_size_left, + window_size_right, + p_dropout, + philox_seed, + philox_offset, + pbuff_tensor.data_ptr() + }; + + const BwdCutlassType cuType = aten_to_Bwd_Cutlass_dtype(q); + dispatch_bwd_by_head(queue, cuType, args, args.head_size, is_causal, is_local, true); } diff --git a/flash-attn2/flash_attn_xpu/src/fmha_bwd.hpp b/flash-attn2/flash_attn_xpu/src/fmha_bwd.hpp index 1e041626..313123e8 100644 --- a/flash-attn2/flash_attn_xpu/src/fmha_bwd.hpp +++ b/flash-attn2/flash_attn_xpu/src/fmha_bwd.hpp @@ -23,40 +23,48 @@ struct bwd_policy_head256; struct bwd_policy_head512; // Dtype-specific bwd dispatch functions (instantiated in per-head TUs) -template +template void bwd_policy_dispatch_fp16( sycl::queue& queue, const fmha_bwd_args_t& args); -template +template void bwd_policy_dispatch_bf16( sycl::queue& queue, const fmha_bwd_args_t& args); // Combined bwd dispatch (delegates to fp16/bf16 based on cuType) // Defined inline in header so callers (fmha_bwd.cpp) can see the template body. -template +template inline void bwd_policy_dispatch( sycl::queue& queue, BwdCutlassType cuType, const fmha_bwd_args_t& args) { if (cuType == BwdCutlassType::half) { - bwd_policy_dispatch_fp16(queue, args); + bwd_policy_dispatch_fp16(queue, args); } else { - bwd_policy_dispatch_bf16(queue, args); + bwd_policy_dispatch_bf16(queue, args); } } // Extern template declarations for all head dimensions (dtype-split) #define EXTERN_BWD_DISPATCH(HDIM) \ - extern template void bwd_policy_dispatch_fp16(sycl::queue&, const fmha_bwd_args_t&); \ - extern template void bwd_policy_dispatch_fp16(sycl::queue&, const fmha_bwd_args_t&); \ - extern template void bwd_policy_dispatch_fp16(sycl::queue&, const fmha_bwd_args_t&); \ - extern template void bwd_policy_dispatch_fp16(sycl::queue&, const fmha_bwd_args_t&); \ - extern template void bwd_policy_dispatch_bf16(sycl::queue&, const fmha_bwd_args_t&); \ - extern template void bwd_policy_dispatch_bf16(sycl::queue&, const fmha_bwd_args_t&); \ - extern template void bwd_policy_dispatch_bf16(sycl::queue&, const fmha_bwd_args_t&); \ - extern template void bwd_policy_dispatch_bf16(sycl::queue&, const fmha_bwd_args_t&); + extern template void bwd_policy_dispatch_fp16(sycl::queue&, const fmha_bwd_args_t&); \ + extern template void bwd_policy_dispatch_fp16(sycl::queue&, const fmha_bwd_args_t&); \ + extern template void bwd_policy_dispatch_fp16(sycl::queue&, const fmha_bwd_args_t&); \ + extern template void bwd_policy_dispatch_fp16(sycl::queue&, const fmha_bwd_args_t&); \ + extern template void bwd_policy_dispatch_bf16(sycl::queue&, const fmha_bwd_args_t&); \ + extern template void bwd_policy_dispatch_bf16(sycl::queue&, const fmha_bwd_args_t&); \ + extern template void bwd_policy_dispatch_bf16(sycl::queue&, const fmha_bwd_args_t&); \ + extern template void bwd_policy_dispatch_bf16(sycl::queue&, const fmha_bwd_args_t&); \ + extern template void bwd_policy_dispatch_fp16(sycl::queue&, const fmha_bwd_args_t&); \ + extern template void bwd_policy_dispatch_fp16(sycl::queue&, const fmha_bwd_args_t&); \ + extern template void bwd_policy_dispatch_fp16(sycl::queue&, const fmha_bwd_args_t&); \ + extern template void bwd_policy_dispatch_fp16(sycl::queue&, const fmha_bwd_args_t&); \ + extern template void bwd_policy_dispatch_bf16(sycl::queue&, const fmha_bwd_args_t&); \ + extern template void bwd_policy_dispatch_bf16(sycl::queue&, const fmha_bwd_args_t&); \ + extern template void bwd_policy_dispatch_bf16(sycl::queue&, const fmha_bwd_args_t&); \ + extern template void bwd_policy_dispatch_bf16(sycl::queue&, const fmha_bwd_args_t&); EXTERN_BWD_DISPATCH(32) EXTERN_BWD_DISPATCH(64) @@ -106,3 +114,29 @@ void cutlass_fmha_bwd_fix_impl( uint64_t philox_seed = 0, uint64_t philox_offset = 0, bool deterministic = false); + +void cutlass_fmha_bwd_varlen_impl( + sycl::queue& queue, + const at::Tensor& dout, + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const at::Tensor& out, + const at::Tensor& softmax_lse, + const at::Tensor& cu_seqlens_q, + const at::Tensor& cu_seqlens_k, + at::Tensor& dq, + at::Tensor& dk, + at::Tensor& dv, + at::Tensor& softmax_d, + float sm_scale, + int max_seqlen_q, + int max_seqlen_k, + int window_size_left, + int window_size_right, + bool is_causal, + bool is_local, + float p_dropout = 0.0f, + uint64_t philox_seed = 0, + uint64_t philox_offset = 0, + bool deterministic = false); diff --git a/flash-attn2/flash_attn_xpu/src/fmha_bwd_impl.hpp b/flash-attn2/flash_attn_xpu/src/fmha_bwd_impl.hpp index ea7ac3ba..c4b54141 100644 --- a/flash-attn2/flash_attn_xpu/src/fmha_bwd_impl.hpp +++ b/flash-attn2/flash_attn_xpu/src/fmha_bwd_impl.hpp @@ -185,10 +185,10 @@ softmax_backward(Tensor &P, CUTLASS_PRAGMA_UNROLL for (int ni = 0; ni < size<1>(dP); ++ni) { int n = get<1>(rC_2d(0, ni)) + sg_local_id; - const float dpsum = dP_sum(n); + const float neg_dpsum_scaled = -(dP_sum(n) * scale); CUTLASS_PRAGMA_UNROLL for (int mi = 0; mi < size<0>(dP); ++mi) { - dP(mi, ni) = P(mi, ni) * (dP(mi, ni) - dpsum) * scale; + dP(mi, ni) = P(mi, ni) * fmaf(dP(mi, ni), scale, neg_dpsum_scaled); } } } else { @@ -197,10 +197,10 @@ softmax_backward(Tensor &P, int n = get<1>(rC_2d(0, ni)) + sg_local_id; // For tail case: skip out-of-bounds rows if (n < tail_m) { - const float dpsum = dP_sum(n); + const float neg_dpsum_scaled = -(dP_sum(n) * scale); CUTLASS_PRAGMA_UNROLL for (int mi = 0; mi < size<0>(dP); ++mi) { - dP(mi, ni) = P(mi, ni) * (dP(mi, ni) - dpsum) * scale; + dP(mi, ni) = P(mi, ni) * fmaf(dP(mi, ni), scale, neg_dpsum_scaled); } } } @@ -310,8 +310,9 @@ gemm_kernel(Trait &trait, if constexpr(clear_acc) clear(acc); + int prefetch_warmup = prefetch_dist < k_tile_count ? prefetch_dist : k_tile_count; CUTE_UNROLL - for (; k_tile_prefetch < prefetch_dist; k_tile_prefetch++) { + for (; k_tile_prefetch < prefetch_warmup; k_tile_prefetch++) { prefetch(prefetch_a, pAgA(_,_,_,k_tile_prefetch)); prefetch(prefetch_b, pBgB(_,_,_,k_tile_prefetch)); } @@ -322,8 +323,10 @@ gemm_kernel(Trait &trait, copy(copy_a, tAgA(_,_,_,k_tile), tArA); copy(copy_b, tBgB(_,_,_,k_tile), tBrB); - prefetch(prefetch_a, pAgA(_,_,_,k_tile_prefetch)); - prefetch(prefetch_b, pBgB(_,_,_,k_tile_prefetch)); + if (k_tile_prefetch < k_tile_count) { + prefetch(prefetch_a, pAgA(_,_,_,k_tile_prefetch)); + prefetch(prefetch_b, pBgB(_,_,_,k_tile_prefetch)); + } reorder(tArA, tCrA); reorder(tBrB, tCrB); @@ -436,10 +439,28 @@ constexpr int round_up(int x, int m) { return (x + m - 1) / m * m; } +template +CUTLASS_DEVICE BwdParam +resolve_varlen_batch(BwdParam param, const int bidb) { + if constexpr (VarLen) { + const int q_start = param.cu_seqlens_q[bidb]; + const int q_end = param.cu_seqlens_q[bidb + 1]; + const int k_start = param.cu_seqlens_k[bidb]; + const int k_end = param.cu_seqlens_k[bidb + 1]; + param.seq_len_q = q_end - q_start; + param.seq_len_kv = k_end - k_start; + param.m_block = ceil_div(param.seq_len_q, kBlockM); + param.n_block = ceil_div(param.seq_len_kv, kBlockN); + param.tail_m = param.seq_len_q % kBlockM; + param.tail_n = param.seq_len_kv % kBlockN; + } + return param; +} + // Main 1-col-block backward computation template void -dq_dk_dv_1colblock(Trait &trait, BwdParam ¶m, +dq_dk_dv_1colblock(Trait &trait, BwdParam ¶m, const int bidb, const int bidh, const int bidhkv, const int n_block, const int tail_n = 0) { using T = typename Trait::DType; @@ -451,12 +472,13 @@ dq_dk_dv_1colblock(Trait &trait, BwdParam ¶m, constexpr int SubgroupSize = Trait::SubgroupSize; constexpr bool is_causal = Trait::is_causal; constexpr bool is_local = Trait::is_local; + constexpr bool has_dropout = Trait::has_dropout; auto sg = compat::get_nd_item<1>().get_sub_group(); auto group = compat::get_nd_item<1>().get_group(); const int local_id = sg.get_local_id(); auto first_thread_in_sg_idx = sg.get_group_linear_id() * trait.SubgroupSize; - auto bofst = BwdOffset(param); + auto bofst = BwdOffset(param); const index_t q_offset = bofst.q_offset(bidb, bidh, 0); const index_t k_offset = bofst.k_offset(bidb, bidhkv, n_block * kBlockN); @@ -464,7 +486,7 @@ dq_dk_dv_1colblock(Trait &trait, BwdParam ¶m, const index_t dk_offset = bofst.dk_offset(bidb, bidh, n_block * kBlockN); const index_t dv_offset = bofst.dv_offset(bidb, bidh, n_block * kBlockN); const index_t o_offset = bofst.o_offset(bidb, bidh, 0); - const index_t dq_offset = bofst.dq_offset(bidb, bidh, 0); + const index_t dqaccum_offset = bofst.dqaccum_offset(bidb, bidh, 0); const index_t lse_offset = bofst.lse_offset(bidb, bidh, 0); // Buffer offset for intermediate P and dS @@ -527,8 +549,8 @@ dq_dk_dv_1colblock(Trait &trait, BwdParam ¶m, make_layout(shapeKtVt, make_stride(param.dv_r_stride, _1{}))); Tensor mdP = make_tensor(make_gmem_ptr(param.pb_ptr + dsb_offset), make_layout(shapeSP, make_stride(_1{}, Int{}))); - Tensor mdQaccum = make_tensor(make_gmem_ptr(param.dqaccum_ptr + dq_offset), - make_layout(shapedQ, make_stride(param.dq_r_stride, _1{}))); + Tensor mdQaccum = make_tensor(make_gmem_ptr(param.dqaccum_ptr + dqaccum_offset), + make_layout(shapedQ, make_stride(param.dqaccum_r_stride, _1{}))); Tensor mdK = make_tensor(make_gmem_ptr(param.dk_ptr + dk_offset), make_layout(shapeKtVt, make_stride(param.dk_r_stride, _1{}))); @@ -564,7 +586,7 @@ dq_dk_dv_1colblock(Trait &trait, BwdParam ¶m, make_layout(make_shape(Int{}, tail_m), make_stride(_1{}, param.o_r_stride))); mdQaccum = make_tensor(make_gmem_ptr(mdQaccum.data()), - make_layout(shapedQ, make_stride(param.dq_r_stride, _1{}))); + make_layout(shapedQ, make_stride(param.dqaccum_r_stride, _1{}))); mQt = make_tensor(make_gmem_ptr(mQt.data()), make_layout(make_shape(Int{}, tail_m), make_stride(_1{}, param.q_r_stride))); @@ -599,17 +621,16 @@ dq_dk_dv_1colblock(Trait &trait, BwdParam ¶m, gemm_SdP(trait, mVt, mdO, rdP, tiled_mma_sdp); Tensor dS = make_tensor(rdP.data(), scores.layout()); - // Dropout backward: compute mask once, apply to dP first, then to P - // after softmax_backward. Caches mask to avoid redundant Philox RNG. - // - // Math: dS = scale * P * (mask * rp * dP_dropped - dpsum) - // where P is the original softmax output, dP_dropped = dO * V^T, - // and dpsum = sum(dO * O) with O having rp scaling from forward. - constexpr int kDropMaskMax = decltype(size<0>(scores))::value * decltype(size<1>(scores))::value; - bool drop_keep[kDropMaskMax]; - int drop_mask_count = 0; - - if (param.dropout.is_enabled) { + if constexpr(has_dropout) { + // Dropout backward: compute mask once, apply to dP first, then to P + // after softmax_backward. Caches mask to avoid redundant Philox RNG. + // + // Math: dS = scale * P * (mask * rp * dP_dropped - dpsum) + // where P is the original softmax output, dP_dropped = dO * V^T, + // and dpsum = sum(dO * O) with O having rp scaling from forward. + constexpr int kDropMaskMax = decltype(size<0>(scores))::value * decltype(size<1>(scores))::value; + bool drop_keep[kDropMaskMax]; + int drop_mask_count = 0; int sg_local_id = sg.get_local_id(); uint32_t batch_head = bidb * param.num_head_q + bidh; float rp_dropout = param.dropout.get_scale(); @@ -637,19 +658,13 @@ dq_dk_dv_1colblock(Trait &trait, BwdParam ¶m, } } } - } - - // dS = P * (dP_masked - dpsum) * scale - // Uses ORIGINAL P from scores, and dropout-masked gradient from dS - if (Is_even_M) { - softmax_backward(scores, mdPsum, dS, taccScS_rt, param.scale_softmax); - } else { - softmax_backward(scores, mdPsum, dS, taccScS_rt, param.scale_softmax, tail_m); - } - - // Step 2: Apply cached dropout mask to P (scores) for dV computation - if (param.dropout.is_enabled) { - float rp_dropout = param.dropout.get_scale(); + + if (Is_even_M) { + softmax_backward(scores, mdPsum, dS, taccScS_rt, param.scale_softmax); + } else { + softmax_backward(scores, mdPsum, dS, taccScS_rt, param.scale_softmax, tail_m); + } + int mask_idx = 0; CUTLASS_PRAGMA_UNROLL @@ -663,6 +678,12 @@ dq_dk_dv_1colblock(Trait &trait, BwdParam ¶m, } } } + } else { + if (Is_even_M) { + softmax_backward(scores, mdPsum, dS, taccScS_rt, param.scale_softmax); + } else { + softmax_backward(scores, mdPsum, dS, taccScS_rt, param.scale_softmax, tail_m); + } } // Mask out elements beyond seqlen_k to prevent NaN in dQ computation @@ -688,7 +709,7 @@ dq_dk_dv_1colblock(Trait &trait, BwdParam ¶m, mQ.data() = mQ.data() + int(kBlockM * param.q_r_stride); mdO.data() = mdO.data() + int(kBlockM * param.o_r_stride); mdOt.data() = mdOt.data() + int(kBlockM * param.o_r_stride); - mdQaccum.data() = mdQaccum.data() + int(kBlockM * param.dq_r_stride); + mdQaccum.data() = mdQaccum.data() + int(kBlockM * param.dqaccum_r_stride); mQt.data() = mQt.data() + int(kBlockM * param.q_r_stride); mLSE.data() = mLSE.data() + int(kBlockM); mdPsum.data() = mdPsum.data() + int(kBlockM); @@ -701,7 +722,7 @@ dq_dk_dv_1colblock(Trait &trait, BwdParam ¶m, // Compute O * dO (dot product) template void -compute_o_dot_do(T &trait, BwdParam ¶m, +compute_o_dot_do(T &trait, BwdParam ¶m, const int m_block, const int bidb, const int bidh) { constexpr int kBlockM = T::kBlockM; constexpr int kHeadDim = T::kHeadDim; @@ -713,10 +734,10 @@ compute_o_dot_do(T &trait, BwdParam ¶m, auto sg = compat::get_nd_item<1>().get_sub_group(); auto group = compat::get_nd_item<1>().get_group(); auto first_thread_in_sg_idx = sg.get_group_linear_id() * trait.SubgroupSize; - auto bofst = BwdOffset(param); + auto bofst = BwdOffset(param); const index_t o_offset = bofst.o_offset(bidb, bidh, m_block * kBlockM); - const index_t dq_offset = bofst.dq_offset(bidb, bidh, m_block * kBlockM); + const index_t dqaccum_offset = bofst.dqaccum_offset(bidb, bidh, m_block * kBlockM); const index_t dpsum_offset = bofst.lse_offset(bidb, bidh, m_block * kBlockM); using ShapeO = Shape, int>, Int>; @@ -736,9 +757,9 @@ compute_o_dot_do(T &trait, BwdParam ¶m, make_layout(O_shape, make_stride(param.o_r_stride, _1{}))); Tensor mO = make_tensor(make_gmem_ptr(param.o_ptr + o_offset), make_layout(O_shape, make_stride(param.o_r_stride, _1{}))); - Tensor mdQaccum = make_tensor(make_gmem_ptr(param.dqaccum_ptr + dq_offset), + Tensor mdQaccum = make_tensor(make_gmem_ptr(param.dqaccum_ptr + dqaccum_offset), make_layout(make_shape(Int{}, Int{}), - make_stride(param.dq_r_stride, _1{}))); + make_stride(param.dqaccum_r_stride, _1{}))); Tensor mdPsum = make_tensor(make_gmem_ptr(param.odo_ptr + dpsum_offset), make_layout(dP_shape, Stride<_1>{})); @@ -823,7 +844,7 @@ compute_o_dot_do(T &trait, BwdParam ¶m, // Convert dQ from float accumulator to target type template void -convert_dq(T &trait, BwdParam ¶m, int m_block, int bidb, int bidh) { +convert_dq(T &trait, BwdParam ¶m, int m_block, int bidb, int bidh) { constexpr int kBlockM = T::kBlockM; constexpr int kHeadDim = T::kHeadDim; using DType = typename T::DType; @@ -831,9 +852,9 @@ convert_dq(T &trait, BwdParam ¶m, int m_block, int bidb, auto sg = compat::get_nd_item<1>().get_sub_group(); auto first_thread_in_sg_idx = sg.get_group_linear_id() * trait.SubgroupSize; - auto bofst = BwdOffset(param); + auto bofst = BwdOffset(param); const index_t dq_offset = bofst.dq_offset(bidb, bidh, m_block * kBlockM); - const index_t q_offset = bofst.q_offset(bidb, bidh, m_block * kBlockM); + const index_t dqaccum_offset = bofst.dqaccum_offset(bidb, bidh, m_block * kBlockM); using ShapeQ = Shape, int>, Int>; ShapeQ shapeQ; @@ -843,11 +864,11 @@ convert_dq(T &trait, BwdParam ¶m, int m_block, int bidb, shapeQ = make_shape(param.tail_m, Int{}); } - Tensor mdQaccum = make_tensor(make_gmem_ptr(param.dqaccum_ptr + dq_offset), + Tensor mdQaccum = make_tensor(make_gmem_ptr(param.dqaccum_ptr + dqaccum_offset), make_layout(Shape, Int>{}, - make_stride(param.dq_r_stride, _1{}))); - Tensor mdQ = make_tensor(make_gmem_ptr(param.dq_ptr + q_offset), - make_layout(shapeQ, make_stride(param.q_r_stride, _1{}))); + make_stride(param.dqaccum_r_stride, _1{}))); + Tensor mdQ = make_tensor(make_gmem_ptr(param.dq_ptr + dq_offset), + make_layout(shapeQ, make_stride(param.dq_r_stride, _1{}))); typename T::TiledMmadQ tiled_mma_dq; auto thr_mma_dq = tiled_mma_dq.get_slice(first_thread_in_sg_idx); @@ -882,9 +903,9 @@ convert_dq(T &trait, BwdParam ¶m, int m_block, int bidb, for (int s = 1; s < nsplits; ++s) { // Create tensor view for split s with the correct base pointer Tensor mdQaccum_s = make_tensor( - make_gmem_ptr(param.dqaccum_ptr + dq_offset + (index_t)s * param.dq_accum_split_stride), + make_gmem_ptr(param.dqaccum_ptr + dqaccum_offset + (index_t)s * param.dq_accum_split_stride), make_layout(Shape, Int>{}, - make_stride(param.dq_r_stride, _1{}))); + make_stride(param.dqaccum_r_stride, _1{}))); auto tileloaddQ_s = make_block_2d_copy_C(tiled_mma_dq, mdQaccum_s); auto thr_load_dQ_s = tileloaddQ_s.get_slice(first_thread_in_sg_idx); copy(tileloaddQ_s, thr_load_dQ_s.partition_S(gdQaccum), tdQrdQaccum_tmp); @@ -901,46 +922,64 @@ convert_dq(T &trait, BwdParam ¶m, int m_block, int bidb, // Kernel entry points template void -mha_dot_do_o(T trait, BwdParam param) { +mha_dot_do_o(T trait, BwdParam param) { const int m_block = BlockIdxX(); const int bidb = BlockIdxZ(); const int bidh = BlockIdxY(); - if (m_block == param.m_block - 1 and param.tail_m > 0) { - compute_o_dot_do(trait, param, m_block, bidb, bidh); + auto batch_param = resolve_varlen_batch(param, bidb); + if constexpr (T::is_varlen) { + if (m_block >= batch_param.m_block) { + return; + } + } + if (m_block == batch_param.m_block - 1 and batch_param.tail_m > 0) { + compute_o_dot_do(trait, batch_param, m_block, bidb, bidh); } else { - compute_o_dot_do(trait, param, m_block, bidb, bidh); + compute_o_dot_do(trait, batch_param, m_block, bidb, bidh); } } template void -mha_backward_seq(T trait, BwdParam param) { +mha_backward_seq(T trait, BwdParam param) { const int bidb = BlockIdxZ(); const int bidhq = BlockIdxY(); const int bidnblk = BlockIdxX(); - const int bidhkv = bidhq / param.num_qh_per_kvh; + auto batch_param = resolve_varlen_batch(param, bidb); + if constexpr (T::is_varlen) { + if (bidnblk >= batch_param.n_block || batch_param.m_block == 0) { + return; + } + } + const int bidhkv = bidhq / batch_param.num_qh_per_kvh; // For deterministic mode, each work-group writes to its own dq_accum split - if (param.deterministic) { - param.dqaccum_ptr = param.dqaccum_ptr + bidnblk * param.dq_accum_split_stride; + if (batch_param.deterministic) { + batch_param.dqaccum_ptr = batch_param.dqaccum_ptr + bidnblk * batch_param.dq_accum_split_stride; } - for (int n_block = bidnblk; n_block < param.n_block; n_block += GridDimX()) { - if (param.tail_n > 0 and n_block == param.n_block - 1) - dq_dk_dv_1colblock(trait, param, bidb, bidhq, bidhkv, param.n_block - 1, param.tail_n); + for (int n_block = bidnblk; n_block < batch_param.n_block; n_block += GridDimX()) { + if (batch_param.tail_n > 0 and n_block == batch_param.n_block - 1) + dq_dk_dv_1colblock(trait, batch_param, bidb, bidhq, bidhkv, batch_param.n_block - 1, batch_param.tail_n); else - dq_dk_dv_1colblock(trait, param, bidb, bidhq, bidhkv, n_block); + dq_dk_dv_1colblock(trait, batch_param, bidb, bidhq, bidhkv, n_block); } } template void -mhd_convert_dq(T trait, BwdParam param) { +mhd_convert_dq(T trait, BwdParam param) { const int m_block = BlockIdxX(); const int bidb = BlockIdxZ(); const int bidh = BlockIdxY(); - if (param.tail_m > 0 and m_block == param.m_block - 1) { - convert_dq(trait, param, m_block, bidb, bidh); + auto batch_param = resolve_varlen_batch(param, bidb); + if constexpr (T::is_varlen) { + if (m_block >= batch_param.m_block) { + return; + } + } + if (batch_param.tail_m > 0 and m_block == batch_param.m_block - 1) { + convert_dq(trait, batch_param, m_block, bidb, bidh); } else { - convert_dq(trait, param, m_block, bidb, bidh); + convert_dq(trait, batch_param, m_block, bidb, bidh); } } @@ -980,7 +1019,7 @@ struct BwdKernelLauncher { // Use pre-allocated intermediate buffer from caller (PyTorch caching allocator) DType* pbuff = reinterpret_cast(args.pbuff); - auto param = BwdParam( + auto param = BwdParam( reinterpret_cast(args.dout), reinterpret_cast(args.out), reinterpret_cast(args.query), @@ -1001,6 +1040,7 @@ struct BwdKernelLauncher { param.num_qh_per_kvh = NUM_HEAD_Q / NUM_HEAD_KV; param.num_nb_per_blk = std::max(N_BLOCK * NUM_HEAD_Q * BATCH / 1024, 1); param.seq_len_q = SEQ_LEN_Q; + param.max_seqlen_q = SEQ_LEN_Q; param.seq_len_kv = SEQ_LEN_KV; param.head_dim = kHeadDim; param.n_block = N_BLOCK; @@ -1011,7 +1051,10 @@ struct BwdKernelLauncher { param.seq_len_q_pad = args.seqlen_q_rounded; param.window_size_left = args.window_size_left; param.window_size_right = args.window_size_right; - param.is_local = args.is_local; + if constexpr (FABwdKernel::is_varlen) { + param.cu_seqlens_q = reinterpret_cast(args.cu_seqlens_q); + param.cu_seqlens_k = reinterpret_cast(args.cu_seqlens_k); + } param.deterministic = args.deterministic; param.nsplits = args.nsplits; param.dq_accum_split_stride = args.dq_accum_split_stride; @@ -1088,8 +1131,8 @@ template < struct FMHABwdConfig { using DType = ElementQ; - template - static void run(sycl::queue& queue, const fmha_bwd_args_t& args) { + template + static void run_impl(sycl::queue& queue, const fmha_bwd_args_t& args) { using FABwdKernelType = FABwdKernel< DType, bwd_policy::kHeadDim, @@ -1100,46 +1143,40 @@ struct FMHABwdConfig { bwd_policy::AtomLayoutNdKV, bwd_policy::AtomLayoutMdQ, Causal, - Local>; + Local, + VarLen, + HasDropout>; BwdKernelLauncher launcher; launcher.run(queue, args); } + template + static void run(sycl::queue& queue, const fmha_bwd_args_t& args) { + if (args.p_dropout > 0.0f) { + return run_impl(queue, args); + } + return run_impl(queue, args); + } + template static void kernel_dispatch(sycl::queue& queue, const fmha_bwd_args_t& args) { return run(queue, args); } - template - static void kernel_dispatch(sycl::queue& queue, const fmha_bwd_args_t& args, bool b, Ts... ts) { - if (b) { - kernel_dispatch(queue, args, ts...); - } else { - kernel_dispatch(queue, args, ts...); - } - } }; // Single-dtype bwd dispatch: only instantiates one dtype path per TU. -template +template void bwd_policy_dispatch_fp16(sycl::queue& queue, const fmha_bwd_args_t& args) { using Config = FMHABwdConfig; - if constexpr (IsCausal != -1 && IsLocal != -1) { - return Config::template kernel_dispatch(queue, args); - } else { - return Config::kernel_dispatch(queue, args, args.is_causal, args.is_local); - } + return Config::template kernel_dispatch(queue, args); } -template +template void bwd_policy_dispatch_bf16(sycl::queue& queue, const fmha_bwd_args_t& args) { using Config = FMHABwdConfig; - if constexpr (IsCausal != -1 && IsLocal != -1) { - return Config::template kernel_dispatch(queue, args); - } else { - return Config::kernel_dispatch(queue, args, args.is_causal, args.is_local); - } + return Config::template kernel_dispatch(queue, args); } // Combined bwd_policy_dispatch is now defined inline in fmha_bwd.hpp diff --git a/flash-attn2/flash_attn_xpu/src/fmha_bwd_types.hpp b/flash-attn2/flash_attn_xpu/src/fmha_bwd_types.hpp index 8a225271..130edd2d 100644 --- a/flash-attn2/flash_attn_xpu/src/fmha_bwd_types.hpp +++ b/flash-attn2/flash_attn_xpu/src/fmha_bwd_types.hpp @@ -10,6 +10,8 @@ struct fmha_bwd_args_t { void* key; void* value; void* softmax_lse; // logsumexp from forward + void* cu_seqlens_q = nullptr; + void* cu_seqlens_k = nullptr; // Output gradient tensors void* dq; @@ -36,9 +38,6 @@ struct fmha_bwd_args_t { float sm_scale; // Flags - bool is_causal = false; - bool is_local = false; - bool is_bf16 = false; bool deterministic = false; // Deterministic mode parameters @@ -97,12 +96,12 @@ struct bwd_policy_head96 { }; struct bwd_policy_head128 { - static constexpr int kBlockM = 64; - static constexpr int kBlockN = 64; + static constexpr int kBlockM = 128; + static constexpr int kBlockN = 128; static constexpr int kHeadDim = 128; - static constexpr int kNSGs = 8; - static constexpr int AtomLayoutMSdP = 2; - static constexpr int AtomLayoutNdKV = 2; + static constexpr int kNSGs = 32; + static constexpr int AtomLayoutMSdP = 4; + static constexpr int AtomLayoutNdKV = 4; static constexpr int AtomLayoutMdQ = 4; }; diff --git a/flash-attn2/flash_attn_xpu/src/fmha_fwd.cpp b/flash-attn2/flash_attn_xpu/src/fmha_fwd.cpp index 27d81d0e..f4cb5ca6 100644 --- a/flash-attn2/flash_attn_xpu/src/fmha_fwd.cpp +++ b/flash-attn2/flash_attn_xpu/src/fmha_fwd.cpp @@ -218,7 +218,14 @@ void cutlass_fmha_fwd_varlen_impl( bool is_varlen, bool is_paged, bool is_causal, - bool is_local) { + bool is_local, + float p_dropout, + uint64_t philox_seed, + uint64_t philox_offset, + void* rng_state, + void* s_dmask, + int seqlen_q_rounded, + int seqlen_k_rounded) { int batch_size, num_heads_q, num_heads_kv, head_size; int total_seqlen_q, total_seqlen_k; int num_blocks, block_size, max_blocks_per_seq; @@ -293,7 +300,14 @@ void cutlass_fmha_fwd_varlen_impl( is_varlen, is_paged, is_causal, - is_local}; + is_local, + p_dropout, + philox_seed, + philox_offset, + rng_state, + s_dmask, + seqlen_q_rounded, + seqlen_k_rounded}; const CutlassType cuType = aten_to_Cutlass_dtype(query); if (args.max_queries == 1) { diff --git a/flash-attn2/flash_attn_xpu/src/fmha_fwd.hpp b/flash-attn2/flash_attn_xpu/src/fmha_fwd.hpp index ed1576f1..bc1c7e42 100644 --- a/flash-attn2/flash_attn_xpu/src/fmha_fwd.hpp +++ b/flash-attn2/flash_attn_xpu/src/fmha_fwd.hpp @@ -157,7 +157,14 @@ void cutlass_fmha_fwd_varlen_impl( bool is_varlen, bool is_paged, bool is_causal, - bool is_local); + bool is_local, + float p_dropout = 0.0f, + uint64_t philox_seed = 0, + uint64_t philox_offset = 0, + void* rng_state = nullptr, + void* s_dmask = nullptr, + int seqlen_q_rounded = 0, + int seqlen_k_rounded = 0); void cutlass_fmha_fwd_fix_impl( sycl::queue& queue, diff --git a/flash-attn2/flash_attn_xpu/src/kernel/fmha_bwd_kernel.hpp b/flash-attn2/flash_attn_xpu/src/kernel/fmha_bwd_kernel.hpp index cbadf972..238e6232 100644 --- a/flash-attn2/flash_attn_xpu/src/kernel/fmha_bwd_kernel.hpp +++ b/flash-attn2/flash_attn_xpu/src/kernel/fmha_bwd_kernel.hpp @@ -6,7 +6,8 @@ using namespace cute; template + bool is_causal_ = false, bool is_local_ = false, bool is_varlen_ = false, + bool has_dropout_ = false> struct FABwdKernel { /* Q BATCH,NUM_HEAD_Q,SEQ_LEN_QO,HEAD_SIZE_QK @@ -26,6 +27,8 @@ struct FABwdKernel { static constexpr int AtomLayoutMdQ = AtomLayoutMdQ_; static constexpr bool is_causal = is_causal_; static constexpr bool is_local = is_local_; + static constexpr bool is_varlen = is_varlen_; + static constexpr bool has_dropout = has_dropout_; using MMA_Atom_ARCH = XE_DPAS_TT<8, VType, DType>; using _K = Int; using SubgroupLayoutSdP = Layout, Int, _1>>; @@ -55,8 +58,17 @@ struct FABwdKernel { using index_t = uint64_t; -template -struct BwdParam { +template +struct BwdVarLenParam {}; + +template<> +struct BwdVarLenParam { + const int *cu_seqlens_q = nullptr; + const int *cu_seqlens_k = nullptr; +}; + +template +struct BwdParam : BwdVarLenParam { BwdParam(const T *dO, const T *o, const T *q, @@ -83,8 +95,7 @@ struct BwdParam { dv_ptr(dv), pb_ptr(pb), scale_softmax(softmax_scale), - scale_softmax_log2(softmax_scale * M_LOG2E), - is_bhsd(true) {} + scale_softmax_log2(softmax_scale * M_LOG2E) {} // read only const T *do_ptr; @@ -109,6 +120,7 @@ struct BwdParam { int num_head_q; int num_head_kv; int seq_len_q; + int max_seqlen_q; int seq_len_q_pad; int seq_len_kv; int seq_len_kv_pad; @@ -150,22 +162,20 @@ struct BwdParam { int o_h_stride; int o_b_stride; - // Strides for S (softmax) - int s_r_stride; - int s_s_stride; - int s_b_stride; - // Strides for dQ int dq_r_stride; int dq_h_stride; int dq_b_stride; + // Strides for intermediate dQ accumulator (stored as B,H,S,D) + int dqaccum_r_stride; + int dqaccum_h_stride; + int dqaccum_b_stride; + // Window size for local attention int window_size_left; int window_size_right; - bool is_bhsd; - bool is_local; bool deterministic; int nsplits; int dq_accum_split_stride; @@ -175,41 +185,72 @@ struct BwdParam { }; /// Computes linear offsets into the Q/K/V/dQ/dK/dV/O/LSE buffers. -template +template struct BwdOffset { - explicit BwdOffset(const BwdParam ¶m_) : param(param_) {} + explicit BwdOffset(const BwdParam ¶m_) : param(param_) {} [[nodiscard]] index_t q_offset(const index_t b_id, const index_t h_id, const index_t s_id) const { + if constexpr (VarLen) { + return (static_cast(param.cu_seqlens_q[b_id]) + s_id) * param.num_head_q * param.head_dim + + h_id * param.head_dim; + } return b_id * param.q_b_stride + h_id * param.q_h_stride + s_id * param.q_r_stride; } [[nodiscard]] index_t k_offset(const index_t b_id, const index_t h_id, const index_t s_id) const { + if constexpr (VarLen) { + return (static_cast(param.cu_seqlens_k[b_id]) + s_id) * param.num_head_kv * param.head_dim + + h_id * param.head_dim; + } return b_id * param.k_b_stride + h_id * param.k_h_stride + s_id * param.k_r_stride; } [[nodiscard]] index_t v_offset(const index_t b_id, const index_t h_id, const index_t s_id) const { + if constexpr (VarLen) { + return (static_cast(param.cu_seqlens_k[b_id]) + s_id) * param.num_head_kv * param.head_dim + + h_id * param.head_dim; + } return b_id * param.v_b_stride + h_id * param.v_h_stride + s_id * param.v_r_stride; } [[nodiscard]] index_t dk_offset(const index_t b_id, const index_t h_id, const index_t s_id) const { + if constexpr (VarLen) { + return (static_cast(param.cu_seqlens_k[b_id]) + s_id) * param.num_head_q * param.head_dim + + h_id * param.head_dim; + } return b_id * param.dk_b_stride + h_id * param.dk_h_stride + s_id * param.dk_r_stride; } [[nodiscard]] index_t dv_offset(const index_t b_id, const index_t h_id, const index_t s_id) const { + if constexpr (VarLen) { + return (static_cast(param.cu_seqlens_k[b_id]) + s_id) * param.num_head_q * param.head_dim + + h_id * param.head_dim; + } return b_id * param.dv_b_stride + h_id * param.dv_h_stride + s_id * param.dv_r_stride; } [[nodiscard]] index_t lse_offset(const index_t b_id, const index_t h_id, const index_t s_id) const { - return b_id * param.seq_len_q * param.num_head_q + h_id * param.seq_len_q + s_id; + return b_id * param.max_seqlen_q * param.num_head_q + h_id * param.max_seqlen_q + s_id; } [[nodiscard]] index_t o_offset(const index_t b_id, const index_t h_id, const index_t s_id) const { + if constexpr (VarLen) { + return (static_cast(param.cu_seqlens_q[b_id]) + s_id) * param.num_head_q * param.head_dim + + h_id * param.head_dim; + } return b_id * param.o_b_stride + h_id * param.o_h_stride + s_id * param.o_r_stride; } [[nodiscard]] index_t dq_offset(const index_t b_id, const index_t h_id, const index_t s_id) const { + if constexpr (VarLen) { + return (static_cast(param.cu_seqlens_q[b_id]) + s_id) * param.num_head_q * param.head_dim + + h_id * param.head_dim; + } return b_id * param.dq_b_stride + h_id * param.dq_h_stride + s_id * param.dq_r_stride; } + [[nodiscard]] index_t dqaccum_offset(const index_t b_id, const index_t h_id, const index_t s_id) const { + return b_id * param.dqaccum_b_stride + h_id * param.dqaccum_h_stride + s_id * param.dqaccum_r_stride; + } - const BwdParam ¶m; + const BwdParam ¶m; }; // Setup strides for BHSD layout (batch, heads, seq, dim) -template -void setup_bhsd_stride_bwd(BwdParam ¶m) { +template +void setup_bhsd_stride_bwd(BwdParam ¶m) { param.q_r_stride = param.head_dim; param.q_h_stride = param.seq_len_q * param.head_dim; param.q_b_stride = param.num_head_q * param.seq_len_q * param.head_dim; @@ -235,13 +276,17 @@ void setup_bhsd_stride_bwd(BwdParam ¶m) { param.o_b_stride = param.num_head_q * param.seq_len_q * param.head_dim; param.dq_r_stride = param.head_dim; - param.dq_h_stride = param.seq_len_q_pad * param.head_dim; - param.dq_b_stride = param.num_head_q * param.seq_len_q_pad * param.head_dim; + param.dq_h_stride = param.seq_len_q * param.head_dim; + param.dq_b_stride = param.num_head_q * param.seq_len_q * param.head_dim; + + param.dqaccum_r_stride = param.head_dim; + param.dqaccum_h_stride = param.seq_len_q_pad * param.head_dim; + param.dqaccum_b_stride = param.num_head_q * param.seq_len_q_pad * param.head_dim; } // Setup strides for BSHD layout (batch, seq, heads, dim) -template -void setup_bshd_stride_bwd(BwdParam ¶m) { +template +void setup_bshd_stride_bwd(BwdParam ¶m) { param.q_r_stride = param.num_head_q * param.head_dim; param.q_h_stride = param.head_dim; param.q_b_stride = param.num_head_q * param.seq_len_q * param.head_dim; @@ -268,5 +313,9 @@ void setup_bshd_stride_bwd(BwdParam ¶m) { param.dq_r_stride = param.num_head_q * param.head_dim; param.dq_h_stride = param.head_dim; - param.dq_b_stride = param.num_head_q * param.seq_len_q_pad * param.head_dim; + param.dq_b_stride = param.num_head_q * param.seq_len_q * param.head_dim; + + param.dqaccum_r_stride = param.head_dim; + param.dqaccum_h_stride = param.seq_len_q_pad * param.head_dim; + param.dqaccum_b_stride = param.num_head_q * param.seq_len_q_pad * param.head_dim; } diff --git a/flash-attn2/tests/test_flash_attn.py b/flash-attn2/tests/test_flash_attn.py index 6fc65305..ee8ff5f9 100644 --- a/flash-attn2/tests/test_flash_attn.py +++ b/flash-attn2/tests/test_flash_attn.py @@ -510,7 +510,9 @@ def normalize_flash_attn_S( scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: scores = scores + attn_bias.to(dtype=scores.dtype) - block_size_n = _get_block_size_n(scores.device, head_dim, is_dropout, causal) + block_size_n = 64 if scores.device.type == "xpu" and seqlen_q == 1 else _get_block_size_n( + scores.device, head_dim, is_dropout, causal + ) scores_block = scores.split(block_size_n, dim=-1) lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1) lse = torch.logsumexp(lse_block, dim=-1) @@ -520,7 +522,18 @@ def normalize_flash_attn_S( scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1) # CUDA iterates K blocks right-to-left → suffix running max; # XPU iterates K blocks left-to-right → prefix running max. - if scores.device.type == "xpu": + if scores.device.type == "xpu" and seqlen_q == 1: + blocks_per_tile = 8 + num_blocks = scores_max_block.size(-1) + scores_max_by_sg = F.pad( + scores_max_block, (0, (-num_blocks) % blocks_per_tile), value=float("-inf") + ).reshape( + *scores_max_block.shape[:-1], -1, blocks_per_tile + ) + cummax_block = torch.cummax(scores_max_by_sg, dim=-2).values.flatten(-2)[ + ..., :num_blocks + ].unbind(dim=-1) + elif scores.device.type == "xpu": cummax_block = torch.cummax(scores_max_block, dim=-1).values.unbind(dim=-1) else: cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1) @@ -760,8 +773,6 @@ def test_flash_attn_varlen_qkvpacked( if device == "xpu": if alibi: pytest.skip("alibi not supported on xpu currently") - if dropout_p != 0.0: - pytest.skip("dropout for varlen not supported on xpu currently") # set seed torch.random.manual_seed(0) @@ -859,13 +870,13 @@ def test_flash_attn_varlen_qkvpacked( print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}") print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") - if device in ["xpu", "cpu"]: + if device == "cpu": assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() - print("XPU and CPU do not support backward currently, skipping grad check.") + print("CPU do not support backward currently, skipping grad check.") return g = torch.randn_like(out) - if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): + if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90) or device == "xpu": (dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g) dqkv = dqkv_pad_fn(dqkv_unpad) (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g) @@ -889,7 +900,7 @@ def test_flash_attn_varlen_qkvpacked( if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) - if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): + if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90) or device == "xpu": assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() @@ -1230,8 +1241,6 @@ def test_flash_attn_varlen_output( pytest.skip("alibi not supported on xpu currently") if softcap != 0.0: pytest.skip("softcap not supported on xpu currently") - if dropout_p != 0.0: - pytest.skip("dropout for varlen not supported on xpu currently") if device == "cpu": if alibi: pytest.skip("alibi not supported on CPU") @@ -1451,13 +1460,13 @@ def test_flash_attn_varlen_output( print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}") print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") - if device in ["xpu", "cpu"]: + if device == "cpu": assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() - print("XPU and CPU do not support backward currently, skipping grad check.") + print("CPU do not support backward currently, skipping grad check.") return g = torch.randn_like(out) - if ((d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90)): + if ((d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90) or device == "xpu"): if kvpacked: ( dq_unpad, @@ -1514,9 +1523,10 @@ def test_flash_attn_varlen_output( assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate if not alibi: - assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.04) + dropout_fraction_tol = 0.04 if local else (0.02 if device == "xpu" and seqlen_q == 1 else 0.01) + assert abs(dropout_fraction - dropout_p) <= dropout_fraction_tol - if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): + if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90) or device == "xpu": assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() @@ -1769,9 +1779,9 @@ def test_flash_attn_varlen_causal( print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - if device in ["xpu", "cpu"]: + if device == "cpu": assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 - print("XPU and CPU do not support backward currently, skipping grad check.") + print("CPU do not support backward currently, skipping grad check.") return g = torch.randn_like(out) @@ -2479,8 +2489,6 @@ def test_flash_attn_bwd_varlen_overflow(d, causal, dtype, device): """ if device == "cpu": pytest.skip("backward not supported on CPU") - if device == "xpu": - pytest.skip("bwd test not supported on xpu currently") # set seed torch.random.manual_seed(0) @@ -2599,8 +2607,6 @@ def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, loc def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype, device): if device == "cpu": pytest.skip("backward not supported on CPU") - if device == "xpu": - pytest.skip("varlen backward not supported on XPU currently") if ( device == "cuda" and max(seqlen_q, seqlen_k) >= 2048