Skip to content

Commit a6b300c

Browse files
committed
add eagle3 gpt-oss test
Signed-off-by: Jhao-Ting Chen <[email protected]>
1 parent 21f785c commit a6b300c

File tree

9 files changed

+50
-25
lines changed

9 files changed

+50
-25
lines changed

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def run(
465465
self.spec_decoding_generation_lengths,
466466
self.spec_decoding_position_offsets, self.spec_decoding_packed_mask
467467
]
468-
if get_sm_version() >= 100:
468+
if get_sm_version() >= 100 and get_sm_version() != 120:
469469
spec_decoding_tensor_params.append(
470470
self.spec_decoding_bl_tree_mask_offset)
471471
spec_decoding_tensor_params.append(self.spec_decoding_bl_tree_mask)
@@ -1158,8 +1158,8 @@ def update_spec_dec_param(
11581158
spec_decoding_generation_lengths = None
11591159

11601160
self.is_spec_decoding_enabled = is_spec_decoding_enabled
1161-
if get_sm_version(
1162-
) >= 100 and not is_spec_dec_tree and not is_spec_dec_dynamic_tree:
1161+
if (get_sm_version() >= 100 and get_sm_version() != 120
1162+
) and not is_spec_dec_tree and not is_spec_dec_dynamic_tree:
11631163
self.is_spec_decoding_enabled = False
11641164

11651165
# use_spec_decoding is default to true by default, change in runtime by layers / requests
@@ -1190,7 +1190,7 @@ def update_spec_dec_param(
11901190
dtype=torch.int,
11911191
device='cuda',
11921192
)
1193-
if get_sm_version() >= 100:
1193+
if get_sm_version() >= 100 and get_sm_version() != 120:
11941194
self.spec_decoding_param_prepare_for_blackwell()
11951195
else:
11961196
self.spec_decoding_bl_tree_mask_offset = None

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3956,12 +3956,14 @@ def test_w4_chunked_prefill(self, kv_cache_dtype, moe_backend, mocker):
39563956
extra_evaluator_kwargs=extra_evaluator_kwargs)
39573957

39583958
@pytest.mark.skip_less_device(4)
3959+
@pytest.mark.parametrize("eagle3_one_model", [False, True],
3960+
ids=["two_model", "one_model"])
39593961
@pytest.mark.parametrize(
39603962
"moe_backend",
39613963
["CUTLASS",
39623964
pytest.param("TRTLLM", marks=skip_pre_blackwell), "TRITON"],
39633965
ids=["cutlass", "trtllm", "triton"])
3964-
def test_eagle3(self, moe_backend, mocker):
3966+
def test_eagle3(self, eagle3_one_model, moe_backend, mocker):
39653967
if moe_backend == "TRITON":
39663968
if not IS_TRITON_KERNELS_AVAILABLE:
39673969
pytest.skip("Triton kernels are not available")
@@ -3976,17 +3978,23 @@ def test_eagle3(self, moe_backend, mocker):
39763978
mocker.patch.object(GPQADiamond, "MAX_OUTPUT_LEN", MAX_OUTPUT_LEN)
39773979
mocker.patch.object(GPQADiamond, "MAX_INPUT_LEN", MAX_INPUT_LEN)
39783980

3979-
# https://nvbugs/5590408: 2-Model overlap scheduling has accuracy issue
3980-
pytorch_config = dict(disable_overlap_scheduler=True,
3981-
cuda_graph_config=CudaGraphConfig())
3981+
if eagle3_one_model:
3982+
pytorch_config = dict(disable_overlap_scheduler=False,
3983+
max_batch_size=1,
3984+
cuda_graph_config=CudaGraphConfig(
3985+
enable_padding=True, max_batch_size=1))
3986+
else:
3987+
# https://nvbugs/5590408: 2-Model overlap scheduling has accuracy issue
3988+
pytorch_config = dict(disable_overlap_scheduler=True,
3989+
cuda_graph_config=CudaGraphConfig())
39823990
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6,
39833991
dtype="auto")
39843992

39853993
eagle_model_dir = f"{llm_models_root()}/gpt_oss/gpt-oss-120b-Eagle3"
39863994
draft_len = 3
39873995
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
39883996
speculative_model_dir=eagle_model_dir,
3989-
eagle3_one_model=False)
3997+
eagle3_one_model=eagle3_one_model)
39903998

39913999
max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN
39924000
llm = LLM(self.MODEL_PATH,

tests/integration/test_lists/qa/llm_function_core.txt

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -564,9 +564,12 @@ accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[cutlass-au
564564
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[trtllm-auto]
565565
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[triton-auto]
566566
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[trtllm-fp8]
567-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm]
568-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass]
569-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[triton]
567+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm-two_model]
568+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass-two_model]
569+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[triton-two_model]
570+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm-one_model]
571+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass-one_model]
572+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[triton-one_model]
570573
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False]
571574
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True]
572575
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram

tests/integration/test_lists/qa/llm_function_core_sanity.txt

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,12 @@ accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[cutlass-au
101101
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[trtllm-auto]
102102
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[triton-auto]
103103
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[trtllm-fp8]
104-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm]
105-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass]
106-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[triton]
104+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm-two_model]
105+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass-two_model]
106+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[triton-two_model]
107+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm-one_model]
108+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass-one_model]
109+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[triton-one_model]
107110
accuracy/test_llm_api_pytorch.py::TestMistralSmall24B::test_auto_dtype
108111
accuracy/test_llm_api_pytorch.py::TestKanana_Instruct::test_auto_dtype
109112
accuracy/test_llm_api_pytorch.py::TestKimiK2::test_fp8_blockscale[latency]

tests/integration/test_lists/qa/llm_function_nim.txt

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -342,9 +342,12 @@ accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[cutlass-au
342342
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[trtllm-auto]
343343
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[triton-auto]
344344
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[trtllm-fp8]
345-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm]
346-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass]
347-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[triton]
345+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm-two_model]
346+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass-two_model]
347+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[triton-two_model]
348+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm-one_model]
349+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass-one_model]
350+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[triton-one_model]
348351
accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_fp8[throughput_latency]
349352
accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_fp8[latency]
350353
accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_cutlass]

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@ l0_dgx_b200:
5050
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-trtllm-fp8]
5151
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16[dp4-auto]
5252
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16[dp4-fp8]
53-
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm]
54-
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass]
53+
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm-two_model]
54+
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass-two_model]
55+
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm-one_model]
56+
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass-one_model]
5557
- disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[DeepSeek-V3-Lite-bf16]
5658
- disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-3.1-8b-instruct-hf-fp8]
5759
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_nixl[DeepSeek-V3-Lite-fp8]
@@ -196,6 +198,8 @@ l0_dgx_b200:
196198
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-trtllm-fp8]
197199
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-cutlass-auto]
198200
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-triton-auto]
199-
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm]
200-
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass]
201+
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm-two_model]
202+
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass-two_model]
203+
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm-one_model]
204+
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass-one_model]
201205
- disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-v3-8b-hf]

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,10 @@ l0_dgx_h100:
185185
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-cutlass-auto]
186186
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-triton-auto]
187187
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16[dp4-auto]
188-
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass]
189-
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[triton]
188+
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass-two_model]
189+
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[triton-two_model]
190+
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass-one_model]
191+
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[triton-one_model]
190192
- condition:
191193
ranges:
192194
system_gpu_count:

tests/integration/test_lists/test-db/l0_rtx_pro_6000.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,5 +109,7 @@ l0_rtx_pro_6000:
109109
# - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_cutlass-torch_compile=False] # failed
110110
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_cutlass-torch_compile=False]
111111
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_cutlass-torch_compile=True]
112+
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass-one_model]
113+
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[triton-one_model]
112114
- test_e2e.py::test_ptp_quickstart_multimodal_2gpu[phi4-multimodal-instruct-fp8-multimodals/Phi-4-multimodal-instruct-FP8]
113115
- test_e2e.py::test_ptp_quickstart_multimodal_2gpu[phi4-multimodal-instruct-fp4-multimodals/Phi-4-multimodal-instruct-FP4]

tests/integration/test_lists/waives.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_fp8[throughput_laten
346346
accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm_attention_dp] SKIP (https://nvbugs/5637220)
347347
llmapi/test_llm_examples.py::test_llmapi_example_multilora SKIP (https://nvbugs/5636857)
348348
unittest/_torch/modules/test_mla_helix.py::test_mla_helix_distributed SKIP (https://nvbugspro.nvidia.com/bug/5637012)
349-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass] SKIP (https://nvbugs/5636916)
349+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass-two_model] SKIP (https://nvbugs/5636916)
350350
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp4-attn_backend=TRTLLM-torch_compile=False] SKIP (https://nvbugs/5616182)
351351
examples/test_phi.py::test_llm_phi_quantization_1gpu[Phi-3-small-128k-instruct-fp8-bfloat16] SKIP (https://nvbugs/5465143)
352352
examples/test_multimodal.py::test_llm_multimodal_general[Mistral-Small-3.1-24B-Instruct-2503-pp:1-tp:1-bfloat16-bs:8-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5644684)

0 commit comments

Comments
 (0)