Skip to content

Commit f075875

Browse files
committed
Clean up DP coord unit-test and code reuse
1 parent 29eed5d commit f075875

File tree

3 files changed

+202
-82
lines changed

3 files changed

+202
-82
lines changed

megatron/core/inference/engines/dynamic_engine.py

Lines changed: 94 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
import multiprocessing
66
import os
7+
import socket
78
import struct
89
import time
910
import warnings
@@ -144,13 +145,14 @@ def __init__(
144145
self.track_paused_request_events = track_paused_request_events
145146
self.enable_chunked_prefill = enable_chunked_prefill
146147
self.unified_memory_level = context.unified_memory_level
148+
self.rank = torch.distributed.get_rank()
147149

148150
self.inference_logging_step_interval = inference_logging_step_interval
149151
# Configure wandb to use separate step counter for inference metrics (only once)
150152
if self.inference_logging_step_interval > 0 and self.context.metrics_writer is not None:
151153
logging.info(
152154
f"\033[1;93m[INFERENCE]\033[0m "
153-
f"\033[1;95mLogging inference metrics to wandb (rank {torch.distributed.get_rank()})\033[0m"
155+
f"\033[1;95mLogging inference metrics to wandb (rank {self.rank})\033[0m"
154156
)
155157
if HAVE_WANDB and self.context.metrics_writer.__name__ == "wandb":
156158
# Make all inference/* metrics use inference_step as their x-axis
@@ -250,7 +252,7 @@ def create_cuda_graphs(self, reset_context: bool = True):
250252

251253
if moe_pad_experts and context.non_decode_cuda_graphs:
252254
context.non_decode_cuda_graphs = False
253-
if torch.distributed.get_rank() == 0:
255+
if self.rank == 0:
254256
warnings.warn(
255257
"MoE models do not support non-decode cuda graphs. "
256258
"Forcing non_decode_cuda_graphs to False."
@@ -350,16 +352,18 @@ async def start_listening_to_data_parallel_coordinator(
350352
`InferenceCoordinator`. It configures different ZMQ socket patterns
351353
based on the rank's role within the distributed topology.
352354
355+
Note that this method must be called on all ranks, as it uses blocking torch broadcasts.
356+
353357
The setup involves two primary roles within each data-parallel group:
354-
1. **TP Coordinator (TP_rank=0, PP_rank=0)**: This rank connects directly
358+
1. **MP Coordinator (TP_rank=0, PP_rank=0)**: This rank connects directly
355359
to the central coordinator via a ZMQ `DEALER` socket. It receives
356360
requests and uses a ZMQ `PUB` (publisher) socket to broadcast them
357-
to all other ranks within its tensor-parallel (TP) group.
358-
2. **TP Workers (all other ranks)**: These ranks use ZMQ `SUB` (subscriber)
359-
sockets to listen for requests broadcast by their local TP Coordinator.
361+
to all other ranks within its model-parallel (MP) group.
362+
2. **MP Workers (all other ranks)**: These ranks use ZMQ `SUB` (subscriber)
363+
sockets to listen for requests broadcast by their local MP Coordinator.
360364
361-
This architecture uses fast Inter-Process Communication (`ipc`) sockets for
362-
intra-node broadcasts within a TP group.
365+
This architecture uses TCP sockets for both inter-node and intra-node broadcasts
366+
within an MP group.
363367
364368
Finally, after setting up the communication channels and ensuring all ranks
365369
are synchronized, this method starts the main engine processing loop
@@ -371,13 +375,16 @@ async def start_listening_to_data_parallel_coordinator(
371375
launch_inference_coordinator (bool, optional): If True, the global rank 0
372376
process will spawn and manage the `InferenceCoordinator`
373377
process. Defaults to True.
378+
<<<<<<< HEAD
374379
verbose (bool): Whether to run in verbose mode.
375380
376381
Note:
377382
The current implementation uses `ipc` sockets for broadcasting requests
378383
within a Tensor Parallel group, which limits each TP group to a single
379384
physical node. For example, if you have 8 GPUs per node, then this will only
380385
work with TP=[1,2,4,8]
386+
=======
387+
>>>>>>> a28d34db94 (Clean up DP coord unit-test and code reuse)
381388
"""
382389

383390
assert HAVE_ZMQ, (
@@ -388,7 +395,32 @@ async def start_listening_to_data_parallel_coordinator(
388395
"pip install msgpack"
389396
)
390397

391-
if launch_inference_coordinator and torch.distributed.get_rank() == 0:
398+
self.zmq_context = zmq.Context().instance()
399+
self.zmq_sockets = [] # keep track of all sockets created by this engine
400+
401+
# Get world info.
402+
dp_group = parallel_state.get_data_parallel_group()
403+
dp_src = parallel_state.get_data_parallel_src_rank()
404+
dp_size = parallel_state.get_data_parallel_world_size()
405+
dp_rank = parallel_state.get_data_parallel_rank()
406+
407+
mp_group = parallel_state.get_model_parallel_group()
408+
mp_src = parallel_state.get_model_parallel_src_rank()
409+
tp_rank = parallel_state.get_tensor_model_parallel_rank()
410+
pp_rank = parallel_state.get_pipeline_model_parallel_rank()
411+
412+
self.is_mp_coordinator = tp_rank == 0 and pp_rank == 0
413+
self.is_dp_coordinator = (dp_rank == 0) and self.is_mp_coordinator
414+
415+
# Get local IP.
416+
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as tmp_sock:
417+
tmp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
418+
tmp_sock.connect(('<broadcast>', 0))
419+
local_ip = tmp_sock.getsockname()[0]
420+
del tmp_sock
421+
422+
# Spawn a DP coordinator process and get the connection info.
423+
if launch_inference_coordinator and self.is_dp_coordinator:
392424
spawn_context = multiprocessing.get_context('spawn')
393425
coordinator_ready_event = spawn_context.Event()
394426
self.inference_coordinator_process = spawn_context.Process(
@@ -401,61 +433,65 @@ async def start_listening_to_data_parallel_coordinator(
401433
)
402434
self.inference_coordinator_process.start()
403435

404-
# Todo [Siddharth]: can we move this code to another file?
405-
self.zmq_context = zmq.Context()
406-
self.zmq_sockets = [] # keep track of all sockets created by this engine
436+
# Find available ports for MP and bind to them.
437+
if self.is_mp_coordinator:
438+
mp_req_sock = self.zmq_context.socket(zmq.PUB)
439+
mp_req_sock.bind_to_random_port(f"tcp://{local_ip}")
440+
mp_req_addr = [mp_req_sock.getsockopt_string(zmq.LAST_ENDPOINT)]
441+
442+
mp_len_sock = self.zmq_context.socket(zmq.PUB)
443+
mp_len_sock.bind_to_random_port(f"tcp://{local_ip}")
444+
mp_len_addr = [mp_len_sock.getsockopt_string(zmq.LAST_ENDPOINT)]
445+
else:
446+
mp_req_addr = [None]
447+
mp_len_addr = [None]
448+
449+
# Broadcast addresses to respective ranks.
450+
torch.distributed.broadcast_object_list(mp_req_addr, src=mp_src, group=mp_group)
451+
torch.distributed.broadcast_object_list(mp_len_addr, src=mp_src, group=mp_group)
452+
407453
ip_address_of_dp_coordinator = os.getenv('MASTER_ADDR', '127.0.0.1')
408-
identity = f'tp-coord-{parallel_state.get_data_parallel_rank()}'
409-
if (
410-
parallel_state.get_tensor_model_parallel_rank() == 0
411-
and parallel_state.get_pipeline_model_parallel_rank() == 0
412-
):
454+
dp_addr = [f"tcp://{ip_address_of_dp_coordinator}:{inference_coordinator_port}"]
455+
identity = f'mp-coord-{dp_rank}'
456+
if self.is_mp_coordinator:
413457
# 1. Create dealer sockets where tp_rank = 0 and pp_rank = 0
414458
# These will receive requests from an InferenceCoordinator.
415459
self.socket_for_receiving_requests = self.zmq_context.socket(zmq.DEALER)
416460

417461
self.socket_for_receiving_requests.setsockopt(zmq.IDENTITY, identity.encode('utf-8'))
418-
self.socket_for_receiving_requests.connect(
419-
f"tcp://{ip_address_of_dp_coordinator}:{inference_coordinator_port}"
420-
)
462+
self.socket_for_receiving_requests.connect(dp_addr[0])
421463

422464
# send empty string. this is used to register with the coordinator.
423465
self.socket_for_receiving_requests.send(b"")
424466

425467
# 2. Create a publisher socket. This is used to publish or broadcast
426-
# requests within the tensor parallel group
427-
self.tensor_parallel_publisher_socket = self.zmq_context.socket(zmq.PUB)
428-
self.tensor_parallel_publisher_socket.bind(f"ipc:///tmp/{identity}-tp-bcast-socket-req")
468+
# requests within the model parallel group
469+
self.model_parallel_publisher_socket = mp_req_sock
429470

430471
# 3. Create another publisher socket to broadcast the number of messages to receive.
431-
self.tensor_parallel_num_msgs_publisher_socket = self.zmq_context.socket(zmq.PUB)
432-
self.tensor_parallel_num_msgs_publisher_socket.bind(
433-
f"ipc:///tmp/{identity}-tp-bcast-socket-len"
434-
)
472+
self.model_parallel_num_msgs_publisher_socket = mp_len_sock
435473
self.zmq_sockets += [
436474
self.socket_for_receiving_requests,
437-
self.tensor_parallel_num_msgs_publisher_socket,
438-
self.tensor_parallel_publisher_socket,
475+
self.model_parallel_num_msgs_publisher_socket,
476+
self.model_parallel_publisher_socket,
439477
]
440-
# All TP ranks subscribe to the two publisher sockets
441-
self.tensor_parallel_subscriber_socket = self.zmq_context.socket(zmq.SUB)
442-
self.tensor_parallel_subscriber_socket.connect(f"ipc:///tmp/{identity}-tp-bcast-socket-req")
443-
self.tensor_parallel_subscriber_socket.setsockopt_string(zmq.SUBSCRIBE, "")
444-
445-
self.tensor_parallel_num_msgs_subscriber_socket = self.zmq_context.socket(zmq.SUB)
446-
self.tensor_parallel_num_msgs_subscriber_socket.connect(
447-
f"ipc:///tmp/{identity}-tp-bcast-socket-len"
448-
)
449-
self.tensor_parallel_num_msgs_subscriber_socket.setsockopt_string(zmq.SUBSCRIBE, "")
478+
# All MP ranks subscribe to the two publisher sockets
479+
self.model_parallel_subscriber_socket = self.zmq_context.socket(zmq.SUB)
480+
self.model_parallel_subscriber_socket.connect(mp_req_addr[0])
481+
self.model_parallel_subscriber_socket.setsockopt_string(zmq.SUBSCRIBE, "")
482+
483+
self.model_parallel_num_msgs_subscriber_socket = self.zmq_context.socket(zmq.SUB)
484+
self.model_parallel_num_msgs_subscriber_socket.connect(mp_len_addr[0])
485+
self.model_parallel_num_msgs_subscriber_socket.setsockopt_string(zmq.SUBSCRIBE, "")
450486

451487
self.zmq_sockets += [
452-
self.tensor_parallel_subscriber_socket,
453-
self.tensor_parallel_num_msgs_subscriber_socket,
488+
self.model_parallel_subscriber_socket,
489+
self.model_parallel_num_msgs_subscriber_socket,
454490
]
455491

456-
torch.distributed.barrier(parallel_state.get_tensor_model_parallel_group())
492+
torch.distributed.barrier(mp_group)
457493

458-
if launch_inference_coordinator and torch.distributed.get_rank() == 0:
494+
if launch_inference_coordinator and self.is_dp_coordinator:
459495
await await_process_event(coordinator_ready_event, self.inference_coordinator_process)
460496
logging.info("Inference co-ordinator is ready to receive requests!")
461497

@@ -664,7 +700,7 @@ def _add_request(
664700
try:
665701
eod = self.controller.tokenizer.eod
666702
except AttributeError:
667-
if torch.distributed.get_rank() == 0:
703+
if self.rank == 0:
668704
warnings.warn(
669705
"Termination ID not specified, and tokenizer does not define eod."
670706
"Defaulting to not using termination id."
@@ -1154,16 +1190,16 @@ def schedule_requests(self) -> int:
11541190
"""Drains the ZMQ socket for a batch of requests and adds them to the engine.
11551191
11561192
This method is a collective and synchronous operation that must be called
1157-
by all ranks in a Tensor Parallel (TP) group at the same time. It ensures
1193+
by all ranks in a Model Parallel (MP) group at the same time. It ensures
11581194
that all ranks process the exact same batch of incoming requests and
11591195
control signals.
11601196
11611197
The synchronization works as follows:
1162-
1. The TP rank 0 drains all pending messages from its subscriber socket
1198+
1. The MP rank 0 drains all pending messages from its subscriber socket
11631199
in a non-blocking manner.
1164-
2. TP rank 0 then broadcasts the number of messages it received to all other
1165-
ranks in its TP group using a dedicated publisher socket.
1166-
3. The other TP ranks wait to receive this count, and then receive exactly
1200+
2. MP rank 0 then broadcasts the number of messages it received to all other
1201+
ranks in its MP group using a dedicated publisher socket.
1202+
3. The other MP ranks wait to receive this count, and then receive exactly
11671203
that many messages from their subscriber sockets.
11681204
11691205
Once all ranks have the same batch of messages, they are unpacked and
@@ -1173,18 +1209,17 @@ def schedule_requests(self) -> int:
11731209
11741210
Note:
11751211
This function is synchronous and must be called collectively by all
1176-
ranks in a TP group. It should not be launched in a separate coroutine
1212+
ranks in a MP group. It should not be launched in a separate coroutine
11771213
to ensure all ranks execute it in lockstep before proceeding to the
11781214
next engine step.
11791215
11801216
Returns:
11811217
int: The number of messages that were received and processed in this batch.
11821218
"""
11831219

1184-
rank = parallel_state.get_tensor_model_parallel_rank()
11851220
torch.cuda.nvtx.range_push("drain_zmq_socket")
11861221
all_messages = []
1187-
if rank == 0:
1222+
if self.is_mp_coordinator:
11881223
while True:
11891224
try:
11901225
# Receive messages in a non-blocking way.
@@ -1196,22 +1231,22 @@ def schedule_requests(self) -> int:
11961231
# First publish the number of messages to dequeue.
11971232
# This is important because we want all tensor parallel ranks
11981233
# to dequeue the same number of messages.
1199-
self.tensor_parallel_num_msgs_publisher_socket.send(
1234+
self.model_parallel_num_msgs_publisher_socket.send(
12001235
struct.pack('!i', messages_to_dequeue)
12011236
)
1202-
# Now publish the actual messages to all tensor parallel ranks
1237+
# Now publish the actual messages to all model parallel ranks
12031238
for message in all_messages:
1204-
self.tensor_parallel_publisher_socket.send(message)
1239+
self.model_parallel_publisher_socket.send(message)
12051240
else:
1206-
# First, receive the number of messages to dequeue from tp-rank 0
1241+
# First, receive the number of messages to dequeue from mp-rank 0
12071242
messages_to_dequeue = struct.unpack(
1208-
'!i', self.tensor_parallel_num_msgs_subscriber_socket.recv()
1243+
'!i', self.model_parallel_num_msgs_subscriber_socket.recv()
12091244
)[0]
12101245
# Now, dequeue the same number of messages from the subscriber socket.
12111246
# Note that these receives are blocking, because the messages
12121247
# are guaranteed to be available after the tp-rank 0 has sent them.
12131248
for _ in range(messages_to_dequeue):
1214-
all_messages.append(self.tensor_parallel_subscriber_socket.recv())
1249+
all_messages.append(self.model_parallel_subscriber_socket.recv())
12151250

12161251
torch.cuda.nvtx.range_pop()
12171252
for message in all_messages:
@@ -1325,13 +1360,8 @@ async def run_engine_with_coordinator(
13251360
# Step.
13261361
engine_output = await self.async_step(verbose=verbose)
13271362

1328-
# Send finished requests.
1329-
is_tp0_and_pp0 = (
1330-
parallel_state.get_tensor_model_parallel_rank() == 0
1331-
and parallel_state.get_pipeline_model_parallel_rank() == 0
1332-
)
13331363
if (
1334-
is_tp0_and_pp0
1364+
self.is_mp_coordinator
13351365
and engine_output is not None
13361366
and engine_output["finished_request_records"]
13371367
):

megatron/training/arguments.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1489,7 +1489,8 @@ def _add_inference_args(parser):
14891489
group.add_argument('--inference-wandb-logging-step-interval', type=int, default=0,
14901490
help='Step interval for logging inference metrics to wandb. '
14911491
'Default to 0 to disable inference wandb logging.')
1492-
1492+
group.add_argument("--inference-coordinator-port", type=int, default=12346,
1493+
help="This port will be used to setup the inference coordinator on node-0")
14931494
return parser
14941495

14951496

0 commit comments

Comments
 (0)