Skip to content

Commit 12f3eff

Browse files
committed
Merge
1 parent a28d34d commit 12f3eff

File tree

1 file changed

+48
-29
lines changed

1 file changed

+48
-29
lines changed

tests/unit_tests/inference/test_data_parallel_inference_coordinator.py

Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ async def async_step(
7979
to_remove = []
8080
for request_id, request in self.requests.items():
8181
if request.status == Status.ACTIVE_AND_GENERATING_TOKENS:
82+
request.sampling_params.num_tokens_to_generate -= 1
83+
if request.sampling_params.num_tokens_to_generate > 0:
84+
continue
8285
request.status = Status.COMPLETED
8386
self.context.active_cnt -= 1
8487
finished_requests.append(request)
@@ -116,6 +119,7 @@ class CoordinatorTestConfig:
116119
num_requests: int = 10**1
117120
min_time_offset: float = 10 ** (-4)
118121
max_time_offset: float = 10 ** (-3)
122+
num_steps_to_finish: int = 1
119123
num_iterations: int = 1
120124

121125
tensor_model_parallel_size: int = 1
@@ -148,7 +152,10 @@ def _build_requests(cls, test_config: CoordinatorTestConfig) -> List[Tuple]:
148152

149153
for _ in range(test_config.num_requests):
150154
arrival_delta = random.uniform(test_config.min_time_offset, test_config.max_time_offset)
151-
ret.append(("Hello world!", SamplingParams(), arrival_delta))
155+
num_tokens = test_config.num_steps_to_finish
156+
ret.append(
157+
("Hello world!", SamplingParams(num_tokens_to_generate=num_tokens), arrival_delta)
158+
)
152159
return ret
153160

154161
@classmethod
@@ -159,6 +166,7 @@ def _build_test_env(cls, test_config):
159166
)
160167
requests = cls._build_requests(test_config)
161168
engine = DummyEngine()
169+
engine.num_steps_to_finish = test_config.num_steps_to_finish
162170
return CoordinatorTestEnv(config=test_config, requests=requests, engine=engine)
163171

164172
@classmethod
@@ -174,31 +182,42 @@ async def _run_test(cls, **test_config_kwargs):
174182
launch_inference_coordinator=test_config.launch_inference_coordinator,
175183
)
176184

177-
if dist.get_rank() == 0:
178-
client = InferenceClient(test_config.port)
179-
await client.start()
180-
env.timing_data["init_time"] = time.time()
181-
182-
all_results = []
183-
for _ in range(test_config.num_iterations):
184-
futures = []
185-
for request in tqdm(env.requests, "add_requests"):
186-
prompt, sampling_params, arrival_delta = request
187-
await asyncio.sleep(arrival_delta)
188-
fut = client.add_request(prompt=prompt, sampling_params=sampling_params)
189-
futures.append(fut)
190-
results: List[DynamicInferenceRequest] = await asyncio.gather(*futures)
191-
all_results.append(results)
192-
env.timing_data["done_time"] = time.time()
193-
194-
if test_config.stop_engines:
195-
client.stop_engines()
196-
client.stop()
197-
198-
if test_config.stop_engines:
199-
await env.engine.engine_loop_task
185+
results_success = False
186+
shutdown_success = False
187+
try:
188+
if dist.get_rank() == 0:
189+
client = InferenceClient(test_config.port)
190+
await client.start()
191+
env.timing_data["init_time"] = time.time()
192+
193+
all_results = []
194+
for _ in range(test_config.num_iterations):
195+
futures = []
196+
for request in tqdm(env.requests, "add_requests"):
197+
prompt, sampling_params, arrival_delta = request
198+
await asyncio.sleep(arrival_delta)
199+
fut = client.add_request(prompt=prompt, sampling_params=sampling_params)
200+
futures.append(fut)
201+
results: List[DynamicInferenceRequest] = await asyncio.gather(*futures)
202+
all_results.append(results)
203+
env.timing_data["done_time"] = time.time()
204+
results_success = True
205+
finally:
206+
try:
207+
if dist.get_rank() == 0:
208+
if test_config.stop_engines:
209+
client.stop_engines()
210+
client.stop()
211+
if test_config.stop_engines:
212+
await env.engine.engine_loop_task
213+
shutdown_success = True
214+
except:
215+
env.engine.engine_loop_task.cancel()
216+
200217
env.timing_data["stop_time"] = time.time()
201218

219+
assert results_success, "Did not receive all results successfully."
220+
assert shutdown_success, "Did not shutdown successfully."
202221
if dist.get_rank() == 0:
203222
env.responses = all_results
204223
if test_config.verify_results:
@@ -297,11 +316,11 @@ def clamp_to_golden_value(value, golden_value, delta=0.1):
297316

298317
if __name__ == "__main__":
299318
test = TestCoordinator()
300-
test.test_simple()
301-
test.test_tp()
302-
test.test_pp()
303-
test_test.tp_pp()
304-
test_test.throughput()
319+
asyncio.run(test.test_simple())
320+
asyncio.run(test.test_tp())
321+
asyncio.run(test.test_pp())
322+
asyncio.run(test.test_tp_pp())
323+
asyncio.run(test.test_throughput())
305324
test.teardown_method(None)
306325
print("~~~")
307326
print("success.")

0 commit comments

Comments
 (0)