Skip to content

Commit 79257a6

Browse files
authored
Clean up DP coord code & unit test (#2277)
1 parent e35495d commit 79257a6

File tree

4 files changed

+61
-68
lines changed

4 files changed

+61
-68
lines changed

examples/inference/gpt/gpt_dynamic_inference_with_coordinator.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ async def main(
4343
engine: DynamicInferenceEngine,
4444
requests: List[Request],
4545
port: int,
46-
mp_port: int,
4746
sampling_params: SamplingParams | None = None,
4847
):
4948
if sampling_params is not None:
@@ -58,7 +57,6 @@ async def main(
5857

5958
await engine.start_listening_to_data_parallel_coordinator(
6059
inference_coordinator_port=port,
61-
inference_mp_coordinator_port=mp_port,
6260
launch_inference_coordinator=True,
6361
verbose=True,
6462
)
@@ -258,6 +256,5 @@ async def main(
258256
engine,
259257
requests,
260258
args.inference_coordinator_port,
261-
args.inference_mp_coordinator_port
262259
)
263260
)

megatron/core/inference/engines/dynamic_engine.py

Lines changed: 59 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
)
4343
from megatron.core.inference.utils import Counter, await_process_event
4444
from megatron.core.transformer.cuda_graphs import delete_cuda_graphs
45-
from megatron.core.utils import get_asyncio_loop, trace_async_exceptions
45+
from megatron.core.utils import get_asyncio_loop, internal_api, trace_async_exceptions
4646

4747
try:
4848
from tqdm import tqdm
@@ -237,10 +237,6 @@ def reset(self) -> None:
237237

238238
# Coordinator state.
239239
self.use_coordinator = False
240-
self.is_tp0_and_pp0 = (
241-
parallel_state.get_tensor_model_parallel_rank() == 0
242-
and parallel_state.get_pipeline_model_parallel_rank() == 0
243-
)
244240

245241
def create_cuda_graphs(self, reset_context: bool = True):
246242
"""Create cuda graphs.
@@ -263,7 +259,7 @@ def create_cuda_graphs(self, reset_context: bool = True):
263259

264260
if moe_pad_experts and context.non_decode_cuda_graphs:
265261
context.non_decode_cuda_graphs = False
266-
if torch.distributed.get_rank() == 0:
262+
if self.rank == 0:
267263
warnings.warn(
268264
"MoE models do not support non-decode cuda graphs. "
269265
"Forcing non_decode_cuda_graphs to False."
@@ -348,10 +344,10 @@ def create_cuda_graphs(self, reset_context: bool = True):
348344

349345
self.capture_stats = capture_stats
350346

347+
@internal_api
351348
async def start_listening_to_data_parallel_coordinator(
352349
self,
353350
inference_coordinator_port: int,
354-
inference_mp_coordinator_port: int = 20000,
355351
launch_inference_coordinator: bool = True,
356352
verbose: bool = False,
357353
*,
@@ -364,6 +360,8 @@ async def start_listening_to_data_parallel_coordinator(
364360
`InferenceCoordinator`. It configures different ZMQ socket patterns
365361
based on the rank's role within the distributed topology.
366362
363+
Note that this method must be called on all ranks, as it uses blocking torch broadcasts.
364+
367365
The setup involves two primary roles within each data-parallel group:
368366
1. **MP Coordinator (TP_rank=0, PP_rank=0)**: This rank connects directly
369367
to the central coordinator via a ZMQ `DEALER` socket. It receives
@@ -382,9 +380,6 @@ async def start_listening_to_data_parallel_coordinator(
382380
Args:
383381
inference_coordinator_port (int): The network port where the central
384382
`InferenceCoordinator` is or will be listening.
385-
inference_mp_coordinator_port (int): The base network port where each model parallel
386-
coordinator will broadcast messages from. Each MP group will compute an independent
387-
port offset from this base port.
388383
launch_inference_coordinator (bool, optional): If True, the global rank 0
389384
process will spawn and manage the `InferenceCoordinator`
390385
process. Defaults to True.
@@ -399,7 +394,25 @@ async def start_listening_to_data_parallel_coordinator(
399394
"pip install msgpack"
400395
)
401396

402-
if launch_inference_coordinator and torch.distributed.get_rank() == 0:
397+
self.zmq_context = zmq.Context().instance()
398+
self.zmq_sockets = [] # keep track of all sockets created by this engine
399+
400+
# Get world info.
401+
dp_group = parallel_state.get_data_parallel_group()
402+
dp_src = parallel_state.get_data_parallel_src_rank()
403+
dp_size = parallel_state.get_data_parallel_world_size()
404+
dp_rank = parallel_state.get_data_parallel_rank()
405+
406+
mp_group = parallel_state.get_model_parallel_group()
407+
mp_src = parallel_state.get_model_parallel_src_rank()
408+
tp_rank = parallel_state.get_tensor_model_parallel_rank()
409+
pp_rank = parallel_state.get_pipeline_model_parallel_rank()
410+
411+
self.is_mp_coordinator = tp_rank == 0 and pp_rank == 0
412+
self.is_dp_coordinator = (dp_rank == 0) and self.is_mp_coordinator
413+
414+
# Spawn a DP coordinator process and get the connection info.
415+
if launch_inference_coordinator and self.is_dp_coordinator:
403416
spawn_context = multiprocessing.get_context('spawn')
404417
coordinator_ready_event = spawn_context.Event()
405418
self.inference_coordinator_process = spawn_context.Process(
@@ -412,80 +425,67 @@ async def start_listening_to_data_parallel_coordinator(
412425
)
413426
self.inference_coordinator_process.start()
414427

415-
# Todo [Siddharth]: can we move this code to another file?
416-
self.zmq_context = zmq.Context()
417-
self.zmq_sockets = [] # keep track of all sockets created by this engine
418-
419-
# We need to broadcast the hostname of the (TP=0, PP=0) rank
420-
# to all other ranks in the same model parallel group.
421-
tp_rank = parallel_state.get_tensor_model_parallel_rank()
422-
pp_rank = parallel_state.get_pipeline_model_parallel_rank()
423-
424-
hostname_list = [None]
425-
if tp_rank == 0 and pp_rank == 0:
426-
hostname_list[0] = socket.gethostname()
428+
# Find available ports for MP and bind to them.
429+
if self.is_mp_coordinator:
430+
local_ip = socket.gethostname()
431+
mp_req_sock = self.zmq_context.socket(zmq.PUB)
432+
mp_req_sock.bind_to_random_port(f"tcp://{local_ip}")
433+
mp_req_addr = mp_req_sock.getsockopt_string(zmq.LAST_ENDPOINT)
427434

428-
# Find the global rank of the (TP=0, PP=0) rank in our MP group
429-
src_global_rank = parallel_state.get_model_parallel_src_rank()
430-
431-
torch.distributed.broadcast_object_list(
432-
hostname_list, src=src_global_rank, group=parallel_state.get_model_parallel_group()
433-
)
434-
bcast_hostname = hostname_list[0]
435+
mp_len_sock = self.zmq_context.socket(zmq.PUB)
436+
mp_len_sock.bind_to_random_port(f"tcp://{local_ip}")
437+
mp_len_addr = mp_len_sock.getsockopt_string(zmq.LAST_ENDPOINT)
438+
else:
439+
mp_req_addr = None
440+
mp_len_addr = None
435441

436-
# We need unique ports for each MP group, so we compute an offset using the DP rank.
437-
dp_rank = parallel_state.get_data_parallel_rank()
438-
req_port = inference_mp_coordinator_port + (dp_rank * 2)
439-
len_port = inference_mp_coordinator_port + (dp_rank * 2) + 1
442+
# Broadcast addresses to respective ranks.
443+
bcast = [mp_req_addr, mp_len_addr]
444+
torch.distributed.broadcast_object_list(bcast, src=mp_src, group=mp_group)
445+
[mp_req_addr, mp_len_addr] = bcast
440446

441447
ip_address_of_dp_coordinator = os.getenv('MASTER_ADDR', '127.0.0.1')
442-
identity = f'mp-coord-{parallel_state.get_data_parallel_rank()}'
443-
if (
444-
parallel_state.get_tensor_model_parallel_rank() == 0
445-
and parallel_state.get_pipeline_model_parallel_rank() == 0
446-
):
448+
dp_addr = f"tcp://{ip_address_of_dp_coordinator}:{inference_coordinator_port}"
449+
identity = f'mp-coord-{dp_rank}'
450+
if self.is_mp_coordinator:
447451
# 1. Create dealer sockets where tp_rank = 0 and pp_rank = 0
448452
# These will receive requests from an InferenceCoordinator.
449453
self.socket_for_receiving_requests = self.zmq_context.socket(zmq.DEALER)
450454

451455
self.socket_for_receiving_requests.setsockopt(zmq.IDENTITY, identity.encode('utf-8'))
452-
self.socket_for_receiving_requests.connect(
453-
f"tcp://{ip_address_of_dp_coordinator}:{inference_coordinator_port}"
454-
)
456+
self.socket_for_receiving_requests.connect(dp_addr)
455457

456458
# send empty string. this is used to register with the coordinator.
457459
self.socket_for_receiving_requests.send(b"")
458460

459461
# 2. Create a publisher socket. This is used to publish or broadcast
460462
# requests within the model parallel group
461-
self.model_parallel_publisher_socket = self.zmq_context.socket(zmq.PUB)
462-
self.model_parallel_publisher_socket.bind(f"tcp://*:{req_port}")
463+
self.model_parallel_publisher_socket = mp_req_sock
463464

464465
# 3. Create another publisher socket to broadcast the number of messages to receive.
465-
self.model_parallel_num_msgs_publisher_socket = self.zmq_context.socket(zmq.PUB)
466-
self.model_parallel_num_msgs_publisher_socket.bind(f"tcp://*:{len_port}")
466+
self.model_parallel_num_msgs_publisher_socket = mp_len_sock
467467
self.zmq_sockets += [
468468
self.socket_for_receiving_requests,
469469
self.model_parallel_num_msgs_publisher_socket,
470470
self.model_parallel_publisher_socket,
471471
]
472472
# All MP ranks subscribe to the two publisher sockets
473473
self.model_parallel_subscriber_socket = self.zmq_context.socket(zmq.SUB)
474-
self.model_parallel_subscriber_socket.connect(f"tcp://{bcast_hostname}:{req_port}")
474+
self.model_parallel_subscriber_socket.connect(mp_req_addr)
475475
self.model_parallel_subscriber_socket.setsockopt_string(zmq.SUBSCRIBE, "")
476476

477477
self.model_parallel_num_msgs_subscriber_socket = self.zmq_context.socket(zmq.SUB)
478-
self.model_parallel_num_msgs_subscriber_socket.connect(f"tcp://{bcast_hostname}:{len_port}")
478+
self.model_parallel_num_msgs_subscriber_socket.connect(mp_len_addr)
479479
self.model_parallel_num_msgs_subscriber_socket.setsockopt_string(zmq.SUBSCRIBE, "")
480480

481481
self.zmq_sockets += [
482482
self.model_parallel_subscriber_socket,
483483
self.model_parallel_num_msgs_subscriber_socket,
484484
]
485485

486-
torch.distributed.barrier(parallel_state.get_model_parallel_group())
486+
torch.distributed.barrier(mp_group)
487487

488-
if launch_inference_coordinator and torch.distributed.get_rank() == 0:
488+
if launch_inference_coordinator and self.is_dp_coordinator:
489489
await await_process_event(coordinator_ready_event, self.inference_coordinator_process)
490490
logging.info("Inference co-ordinator is ready to receive requests!")
491491

@@ -697,7 +697,7 @@ def _add_request(
697697
try:
698698
eod = self.controller.tokenizer.eod
699699
except AttributeError:
700-
if torch.distributed.get_rank() == 0:
700+
if self.rank == 0:
701701
warnings.warn(
702702
"Termination ID not specified, and tokenizer does not define eod."
703703
"Defaulting to not using termination id."
@@ -1093,7 +1093,7 @@ async def async_bookkeep(
10931093
self.failed_request_ids.clear()
10941094

10951095
# Handle necessary ZMQ DP coordinator communication.
1096-
if self.use_coordinator and self.is_tp0_and_pp0 and finished_request_records:
1096+
if self.use_coordinator and self.is_mp_coordinator and finished_request_records:
10971097
payload = msgpack.packb(
10981098
[Headers.ENGINE_REPLY.value, [r.serialize() for r in finished_request_records]],
10991099
use_bin_type=True,
@@ -1277,11 +1277,9 @@ def schedule_requests(self) -> int:
12771277
int: The number of messages that were received and processed in this batch.
12781278
"""
12791279

1280-
tp_rank = parallel_state.get_tensor_model_parallel_rank()
1281-
pp_rank = parallel_state.get_pipeline_model_parallel_rank()
12821280
torch.cuda.nvtx.range_push("drain_zmq_socket")
12831281
all_messages = []
1284-
if tp_rank == 0 and pp_rank == 0:
1282+
if self.is_mp_coordinator:
12851283
while True:
12861284
try:
12871285
# Receive messages in a non-blocking way.
@@ -1297,8 +1295,8 @@ def schedule_requests(self) -> int:
12971295
struct.pack('!i', messages_to_dequeue)
12981296
)
12991297
# Now publish the actual messages to all model parallel ranks
1300-
for message in all_messages:
1301-
self.model_parallel_publisher_socket.send(message)
1298+
if messages_to_dequeue > 0:
1299+
self.model_parallel_publisher_socket.send_multipart(all_messages)
13021300
else:
13031301
# First, receive the number of messages to dequeue from mp-rank 0
13041302
messages_to_dequeue = struct.unpack(
@@ -1307,8 +1305,10 @@ def schedule_requests(self) -> int:
13071305
# Now, dequeue the same number of messages from the subscriber socket.
13081306
# Note that these receives are blocking, because the messages
13091307
# are guaranteed to be available after the tp-rank 0 has sent them.
1310-
for _ in range(messages_to_dequeue):
1311-
all_messages.append(self.model_parallel_subscriber_socket.recv())
1308+
if messages_to_dequeue > 0:
1309+
all_messages = self.model_parallel_subscriber_socket.recv_multipart()
1310+
else:
1311+
all_messages = []
13121312

13131313
torch.cuda.nvtx.range_pop()
13141314
for message in all_messages:
@@ -1347,7 +1347,6 @@ def stop(self):
13471347
for socket in self.zmq_sockets:
13481348
socket.close()
13491349
self.zmq_context.term()
1350-
parallel_state.destroy_model_parallel()
13511350

13521351
@trace_async_exceptions
13531352
async def run_engine(
@@ -1369,7 +1368,6 @@ async def run_engine(
13691368
)
13701369
)
13711370
)
1372-
13731371
await self.async_step(verbose=verbose)
13741372
except asyncio.CancelledError:
13751373
pass

megatron/training/arguments.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1510,8 +1510,6 @@ def _add_inference_args(parser):
15101510
'Default to 0 to disable inference wandb logging.')
15111511
group.add_argument("--inference-coordinator-port", type=int, default=12346,
15121512
help="This port will be used to setup the inference coordinator on node-0")
1513-
group.add_argument("--inference-mp-coordinator-port", type=int, default=20000,
1514-
help="This port will be used to setup the inference model parallel coordinators")
15151513
return parser
15161514

15171515

tests/unit_tests/inference/test_data_parallel_inference_coordinator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,6 @@ async def _run_test(cls, **test_config_kwargs):
188188
env.timing_data["start_time"] = time.time()
189189
await env.engine.start_listening_to_data_parallel_coordinator(
190190
inference_coordinator_port=test_config.port,
191-
inference_mp_coordinator_port=test_config.mp_port,
192191
launch_inference_coordinator=test_config.launch_inference_coordinator,
193192
)
194193

@@ -232,7 +231,8 @@ async def _run_test(cls, **test_config_kwargs):
232231
env.responses = all_results
233232
if test_config.verify_results:
234233
for batch in all_results:
235-
for request in batch:
234+
for record in batch:
235+
request = record[-1]
236236
assert request.status == Status.COMPLETED
237237

238238
return env

0 commit comments

Comments
 (0)