Skip to content

Commit 07bfd2f

Browse files
YangKai0616vasqu
andauthored
[XPU] Add flash_attn2 support for XPU (#41956)
* Add flash_attention_2 and kernels-community/flash-attn support for XPU * Add flash-attn-2 support for XPU * Delete deterministic algorithm for xpu * Fix code style * Modify repo_id to match the latest kernels-community/flash-attn2 * Fix code style * Update * Make quality * Use kernels loading * Update * Delete invalid import * Update comment --------- Co-authored-by: Anton Vlasjuk <[email protected]>
1 parent 9162e19 commit 07bfd2f

23 files changed

+69
-49
lines changed

src/transformers/modeling_flash_attention_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
is_flash_attn_3_available,
2525
is_flash_attn_greater_or_equal_2_10,
2626
is_torch_npu_available,
27+
is_torch_xpu_available,
2728
logging,
2829
)
2930

@@ -45,7 +46,12 @@ def flash_attn_supports_top_left_mask():
4546

4647
# TODO Deprecate when all models have the attention interface
4748
def is_flash_attn_available():
48-
return is_flash_attn_3_available() or is_flash_attn_2_available() or is_torch_npu_available()
49+
return (
50+
is_flash_attn_3_available()
51+
or is_flash_attn_2_available()
52+
or is_torch_npu_available()
53+
or is_torch_xpu_available()
54+
)
4955

5056

5157
# `globals()` is not compatible with dynamo, hence we have do define them in global scope ourselves
@@ -97,7 +103,7 @@ def _lazy_imports(implementation: Optional[str]):
97103
if flash_attn_varlen_func is None or flash_attn_func is None:
98104
raise ValueError(
99105
f"Could not find the currently requested flash attention implementation at `{implementation}`."
100-
f"Make sure that you request a valid kernel from the hub, e.g. `kernels-community/flash-attn`."
106+
f"Make sure that you request a valid kernel from the hub, e.g. `kernels-community/flash-attn2`."
101107
)
102108

103109
return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input

src/transformers/modeling_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
is_torch_greater_or_equal,
117117
is_torch_mlu_available,
118118
is_torch_npu_available,
119+
is_torch_xpu_available,
119120
logging,
120121
)
121122
from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder
@@ -1575,6 +1576,10 @@ def _flash_attn_2_can_dispatch(self, is_init_check: bool = False) -> bool:
15751576
logger.info("Detect using FlashAttention2 on Ascend NPU.")
15761577
return True
15771578

1579+
if is_torch_xpu_available():
1580+
logger.info("Detect using FlashAttention2 (via kernel `kernels-community/flash-attn2`) on XPU.")
1581+
return True
1582+
15781583
if importlib.util.find_spec("flash_attn") is None:
15791584
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
15801585
else:
@@ -1800,7 +1805,10 @@ def _check_and_adjust_attn_implementation(
18001805
and not is_torch_npu_available()
18011806
):
18021807
if attn_implementation.endswith("2"):
1803-
applicable_attn_implementation = "kernels-community/flash-attn"
1808+
applicable_attn_implementation = "kernels-community/flash-attn2"
1809+
if is_torch_xpu_available():
1810+
# On XPU, kernels library is the native implementation. Rename variable to avoid "fallback" warning and irrelevant checks.
1811+
attn_implementation = "kernels-community/flash-attn2"
18041812
else:
18051813
applicable_attn_implementation = "kernels-community/vllm-flash-attn3"
18061814

src/transformers/testing_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ def require_flash_attn(test_case):
593593
try:
594594
from kernels import get_kernel
595595

596-
get_kernel("kernels-community/flash-attn")
596+
get_kernel("kernels-community/flash-attn2")
597597
except Exception as _:
598598
kernels_available = False
599599

tests/causal_lm_tester.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
_COMMON_MODEL_NAMES_MAP,
2525
is_flaky,
2626
require_flash_attn,
27-
require_torch_gpu,
27+
require_torch_accelerator,
2828
slow,
2929
)
3030

@@ -550,7 +550,7 @@ def test_model_rope_scaling_frequencies(self):
550550
torch.testing.assert_close(yarn_sin_long, original_sin_long)
551551

552552
@require_flash_attn
553-
@require_torch_gpu
553+
@require_torch_accelerator
554554
@pytest.mark.flash_attn_test
555555
@is_flaky()
556556
@slow

tests/generation/test_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1848,7 +1848,7 @@ def test_eager_matches_sdpa_generate(self):
18481848

18491849
@pytest.mark.flash_attn_test
18501850
@require_flash_attn
1851-
@require_torch_gpu
1851+
@require_torch_accelerator
18521852
@slow
18531853
def test_eager_matches_fa2_generate(self):
18541854
"""Tests that generate has equivalent outputs with FA2 and eager attention implementations."""
@@ -1863,7 +1863,7 @@ def test_eager_matches_fa3_generate(self):
18631863
self._test_attention_implementation("flash_attention_3")
18641864

18651865
@require_flash_attn
1866-
@require_torch_gpu
1866+
@require_torch_accelerator
18671867
@pytest.mark.flash_attn_test
18681868
def test_flash_attention_2_continue_generate_with_position_ids(self):
18691869
"""
@@ -2065,14 +2065,14 @@ def test_sdpa_padding_matches_padding_free_with_position_ids(self):
20652065
self.attention_mask_padding_matches_padding_free_with_position_ids(attn_implementation="sdpa")
20662066

20672067
@require_flash_attn
2068-
@require_torch_gpu
2068+
@require_torch_accelerator
20692069
@pytest.mark.flash_attn_test
20702070
@slow
20712071
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
20722072
self.attention_mask_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_2")
20732073

20742074
@require_flash_attn
2075-
@require_torch_gpu
2075+
@require_torch_accelerator
20762076
@pytest.mark.flash_attn_test
20772077
@slow
20782078
def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):

tests/models/bamba/test_modeling_bamba.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
require_flash_attn,
3535
require_torch,
3636
require_torch_accelerator,
37-
require_torch_gpu,
3837
slow,
3938
torch_device,
4039
)
@@ -444,7 +443,7 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa
444443
pass
445444

446445
@require_flash_attn
447-
@require_torch_gpu
446+
@require_torch_accelerator
448447
@mark.flash_attn_test
449448
@slow
450449
@unittest.skip(

tests/models/diffllama/test_modeling_diffllama.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
require_read_token,
3030
require_torch,
3131
require_torch_accelerator,
32-
require_torch_gpu,
3332
slow,
3433
torch_device,
3534
)
@@ -324,7 +323,7 @@ def _reinitialize_config(base_config, new_kwargs):
324323
) # missing "factor"
325324

326325
@require_flash_attn
327-
@require_torch_gpu
326+
@require_torch_accelerator
328327
@require_bitsandbytes
329328
@pytest.mark.flash_attn_test
330329
@require_read_token
@@ -364,7 +363,7 @@ def test_flash_attn_2_generate_padding_right(self):
364363
self.assertListEqual(output_native, output_fa_2)
365364

366365
@require_flash_attn
367-
@require_torch_gpu
366+
@require_torch_accelerator
368367
@slow
369368
@pytest.mark.flash_attn_test
370369
def test_use_flash_attention_2_true(self):
@@ -379,7 +378,7 @@ def test_use_flash_attention_2_true(self):
379378

380379
new_model = DiffLlamaForCausalLM.from_pretrained(
381380
tmp_dir, attn_implementation="flash_attention_2", dtype=torch.float16
382-
).to("cuda")
381+
).to(torch_device)
383382

384383
self.assertTrue(new_model.config._attn_implementation == "flash_attention_2")
385384

tests/models/ernie4_5_moe/test_modeling_ernie4_5_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
require_bitsandbytes,
2626
require_flash_attn,
2727
require_torch,
28-
require_torch_gpu,
28+
require_torch_accelerator,
2929
require_torch_large_accelerator,
3030
require_torch_multi_accelerator,
3131
slow,
@@ -56,7 +56,7 @@ class Ernie4_5_MoeModelTest(CausalLMModelTest, unittest.TestCase):
5656
model_tester_class = Ernie4_5_MoeModelTester
5757

5858
@require_flash_attn
59-
@require_torch_gpu
59+
@require_torch_accelerator
6060
@pytest.mark.flash_attn_test
6161
@is_flaky()
6262
@slow

tests/models/esm/test_modeling_esm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
require_bitsandbytes,
2626
require_flash_attn,
2727
require_torch,
28-
require_torch_gpu,
28+
require_torch_accelerator,
2929
slow,
3030
torch_device,
3131
)
@@ -306,7 +306,7 @@ def test_resize_tokens_embeddings(self):
306306
pass
307307

308308
@require_flash_attn
309-
@require_torch_gpu
309+
@require_torch_accelerator
310310
@pytest.mark.flash_attn_test
311311
@is_flaky()
312312
@slow

tests/models/glm4/test_modeling_glm4.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
require_flash_attn,
2626
require_torch,
2727
require_torch_large_accelerator,
28-
require_torch_large_gpu,
2928
slow,
3029
torch_device,
3130
)
@@ -177,7 +176,7 @@ def test_model_9b_sdpa(self):
177176
self.assertEqual(output_text, EXPECTED_TEXT)
178177

179178
@require_flash_attn
180-
@require_torch_large_gpu
179+
@require_torch_large_accelerator
181180
@pytest.mark.flash_attn_test
182181
def test_model_9b_flash_attn(self):
183182
EXPECTED_TEXTS = Expectations(
@@ -187,6 +186,10 @@ def test_model_9b_flash_attn(self):
187186
"Hello I am doing a project on the history of the internet and I need to know what the first website was and what",
188187
"Hi today I am going to tell you about the most common disease in the world. This disease is called diabetes",
189188
],
189+
("xpu", None): [
190+
"Hello I am doing a project on the history of the internet and I need to know what the first website was and what",
191+
"Hi today I am going to tell you about the most common disease in the world. This disease is called diabetes",
192+
],
190193
}
191194
)
192195
EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation()

0 commit comments

Comments
 (0)