Skip to content

Commit 2e93004

Browse files
authored
extend FA2 and other cases to XPU, (#42536)
* extend FA2 and other cases to XPU, we expect all model cases except CUDAGraph specific, CUDA compute capability specific and FA3 specific can run XPU. For FA3, we are develioping Signed-off-by: Yao, Matrix <[email protected]> * fix style Signed-off-by: Yao, Matrix <[email protected]> * Update modeling_mimi.py --------- Signed-off-by: Yao, Matrix <[email protected]>
1 parent a48d68c commit 2e93004

29 files changed

+149
-93
lines changed

tests/generation/test_continuous_batching.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
require_kernels,
2626
require_read_token,
2727
require_torch_accelerator,
28-
require_torch_gpu,
2928
slow,
3029
torch_device,
3130
)
@@ -315,36 +314,47 @@ def test_continuous_batching_parity_gemma_sdpa(self) -> None:
315314
# GPT-OSS is not compatible with SDPA because it has an attention sink. TODO: is this fixable?
316315

317316
# Flash attention test
318-
@require_torch_gpu
317+
@require_torch_accelerator
319318
@require_kernels
320319
@slow
321320
def test_continuous_batching_parity_llama_flash(self) -> None:
322321
expected_outputs = Expectations({
323322
("cuda", (9, 0)): {
324323
"req_1": " 3 bolts of blue fiber and 1.5 bolts of white fiber. The total number of bolts is 4.5 bolts. The total number of bolts is 4.5 bolts.",
325-
}
324+
},
325+
("xpu", None): {
326+
"req_1": " 3 bolts of blue fiber and 1.5 bolts of white fiber. The total number of bolts is 4.5 bolts. The total number of bolts is 4.5 bolts.",
327+
},
326328
}).get_expectation() # fmt: skip
327329
self._continuous_batching_parity("meta-llama/Llama-3.1-8B", "paged|flash_attention_2", expected_outputs)
328330

329-
@require_torch_gpu
331+
@require_torch_accelerator
330332
@require_kernels
331333
@slow
332334
def test_continuous_batching_parity_gemma_flash(self) -> None:
333335
expected_outputs = Expectations({
334336
("cuda", (9, 0)): {
335337
"req_1": " \n \n 2 + 1 = 3 bolts \n \n \n \n \n \n \n \n \n \n \n \n \n ",
336-
}
338+
},
339+
("xpu", None): {
340+
"req_0": "\n\n**$128**\n\n**Here's how to solve it:**\n\n* **Eggs eaten:** 3\n* **Eggs left:** 16 - 3 = 1",
341+
"req_1": "\n\n**Answer:** 3 bolts\n\n**Solution:**\n\n* **White fiber:** The robe needs half as much white fiber as blue fiber, so it needs 2 bolts / 2 =",
342+
},
337343
}).get_expectation() # fmt: skip
338344
self._continuous_batching_parity("google/gemma-2-2b-it", "paged|flash_attention_2", expected_outputs)
339345

340-
@require_torch_gpu
346+
@require_torch_accelerator
341347
@require_kernels
342348
@slow
343349
def test_continuous_batching_parity_qwen_flash(self) -> None:
344-
expected_outputs = {}
350+
expected_outputs = Expectations({
351+
("xpu", None): {
352+
"req_1": " 3.5 bolts.\n\nLet's break it down step by step:\n\n- Blue fiber: 2 bolts\n- White fiber: half of 2 bolts = 1 bolt\n\nTotal = ",
353+
},
354+
}).get_expectation() # fmt: skip
345355
self._continuous_batching_parity("Qwen/Qwen3-4B-Instruct-2507", "paged|flash_attention_2", expected_outputs)
346356

347-
@require_torch_gpu
357+
@require_torch_accelerator
348358
@require_kernels
349359
@slow
350360
def test_continuous_batching_parity_gpt_oss_flash(self) -> None:

tests/generation/test_paged_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from parameterized import parameterized
55

66
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
7-
from transformers.testing_utils import require_flash_attn, require_torch_gpu, slow
7+
from transformers.testing_utils import require_flash_attn, require_torch_accelerator, slow
88

99

1010
_TEST_PROMPTS = [
@@ -26,7 +26,7 @@
2626

2727
@slow
2828
@require_flash_attn
29-
@require_torch_gpu
29+
@require_torch_accelerator
3030
class TestBatchGeneration(unittest.TestCase):
3131
@classmethod
3232
def setUpClass(cls):

tests/models/gemma2/test_modeling_gemma2.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
require_torch,
3434
require_torch_accelerator,
3535
require_torch_large_accelerator,
36-
require_torch_large_gpu,
3736
run_test_using_subprocess,
3837
slow,
3938
torch_device,
@@ -172,16 +171,25 @@ def test_model_2b_pipeline_bf16_flex_attention(self):
172171

173172
@require_read_token
174173
@require_flash_attn
175-
@require_torch_large_gpu
174+
@require_torch_large_accelerator
176175
@mark.flash_attn_test
177176
@slow
178177
def test_model_9b_flash_attn(self):
179178
# See https://github.com/huggingface/transformers/issues/31953 --- flash attn was generating garbage for gemma2, especially in long context
180179
model_id = "google/gemma-2-9b"
181-
EXPECTED_TEXTS = [
182-
'<bos>Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many people died in the United States. I have found a few sites that say 500,000 but I am not sure if that is correct. I have also found a site that says 675,000 but I am not sure if that is correct either. I am trying to find out how many people died in the United States. I have found a few',
183-
"<pad><pad><bos>Hi today I'm going to be talking about the history of the United States. The United States of America is a country in North America. It is the third largest country in the world by total area and the third most populous country with over 320 million people. The United States is a federal republic composed of 50 states and a federal district. The 48 contiguous states and the district of Columbia are in central North America between Canada and Mexico. The state of Alaska is in the",
184-
] # fmt: skip
180+
# fmt: off
181+
EXPECTED_TEXTS = Expectations(
182+
{
183+
(None, None): ['<bos>Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many people died in the United States. I have found a few sites that say 500,000 but I am not sure if that is correct. I have also found a site that says 675,000 but I am not sure if that is correct either. I am trying to find out how many people died in the United States. I have found a few',
184+
"<pad><pad><bos>Hi today I'm going to be talking about the history of the United States. The United States of America is a country in North America. It is the third largest country in the world by total area and the third most populous country with over 320 million people. The United States is a federal republic composed of 50 states and a federal district. The 48 contiguous states and the district of Columbia are in central North America between Canada and Mexico. The state of Alaska is in the",
185+
],
186+
("xpu", None): ['<bos>Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many people died in the United States. I have found a few sites that say 500,000 but I am not sure if that is correct. I have also found a site that says 675,000 but I am not sure if that is correct either. I am trying to find out how many people died in the United States. I have found a few',
187+
"<pad><pad><bos>Hi today I'm going to be talking about the history of the United States. The United States of America is a country in North America. It is the third largest country in the world by total area and the third most populous country with over 320 million people. The United States is a federal republic consisting of 50 states and a federal district. The 48 contiguous states and the district of Columbia are in central North America between Canada and Mexico. The state of Alaska is in the",
188+
],
189+
}
190+
)
191+
# fmt: on
192+
EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation()
185193

186194
model = AutoModelForCausalLM.from_pretrained(
187195
model_id, attn_implementation="flash_attention_2", dtype="float16"
@@ -192,7 +200,7 @@ def test_model_9b_flash_attn(self):
192200
output = model.generate(**inputs, max_new_tokens=100, do_sample=False)
193201
output_text = tokenizer.batch_decode(output, skip_special_tokens=False)
194202

195-
self.assertEqual(output_text, EXPECTED_TEXTS)
203+
self.assertEqual(output_text, EXPECTED_TEXT)
196204

197205
@pytest.mark.torch_export_test
198206
@slow

tests/models/gemma3/test_modeling_gemma3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ def test_automodelforcausallm(self):
440440
self.assertIsInstance(for_causal_lm, Gemma3ForConditionalGeneration)
441441

442442
@require_flash_attn
443-
@require_torch_gpu
443+
@require_torch_accelerator
444444
@mark.flash_attn_test
445445
@slow
446446
def test_flash_attn_2_from_config(self):

tests/models/glm4v/test_modeling_glm4v.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
require_deterministic_for_xpu,
3030
require_flash_attn,
3131
require_torch,
32-
require_torch_gpu,
32+
require_torch_accelerator,
3333
slow,
3434
torch_device,
3535
)
@@ -512,7 +512,7 @@ def test_small_model_integration_test_batch_different_resolutions(self):
512512

513513
@slow
514514
@require_flash_attn
515-
@require_torch_gpu
515+
@require_torch_accelerator
516516
def test_small_model_integration_test_batch_flashatt2(self):
517517
model = Glm4vForConditionalGeneration.from_pretrained(
518518
"THUDM/GLM-4.1V-9B-Thinking",
@@ -547,7 +547,7 @@ def test_small_model_integration_test_batch_flashatt2(self):
547547

548548
@slow
549549
@require_flash_attn
550-
@require_torch_gpu
550+
@require_torch_accelerator
551551
def test_small_model_integration_test_batch_wo_image_flashatt2(self):
552552
model = Glm4vForConditionalGeneration.from_pretrained(
553553
"THUDM/GLM-4.1V-9B-Thinking",

tests/models/glm4v_moe/test_modeling_glm4v_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
cleanup,
2828
require_flash_attn,
2929
require_torch,
30-
require_torch_gpu,
30+
require_torch_accelerator,
3131
run_first,
3232
slow,
3333
torch_device,
@@ -434,7 +434,7 @@ def test_small_model_integration_test_with_video(self):
434434

435435
@run_first
436436
@require_flash_attn
437-
@require_torch_gpu
437+
@require_torch_accelerator
438438
def test_small_model_integration_test_batch_flashatt2(self):
439439
model = Glm4vMoeForConditionalGeneration.from_pretrained(
440440
"zai-org/GLM-4.5V",

tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from transformers.testing_utils import (
3131
require_flash_attn,
3232
require_torch,
33-
require_torch_gpu,
33+
require_torch_accelerator,
3434
slow,
3535
torch_device,
3636
)
@@ -235,7 +235,7 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa
235235
pass
236236

237237
@require_flash_attn
238-
@require_torch_gpu
238+
@require_torch_accelerator
239239
@mark.flash_attn_test
240240
@slow
241241
@unittest.skip(
@@ -356,7 +356,7 @@ def test_config_requires_mamba_or_attention_layers(self):
356356

357357
# TODO (@alex-jw-brooks) - update this once the model(s) are out
358358
@unittest.skip(reason="GraniteMoeHybrid models are not yet released")
359-
@require_torch_gpu
359+
@require_torch_accelerator
360360
class GraniteMoeHybridIntegrationTest(unittest.TestCase):
361361
@slow
362362
def test_model_logits(self):

tests/models/idefics2/test_modeling_idefics2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
require_bitsandbytes,
3737
require_flash_attn,
3838
require_torch,
39-
require_torch_gpu,
39+
require_torch_accelerator,
4040
require_torch_multi_accelerator,
4141
slow,
4242
torch_device,
@@ -645,7 +645,7 @@ def test_integration_test_4bit_batch2(self):
645645

646646
@pytest.mark.flash_attn_test
647647
@require_flash_attn
648-
@require_torch_gpu
648+
@require_torch_accelerator
649649
@require_bitsandbytes
650650
def test_flash_attn_2_eager_equivalence(self):
651651
# Create inputs

tests/models/kosmos2_5/test_modeling_kosmos2_5.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
require_flash_attn,
3434
require_torch,
3535
require_torch_accelerator,
36-
require_torch_gpu,
3736
require_vision,
3837
slow,
3938
torch_device,
@@ -467,7 +466,7 @@ def test_model_parallelism(self):
467466
pass
468467

469468
# TODO: ydshieh
470-
@require_torch_gpu
469+
@require_torch_accelerator
471470
@slow
472471
@unittest.skip(reason="_update_causal_mask is not implemented yet which fails this test")
473472
def test_sdpa_can_dispatch_on_flash(self):

tests/models/longcat_flash/test_modeling_longcat_flash.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
require_flash_attn,
2626
require_large_cpu_ram,
2727
require_torch,
28-
require_torch_gpu,
28+
require_torch_accelerator,
2929
slow,
3030
torch_device,
3131
)
@@ -285,7 +285,7 @@ def _prepare_config_headdim(config, requested_dim):
285285
return config
286286

287287
@require_flash_attn
288-
@require_torch_gpu
288+
@require_torch_accelerator
289289
@require_bitsandbytes
290290
@mark.flash_attn_test
291291
@slow

0 commit comments

Comments
 (0)