@@ -231,10 +231,13 @@ def test_fp8_llm_sampler(self):
231231 @skip_pre_hopper
232232 @parametrize_with_ids ("overlap_scheduler" , [True , False ])
233233 @parametrize_with_ids ("eagle3_one_model" , [True , False ])
234- def test_eagle3 (self , overlap_scheduler , eagle3_one_model ):
234+ @parametrize_with_ids ("sampler_async_worker" , [True , False ])
235+ def test_eagle3 (self , overlap_scheduler , eagle3_one_model ,
236+ sampler_async_worker ):
235237 pytorch_config = dict (
236238 max_batch_size =
237239 1 , # add max_batch_size to avoid error in overlap scheduler
240+ sampler_enable_async_worker = sampler_async_worker ,
238241 disable_overlap_scheduler = not overlap_scheduler ,
239242 cuda_graph_config = CudaGraphConfig (max_batch_size = 1 ,
240243 enable_padding = True ),
@@ -398,6 +401,7 @@ def test_auto_spec_decode(self):
398401 task = GSM8K (self .MODEL_NAME )
399402 task .evaluate (llm )
400403
404+ @parametrize_with_ids ("sampler_async_worker" , [True , False ])
401405 @parametrize_with_ids ("disable_overlap_scheduler" , [False , True ])
402406 @parametrize_with_ids (
403407 "enable_cuda_graph,enable_padding" ,
@@ -407,7 +411,8 @@ def test_auto_spec_decode(self):
407411 (True , True ), # CUDA Graph with padding
408412 ])
409413 def test_auto_dtype_beam_search (self , enable_cuda_graph , enable_padding ,
410- disable_overlap_scheduler ):
414+ disable_overlap_scheduler ,
415+ sampler_async_worker ):
411416 max_beam_width = 2
412417 sampling_params = SamplingParams (n = max_beam_width ,
413418 best_of = max_beam_width ,
@@ -432,6 +437,7 @@ def test_auto_dtype_beam_search(self, enable_cuda_graph, enable_padding,
432437 max_batch_size = max_beam_width ,
433438 max_seq_len = 2048 ,
434439 max_beam_width = max_beam_width ,
440+ sampler_enable_async_worker = sampler_async_worker ,
435441 disable_overlap_scheduler = disable_overlap_scheduler ,
436442 cuda_graph_config = cuda_graph_config ,
437443 ) as llm :
@@ -441,6 +447,7 @@ def test_auto_dtype_beam_search(self, enable_cuda_graph, enable_padding,
441447 extra_acc_spec = "beam_width=2" )
442448
443449 @skip_pre_hopper
450+ @parametrize_with_ids ("sampler_async_worker" , [True , False ])
444451 @parametrize_with_ids ("disable_overlap_scheduler" , [False , True ])
445452 @parametrize_with_ids (
446453 "enable_cuda_graph,enable_padding" ,
@@ -450,7 +457,7 @@ def test_auto_dtype_beam_search(self, enable_cuda_graph, enable_padding,
450457 (True , True ), # CUDA Graph with padding
451458 ])
452459 def test_fp8_beam_search (self , enable_cuda_graph , enable_padding ,
453- disable_overlap_scheduler ):
460+ disable_overlap_scheduler , sampler_async_worker ):
454461 model_path = f"{ llm_models_root ()} /llama-3.1-model/Llama-3.1-8B-Instruct-FP8"
455462 max_beam_width = 2
456463 sampling_params = SamplingParams (n = max_beam_width ,
@@ -476,6 +483,7 @@ def test_fp8_beam_search(self, enable_cuda_graph, enable_padding,
476483 max_seq_len = 2048 ,
477484 max_beam_width = max_beam_width ,
478485 disable_overlap_scheduler = disable_overlap_scheduler ,
486+ sampler_enable_async_worker = sampler_async_worker ,
479487 cuda_graph_config = cuda_graph_config ,
480488 )
481489
@@ -506,14 +514,17 @@ def test_fp8_prequantized(self):
506514
507515 @skip_pre_hopper
508516 @pytest .mark .skip_less_device (4 )
517+ @pytest .mark .parametrize ("sampler_async_worker" , [True , False ])
509518 @pytest .mark .parametrize ("disable_overlap_scheduler" , [True , False ])
510519 @pytest .mark .parametrize ("pp_size" , [2 , 4 ], ids = ["pp2" , "pp4" ])
511- def test_return_logits_pp (self , pp_size , disable_overlap_scheduler ):
520+ def test_return_logits_pp (self , pp_size , disable_overlap_scheduler ,
521+ sampler_async_worker ):
512522 prompts = ["A B C" ]
513523
514524 llm = LLM (model = self .MODEL_PATH ,
515525 pipeline_parallel_size = pp_size ,
516- disable_overlap_scheduler = disable_overlap_scheduler )
526+ disable_overlap_scheduler = disable_overlap_scheduler ,
527+ sampler_enable_async_worker = sampler_async_worker )
517528
518529 sampling_params = SamplingParams (max_tokens = 8 ,
519530 return_context_logits = True ,
@@ -1470,6 +1481,7 @@ def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn,
14701481 @pytest .mark .skip_less_device (4 )
14711482 @skip_pre_hopper
14721483 @skip_ray
1484+ @parametrize_with_ids ("sampler_async_worker" , [True , False ])
14731485 @parametrize_with_ids ("torch_compile" , [False , True ])
14741486 @parametrize_with_ids ("fp8kv,attention_dp,cuda_graph,overlap_scheduler" ,
14751487 [(False , False , False , False ),
@@ -1485,7 +1497,8 @@ def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn,
14851497 ids = ["tp4" , "ep4" , "tp2pp2" , "pp4" ])
14861498 def test_fp8_block_scales_4gpus (self , tp_size , pp_size , ep_size , mtp_nextn ,
14871499 fp8kv , attention_dp , cuda_graph ,
1488- overlap_scheduler , torch_compile ):
1500+ overlap_scheduler , torch_compile ,
1501+ sampler_async_worker ):
14891502 if torch_compile and pp_size > 1 :
14901503 pytest .skip ("PP with torch.compile is not supported yet." )
14911504 kv_cache_config = KvCacheConfig (free_gpu_memory_fraction = 0.75 )
@@ -1500,6 +1513,7 @@ def test_fp8_block_scales_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
15001513 torch_compile_config = torch_compile_config ,
15011514 moe_config = MoeConfig (
15021515 backend = "DEEPGEMM" if get_sm_version () >= 100 else "CUTLASS" ),
1516+ sampler_enable_async_worker = sampler_async_worker ,
15031517 )
15041518
15051519 if fp8kv :
0 commit comments