Skip to content

Commit a28d34d

Browse files
committed
Clean up DP coord unit-test and code reuse
1 parent c4ba666 commit a28d34d

File tree

3 files changed

+199
-87
lines changed

3 files changed

+199
-87
lines changed

megatron/core/inference/engines/dynamic_engine.py

Lines changed: 91 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
import multiprocessing
66
import os
7+
import socket
78
import struct
89
import time
910
import warnings
@@ -142,13 +143,14 @@ def __init__(
142143
self.paused = False
143144
self.stopped = False
144145
self.enable_chunked_prefill = enable_chunked_prefill
146+
self.rank = torch.distributed.get_rank()
145147

146148
self.inference_logging_step_interval = inference_logging_step_interval
147149
# Configure wandb to use separate step counter for inference metrics (only once)
148150
if self.inference_logging_step_interval > 0 and self.context.metrics_writer is not None:
149151
logging.info(
150152
f"\033[1;93m[INFERENCE]\033[0m "
151-
f"\033[1;95mLogging inference metrics to wandb (rank {torch.distributed.get_rank()})\033[0m"
153+
f"\033[1;95mLogging inference metrics to wandb (rank {self.rank})\033[0m"
152154
)
153155
if HAVE_WANDB and self.context.metrics_writer.__name__ == "wandb":
154156
# Make all inference/* metrics use inference_step as their x-axis
@@ -202,7 +204,7 @@ def create_cuda_graphs(self, reset_context: bool = True):
202204

203205
if moe_pad_experts and context.non_decode_cuda_graphs:
204206
context.non_decode_cuda_graphs = False
205-
if torch.distributed.get_rank() == 0:
207+
if self.rank == 0:
206208
warnings.warn(
207209
"MoE models do not support non-decode cuda graphs. "
208210
"Forcing non_decode_cuda_graphs to False."
@@ -301,16 +303,18 @@ async def start_listening_to_data_parallel_coordinator(
301303
`InferenceCoordinator`. It configures different ZMQ socket patterns
302304
based on the rank's role within the distributed topology.
303305
306+
Note that this method must be called on all ranks, as it uses blocking torch broadcasts.
307+
304308
The setup involves two primary roles within each data-parallel group:
305-
1. **TP Coordinator (TP_rank=0, PP_rank=0)**: This rank connects directly
309+
1. **MP Coordinator (TP_rank=0, PP_rank=0)**: This rank connects directly
306310
to the central coordinator via a ZMQ `DEALER` socket. It receives
307311
requests and uses a ZMQ `PUB` (publisher) socket to broadcast them
308-
to all other ranks within its tensor-parallel (TP) group.
309-
2. **TP Workers (all other ranks)**: These ranks use ZMQ `SUB` (subscriber)
310-
sockets to listen for requests broadcast by their local TP Coordinator.
312+
to all other ranks within its model-parallel (MP) group.
313+
2. **MP Workers (all other ranks)**: These ranks use ZMQ `SUB` (subscriber)
314+
sockets to listen for requests broadcast by their local MP Coordinator.
311315
312-
This architecture uses fast Inter-Process Communication (`ipc`) sockets for
313-
intra-node broadcasts within a TP group.
316+
This architecture uses TCP sockets for both inter-node and intra-node broadcasts
317+
within an MP group.
314318
315319
Finally, after setting up the communication channels and ensuring all ranks
316320
are synchronized, this method starts the main engine processing loop
@@ -322,12 +326,6 @@ async def start_listening_to_data_parallel_coordinator(
322326
launch_inference_coordinator (bool, optional): If True, the global rank 0
323327
process will spawn and manage the `InferenceCoordinator`
324328
process. Defaults to True.
325-
326-
Note:
327-
The current implementation uses `ipc` sockets for broadcasting requests
328-
within a Tensor Parallel group, which limits each TP group to a single
329-
physical node. For example, if you have 8 GPUs per node, then this will only
330-
work with TP=[1,2,4,8]
331329
"""
332330

333331
assert HAVE_ZMQ, (
@@ -338,7 +336,32 @@ async def start_listening_to_data_parallel_coordinator(
338336
"pip install msgpack"
339337
)
340338

341-
if launch_inference_coordinator and torch.distributed.get_rank() == 0:
339+
self.zmq_context = zmq.Context().instance()
340+
self.zmq_sockets = [] # keep track of all sockets created by this engine
341+
342+
# Get world info.
343+
dp_group = parallel_state.get_data_parallel_group()
344+
dp_src = parallel_state.get_data_parallel_src_rank()
345+
dp_size = parallel_state.get_data_parallel_world_size()
346+
dp_rank = parallel_state.get_data_parallel_rank()
347+
348+
mp_group = parallel_state.get_model_parallel_group()
349+
mp_src = parallel_state.get_model_parallel_src_rank()
350+
tp_rank = parallel_state.get_tensor_model_parallel_rank()
351+
pp_rank = parallel_state.get_pipeline_model_parallel_rank()
352+
353+
self.is_mp_coordinator = tp_rank == 0 and pp_rank == 0
354+
self.is_dp_coordinator = (dp_rank == 0) and self.is_mp_coordinator
355+
356+
# Get local IP.
357+
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as tmp_sock:
358+
tmp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
359+
tmp_sock.connect(('<broadcast>', 0))
360+
local_ip = tmp_sock.getsockname()[0]
361+
del tmp_sock
362+
363+
# Spawn a DP coordinator process and get the connection info.
364+
if launch_inference_coordinator and self.is_dp_coordinator:
342365
spawn_context = multiprocessing.get_context('spawn')
343366
coordinator_ready_event = spawn_context.Event()
344367
self.inference_coordinator_process = spawn_context.Process(
@@ -351,61 +374,65 @@ async def start_listening_to_data_parallel_coordinator(
351374
)
352375
self.inference_coordinator_process.start()
353376

354-
# Todo [Siddharth]: can we move this code to another file?
355-
self.zmq_context = zmq.Context()
356-
self.zmq_sockets = [] # keep track of all sockets created by this engine
377+
# Find available ports for MP and bind to them.
378+
if self.is_mp_coordinator:
379+
mp_req_sock = self.zmq_context.socket(zmq.PUB)
380+
mp_req_sock.bind_to_random_port(f"tcp://{local_ip}")
381+
mp_req_addr = [mp_req_sock.getsockopt_string(zmq.LAST_ENDPOINT)]
382+
383+
mp_len_sock = self.zmq_context.socket(zmq.PUB)
384+
mp_len_sock.bind_to_random_port(f"tcp://{local_ip}")
385+
mp_len_addr = [mp_len_sock.getsockopt_string(zmq.LAST_ENDPOINT)]
386+
else:
387+
mp_req_addr = [None]
388+
mp_len_addr = [None]
389+
390+
# Broadcast addresses to respective ranks.
391+
torch.distributed.broadcast_object_list(mp_req_addr, src=mp_src, group=mp_group)
392+
torch.distributed.broadcast_object_list(mp_len_addr, src=mp_src, group=mp_group)
393+
357394
ip_address_of_dp_coordinator = os.getenv('MASTER_ADDR', '127.0.0.1')
358-
identity = f'tp-coord-{parallel_state.get_data_parallel_rank()}'
359-
if (
360-
parallel_state.get_tensor_model_parallel_rank() == 0
361-
and parallel_state.get_pipeline_model_parallel_rank() == 0
362-
):
395+
dp_addr = [f"tcp://{ip_address_of_dp_coordinator}:{inference_coordinator_port}"]
396+
identity = f'mp-coord-{dp_rank}'
397+
if self.is_mp_coordinator:
363398
# 1. Create dealer sockets where tp_rank = 0 and pp_rank = 0
364399
# These will receive requests from an InferenceCoordinator.
365400
self.socket_for_receiving_requests = self.zmq_context.socket(zmq.DEALER)
366401

367402
self.socket_for_receiving_requests.setsockopt(zmq.IDENTITY, identity.encode('utf-8'))
368-
self.socket_for_receiving_requests.connect(
369-
f"tcp://{ip_address_of_dp_coordinator}:{inference_coordinator_port}"
370-
)
403+
self.socket_for_receiving_requests.connect(dp_addr[0])
371404

372405
# send empty string. this is used to register with the coordinator.
373406
self.socket_for_receiving_requests.send(b"")
374407

375408
# 2. Create a publisher socket. This is used to publish or broadcast
376-
# requests within the tensor parallel group
377-
self.tensor_parallel_publisher_socket = self.zmq_context.socket(zmq.PUB)
378-
self.tensor_parallel_publisher_socket.bind(f"ipc:///tmp/{identity}-tp-bcast-socket-req")
409+
# requests within the model parallel group
410+
self.model_parallel_publisher_socket = mp_req_sock
379411

380412
# 3. Create another publisher socket to broadcast the number of messages to receive.
381-
self.tensor_parallel_num_msgs_publisher_socket = self.zmq_context.socket(zmq.PUB)
382-
self.tensor_parallel_num_msgs_publisher_socket.bind(
383-
f"ipc:///tmp/{identity}-tp-bcast-socket-len"
384-
)
413+
self.model_parallel_num_msgs_publisher_socket = mp_len_sock
385414
self.zmq_sockets += [
386415
self.socket_for_receiving_requests,
387-
self.tensor_parallel_num_msgs_publisher_socket,
388-
self.tensor_parallel_publisher_socket,
416+
self.model_parallel_num_msgs_publisher_socket,
417+
self.model_parallel_publisher_socket,
389418
]
390-
# All TP ranks subscribe to the two publisher sockets
391-
self.tensor_parallel_subscriber_socket = self.zmq_context.socket(zmq.SUB)
392-
self.tensor_parallel_subscriber_socket.connect(f"ipc:///tmp/{identity}-tp-bcast-socket-req")
393-
self.tensor_parallel_subscriber_socket.setsockopt_string(zmq.SUBSCRIBE, "")
394-
395-
self.tensor_parallel_num_msgs_subscriber_socket = self.zmq_context.socket(zmq.SUB)
396-
self.tensor_parallel_num_msgs_subscriber_socket.connect(
397-
f"ipc:///tmp/{identity}-tp-bcast-socket-len"
398-
)
399-
self.tensor_parallel_num_msgs_subscriber_socket.setsockopt_string(zmq.SUBSCRIBE, "")
419+
# All MP ranks subscribe to the two publisher sockets
420+
self.model_parallel_subscriber_socket = self.zmq_context.socket(zmq.SUB)
421+
self.model_parallel_subscriber_socket.connect(mp_req_addr[0])
422+
self.model_parallel_subscriber_socket.setsockopt_string(zmq.SUBSCRIBE, "")
423+
424+
self.model_parallel_num_msgs_subscriber_socket = self.zmq_context.socket(zmq.SUB)
425+
self.model_parallel_num_msgs_subscriber_socket.connect(mp_len_addr[0])
426+
self.model_parallel_num_msgs_subscriber_socket.setsockopt_string(zmq.SUBSCRIBE, "")
400427

401428
self.zmq_sockets += [
402-
self.tensor_parallel_subscriber_socket,
403-
self.tensor_parallel_num_msgs_subscriber_socket,
429+
self.model_parallel_subscriber_socket,
430+
self.model_parallel_num_msgs_subscriber_socket,
404431
]
405432

406-
torch.distributed.barrier(parallel_state.get_tensor_model_parallel_group())
433+
torch.distributed.barrier(mp_group)
407434

408-
if launch_inference_coordinator and torch.distributed.get_rank() == 0:
435+
if launch_inference_coordinator and self.is_dp_coordinator:
409436
await await_process_event(coordinator_ready_event, self.inference_coordinator_process)
410437
logging.info("Inference co-ordinator is ready to receive requests!")
411438

@@ -455,7 +482,7 @@ def _add_request(
455482
try:
456483
eod = self.controller.tokenizer.eod
457484
except AttributeError:
458-
if torch.distributed.get_rank() == 0:
485+
if self.rank == 0:
459486
warnings.warn(
460487
"Termination ID not specified, and tokenizer does not define eod."
461488
"Defaulting to not using termination id."
@@ -932,16 +959,16 @@ def schedule_requests(self) -> int:
932959
"""Drains the ZMQ socket for a batch of requests and adds them to the engine.
933960
934961
This method is a collective and synchronous operation that must be called
935-
by all ranks in a Tensor Parallel (TP) group at the same time. It ensures
962+
by all ranks in a Model Parallel (MP) group at the same time. It ensures
936963
that all ranks process the exact same batch of incoming requests and
937964
control signals.
938965
939966
The synchronization works as follows:
940-
1. The TP rank 0 drains all pending messages from its subscriber socket
967+
1. The MP rank 0 drains all pending messages from its subscriber socket
941968
in a non-blocking manner.
942-
2. TP rank 0 then broadcasts the number of messages it received to all other
943-
ranks in its TP group using a dedicated publisher socket.
944-
3. The other TP ranks wait to receive this count, and then receive exactly
969+
2. MP rank 0 then broadcasts the number of messages it received to all other
970+
ranks in its MP group using a dedicated publisher socket.
971+
3. The other MP ranks wait to receive this count, and then receive exactly
945972
that many messages from their subscriber sockets.
946973
947974
Once all ranks have the same batch of messages, they are unpacked and
@@ -950,18 +977,17 @@ def schedule_requests(self) -> int:
950977
951978
Note:
952979
This function is synchronous and must be called collectively by all
953-
ranks in a TP group. It should not be launched in a separate coroutine
980+
ranks in a MP group. It should not be launched in a separate coroutine
954981
to ensure all ranks execute it in lockstep before proceeding to the
955982
next engine step.
956983
957984
Returns:
958985
int: The number of messages that were received and processed in this batch.
959986
"""
960987

961-
rank = parallel_state.get_tensor_model_parallel_rank()
962988
torch.cuda.nvtx.range_push("drain_zmq_socket")
963989
all_messages = []
964-
if rank == 0:
990+
if self.is_mp_coordinator:
965991
while True:
966992
try:
967993
# Receive messages in a non-blocking way.
@@ -973,22 +999,22 @@ def schedule_requests(self) -> int:
973999
# First publish the number of messages to dequeue.
9741000
# This is important because we want all tensor parallel ranks
9751001
# to dequeue the same number of messages.
976-
self.tensor_parallel_num_msgs_publisher_socket.send(
1002+
self.model_parallel_num_msgs_publisher_socket.send(
9771003
struct.pack('!i', messages_to_dequeue)
9781004
)
979-
# Now publish the actual messages to all tensor parallel ranks
1005+
# Now publish the actual messages to all model parallel ranks
9801006
for message in all_messages:
981-
self.tensor_parallel_publisher_socket.send(message)
1007+
self.model_parallel_publisher_socket.send(message)
9821008
else:
983-
# First, receive the number of messages to dequeue from tp-rank 0
1009+
# First, receive the number of messages to dequeue from mp-rank 0
9841010
messages_to_dequeue = struct.unpack(
985-
'!i', self.tensor_parallel_num_msgs_subscriber_socket.recv()
1011+
'!i', self.model_parallel_num_msgs_subscriber_socket.recv()
9861012
)[0]
9871013
# Now, dequeue the same number of messages from the subscriber socket.
9881014
# Note that these receives are blocking, because the messages
9891015
# are guaranteed to be available after the tp-rank 0 has sent them.
9901016
for _ in range(messages_to_dequeue):
991-
all_messages.append(self.tensor_parallel_subscriber_socket.recv())
1017+
all_messages.append(self.model_parallel_subscriber_socket.recv())
9921018

9931019
torch.cuda.nvtx.range_pop()
9941020
for message in all_messages:
@@ -1080,12 +1106,8 @@ async def run_engine_with_coordinator(
10801106

10811107
engine_output = await self.async_step(verbose=verbose)
10821108

1083-
is_tp0_and_pp0 = (
1084-
parallel_state.get_tensor_model_parallel_rank() == 0
1085-
and parallel_state.get_pipeline_model_parallel_rank() == 0
1086-
)
10871109
if (
1088-
is_tp0_and_pp0
1110+
self.is_mp_coordinator
10891111
and engine_output is not None
10901112
and engine_output["finished_requests"]
10911113
):

megatron/training/arguments.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1491,7 +1491,8 @@ def _add_inference_args(parser):
14911491
group.add_argument('--inference-wandb-logging-step-interval', type=int, default=0,
14921492
help='Step interval for logging inference metrics to wandb. '
14931493
'Default to 0 to disable inference wandb logging.')
1494-
1494+
group.add_argument("--inference-coordinator-port", type=int, default=12346,
1495+
help="This port will be used to setup the inference coordinator on node-0")
14951496
return parser
14961497

14971498

0 commit comments

Comments
 (0)