@@ -3956,12 +3956,14 @@ def test_w4_chunked_prefill(self, kv_cache_dtype, moe_backend, mocker):
39563956 extra_evaluator_kwargs = extra_evaluator_kwargs )
39573957
39583958 @pytest .mark .skip_less_device (4 )
3959+ @pytest .mark .parametrize ("eagle3_one_model" , [False , True ],
3960+ ids = ["two_model" , "one_model" ])
39593961 @pytest .mark .parametrize (
39603962 "moe_backend" ,
39613963 ["CUTLASS" ,
39623964 pytest .param ("TRTLLM" , marks = skip_pre_blackwell ), "TRITON" ],
39633965 ids = ["cutlass" , "trtllm" , "triton" ])
3964- def test_eagle3 (self , moe_backend , mocker ):
3966+ def test_eagle3 (self , eagle3_one_model , moe_backend , mocker ):
39653967 if moe_backend == "TRITON" :
39663968 if not IS_TRITON_KERNELS_AVAILABLE :
39673969 pytest .skip ("Triton kernels are not available" )
@@ -3976,17 +3978,23 @@ def test_eagle3(self, moe_backend, mocker):
39763978 mocker .patch .object (GPQADiamond , "MAX_OUTPUT_LEN" , MAX_OUTPUT_LEN )
39773979 mocker .patch .object (GPQADiamond , "MAX_INPUT_LEN" , MAX_INPUT_LEN )
39783980
3979- # https://nvbugs/5590408: 2-Model overlap scheduling has accuracy issue
3980- pytorch_config = dict (disable_overlap_scheduler = True ,
3981- cuda_graph_config = CudaGraphConfig ())
3981+ if eagle3_one_model :
3982+ pytorch_config = dict (disable_overlap_scheduler = False ,
3983+ max_batch_size = 1 ,
3984+ cuda_graph_config = CudaGraphConfig (
3985+ enable_padding = True , max_batch_size = 1 ))
3986+ else :
3987+ # https://nvbugs/5590408: 2-Model overlap scheduling has accuracy issue
3988+ pytorch_config = dict (disable_overlap_scheduler = True ,
3989+ cuda_graph_config = CudaGraphConfig ())
39823990 kv_cache_config = KvCacheConfig (free_gpu_memory_fraction = 0.6 ,
39833991 dtype = "auto" )
39843992
39853993 eagle_model_dir = f"{ llm_models_root ()} /gpt_oss/gpt-oss-120b-Eagle3"
39863994 draft_len = 3
39873995 spec_config = EagleDecodingConfig (max_draft_len = draft_len ,
39883996 speculative_model_dir = eagle_model_dir ,
3989- eagle3_one_model = False )
3997+ eagle3_one_model = eagle3_one_model )
39903998
39913999 max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN
39924000 llm = LLM (self .MODEL_PATH ,
0 commit comments