Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions flash-attn2/build.toml
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,22 @@ src = [
"flash_attn_xpu/src/fmha_bwd.hpp",
"flash_attn_xpu/src/fmha_bwd_impl.hpp",
"flash_attn_xpu/src/fmha_bwd.cpp",
"flash_attn_xpu/src/flash_bwd_hdim32_varlen_fp16.cpp",
"flash_attn_xpu/src/flash_bwd_hdim32_varlen_bf16.cpp",
"flash_attn_xpu/src/flash_bwd_hdim64_varlen_fp16.cpp",
"flash_attn_xpu/src/flash_bwd_hdim64_varlen_bf16.cpp",
"flash_attn_xpu/src/flash_bwd_hdim96_varlen_fp16.cpp",
"flash_attn_xpu/src/flash_bwd_hdim96_varlen_bf16.cpp",
"flash_attn_xpu/src/flash_bwd_hdim128_varlen_fp16.cpp",
"flash_attn_xpu/src/flash_bwd_hdim128_varlen_bf16.cpp",
"flash_attn_xpu/src/flash_bwd_hdim160_varlen_fp16.cpp",
"flash_attn_xpu/src/flash_bwd_hdim160_varlen_bf16.cpp",
"flash_attn_xpu/src/flash_bwd_hdim192_varlen_fp16.cpp",
"flash_attn_xpu/src/flash_bwd_hdim192_varlen_bf16.cpp",
"flash_attn_xpu/src/flash_bwd_hdim256_varlen_fp16.cpp",
"flash_attn_xpu/src/flash_bwd_hdim256_varlen_bf16.cpp",
"flash_attn_xpu/src/flash_bwd_hdim512_varlen_fp16.cpp",
"flash_attn_xpu/src/flash_bwd_hdim512_varlen_bf16.cpp",
"flash_attn_xpu/src/flash_bwd_hdim32_fix_fp16.cpp",
"flash_attn_xpu/src/flash_bwd_hdim32_fix_bf16.cpp",
"flash_attn_xpu/src/flash_bwd_hdim64_fix_fp16.cpp",
Expand Down
276 changes: 271 additions & 5 deletions flash-attn2/flash_attn_xpu/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,8 @@ mha_varlen_fwd(
compat::select_device(device_idx);

// check inputs
TORCH_CHECK(p_dropout < 1.0f, "FlashAttention dropout probability must be less than 1.0");

q = ensure_contiguous(q);
const auto sizes = q.sizes();
const int total_q = sizes[0];
Expand Down Expand Up @@ -446,6 +448,31 @@ mha_varlen_fwd(

auto queue = c10::xpu::getCurrentXPUStream(device_idx).queue();

uint64_t philox_seed = 0;
uint64_t philox_offset = 0;
at::Tensor rng_state = at::empty({2}, q.options().dtype(at::kLong));

if (p_dropout > 0.0f) {
int64_t counter_offset = batch_size * num_heads * 32;
auto [seed, offset] = get_philox_state(gen_, counter_offset);
philox_seed = seed;
philox_offset = offset;
rng_state[0] = static_cast<int64_t>(philox_seed);
rng_state[1] = static_cast<int64_t>(philox_offset);
}

const int q_tile_size = (head_size_og <= 32) ? 64 : (head_size_og <= 128) ? 128 : 256;
const int k_tile_size = (max_seqlen_q == 1 && !is_paged) ? 512 : 128;
const int seqlen_q_rounded = round_multiple(max_seqlen_q, q_tile_size);
const int seqlen_k_rounded = round_multiple(max_seqlen_k, k_tile_size);
at::Tensor S_dmask;
if (return_softmax) {
TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
S_dmask = torch::empty({batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded}, q.options());
} else {
S_dmask = torch::empty({0}, q.options());
}

cutlass_fmha_fwd_varlen_impl(
queue,
q_padded, k_padded, v_padded, out_padded,
Expand All @@ -455,7 +482,10 @@ mha_varlen_fwd(
max_seqlen_q, max_seqlen_k,
softmax_scale,
window_size_left, window_size_right,
true, is_paged, is_causal, is_local);
true, is_paged, is_causal, is_local,
p_dropout, philox_seed, philox_offset, nullptr,
return_softmax ? S_dmask.data_ptr() : nullptr,
seqlen_q_rounded, seqlen_k_rounded);

// Strip padding from output back to original head_size
at::Tensor out = needs_padding
Expand All @@ -464,11 +494,214 @@ mha_varlen_fwd(
.contiguous()
: ensure_contiguous(out_padded);

at::Tensor S_dmask;
at::Tensor rng_state;
return {out, softmax_lse, S_dmask, rng_state};
}

std::vector<at::Tensor>
mha_varlen_bwd(
const at::Tensor &dout,
const at::Tensor &q,
const at::Tensor &k,
const at::Tensor &v,
const at::Tensor &out,
const at::Tensor &softmax_lse,
std::optional<at::Tensor> &dq_,
std::optional<at::Tensor> &dk_,
std::optional<at::Tensor> &dv_,
const at::Tensor &cu_seqlens_q,
const at::Tensor &cu_seqlens_k,
std::optional<at::Tensor> &alibi_slopes_,
const int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
const float softmax_scale,
const bool zero_tensors,
const bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool deterministic,
std::optional<at::Generator> gen_,
std::optional<at::Tensor> &rng_state) {

auto device_idx = q.device().index();
compat::select_device(device_idx);

TORCH_CHECK(p_dropout < 1.0f, "FlashAttention dropout probability must be less than 1.0");

auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention varlen backward only supports fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");

CHECK_DEVICE(q);
CHECK_DEVICE(k);
CHECK_DEVICE(v);
CHECK_DEVICE(out);
CHECK_DEVICE(dout);
CHECK_DEVICE(softmax_lse);
CHECK_DEVICE(cu_seqlens_q);
CHECK_DEVICE(cu_seqlens_k);

TORCH_CHECK(q.dim() == 3, "q must have shape (total_q, num_heads, head_size)");
TORCH_CHECK(k.dim() == 3, "k must have shape (total_k, num_heads_k, head_size)");
TORCH_CHECK(v.dim() == 3, "v must have shape (total_k, num_heads_k, head_size)");
TORCH_CHECK(out.sizes() == q.sizes(), "out must have the same shape as q");
TORCH_CHECK(dout.sizes() == q.sizes(), "dout must have the same shape as q");
TORCH_CHECK(k.sizes() == v.sizes(), "k and v must have the same shape");
TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_k.dim() == 1,
"cu_seqlens_q and cu_seqlens_k must be 1D tensors");
TORCH_CHECK(cu_seqlens_q.scalar_type() == at::ScalarType::Int &&
cu_seqlens_k.scalar_type() == at::ScalarType::Int,
"cu_seqlens_q and cu_seqlens_k must be int32 tensors");
TORCH_CHECK(cu_seqlens_q.numel() == cu_seqlens_k.numel(),
"cu_seqlens_q and cu_seqlens_k must have the same length");
TORCH_CHECK(max_seqlen_q > 0 && max_seqlen_k > 0,
"max_seqlen_q and max_seqlen_k must be positive");

at::Tensor q_contig = ensure_contiguous(q);
at::Tensor k_contig = ensure_contiguous(k);
at::Tensor v_contig = ensure_contiguous(v);
at::Tensor out_contig = ensure_contiguous(out);
at::Tensor dout_contig = ensure_contiguous(dout);
at::Tensor softmax_lse_contig = ensure_contiguous(softmax_lse);
at::Tensor cu_q_contig = ensure_contiguous(cu_seqlens_q);
at::Tensor cu_k_contig = ensure_contiguous(cu_seqlens_k);

TORCH_CHECK(q_contig.stride(-1) == 1, "q must have contiguous last dimension");
TORCH_CHECK(k_contig.stride(-1) == 1, "k must have contiguous last dimension");
TORCH_CHECK(v_contig.stride(-1) == 1, "v must have contiguous last dimension");
TORCH_CHECK(out_contig.stride(-1) == 1, "out must have contiguous last dimension");
TORCH_CHECK(dout_contig.stride(-1) == 1, "dout must have contiguous last dimension");

const int total_q = q_contig.size(0);
const int total_k = k_contig.size(0);
const int num_heads = q_contig.size(1);
const int head_size_og = q_contig.size(2);
const int num_heads_k = k_contig.size(1);
const int batch_size = cu_q_contig.numel() - 1;

TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size_og % 8 == 0, "head_size should be a multiple of 8");
TORCH_CHECK(head_size_og <= 256 || head_size_og == 512,
"FlashAttention XPU varlen backward only supports head dimension up to 256 or exactly 512. Got: " + std::to_string(head_size_og));
TORCH_CHECK(k_contig.size(2) == head_size_og && v_contig.size(2) == head_size_og,
"k and v head dimensions must match q");
TORCH_CHECK(num_heads % num_heads_k == 0,
"Number of heads in key/value must divide number of heads in query");
TORCH_CHECK(softmax_lse_contig.dim() == 3 && softmax_lse_contig.size(0) == batch_size &&
softmax_lse_contig.size(1) == num_heads && softmax_lse_contig.size(2) == max_seqlen_q,
"softmax_lse must have shape (batch_size, num_heads, max_seqlen_q)");
const int head_size_padded = round_multiple(head_size_og, 32);
const bool needs_padding = (head_size_og != head_size_padded);
const int pad_size = head_size_padded - head_size_og;

auto maybe_pad = [&](const at::Tensor &t) -> at::Tensor {
return needs_padding
? torch::nn::functional::pad(t, torch::nn::functional::PadFuncOptions({0, pad_size}))
: t;
};

at::Tensor q_padded = maybe_pad(q_contig);
at::Tensor k_padded = maybe_pad(k_contig);
at::Tensor v_padded = maybe_pad(v_contig);
at::Tensor out_padded = maybe_pad(out_contig);
at::Tensor dout_padded = maybe_pad(dout_contig);

auto opts = q_contig.options();
at::Tensor dq, dk, dv;
at::Tensor dq_work, dk_work, dv_work;
bool dq_needs_copy = false, dk_needs_copy = false, dv_needs_copy = false;

auto get_or_alloc = [&](const c10::optional<at::Tensor> &opt,
const std::vector<int64_t> &sizes,
bool &needs_copy) -> at::Tensor {
if (opt.has_value() && opt.value().is_contiguous()) {
return opt.value();
}
needs_copy = opt.has_value();
return torch::empty(sizes, opts);
};

if (!needs_padding && num_heads_k == num_heads) {
dq_work = get_or_alloc(dq_, {total_q, num_heads, head_size_og}, dq_needs_copy);
dk_work = get_or_alloc(dk_, {total_k, num_heads_k, head_size_og}, dk_needs_copy);
dv_work = get_or_alloc(dv_, {total_k, num_heads_k, head_size_og}, dv_needs_copy);
} else {
dq_work = torch::empty({total_q, num_heads, head_size_padded}, opts);
dk_work = torch::empty({total_k, num_heads_k, head_size_padded}, opts);
dv_work = torch::empty({total_k, num_heads_k, head_size_padded}, opts);
}

at::Tensor softmax_d = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));

at::Tensor dk_expanded, dv_expanded;
if (num_heads_k != num_heads) {
dk_expanded = torch::empty({total_k, num_heads, head_size_padded}, opts);
dv_expanded = torch::empty({total_k, num_heads, head_size_padded}, opts);
} else {
dk_expanded = dk_work;
dv_expanded = dv_work;
}

const bool is_local = (window_size_left != -1) || (window_size_right != -1);
auto queue = c10::xpu::getCurrentXPUStream(device_idx).queue();

uint64_t philox_seed = 0;
uint64_t philox_offset = 0;
if (p_dropout > 0.0f && rng_state.has_value()) {
auto rng_state_val = rng_state.value();
philox_seed = static_cast<uint64_t>(rng_state_val[0].item<int64_t>());
philox_offset = static_cast<uint64_t>(rng_state_val[1].item<int64_t>());
}

cutlass_fmha_bwd_varlen_impl(
queue,
dout_padded, q_padded, k_padded, v_padded, out_padded, softmax_lse_contig,
cu_q_contig, cu_k_contig,
dq_work, dk_expanded, dv_expanded, softmax_d,
softmax_scale,
max_seqlen_q, max_seqlen_k,
window_size_left, window_size_right, is_causal, is_local,
p_dropout, philox_seed, philox_offset, deterministic);

if (num_heads_k != num_heads) {
at::sum_out(dk_work, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size_padded}), {2});
at::sum_out(dv_work, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size_padded}), {2});
}

auto slice_head = [&](const at::Tensor &t) -> at::Tensor {
return t.index({torch::indexing::Slice(), torch::indexing::Slice(),
torch::indexing::Slice(0, head_size_og)}).contiguous();
};
auto copy_if_provided = [&](const c10::optional<at::Tensor> &opt, const at::Tensor &src) {
if (opt.has_value()) {
opt.value().copy_(src);
}
};

if (needs_padding || num_heads_k != num_heads) {
dq = slice_head(dq_work);
dk = slice_head(dk_work);
dv = slice_head(dv_work);
copy_if_provided(dq_, dq);
copy_if_provided(dk_, dk);
copy_if_provided(dv_, dv);
} else {
dq = dq_work;
dk = dk_work;
dv = dv_work;
if (dq_needs_copy) dq_.value().copy_(dq);
if (dk_needs_copy) dk_.value().copy_(dk);
if (dv_needs_copy) dv_.value().copy_(dv);
}

return {dq, dk, dv, softmax_d};
}

std::vector<at::Tensor>
mha_fwd_kvcache(
at::Tensor &q,
Expand Down Expand Up @@ -749,8 +982,41 @@ mha_varlen_bwd(
const bool deterministic,
c10::optional<at::Generator> gen_,
const c10::optional<torch::Tensor> &rng_state) {
TORCH_CHECK(false, "FlashAttention varlen backward is not supported on XPU yet.");
return {};
auto to_std_opt = [](const c10::optional<at::Tensor>& opt) -> std::optional<at::Tensor> {
return opt.has_value() ? std::optional<at::Tensor>(opt.value()) : std::nullopt;
};

auto dq_opt = to_std_opt(dq_);
auto dk_opt = to_std_opt(dk_);
auto dv_opt = to_std_opt(dv_);
auto alibi_opt = to_std_opt(alibi_slopes_);
auto rng_opt = to_std_opt(rng_state);

return FLASH_NAMESPACE::mha_varlen_bwd(
dout,
q,
k,
v,
out,
softmax_lse,
dq_opt,
dk_opt,
dv_opt,
cu_seqlens_q,
cu_seqlens_k,
alibi_opt,
static_cast<int>(max_seqlen_q),
static_cast<int>(max_seqlen_k),
static_cast<float>(p_dropout),
static_cast<float>(softmax_scale),
zero_tensors,
is_causal,
static_cast<int>(window_size_left),
static_cast<int>(window_size_right),
static_cast<float>(softcap),
deterministic,
gen_,
rng_opt);
}

std::vector<torch::Tensor>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

// Fixed mode backward for head_dim=128, dtype=bf16

template void bwd_policy_dispatch_bf16<bwd_policy_head128, 0, 0>(
template void bwd_policy_dispatch_bf16<bwd_policy_head128, 0, 0, 0>(
sycl::queue& queue, const fmha_bwd_args_t& args);

template void bwd_policy_dispatch_bf16<bwd_policy_head128, 0, 1>(
template void bwd_policy_dispatch_bf16<bwd_policy_head128, 0, 1, 0>(
sycl::queue& queue, const fmha_bwd_args_t& args);

template void bwd_policy_dispatch_bf16<bwd_policy_head128, 1, 0>(
template void bwd_policy_dispatch_bf16<bwd_policy_head128, 1, 0, 0>(
sycl::queue& queue, const fmha_bwd_args_t& args);

template void bwd_policy_dispatch_bf16<bwd_policy_head128, 1, 1>(
template void bwd_policy_dispatch_bf16<bwd_policy_head128, 1, 1, 0>(
sycl::queue& queue, const fmha_bwd_args_t& args);
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

// Fixed mode backward for head_dim=128, dtype=fp16

template void bwd_policy_dispatch_fp16<bwd_policy_head128, 0, 0>(
template void bwd_policy_dispatch_fp16<bwd_policy_head128, 0, 0, 0>(
sycl::queue& queue, const fmha_bwd_args_t& args);

template void bwd_policy_dispatch_fp16<bwd_policy_head128, 0, 1>(
template void bwd_policy_dispatch_fp16<bwd_policy_head128, 0, 1, 0>(
sycl::queue& queue, const fmha_bwd_args_t& args);

template void bwd_policy_dispatch_fp16<bwd_policy_head128, 1, 0>(
template void bwd_policy_dispatch_fp16<bwd_policy_head128, 1, 0, 0>(
sycl::queue& queue, const fmha_bwd_args_t& args);

template void bwd_policy_dispatch_fp16<bwd_policy_head128, 1, 1>(
template void bwd_policy_dispatch_fp16<bwd_policy_head128, 1, 1, 0>(
sycl::queue& queue, const fmha_bwd_args_t& args);
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include "fmha_bwd_impl.hpp"

// Varlen mode backward for head_dim=128, dtype=bf16

template void bwd_policy_dispatch_bf16<bwd_policy_head128, 0, 0, 1>(
sycl::queue& queue, const fmha_bwd_args_t& args);

template void bwd_policy_dispatch_bf16<bwd_policy_head128, 0, 1, 1>(
sycl::queue& queue, const fmha_bwd_args_t& args);

template void bwd_policy_dispatch_bf16<bwd_policy_head128, 1, 0, 1>(
sycl::queue& queue, const fmha_bwd_args_t& args);

template void bwd_policy_dispatch_bf16<bwd_policy_head128, 1, 1, 1>(
sycl::queue& queue, const fmha_bwd_args_t& args);
Loading
Loading