Skip to content

Commit 56a5a0f

Browse files
committed
Merge
1 parent f075875 commit 56a5a0f

File tree

3 files changed

+52
-51
lines changed

3 files changed

+52
-51
lines changed

examples/inference/gpt/utils.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,6 @@ def add_common_inference_args(parser: ArgumentParser) -> ArgumentParser:
117117
'`--prompt-file` above). The first `--prompt-file-num-truncate` samples '
118118
'will be used, in order.',
119119
)
120-
group.add_argument(
121-
"--inference-coordinator-port",
122-
type=int,
123-
help="This port will be used to setup the inference co-ordinator on node-0",
124-
default=12346
125-
)
126120
group.add_argument(
127121
"--use-flashinfer-fused-rope",
128122
action='store_true',

megatron/core/inference/engines/dynamic_engine.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -375,16 +375,7 @@ async def start_listening_to_data_parallel_coordinator(
375375
launch_inference_coordinator (bool, optional): If True, the global rank 0
376376
process will spawn and manage the `InferenceCoordinator`
377377
process. Defaults to True.
378-
<<<<<<< HEAD
379378
verbose (bool): Whether to run in verbose mode.
380-
381-
Note:
382-
The current implementation uses `ipc` sockets for broadcasting requests
383-
within a Tensor Parallel group, which limits each TP group to a single
384-
physical node. For example, if you have 8 GPUs per node, then this will only
385-
work with TP=[1,2,4,8]
386-
=======
387-
>>>>>>> a28d34db94 (Clean up DP coord unit-test and code reuse)
388379
"""
389380

390381
assert HAVE_ZMQ, (
@@ -1285,7 +1276,6 @@ def stop(self):
12851276
for socket in self.zmq_sockets:
12861277
socket.close()
12871278
self.zmq_context.term()
1288-
parallel_state.destroy_model_parallel()
12891279

12901280
@trace_async_exceptions
12911281
async def run_engine(
@@ -1306,7 +1296,6 @@ async def run_engine(
13061296
)
13071297
)
13081298
)
1309-
13101299
await self.async_step(verbose=verbose)
13111300
except asyncio.CancelledError:
13121301
pass
@@ -1345,7 +1334,6 @@ async def run_engine_with_coordinator(
13451334
self.suspend()
13461335
await asyncio.sleep(0.02)
13471336
continue
1348-
13491337
else:
13501338
self.resume()
13511339

tests/unit_tests/inference/test_data_parallel_inference_coordinator.py

Lines changed: 52 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ async def async_step(self, *, verbose: Optional[bool] = False) -> Dict:
8585
to_remove = []
8686
for request_id, record in self.request_records.items():
8787
if record[-1].status == Status.ACTIVE_AND_GENERATING_TOKENS:
88+
record[-1].sampling_params.num_tokens_to_generate -= 1
89+
if record[-1].sampling_params.num_tokens_to_generate > 0:
90+
continue
8891
record[-1].status = Status.COMPLETED
8992
self.context.active_cnt -= 1
9093
finished_request_records.append(record)
@@ -122,6 +125,7 @@ class CoordinatorTestConfig:
122125
num_requests: int = 10**1
123126
min_time_offset: float = 10 ** (-4)
124127
max_time_offset: float = 10 ** (-3)
128+
num_steps_to_finish: int = 1
125129
num_iterations: int = 1
126130

127131
tensor_model_parallel_size: int = 1
@@ -154,7 +158,10 @@ def _build_requests(cls, test_config: CoordinatorTestConfig) -> List[Tuple]:
154158

155159
for _ in range(test_config.num_requests):
156160
arrival_delta = random.uniform(test_config.min_time_offset, test_config.max_time_offset)
157-
ret.append(("Hello world!", SamplingParams(), arrival_delta))
161+
num_tokens = test_config.num_steps_to_finish
162+
ret.append(
163+
("Hello world!", SamplingParams(num_tokens_to_generate=num_tokens), arrival_delta)
164+
)
158165
return ret
159166

160167
@classmethod
@@ -165,6 +172,7 @@ def _build_test_env(cls, test_config):
165172
)
166173
requests = cls._build_requests(test_config)
167174
engine = DummyEngine()
175+
engine.num_steps_to_finish = test_config.num_steps_to_finish
168176
return CoordinatorTestEnv(config=test_config, requests=requests, engine=engine)
169177

170178
@classmethod
@@ -180,37 +188,48 @@ async def _run_test(cls, **test_config_kwargs):
180188
launch_inference_coordinator=test_config.launch_inference_coordinator,
181189
)
182190

183-
if dist.get_rank() == 0:
184-
client = InferenceClient(test_config.port)
185-
await client.start()
186-
env.timing_data["init_time"] = time.time()
187-
188-
all_results = []
189-
for _ in range(test_config.num_iterations):
190-
futures = []
191-
for request in tqdm(env.requests, "add_requests"):
192-
prompt, sampling_params, arrival_delta = request
193-
await asyncio.sleep(arrival_delta)
194-
fut = client.add_request(prompt=prompt, sampling_params=sampling_params)
195-
futures.append(fut)
196-
results: List[DynamicInferenceRequestRecord] = await asyncio.gather(*futures)
197-
all_results.append(results)
198-
env.timing_data["done_time"] = time.time()
199-
200-
if test_config.stop_engines:
201-
client.stop_engines()
202-
client.stop()
203-
204-
if test_config.stop_engines:
205-
await env.engine.engine_loop_task
191+
results_success = False
192+
shutdown_success = False
193+
try:
194+
if dist.get_rank() == 0:
195+
client = InferenceClient(test_config.port)
196+
await client.start()
197+
env.timing_data["init_time"] = time.time()
198+
199+
all_results = []
200+
for _ in range(test_config.num_iterations):
201+
futures = []
202+
for request in tqdm(env.requests, "add_requests"):
203+
prompt, sampling_params, arrival_delta = request
204+
await asyncio.sleep(arrival_delta)
205+
fut = client.add_request(prompt=prompt, sampling_params=sampling_params)
206+
futures.append(fut)
207+
results: List[DynamicInferenceRequestRecord] = await asyncio.gather(*futures)
208+
all_results.append(results)
209+
env.timing_data["done_time"] = time.time()
210+
results_success = True
211+
finally:
212+
try:
213+
if dist.get_rank() == 0:
214+
if test_config.stop_engines:
215+
client.stop_engines()
216+
client.stop()
217+
if test_config.stop_engines:
218+
await env.engine.engine_loop_task
219+
shutdown_success = True
220+
except:
221+
env.engine.engine_loop_task.cancel()
222+
206223
env.timing_data["stop_time"] = time.time()
207224

225+
assert results_success, "Did not receive all results successfully."
226+
assert shutdown_success, "Did not shutdown successfully."
208227
if dist.get_rank() == 0:
209228
env.responses = all_results
210229
if test_config.verify_results:
211230
for batch in all_results:
212-
for result in batch:
213-
assert result.status == Status.COMPLETED
231+
for record in batch:
232+
assert record[-1].status == Status.COMPLETED
214233

215234
return env
216235

@@ -267,9 +286,9 @@ async def test_throughput(self):
267286
init_duration = (env.timing_data["init_time"] - env.timing_data["start_time"]) * 10**3
268287
golden_init_duration = 4445.64 # ms
269288
run_duration = (env.timing_data["done_time"] - env.timing_data["init_time"]) * 10**3
270-
golden_run_duration = 3088.87 # ms
289+
golden_run_duration = 2906.29 # ms
271290
stop_duration = (env.timing_data["stop_time"] - env.timing_data["done_time"]) * 10**3
272-
golden_stop_duration = 129.57 # ms
291+
golden_stop_duration = 10.77 # ms
273292

274293
# Print current results.
275294
print(f"Initialization time: {init_duration:.2f} ms")
@@ -288,7 +307,7 @@ def clamp_to_golden_value(value, golden_value, delta=0.1):
288307
f"WARNING: Run duration {run_duration:.2f}s deviates from "
289308
f"golden value {golden_run_duration:.2f}s"
290309
)
291-
assert clamp_to_golden_value(stop_duration, golden_stop_duration, delta=0.3), (
310+
assert clamp_to_golden_value(stop_duration, golden_stop_duration, delta=1.0), (
292311
f"WARNING: Stop duration {stop_duration:.2f}s deviates from "
293312
f"golden value {golden_stop_duration:.2f}s"
294313
)
@@ -304,10 +323,10 @@ def clamp_to_golden_value(value, golden_value, delta=0.1):
304323
if __name__ == "__main__":
305324
test = TestCoordinator()
306325
asyncio.run(test.test_simple())
307-
test.test_tp()
308-
test.test_pp()
309-
test_test.tp_pp()
310-
test_test.throughput()
326+
asyncio.run(test.test_tp())
327+
asyncio.run(test.test_pp())
328+
asyncio.run(test.test_tp_pp())
329+
asyncio.run(test.test_throughput())
311330
test.teardown_method(None)
312331
print("~~~")
313332
print("success.")

0 commit comments

Comments
 (0)