@@ -451,23 +451,39 @@ flash_attention_forward_sycltla(
451451 layout = fuse_attn_tensor_layout (layout, get_attn_tensor_layout (key));
452452 TORCH_CHECK (
453453 ATTN_TENSOR_LAYOUT::UNSUPPORTED != layout,
454- " FlashAttentionBackwardXPU : query and key must have the same layout, got query with layout " ,
454+ " FlashAttentionForwardXPU : query and key must have the same layout, got query with layout " ,
455455 to_string (layout),
456456 " , key with layout " ,
457457 to_string (get_attn_tensor_layout (key)));
458458 layout = fuse_attn_tensor_layout (layout, get_attn_tensor_layout (value));
459459 TORCH_CHECK (
460460 ATTN_TENSOR_LAYOUT::UNSUPPORTED != layout,
461- " FlashAttentionBackwardXPU : query and value must have the same layout, got query with layout " ,
461+ " FlashAttentionForwardXPU : query and value must have the same layout, got query with layout " ,
462462 to_string (layout),
463463 " , value with layout " ,
464464 to_string (get_attn_tensor_layout (value)));
465465 if (layout == ATTN_TENSOR_LAYOUT::BXD) {
466466 layout = ATTN_TENSOR_LAYOUT::BSHD;
467467 }
468+
469+ at::Tensor query_ = query, key_ = key, value_ = value;
470+ {
471+ // Currently fwd only supports BSHD layout.
472+ // However, input headdim may be padded when headdim is not multiple of 64.
473+ // The pad op will make input tensor become BHSD contiguous.
474+ // TODO: This code block is temporary WA. Remove it after supporting BHSD
475+ // layouts.
476+ if (layout != ATTN_TENSOR_LAYOUT::BSHD) {
477+ query_ = attn_tensor_to_layout (query, ATTN_TENSOR_LAYOUT::BSHD);
478+ key_ = attn_tensor_to_layout (key, ATTN_TENSOR_LAYOUT::BSHD);
479+ value_ = attn_tensor_to_layout (value, ATTN_TENSOR_LAYOUT::BSHD);
480+ layout = ATTN_TENSOR_LAYOUT::BSHD;
481+ }
482+ }
483+
468484 TORCH_CHECK (
469485 layout == ATTN_TENSOR_LAYOUT::BSHD,
470- " FlashAttentionBackwardXPU : currently only support BSHD layout" );
486+ " FlashAttentionForwardXPU : currently only support BSHD layout" );
471487
472488 auto opts = query.options ();
473489 at::Tensor out;
@@ -516,9 +532,9 @@ flash_attention_forward_sycltla(
516532 cute::run_mha_fwd<decltype (problem_shape)>(
517533 sycl_queue,
518534 problem_shape,
519- query .data_ptr (),
520- key .data_ptr (),
521- value .data_ptr (),
535+ query_ .data_ptr (),
536+ key_ .data_ptr (),
537+ value_ .data_ptr (),
522538 out.data_ptr (),
523539 logsumexp.data_ptr (),
524540 is_causal,
0 commit comments