Skip to content

Commit 18a65c0

Browse files
committed
add eagle3 gpt-oss test
Signed-off-by: Jhao-Ting Chen <[email protected]>
1 parent 4a34055 commit 18a65c0

File tree

6 files changed

+64
-3
lines changed

6 files changed

+64
-3
lines changed

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,10 +1071,10 @@ def update_spec_dec_param(
10711071
spec_decoding_packed_mask = None
10721072
spec_decoding_generation_lengths = None
10731073
# spec_dec mode should only be enabled for pre-Blackwell machines and when there's a spec-dec tree.
1074-
self.is_spec_decoding_enabled = is_spec_decoding_enabled and get_sm_version(
1075-
) < 100
1074+
self.is_spec_decoding_enabled = is_spec_decoding_enabled and (
1075+
get_sm_version() < 100 or get_sm_version() == 120)
10761076

1077-
if get_sm_version() >= 100:
1077+
if get_sm_version() >= 100 and get_sm_version() != 120:
10781078
if is_spec_dec_tree or is_spec_dec_dynamic_tree:
10791079
assert not is_spec_dec_tree, "Spec-dec tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec tree."
10801080
assert not is_spec_dec_dynamic_tree, "Spec-dec dynamic tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec dynamic tree."

tests/integration/defs/accuracy/references/gsm8k.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,11 @@ GPT-OSS/BF16:
212212
- accuracy: 90.3
213213
- kv_cache_quant_algo: FP8
214214
accuracy: 90.3
215+
- quant_algo: W4A16_MXFP4
216+
accuracy: 90.3
217+
- quant_algo: W4A16_MXFP4
218+
spec_dec_algo: Eagle
219+
accuracy: 90.3
215220
GPT-OSS/MXFP4:
216221
- accuracy: 90.3
217222
- quant_algo: W4A8_MXFP4_MXFP8

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3595,6 +3595,59 @@ def test_w4a16(self, kv_cache_dtype, tp_size, pp_size, ep_size,
35953595
task.evaluate(llm,
35963596
extra_evaluator_kwargs=self.extra_evaluator_kwargs)
35973597

3598+
@pytest.mark.skip_less_device(4)
3599+
@pytest.mark.parametrize("kv_cache_dtype", ["auto"])
3600+
@pytest.mark.parametrize(
3601+
"tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler", [
3602+
(4, 1, 4, False, True, True),
3603+
],
3604+
ids=["tep4"])
3605+
@pytest.mark.parametrize(
3606+
"moe_backend",
3607+
["triton", "cutlass",
3608+
pytest.param("trtllm", marks=skip_pre_blackwell)])
3609+
def test_w4a16_eagle3(self, kv_cache_dtype, tp_size, pp_size, ep_size,
3610+
attention_dp, cuda_graph, overlap_scheduler,
3611+
moe_backend, monkeypatch, mocker):
3612+
mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN", 8192)
3613+
mocker.patch.dict(GSM8K.EVALUATE_KWARGS,
3614+
{"scores_filter": "exact_match,flexible-extract"})
3615+
if moe_backend == "TRITON" and not IS_TRITON_KERNELS_AVAILABLE:
3616+
pytest.skip("Triton kernels are not available")
3617+
monkeypatch.setenv("OVERRIDE_QUANT_ALGO", "W4A16_MXFP4")
3618+
3619+
cuda_graph_config = CudaGraphConfig(enable_padding=True,
3620+
max_batch_size=8)
3621+
3622+
pytorch_config = dict(
3623+
max_batch_size=8,
3624+
disable_overlap_scheduler=not overlap_scheduler,
3625+
cuda_graph_config=CudaGraphConfig() if cuda_graph else None)
3626+
3627+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.5,
3628+
dtype=kv_cache_dtype)
3629+
spec_config = EagleDecodingConfig(
3630+
max_draft_len=3,
3631+
speculative_model_dir=
3632+
f"{llm_models_root()}/gpt_oss/gpt-oss-120b-Eagle3/",
3633+
eagle3_one_model=True)
3634+
3635+
llm = LLM(self.MODEL_PATH,
3636+
tensor_parallel_size=tp_size,
3637+
pipeline_parallel_size=pp_size,
3638+
moe_expert_parallel_size=ep_size,
3639+
kv_cache_config=kv_cache_config,
3640+
**pytorch_config,
3641+
enable_attention_dp=attention_dp,
3642+
moe_config=MoeConfig(backend=moe_backend),
3643+
speculative_config=spec_config)
3644+
3645+
with llm:
3646+
model_name = "GPT-OSS/BF16"
3647+
task = GSM8K(model_name)
3648+
task.evaluate(llm,
3649+
extra_evaluator_kwargs=self.extra_evaluator_kwargs)
3650+
35983651
@pytest.mark.skip_less_device(2)
35993652
@pytest.mark.parametrize(
36003653
"kv_cache_dtype",

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ l0_dgx_b200:
4949
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-trtllm-fp8]
5050
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16[dp4-auto]
5151
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16[dp4-fp8]
52+
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16_eagle3[trtllm-tep4-auto]
5253
- disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[DeepSeek-V3-Lite-bf16]
5354
- disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-3.1-8b-instruct-hf-fp8]
5455
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_nixl[DeepSeek-V3-Lite-fp8]

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ l0_dgx_h100:
174174
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-cutlass-auto]
175175
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-triton-auto]
176176
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16[dp4-auto]
177+
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16_eagle3[triton-tep4-auto]
177178
- condition:
178179
ranges:
179180
system_gpu_count:

tests/integration/test_lists/test-db/l0_rtx_pro_6000.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,5 +107,6 @@ l0_rtx_pro_6000:
107107
# - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_cutlass-torch_compile=False] # failed
108108
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_cutlass-torch_compile=False]
109109
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_cutlass-torch_compile=True]
110+
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16_eagle3[cutlass-tep4-auto]
110111
- test_e2e.py::test_ptp_quickstart_multimodal_2gpu[phi4-multimodal-instruct-fp8-multimodals/Phi-4-multimodal-instruct-FP8]
111112
- test_e2e.py::test_ptp_quickstart_multimodal_2gpu[phi4-multimodal-instruct-fp4-multimodals/Phi-4-multimodal-instruct-FP4]

0 commit comments

Comments
 (0)