Skip to content

Commit b61325e

Browse files
committed
pad input tensors if headdim is not multiple of 64
1 parent 89c6a49 commit b61325e

File tree

3 files changed

+29
-17
lines changed

3 files changed

+29
-17
lines changed

src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1406,19 +1406,15 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> flash_attention_backward_sycltla(
14061406
to_string(layout),
14071407
", value with layout ",
14081408
to_string(get_attn_tensor_layout(value)));
1409-
layout = fuse_attn_tensor_layout(layout, get_attn_tensor_layout(out));
1410-
TORCH_CHECK(
1411-
ATTN_TENSOR_LAYOUT::UNSUPPORTED != layout,
1412-
"FlashAttentionBackwardXPU: query and out must have the same layout, got query with layout ",
1413-
to_string(layout),
1414-
", out with layout ",
1415-
to_string(get_attn_tensor_layout(out)));
14161409
if (layout == ATTN_TENSOR_LAYOUT::BXD) {
14171410
layout = ATTN_TENSOR_LAYOUT::BHSD;
14181411
}
14191412
TORCH_CHECK(logsumexp.is_contiguous(), "logsumexp must have BHS layout");
14201413
// grad_out is created by autograd, may not have standard layout
1421-
auto contiguous_grad_out = attn_tensor_to_layout(grad_out, layout);
1414+
auto grad_out_ = attn_tensor_to_layout(grad_out, layout);
1415+
// TODO: This code block is temporary WA. Remove it after fwd supporting BHSD
1416+
// layouts
1417+
auto out_ = attn_tensor_to_layout(out, layout);
14221418

14231419
auto sycl_queue = at::xpu::getCurrentXPUStream().queue();
14241420
auto device_architecture =
@@ -1493,8 +1489,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> flash_attention_backward_sycltla(
14931489
cute::run_mha_bwd<decltype(problem_shape), kMPad, kNPad>(
14941490
sycl_queue,
14951491
problem_shape,
1496-
contiguous_grad_out.data_ptr(),
1497-
out.data_ptr(),
1492+
grad_out_.data_ptr(),
1493+
out_.data_ptr(),
14981494
query.data_ptr(),
14991495
key.data_ptr(),
15001496
value.data_ptr(),

src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -451,23 +451,39 @@ flash_attention_forward_sycltla(
451451
layout = fuse_attn_tensor_layout(layout, get_attn_tensor_layout(key));
452452
TORCH_CHECK(
453453
ATTN_TENSOR_LAYOUT::UNSUPPORTED != layout,
454-
"FlashAttentionBackwardXPU: query and key must have the same layout, got query with layout ",
454+
"FlashAttentionForwardXPU: query and key must have the same layout, got query with layout ",
455455
to_string(layout),
456456
", key with layout ",
457457
to_string(get_attn_tensor_layout(key)));
458458
layout = fuse_attn_tensor_layout(layout, get_attn_tensor_layout(value));
459459
TORCH_CHECK(
460460
ATTN_TENSOR_LAYOUT::UNSUPPORTED != layout,
461-
"FlashAttentionBackwardXPU: query and value must have the same layout, got query with layout ",
461+
"FlashAttentionForwardXPU: query and value must have the same layout, got query with layout ",
462462
to_string(layout),
463463
", value with layout ",
464464
to_string(get_attn_tensor_layout(value)));
465465
if (layout == ATTN_TENSOR_LAYOUT::BXD) {
466466
layout = ATTN_TENSOR_LAYOUT::BSHD;
467467
}
468+
469+
at::Tensor query_ = query, key_ = key, value_ = value;
470+
{
471+
// Currently fwd only supports BSHD layout.
472+
// However, input headdim may be padded when headdim is not multiple of 64.
473+
// The pad op will make input tensor become BHSD contiguous.
474+
// TODO: This code block is temporary WA. Remove it after supporting BHSD
475+
// layouts.
476+
if (layout != ATTN_TENSOR_LAYOUT::BSHD) {
477+
query_ = attn_tensor_to_layout(query, ATTN_TENSOR_LAYOUT::BSHD);
478+
key_ = attn_tensor_to_layout(key, ATTN_TENSOR_LAYOUT::BSHD);
479+
value_ = attn_tensor_to_layout(value, ATTN_TENSOR_LAYOUT::BSHD);
480+
layout = ATTN_TENSOR_LAYOUT::BSHD;
481+
}
482+
}
483+
468484
TORCH_CHECK(
469485
layout == ATTN_TENSOR_LAYOUT::BSHD,
470-
"FlashAttentionBackwardXPU: currently only support BSHD layout");
486+
"FlashAttentionForwardXPU: currently only support BSHD layout");
471487

472488
auto opts = query.options();
473489
at::Tensor out;
@@ -516,9 +532,9 @@ flash_attention_forward_sycltla(
516532
cute::run_mha_fwd<decltype(problem_shape)>(
517533
sycl_queue,
518534
problem_shape,
519-
query.data_ptr(),
520-
key.data_ptr(),
521-
value.data_ptr(),
535+
query_.data_ptr(),
536+
key_.data_ptr(),
537+
value_.data_ptr(),
522538
out.data_ptr(),
523539
logsumexp.data_ptr(),
524540
is_causal,

src/ATen/native/transformers/xpu/flash_attn/utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ inline at::Tensor attn_tensor_to_layout(
9494
return output;
9595
}
9696

97-
inline bool check_flash_attention_bshd_layout(
97+
inline bool check_flash_attention_layout(
9898
sdp::sdp_params const& params,
9999
bool debug) {
100100
sycltla::ATTN_TENSOR_LAYOUT layout =

0 commit comments

Comments
 (0)