33import random
44import time
55from collections import deque
6- from dataclasses import dataclass
6+ from dataclasses import dataclass , field
77from typing import Dict , List , Optional , Tuple
88
99import pytest
2727except Exception :
2828 HAVE_ZMQ = False
2929
30+ IS_ZMQ_FLAKY = True
31+
3032
3133class DummyContext :
3234 """Dummy inference context."""
@@ -77,6 +79,9 @@ async def async_step(
7779 to_remove = []
7880 for request_id , request in self .requests .items ():
7981 if request .status == Status .ACTIVE_AND_GENERATING_TOKENS :
82+ request .sampling_params .num_tokens_to_generate -= 1
83+ if request .sampling_params .num_tokens_to_generate > 0 :
84+ continue
8085 request .status = Status .COMPLETED
8186 self .context .active_cnt -= 1
8287 finished_requests .append (request )
@@ -107,10 +112,15 @@ class CoordinatorTestConfig:
107112 """Test configuration args."""
108113
109114 port : int = 46581
115+ launch_inference_coordinator : bool = True
116+ stop_engines : bool = True
117+ verify_results : bool = True
110118
111119 num_requests : int = 10 ** 1
112120 min_time_offset : float = 10 ** (- 4 )
113121 max_time_offset : float = 10 ** (- 3 )
122+ num_steps_to_finish : int = 1
123+ num_iterations : int = 1
114124
115125 tensor_model_parallel_size : int = 1
116126 pipeline_model_parallel_size : int = 1
@@ -123,6 +133,15 @@ class CoordinatorTestEnv:
123133 config : CoordinatorTestConfig
124134 requests : List [Tuple ]
125135 engine : DummyEngine
136+ responses : List [List [DynamicInferenceRequest ]] = field (default_factory = list )
137+ timing_data : Dict [str , Optional [float ]] = field (
138+ default_factory = lambda : {
139+ "start_time" : None ,
140+ "init_time" : None ,
141+ "done_time" : None ,
142+ "stop_time" : None ,
143+ }
144+ )
126145
127146
128147class TestCoordinator :
@@ -133,7 +152,10 @@ def _build_requests(cls, test_config: CoordinatorTestConfig) -> List[Tuple]:
133152
134153 for _ in range (test_config .num_requests ):
135154 arrival_delta = random .uniform (test_config .min_time_offset , test_config .max_time_offset )
136- ret .append (("Hello world!" , SamplingParams (), arrival_delta ))
155+ num_tokens = test_config .num_steps_to_finish
156+ ret .append (
157+ ("Hello world!" , SamplingParams (num_tokens_to_generate = num_tokens ), arrival_delta )
158+ )
137159 return ret
138160
139161 @classmethod
@@ -144,6 +166,7 @@ def _build_test_env(cls, test_config):
144166 )
145167 requests = cls ._build_requests (test_config )
146168 engine = DummyEngine ()
169+ engine .num_steps_to_finish = test_config .num_steps_to_finish
147170 return CoordinatorTestEnv (config = test_config , requests = requests , engine = engine )
148171
149172 @classmethod
@@ -152,67 +175,152 @@ async def _run_test(cls, **test_config_kwargs):
152175 test_config = CoordinatorTestConfig (** test_config_kwargs )
153176 env = cls ._build_test_env (test_config )
154177
178+ # Connect each engine to their respective processes.
179+ env .timing_data ["start_time" ] = time .time ()
155180 await env .engine .start_listening_to_data_parallel_coordinator (
156- inference_coordinator_port = test_config .port , launch_inference_coordinator = True
181+ inference_coordinator_port = test_config .port ,
182+ launch_inference_coordinator = test_config .launch_inference_coordinator ,
157183 )
158184
185+ results_success = False
186+ shutdown_success = False
187+ try :
188+ if dist .get_rank () == 0 :
189+ client = InferenceClient (test_config .port )
190+ await client .start ()
191+ env .timing_data ["init_time" ] = time .time ()
192+
193+ all_results = []
194+ for _ in range (test_config .num_iterations ):
195+ futures = []
196+ for request in tqdm (env .requests , "add_requests" ):
197+ prompt , sampling_params , arrival_delta = request
198+ await asyncio .sleep (arrival_delta )
199+ fut = client .add_request (prompt = prompt , sampling_params = sampling_params )
200+ futures .append (fut )
201+ results : List [DynamicInferenceRequestRecord ] = await asyncio .gather (* futures )
202+ all_results .append (results )
203+ env .timing_data ["done_time" ] = time .time ()
204+ results_success = True
205+ finally :
206+ try :
207+ if dist .get_rank () == 0 :
208+ if test_config .stop_engines :
209+ client .stop_engines ()
210+ client .stop ()
211+ if test_config .stop_engines :
212+ await env .engine .engine_loop_task
213+ shutdown_success = True
214+ except :
215+ env .engine .engine_loop_task .cancel ()
216+
217+ env .timing_data ["stop_time" ] = time .time ()
218+
219+ assert results_success , "Did not receive all results successfully."
220+ assert shutdown_success , "Did not shutdown successfully."
159221 if dist .get_rank () == 0 :
160- client = InferenceClient (test_config .port )
161- await client .start ()
162- futures = []
163- for request in tqdm (env .requests , "add_requests" ):
164- prompt , sampling_params , arrival_delta = request
165- await asyncio .sleep (arrival_delta )
166- fut = client .add_request (prompt = prompt , sampling_params = sampling_params )
167- futures .append (fut )
168- results : List [DynamicInferenceRequest ] = await asyncio .gather (* futures )
169-
170- client .stop_engines ()
171- client .stop ()
172-
173- await env .engine .engine_loop_task
222+ env .responses = all_results
223+ if test_config .verify_results :
224+ for batch in all_results :
225+ for request in batch :
226+ assert request .status == Status .COMPLETED
174227
175228 return env
176229
177230 def teardown_method (self , method ):
178231 Utils .destroy_model_parallel ()
179232
180233 @pytest .mark .internal
234+ @pytest .mark .skipif (IS_ZMQ_FLAKY , reason = "pyzmq is flaky in CI" )
181235 @pytest .mark .skipif (not HAVE_ZMQ , reason = "pyzmq is required for this test" )
182236 @pytest .mark .asyncio
183237 async def test_simple (self ):
184238 """Simple test with no TP or PP."""
185239 env = await self ._run_test (tensor_model_parallel_size = 1 , pipeline_model_parallel_size = 1 )
186240
187241 @pytest .mark .internal
242+ @pytest .mark .skipif (IS_ZMQ_FLAKY , reason = "pyzmq is flaky in CI" )
188243 @pytest .mark .skipif (not HAVE_ZMQ , reason = "pyzmq is required for this test" )
189244 @pytest .mark .asyncio
190245 async def test_tp (self ):
191- """Simple test with no TP or PP."""
246+ """Simple test with TP, but no PP."""
192247 env = await self ._run_test (tensor_model_parallel_size = 2 , pipeline_model_parallel_size = 1 )
193248
194249 @pytest .mark .internal
250+ @pytest .mark .skipif (IS_ZMQ_FLAKY , reason = "pyzmq is flaky in CI" )
251+ @pytest .mark .skipif (not HAVE_ZMQ , reason = "pyzmq is required for this test" )
252+ @pytest .mark .asyncio
253+ async def test_pp (self ):
254+ """Simple test with no TP, but PP."""
255+ env = await self ._run_test (tensor_model_parallel_size = 1 , pipeline_model_parallel_size = 2 )
256+
257+ @pytest .mark .internal
258+ @pytest .mark .skipif (IS_ZMQ_FLAKY , reason = "pyzmq is flaky in CI" )
259+ @pytest .mark .skipif (not HAVE_ZMQ , reason = "pyzmq is required for this test" )
260+ @pytest .mark .asyncio
261+ async def test_tp_pp (self ):
262+ """Simple test with both TP and PP."""
263+ env = await self ._run_test (tensor_model_parallel_size = 2 , pipeline_model_parallel_size = 2 )
264+
265+ @pytest .mark .internal
266+ @pytest .mark .skipif (IS_ZMQ_FLAKY , reason = "pyzmq is flaky in CI" )
195267 @pytest .mark .skipif (not HAVE_ZMQ , reason = "pyzmq is required for this test" )
196268 @pytest .mark .asyncio
197269 async def test_throughput (self ):
198270 """Throughput test with no TP or PP."""
199- start = time .time ()
200271 env = await self ._run_test (
201272 tensor_model_parallel_size = 1 ,
202273 pipeline_model_parallel_size = 1 ,
203- num_requests = 10 ** 3 ,
274+ num_requests = 10 ** 4 ,
275+ num_iterations = 10 ,
204276 min_time_offset = 0.0 ,
205277 max_time_offset = 0.0 ,
206278 )
207- end = time .time ()
208279 if dist .get_rank () == 0 :
209- print (f"Throughput test time: { end - start } seconds." )
280+ init_duration = (env .timing_data ["init_time" ] - env .timing_data ["start_time" ]) * 10 ** 3
281+ golden_init_duration = 4445.64 # ms
282+ run_duration = (env .timing_data ["done_time" ] - env .timing_data ["init_time" ]) * 10 ** 3
283+ golden_run_duration = 2906.29 # ms
284+ stop_duration = (env .timing_data ["stop_time" ] - env .timing_data ["done_time" ]) * 10 ** 3
285+ golden_stop_duration = 10.77 # ms
286+
287+ # Print current results.
288+ print (f"Initialization time: { init_duration :.2f} ms" )
289+ print (f"Run time: { run_duration :.2f} ms" )
290+ print (f"Stop time: { stop_duration :.2f} ms" )
291+
292+ # Check against golden values.
293+ def clamp_to_golden_value (value , golden_value , delta = 0.1 ):
294+ return value > golden_value * (1 - delta ) and value < golden_value * (1 + delta )
295+
296+ assert clamp_to_golden_value (init_duration , golden_init_duration , delta = 0.5 ), (
297+ f"WARNING: Init duration { init_duration :.2f} s deviates from "
298+ f"golden value { golden_init_duration :.2f} s"
299+ )
300+ assert clamp_to_golden_value (run_duration , golden_run_duration , delta = 0.2 ), (
301+ f"WARNING: Run duration { run_duration :.2f} s deviates from "
302+ f"golden value { golden_run_duration :.2f} s"
303+ )
304+ assert clamp_to_golden_value (stop_duration , golden_stop_duration , delta = 1.0 ), (
305+ f"WARNING: Stop duration { stop_duration :.2f} s deviates from "
306+ f"golden value { golden_stop_duration :.2f} s"
307+ )
308+
309+ # Print summary.
310+ print (
311+ f"ZMQ throughput is approximately "
312+ f"{ env .config .num_requests * env .config .num_iterations / (run_duration ):.2f} "
313+ f"requests/ms"
314+ )
210315
211316
212317if __name__ == "__main__" :
213318 test = TestCoordinator ()
214- test .test_simple ()
215- test .test_tp ()
319+ asyncio .run (test .test_simple ())
320+ asyncio .run (test .test_tp ())
321+ asyncio .run (test .test_pp ())
322+ asyncio .run (test .test_tp_pp ())
323+ asyncio .run (test .test_throughput ())
216324 test .teardown_method (None )
217325 print ("~~~" )
218326 print ("success." )
0 commit comments