@@ -85,6 +85,9 @@ async def async_step(self, *, verbose: Optional[bool] = False) -> Dict:
8585 to_remove = []
8686 for request_id , record in self .request_records .items ():
8787 if record [- 1 ].status == Status .ACTIVE_AND_GENERATING_TOKENS :
88+ record [- 1 ].sampling_params .num_tokens_to_generate -= 1
89+ if record [- 1 ].sampling_params .num_tokens_to_generate > 0 :
90+ continue
8891 record [- 1 ].status = Status .COMPLETED
8992 self .context .active_cnt -= 1
9093 finished_request_records .append (record )
@@ -122,6 +125,7 @@ class CoordinatorTestConfig:
122125 num_requests : int = 10 ** 1
123126 min_time_offset : float = 10 ** (- 4 )
124127 max_time_offset : float = 10 ** (- 3 )
128+ num_steps_to_finish : int = 1
125129 num_iterations : int = 1
126130
127131 tensor_model_parallel_size : int = 1
@@ -154,7 +158,10 @@ def _build_requests(cls, test_config: CoordinatorTestConfig) -> List[Tuple]:
154158
155159 for _ in range (test_config .num_requests ):
156160 arrival_delta = random .uniform (test_config .min_time_offset , test_config .max_time_offset )
157- ret .append (("Hello world!" , SamplingParams (), arrival_delta ))
161+ num_tokens = test_config .num_steps_to_finish
162+ ret .append (
163+ ("Hello world!" , SamplingParams (num_tokens_to_generate = num_tokens ), arrival_delta )
164+ )
158165 return ret
159166
160167 @classmethod
@@ -165,6 +172,7 @@ def _build_test_env(cls, test_config):
165172 )
166173 requests = cls ._build_requests (test_config )
167174 engine = DummyEngine ()
175+ engine .num_steps_to_finish = test_config .num_steps_to_finish
168176 return CoordinatorTestEnv (config = test_config , requests = requests , engine = engine )
169177
170178 @classmethod
@@ -180,37 +188,48 @@ async def _run_test(cls, **test_config_kwargs):
180188 launch_inference_coordinator = test_config .launch_inference_coordinator ,
181189 )
182190
183- if dist .get_rank () == 0 :
184- client = InferenceClient (test_config .port )
185- await client .start ()
186- env .timing_data ["init_time" ] = time .time ()
187-
188- all_results = []
189- for _ in range (test_config .num_iterations ):
190- futures = []
191- for request in tqdm (env .requests , "add_requests" ):
192- prompt , sampling_params , arrival_delta = request
193- await asyncio .sleep (arrival_delta )
194- fut = client .add_request (prompt = prompt , sampling_params = sampling_params )
195- futures .append (fut )
196- results : List [DynamicInferenceRequestRecord ] = await asyncio .gather (* futures )
197- all_results .append (results )
198- env .timing_data ["done_time" ] = time .time ()
199-
200- if test_config .stop_engines :
201- client .stop_engines ()
202- client .stop ()
203-
204- if test_config .stop_engines :
205- await env .engine .engine_loop_task
191+ results_success = False
192+ shutdown_success = False
193+ try :
194+ if dist .get_rank () == 0 :
195+ client = InferenceClient (test_config .port )
196+ await client .start ()
197+ env .timing_data ["init_time" ] = time .time ()
198+
199+ all_results = []
200+ for _ in range (test_config .num_iterations ):
201+ futures = []
202+ for request in tqdm (env .requests , "add_requests" ):
203+ prompt , sampling_params , arrival_delta = request
204+ await asyncio .sleep (arrival_delta )
205+ fut = client .add_request (prompt = prompt , sampling_params = sampling_params )
206+ futures .append (fut )
207+ results : List [DynamicInferenceRequestRecord ] = await asyncio .gather (* futures )
208+ all_results .append (results )
209+ env .timing_data ["done_time" ] = time .time ()
210+ results_success = True
211+ finally :
212+ try :
213+ if dist .get_rank () == 0 :
214+ if test_config .stop_engines :
215+ client .stop_engines ()
216+ client .stop ()
217+ if test_config .stop_engines :
218+ await env .engine .engine_loop_task
219+ shutdown_success = True
220+ except :
221+ env .engine .engine_loop_task .cancel ()
222+
206223 env .timing_data ["stop_time" ] = time .time ()
207224
225+ assert results_success , "Did not receive all results successfully."
226+ assert shutdown_success , "Did not shutdown successfully."
208227 if dist .get_rank () == 0 :
209228 env .responses = all_results
210229 if test_config .verify_results :
211230 for batch in all_results :
212- for result in batch :
213- assert result .status == Status .COMPLETED
231+ for record in batch :
232+ assert record [ - 1 ] .status == Status .COMPLETED
214233
215234 return env
216235
@@ -267,9 +286,9 @@ async def test_throughput(self):
267286 init_duration = (env .timing_data ["init_time" ] - env .timing_data ["start_time" ]) * 10 ** 3
268287 golden_init_duration = 4445.64 # ms
269288 run_duration = (env .timing_data ["done_time" ] - env .timing_data ["init_time" ]) * 10 ** 3
270- golden_run_duration = 3088.87 # ms
289+ golden_run_duration = 2906.29 # ms
271290 stop_duration = (env .timing_data ["stop_time" ] - env .timing_data ["done_time" ]) * 10 ** 3
272- golden_stop_duration = 129.57 # ms
291+ golden_stop_duration = 10.77 # ms
273292
274293 # Print current results.
275294 print (f"Initialization time: { init_duration :.2f} ms" )
@@ -288,7 +307,7 @@ def clamp_to_golden_value(value, golden_value, delta=0.1):
288307 f"WARNING: Run duration { run_duration :.2f} s deviates from "
289308 f"golden value { golden_run_duration :.2f} s"
290309 )
291- assert clamp_to_golden_value (stop_duration , golden_stop_duration , delta = 0.3 ), (
310+ assert clamp_to_golden_value (stop_duration , golden_stop_duration , delta = 1.0 ), (
292311 f"WARNING: Stop duration { stop_duration :.2f} s deviates from "
293312 f"golden value { golden_stop_duration :.2f} s"
294313 )
@@ -304,10 +323,10 @@ def clamp_to_golden_value(value, golden_value, delta=0.1):
304323if __name__ == "__main__" :
305324 test = TestCoordinator ()
306325 asyncio .run (test .test_simple ())
307- test .test_tp ()
308- test .test_pp ()
309- test_test . tp_pp ( )
310- test_test . throughput ( )
326+ asyncio . run ( test .test_tp () )
327+ asyncio . run ( test .test_pp () )
328+ asyncio . run ( test . test_tp_pp () )
329+ asyncio . run ( test . test_throughput () )
311330 test .teardown_method (None )
312331 print ("~~~" )
313332 print ("success." )
0 commit comments