44import logging
55import multiprocessing
66import os
7+ import socket
78import struct
89import time
910import warnings
@@ -144,13 +145,14 @@ def __init__(
144145 self .track_paused_request_events = track_paused_request_events
145146 self .enable_chunked_prefill = enable_chunked_prefill
146147 self .unified_memory_level = context .unified_memory_level
148+ self .rank = torch .distributed .get_rank ()
147149
148150 self .inference_logging_step_interval = inference_logging_step_interval
149151 # Configure wandb to use separate step counter for inference metrics (only once)
150152 if self .inference_logging_step_interval > 0 and self .context .metrics_writer is not None :
151153 logging .info (
152154 f"\033 [1;93m[INFERENCE]\033 [0m "
153- f"\033 [1;95mLogging inference metrics to wandb (rank { torch . distributed . get_rank () } )\033 [0m"
155+ f"\033 [1;95mLogging inference metrics to wandb (rank { self . rank } )\033 [0m"
154156 )
155157 if HAVE_WANDB and self .context .metrics_writer .__name__ == "wandb" :
156158 # Make all inference/* metrics use inference_step as their x-axis
@@ -250,7 +252,7 @@ def create_cuda_graphs(self, reset_context: bool = True):
250252
251253 if moe_pad_experts and context .non_decode_cuda_graphs :
252254 context .non_decode_cuda_graphs = False
253- if torch . distributed . get_rank () == 0 :
255+ if self . rank == 0 :
254256 warnings .warn (
255257 "MoE models do not support non-decode cuda graphs. "
256258 "Forcing non_decode_cuda_graphs to False."
@@ -350,16 +352,18 @@ async def start_listening_to_data_parallel_coordinator(
350352 `InferenceCoordinator`. It configures different ZMQ socket patterns
351353 based on the rank's role within the distributed topology.
352354
355+ Note that this method must be called on all ranks, as it uses blocking torch broadcasts.
356+
353357 The setup involves two primary roles within each data-parallel group:
354- 1. **TP Coordinator (TP_rank=0, PP_rank=0)**: This rank connects directly
358+ 1. **MP Coordinator (TP_rank=0, PP_rank=0)**: This rank connects directly
355359 to the central coordinator via a ZMQ `DEALER` socket. It receives
356360 requests and uses a ZMQ `PUB` (publisher) socket to broadcast them
357- to all other ranks within its tensor -parallel (TP ) group.
358- 2. **TP Workers (all other ranks)**: These ranks use ZMQ `SUB` (subscriber)
359- sockets to listen for requests broadcast by their local TP Coordinator.
361+ to all other ranks within its model -parallel (MP ) group.
362+ 2. **MP Workers (all other ranks)**: These ranks use ZMQ `SUB` (subscriber)
363+ sockets to listen for requests broadcast by their local MP Coordinator.
360364
361- This architecture uses fast Inter-Process Communication (`ipc`) sockets for
362- intra-node broadcasts within a TP group.
365+ This architecture uses TCP sockets for both inter-node and intra-node broadcasts
366+ within an MP group.
363367
364368 Finally, after setting up the communication channels and ensuring all ranks
365369 are synchronized, this method starts the main engine processing loop
@@ -371,13 +375,16 @@ async def start_listening_to_data_parallel_coordinator(
371375 launch_inference_coordinator (bool, optional): If True, the global rank 0
372376 process will spawn and manage the `InferenceCoordinator`
373377 process. Defaults to True.
378+ <<<<<<< HEAD
374379 verbose (bool): Whether to run in verbose mode.
375380
376381 Note:
377382 The current implementation uses `ipc` sockets for broadcasting requests
378383 within a Tensor Parallel group, which limits each TP group to a single
379384 physical node. For example, if you have 8 GPUs per node, then this will only
380385 work with TP=[1,2,4,8]
386+ =======
387+ >>>>>>> a28d34db94 (Clean up DP coord unit-test and code reuse)
381388 """
382389
383390 assert HAVE_ZMQ , (
@@ -388,7 +395,32 @@ async def start_listening_to_data_parallel_coordinator(
388395 "pip install msgpack"
389396 )
390397
391- if launch_inference_coordinator and torch .distributed .get_rank () == 0 :
398+ self .zmq_context = zmq .Context ().instance ()
399+ self .zmq_sockets = [] # keep track of all sockets created by this engine
400+
401+ # Get world info.
402+ dp_group = parallel_state .get_data_parallel_group ()
403+ dp_src = parallel_state .get_data_parallel_src_rank ()
404+ dp_size = parallel_state .get_data_parallel_world_size ()
405+ dp_rank = parallel_state .get_data_parallel_rank ()
406+
407+ mp_group = parallel_state .get_model_parallel_group ()
408+ mp_src = parallel_state .get_model_parallel_src_rank ()
409+ tp_rank = parallel_state .get_tensor_model_parallel_rank ()
410+ pp_rank = parallel_state .get_pipeline_model_parallel_rank ()
411+
412+ self .is_mp_coordinator = tp_rank == 0 and pp_rank == 0
413+ self .is_dp_coordinator = (dp_rank == 0 ) and self .is_mp_coordinator
414+
415+ # Get local IP.
416+ with socket .socket (socket .AF_INET , socket .SOCK_DGRAM ) as tmp_sock :
417+ tmp_sock .setsockopt (socket .SOL_SOCKET , socket .SO_BROADCAST , 1 )
418+ tmp_sock .connect (('<broadcast>' , 0 ))
419+ local_ip = tmp_sock .getsockname ()[0 ]
420+ del tmp_sock
421+
422+ # Spawn a DP coordinator process and get the connection info.
423+ if launch_inference_coordinator and self .is_dp_coordinator :
392424 spawn_context = multiprocessing .get_context ('spawn' )
393425 coordinator_ready_event = spawn_context .Event ()
394426 self .inference_coordinator_process = spawn_context .Process (
@@ -401,61 +433,65 @@ async def start_listening_to_data_parallel_coordinator(
401433 )
402434 self .inference_coordinator_process .start ()
403435
404- # Todo [Siddharth]: can we move this code to another file?
405- self .zmq_context = zmq .Context ()
406- self .zmq_sockets = [] # keep track of all sockets created by this engine
436+ # Find available ports for MP and bind to them.
437+ if self .is_mp_coordinator :
438+ mp_req_sock = self .zmq_context .socket (zmq .PUB )
439+ mp_req_sock .bind_to_random_port (f"tcp://{ local_ip } " )
440+ mp_req_addr = [mp_req_sock .getsockopt_string (zmq .LAST_ENDPOINT )]
441+
442+ mp_len_sock = self .zmq_context .socket (zmq .PUB )
443+ mp_len_sock .bind_to_random_port (f"tcp://{ local_ip } " )
444+ mp_len_addr = [mp_len_sock .getsockopt_string (zmq .LAST_ENDPOINT )]
445+ else :
446+ mp_req_addr = [None ]
447+ mp_len_addr = [None ]
448+
449+ # Broadcast addresses to respective ranks.
450+ torch .distributed .broadcast_object_list (mp_req_addr , src = mp_src , group = mp_group )
451+ torch .distributed .broadcast_object_list (mp_len_addr , src = mp_src , group = mp_group )
452+
407453 ip_address_of_dp_coordinator = os .getenv ('MASTER_ADDR' , '127.0.0.1' )
408- identity = f'tp-coord-{ parallel_state .get_data_parallel_rank ()} '
409- if (
410- parallel_state .get_tensor_model_parallel_rank () == 0
411- and parallel_state .get_pipeline_model_parallel_rank () == 0
412- ):
454+ dp_addr = [f"tcp://{ ip_address_of_dp_coordinator } :{ inference_coordinator_port } " ]
455+ identity = f'mp-coord-{ dp_rank } '
456+ if self .is_mp_coordinator :
413457 # 1. Create dealer sockets where tp_rank = 0 and pp_rank = 0
414458 # These will receive requests from an InferenceCoordinator.
415459 self .socket_for_receiving_requests = self .zmq_context .socket (zmq .DEALER )
416460
417461 self .socket_for_receiving_requests .setsockopt (zmq .IDENTITY , identity .encode ('utf-8' ))
418- self .socket_for_receiving_requests .connect (
419- f"tcp://{ ip_address_of_dp_coordinator } :{ inference_coordinator_port } "
420- )
462+ self .socket_for_receiving_requests .connect (dp_addr [0 ])
421463
422464 # send empty string. this is used to register with the coordinator.
423465 self .socket_for_receiving_requests .send (b"" )
424466
425467 # 2. Create a publisher socket. This is used to publish or broadcast
426- # requests within the tensor parallel group
427- self .tensor_parallel_publisher_socket = self .zmq_context .socket (zmq .PUB )
428- self .tensor_parallel_publisher_socket .bind (f"ipc:///tmp/{ identity } -tp-bcast-socket-req" )
468+ # requests within the model parallel group
469+ self .model_parallel_publisher_socket = mp_req_sock
429470
430471 # 3. Create another publisher socket to broadcast the number of messages to receive.
431- self .tensor_parallel_num_msgs_publisher_socket = self .zmq_context .socket (zmq .PUB )
432- self .tensor_parallel_num_msgs_publisher_socket .bind (
433- f"ipc:///tmp/{ identity } -tp-bcast-socket-len"
434- )
472+ self .model_parallel_num_msgs_publisher_socket = mp_len_sock
435473 self .zmq_sockets += [
436474 self .socket_for_receiving_requests ,
437- self .tensor_parallel_num_msgs_publisher_socket ,
438- self .tensor_parallel_publisher_socket ,
475+ self .model_parallel_num_msgs_publisher_socket ,
476+ self .model_parallel_publisher_socket ,
439477 ]
440- # All TP ranks subscribe to the two publisher sockets
441- self .tensor_parallel_subscriber_socket = self .zmq_context .socket (zmq .SUB )
442- self .tensor_parallel_subscriber_socket .connect (f"ipc:///tmp/{ identity } -tp-bcast-socket-req" )
443- self .tensor_parallel_subscriber_socket .setsockopt_string (zmq .SUBSCRIBE , "" )
444-
445- self .tensor_parallel_num_msgs_subscriber_socket = self .zmq_context .socket (zmq .SUB )
446- self .tensor_parallel_num_msgs_subscriber_socket .connect (
447- f"ipc:///tmp/{ identity } -tp-bcast-socket-len"
448- )
449- self .tensor_parallel_num_msgs_subscriber_socket .setsockopt_string (zmq .SUBSCRIBE , "" )
478+ # All MP ranks subscribe to the two publisher sockets
479+ self .model_parallel_subscriber_socket = self .zmq_context .socket (zmq .SUB )
480+ self .model_parallel_subscriber_socket .connect (mp_req_addr [0 ])
481+ self .model_parallel_subscriber_socket .setsockopt_string (zmq .SUBSCRIBE , "" )
482+
483+ self .model_parallel_num_msgs_subscriber_socket = self .zmq_context .socket (zmq .SUB )
484+ self .model_parallel_num_msgs_subscriber_socket .connect (mp_len_addr [0 ])
485+ self .model_parallel_num_msgs_subscriber_socket .setsockopt_string (zmq .SUBSCRIBE , "" )
450486
451487 self .zmq_sockets += [
452- self .tensor_parallel_subscriber_socket ,
453- self .tensor_parallel_num_msgs_subscriber_socket ,
488+ self .model_parallel_subscriber_socket ,
489+ self .model_parallel_num_msgs_subscriber_socket ,
454490 ]
455491
456- torch .distributed .barrier (parallel_state . get_tensor_model_parallel_group () )
492+ torch .distributed .barrier (mp_group )
457493
458- if launch_inference_coordinator and torch . distributed . get_rank () == 0 :
494+ if launch_inference_coordinator and self . is_dp_coordinator :
459495 await await_process_event (coordinator_ready_event , self .inference_coordinator_process )
460496 logging .info ("Inference co-ordinator is ready to receive requests!" )
461497
@@ -664,7 +700,7 @@ def _add_request(
664700 try :
665701 eod = self .controller .tokenizer .eod
666702 except AttributeError :
667- if torch . distributed . get_rank () == 0 :
703+ if self . rank == 0 :
668704 warnings .warn (
669705 "Termination ID not specified, and tokenizer does not define eod."
670706 "Defaulting to not using termination id."
@@ -1154,16 +1190,16 @@ def schedule_requests(self) -> int:
11541190 """Drains the ZMQ socket for a batch of requests and adds them to the engine.
11551191
11561192 This method is a collective and synchronous operation that must be called
1157- by all ranks in a Tensor Parallel (TP ) group at the same time. It ensures
1193+ by all ranks in a Model Parallel (MP ) group at the same time. It ensures
11581194 that all ranks process the exact same batch of incoming requests and
11591195 control signals.
11601196
11611197 The synchronization works as follows:
1162- 1. The TP rank 0 drains all pending messages from its subscriber socket
1198+ 1. The MP rank 0 drains all pending messages from its subscriber socket
11631199 in a non-blocking manner.
1164- 2. TP rank 0 then broadcasts the number of messages it received to all other
1165- ranks in its TP group using a dedicated publisher socket.
1166- 3. The other TP ranks wait to receive this count, and then receive exactly
1200+ 2. MP rank 0 then broadcasts the number of messages it received to all other
1201+ ranks in its MP group using a dedicated publisher socket.
1202+ 3. The other MP ranks wait to receive this count, and then receive exactly
11671203 that many messages from their subscriber sockets.
11681204
11691205 Once all ranks have the same batch of messages, they are unpacked and
@@ -1173,18 +1209,17 @@ def schedule_requests(self) -> int:
11731209
11741210 Note:
11751211 This function is synchronous and must be called collectively by all
1176- ranks in a TP group. It should not be launched in a separate coroutine
1212+ ranks in a MP group. It should not be launched in a separate coroutine
11771213 to ensure all ranks execute it in lockstep before proceeding to the
11781214 next engine step.
11791215
11801216 Returns:
11811217 int: The number of messages that were received and processed in this batch.
11821218 """
11831219
1184- rank = parallel_state .get_tensor_model_parallel_rank ()
11851220 torch .cuda .nvtx .range_push ("drain_zmq_socket" )
11861221 all_messages = []
1187- if rank == 0 :
1222+ if self . is_mp_coordinator :
11881223 while True :
11891224 try :
11901225 # Receive messages in a non-blocking way.
@@ -1196,22 +1231,22 @@ def schedule_requests(self) -> int:
11961231 # First publish the number of messages to dequeue.
11971232 # This is important because we want all tensor parallel ranks
11981233 # to dequeue the same number of messages.
1199- self .tensor_parallel_num_msgs_publisher_socket .send (
1234+ self .model_parallel_num_msgs_publisher_socket .send (
12001235 struct .pack ('!i' , messages_to_dequeue )
12011236 )
1202- # Now publish the actual messages to all tensor parallel ranks
1237+ # Now publish the actual messages to all model parallel ranks
12031238 for message in all_messages :
1204- self .tensor_parallel_publisher_socket .send (message )
1239+ self .model_parallel_publisher_socket .send (message )
12051240 else :
1206- # First, receive the number of messages to dequeue from tp -rank 0
1241+ # First, receive the number of messages to dequeue from mp -rank 0
12071242 messages_to_dequeue = struct .unpack (
1208- '!i' , self .tensor_parallel_num_msgs_subscriber_socket .recv ()
1243+ '!i' , self .model_parallel_num_msgs_subscriber_socket .recv ()
12091244 )[0 ]
12101245 # Now, dequeue the same number of messages from the subscriber socket.
12111246 # Note that these receives are blocking, because the messages
12121247 # are guaranteed to be available after the tp-rank 0 has sent them.
12131248 for _ in range (messages_to_dequeue ):
1214- all_messages .append (self .tensor_parallel_subscriber_socket .recv ())
1249+ all_messages .append (self .model_parallel_subscriber_socket .recv ())
12151250
12161251 torch .cuda .nvtx .range_pop ()
12171252 for message in all_messages :
@@ -1325,13 +1360,8 @@ async def run_engine_with_coordinator(
13251360 # Step.
13261361 engine_output = await self .async_step (verbose = verbose )
13271362
1328- # Send finished requests.
1329- is_tp0_and_pp0 = (
1330- parallel_state .get_tensor_model_parallel_rank () == 0
1331- and parallel_state .get_pipeline_model_parallel_rank () == 0
1332- )
13331363 if (
1334- is_tp0_and_pp0
1364+ self . is_mp_coordinator
13351365 and engine_output is not None
13361366 and engine_output ["finished_request_records" ]
13371367 ):
0 commit comments