@@ -79,6 +79,9 @@ async def async_step(
7979 to_remove = []
8080 for request_id , request in self .requests .items ():
8181 if request .status == Status .ACTIVE_AND_GENERATING_TOKENS :
82+ request .sampling_params .num_tokens_to_generate -= 1
83+ if request .sampling_params .num_tokens_to_generate > 0 :
84+ continue
8285 request .status = Status .COMPLETED
8386 self .context .active_cnt -= 1
8487 finished_requests .append (request )
@@ -116,6 +119,7 @@ class CoordinatorTestConfig:
116119 num_requests : int = 10 ** 1
117120 min_time_offset : float = 10 ** (- 4 )
118121 max_time_offset : float = 10 ** (- 3 )
122+ num_steps_to_finish : int = 1
119123 num_iterations : int = 1
120124
121125 tensor_model_parallel_size : int = 1
@@ -148,7 +152,10 @@ def _build_requests(cls, test_config: CoordinatorTestConfig) -> List[Tuple]:
148152
149153 for _ in range (test_config .num_requests ):
150154 arrival_delta = random .uniform (test_config .min_time_offset , test_config .max_time_offset )
151- ret .append (("Hello world!" , SamplingParams (), arrival_delta ))
155+ num_tokens = test_config .num_steps_to_finish
156+ ret .append (
157+ ("Hello world!" , SamplingParams (num_tokens_to_generate = num_tokens ), arrival_delta )
158+ )
152159 return ret
153160
154161 @classmethod
@@ -159,6 +166,7 @@ def _build_test_env(cls, test_config):
159166 )
160167 requests = cls ._build_requests (test_config )
161168 engine = DummyEngine ()
169+ engine .num_steps_to_finish = test_config .num_steps_to_finish
162170 return CoordinatorTestEnv (config = test_config , requests = requests , engine = engine )
163171
164172 @classmethod
@@ -174,31 +182,42 @@ async def _run_test(cls, **test_config_kwargs):
174182 launch_inference_coordinator = test_config .launch_inference_coordinator ,
175183 )
176184
177- if dist .get_rank () == 0 :
178- client = InferenceClient (test_config .port )
179- await client .start ()
180- env .timing_data ["init_time" ] = time .time ()
181-
182- all_results = []
183- for _ in range (test_config .num_iterations ):
184- futures = []
185- for request in tqdm (env .requests , "add_requests" ):
186- prompt , sampling_params , arrival_delta = request
187- await asyncio .sleep (arrival_delta )
188- fut = client .add_request (prompt = prompt , sampling_params = sampling_params )
189- futures .append (fut )
190- results : List [DynamicInferenceRequest ] = await asyncio .gather (* futures )
191- all_results .append (results )
192- env .timing_data ["done_time" ] = time .time ()
193-
194- if test_config .stop_engines :
195- client .stop_engines ()
196- client .stop ()
197-
198- if test_config .stop_engines :
199- await env .engine .engine_loop_task
185+ results_success = False
186+ shutdown_success = False
187+ try :
188+ if dist .get_rank () == 0 :
189+ client = InferenceClient (test_config .port )
190+ await client .start ()
191+ env .timing_data ["init_time" ] = time .time ()
192+
193+ all_results = []
194+ for _ in range (test_config .num_iterations ):
195+ futures = []
196+ for request in tqdm (env .requests , "add_requests" ):
197+ prompt , sampling_params , arrival_delta = request
198+ await asyncio .sleep (arrival_delta )
199+ fut = client .add_request (prompt = prompt , sampling_params = sampling_params )
200+ futures .append (fut )
201+ results : List [DynamicInferenceRequest ] = await asyncio .gather (* futures )
202+ all_results .append (results )
203+ env .timing_data ["done_time" ] = time .time ()
204+ results_success = True
205+ finally :
206+ try :
207+ if dist .get_rank () == 0 :
208+ if test_config .stop_engines :
209+ client .stop_engines ()
210+ client .stop ()
211+ if test_config .stop_engines :
212+ await env .engine .engine_loop_task
213+ shutdown_success = True
214+ except :
215+ env .engine .engine_loop_task .cancel ()
216+
200217 env .timing_data ["stop_time" ] = time .time ()
201218
219+ assert results_success , "Did not receive all results successfully."
220+ assert shutdown_success , "Did not shutdown successfully."
202221 if dist .get_rank () == 0 :
203222 env .responses = all_results
204223 if test_config .verify_results :
@@ -297,11 +316,11 @@ def clamp_to_golden_value(value, golden_value, delta=0.1):
297316
298317if __name__ == "__main__" :
299318 test = TestCoordinator ()
300- test .test_simple ()
301- test .test_tp ()
302- test .test_pp ()
303- test_test . tp_pp ( )
304- test_test . throughput ( )
319+ asyncio . run ( test .test_simple () )
320+ asyncio . run ( test .test_tp () )
321+ asyncio . run ( test .test_pp () )
322+ asyncio . run ( test . test_tp_pp () )
323+ asyncio . run ( test . test_throughput () )
305324 test .teardown_method (None )
306325 print ("~~~" )
307326 print ("success." )
0 commit comments