2020from tensorrt_llm .llmapi import CompletionOutput , RequestOutput , SamplingParams
2121from tensorrt_llm .llmapi .llm_args import LlmArgs
2222
23- from ..conftest import llm_models_root , parametrize_with_ids , skip_pre_hopper
23+ from ..conftest import (get_device_count , llm_models_root , parametrize_with_ids ,
24+ skip_pre_hopper )
2425from ..trt_test_alternative import popen
25- from .accuracy_core import GSM8K , MMLU , LlmapiAccuracyTestHarness
26+ from .accuracy_core import (GSM8K , MMLU , LlmapiAccuracyTestHarness ,
27+ get_accuracy_task )
2628
2729
2830class Result (GenerationResultBase ):
@@ -71,6 +73,12 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
7173 temp_dir = tempfile .TemporaryDirectory ()
7274 disaggregated_serving_config_path = os .path .join (
7375 temp_dir .name , "disaggregated_serving_config.yaml" )
76+
77+ if tensor_parallel_size > 1 :
78+ print (
79+ f"Using unified tp parameter for testing is not recommended. Please use server configs instead."
80+ )
81+
7482 with open (disaggregated_serving_config_path , "w" ) as f :
7583 yaml .dump (disaggregated_server_config , f )
7684 ctx_server_config_path = os .path .join (temp_dir .name ,
@@ -88,21 +96,47 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
8896 trtllm_serve_path = "trtllm-serve"
8997 # Common arguments for both servers
9098 common_args = [
91- trtllm_serve_path , model_name , "--host" , "localhost" , "--backend" ,
92- "pytorch"
99+ trtllm_serve_path ,
100+ model_name ,
101+ "--host" ,
102+ "localhost" ,
103+ "--backend" ,
104+ "pytorch" ,
93105 ]
94- if tensor_parallel_size > 1 :
95- common_args .append (f"--tp_size={ tensor_parallel_size } " )
106+ gen_tp , gen_pp = gen_server_config .get (
107+ "tensor_parallel_size" ,
108+ tensor_parallel_size ), gen_server_config .get ("pipeline_parallel_size" ,
109+ 1 )
110+ ctx_tp , ctx_pp = ctx_server_config .get (
111+ "tensor_parallel_size" ,
112+ tensor_parallel_size ), ctx_server_config .get ("pipeline_parallel_size" ,
113+ 1 )
114+
115+ ctx_total_gpus = ctx_tp * ctx_pp
116+ gen_total_gpus = gen_tp * gen_pp
96117
97118 env_ctx = os .environ .copy ()
98119 env_ctx ["TRTLLM_USE_UCX_KVCACHE" ] = "1"
99- env_ctx ["CUDA_VISIBLE_DEVICES" ] = "," .join (
100- map (str , range (tensor_parallel_size )))
120+ env_ctx ["CUDA_VISIBLE_DEVICES" ] = "," .join (map (str , range (ctx_total_gpus )))
101121
102122 env_gen = os .environ .copy ()
103123 env_gen ["TRTLLM_USE_UCX_KVCACHE" ] = "1"
104124 env_gen ["CUDA_VISIBLE_DEVICES" ] = "," .join (
105- map (str , range (tensor_parallel_size , 2 * tensor_parallel_size )))
125+ map (str , range (ctx_total_gpus , ctx_total_gpus + gen_total_gpus )))
126+ ctx_server_args = common_args + [
127+ "--port" , "8001" , "--extra_llm_api_options" , ctx_server_config_path ,
128+ f"--tp_size={ ctx_tp } " , f"--pp_size={ ctx_pp } "
129+ ]
130+ gen_server_args = common_args + [
131+ "--port" , "8002" , "--extra_llm_api_options" , gen_server_config_path ,
132+ f"--tp_size={ gen_tp } " , f"--pp_size={ gen_pp } "
133+ ]
134+ if "max_num_tokens" in ctx_server_config :
135+ ctx_server_args .append (
136+ f"--max_num_tokens={ ctx_server_config ['max_num_tokens' ]} " )
137+ if "max_num_tokens" in gen_server_config :
138+ gen_server_args .append (
139+ f"--max_num_tokens={ gen_server_config ['max_num_tokens' ]} " )
106140
107141 with (MyThreadPoolExecutor (max_workers = 16 ) as thread_pool , temp_dir ,
108142 popen (common_args + [
@@ -177,6 +211,56 @@ def generate_async(prompt: str,
177211 disaggregated_server .wait ()
178212
179213
214+ def run_parallel_test (model_name : str , model_path : str , ctx_pp : int ,
215+ ctx_tp : int , gen_pp : int , gen_tp : int ,
216+ test_set : LlmapiAccuracyTestHarness ):
217+ if ctx_tp * ctx_pp + gen_tp * gen_pp > get_device_count ():
218+ pytest .fail (
219+ f"Not enough devices for ctx_pp={ ctx_pp } +ctx_tp={ ctx_tp } and gen_pp={ gen_pp } +gen_tp={ gen_tp } test"
220+ )
221+
222+ kv_cache_config = {
223+ "free_gpu_memory_fraction" : 0.5 ,
224+ "enable_block_reuse" : False
225+ }
226+ ctx_server_config = {
227+ "pipeline_parallel_size" : ctx_pp ,
228+ "tensor_parallel_size" : ctx_tp ,
229+ "disable_overlap_scheduler" : True ,
230+ "kv_cache_config" : kv_cache_config ,
231+ "cache_transceiver_config" : {
232+ "backend" : "default"
233+ }
234+ }
235+ gen_server_config = {
236+ "tensor_parallel_size" : gen_tp ,
237+ "pipeline_parallel_size" : gen_pp ,
238+ "disable_overlap_scheduler" : True ,
239+ "kv_cache_config" : kv_cache_config ,
240+ "cache_transceiver_config" : {
241+ "backend" : "default"
242+ }
243+ }
244+ disaggregated_server_config = {
245+ "hostname" : "localhost" ,
246+ "port" : 8000 ,
247+ "backend" : "pytorch" ,
248+ "context_servers" : {
249+ "num_instances" : 1 ,
250+ "urls" : ["localhost:8001" ]
251+ },
252+ "generation_servers" : {
253+ "num_instances" : 1 ,
254+ "urls" : ["localhost:8002" ]
255+ }
256+ }
257+ with launch_disaggregated_llm (disaggregated_server_config ,
258+ ctx_server_config , gen_server_config ,
259+ model_path ) as llm :
260+ task = test_set (model_name )
261+ task .evaluate (llm )
262+
263+
180264@pytest .mark .timeout (3600 )
181265class TestLlama3_1_8BInstruct (LlmapiAccuracyTestHarness ):
182266 MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
@@ -252,6 +336,20 @@ def test_ngram(self):
252336 task = GSM8K (self .MODEL_NAME )
253337 task .evaluate (llm )
254338
339+ @pytest .mark .parametrize ("tp,pp" , [(1 , 2 ), (2 , 1 ), (2 , 2 )],
340+ ids = ["tp1pp2" , "tp2pp1" , "tp2pp2" ])
341+ @pytest .mark .parametrize ("testset" , ["GSM8K" , "MMLU" ])
342+ def test_tp_pp_symmetric (self , tp , pp , testset ):
343+ return run_parallel_test (self .MODEL_NAME , self .MODEL_PATH , pp , tp , pp ,
344+ tp , get_accuracy_task (testset ))
345+
346+ @parametrize_with_ids ("ctx_pp" , [2 , 4 ])
347+ @parametrize_with_ids ("gen_tp" , [1 , 2 ])
348+ @pytest .mark .parametrize ("testset" , ["GSM8K" , "MMLU" ])
349+ def test_ctx_pp_gen_tp_asymmetric (self , ctx_pp , gen_tp , testset ):
350+ return run_parallel_test (self .MODEL_NAME , self .MODEL_PATH , ctx_pp , 1 , 1 ,
351+ gen_tp , get_accuracy_task (testset ))
352+
255353
256354@pytest .mark .timeout (3600 )
257355@pytest .mark .skip_less_device_memory (140000 )
0 commit comments