44import logging
55import multiprocessing
66import os
7+ import socket
78import struct
89import time
910import warnings
@@ -142,13 +143,14 @@ def __init__(
142143 self .paused = False
143144 self .stopped = False
144145 self .enable_chunked_prefill = enable_chunked_prefill
146+ self .rank = torch .distributed .get_rank ()
145147
146148 self .inference_logging_step_interval = inference_logging_step_interval
147149 # Configure wandb to use separate step counter for inference metrics (only once)
148150 if self .inference_logging_step_interval > 0 and self .context .metrics_writer is not None :
149151 logging .info (
150152 f"\033 [1;93m[INFERENCE]\033 [0m "
151- f"\033 [1;95mLogging inference metrics to wandb (rank { torch . distributed . get_rank () } )\033 [0m"
153+ f"\033 [1;95mLogging inference metrics to wandb (rank { self . rank } )\033 [0m"
152154 )
153155 if HAVE_WANDB and self .context .metrics_writer .__name__ == "wandb" :
154156 # Make all inference/* metrics use inference_step as their x-axis
@@ -202,7 +204,7 @@ def create_cuda_graphs(self, reset_context: bool = True):
202204
203205 if moe_pad_experts and context .non_decode_cuda_graphs :
204206 context .non_decode_cuda_graphs = False
205- if torch . distributed . get_rank () == 0 :
207+ if self . rank == 0 :
206208 warnings .warn (
207209 "MoE models do not support non-decode cuda graphs. "
208210 "Forcing non_decode_cuda_graphs to False."
@@ -301,16 +303,18 @@ async def start_listening_to_data_parallel_coordinator(
301303 `InferenceCoordinator`. It configures different ZMQ socket patterns
302304 based on the rank's role within the distributed topology.
303305
306+ Note that this method must be called on all ranks, as it uses blocking torch broadcasts.
307+
304308 The setup involves two primary roles within each data-parallel group:
305- 1. **TP Coordinator (TP_rank=0, PP_rank=0)**: This rank connects directly
309+ 1. **MP Coordinator (TP_rank=0, PP_rank=0)**: This rank connects directly
306310 to the central coordinator via a ZMQ `DEALER` socket. It receives
307311 requests and uses a ZMQ `PUB` (publisher) socket to broadcast them
308- to all other ranks within its tensor -parallel (TP ) group.
309- 2. **TP Workers (all other ranks)**: These ranks use ZMQ `SUB` (subscriber)
310- sockets to listen for requests broadcast by their local TP Coordinator.
312+ to all other ranks within its model -parallel (MP ) group.
313+ 2. **MP Workers (all other ranks)**: These ranks use ZMQ `SUB` (subscriber)
314+ sockets to listen for requests broadcast by their local MP Coordinator.
311315
312- This architecture uses fast Inter-Process Communication (`ipc`) sockets for
313- intra-node broadcasts within a TP group.
316+ This architecture uses TCP sockets for both inter-node and intra-node broadcasts
317+ within an MP group.
314318
315319 Finally, after setting up the communication channels and ensuring all ranks
316320 are synchronized, this method starts the main engine processing loop
@@ -322,12 +326,6 @@ async def start_listening_to_data_parallel_coordinator(
322326 launch_inference_coordinator (bool, optional): If True, the global rank 0
323327 process will spawn and manage the `InferenceCoordinator`
324328 process. Defaults to True.
325-
326- Note:
327- The current implementation uses `ipc` sockets for broadcasting requests
328- within a Tensor Parallel group, which limits each TP group to a single
329- physical node. For example, if you have 8 GPUs per node, then this will only
330- work with TP=[1,2,4,8]
331329 """
332330
333331 assert HAVE_ZMQ , (
@@ -338,7 +336,32 @@ async def start_listening_to_data_parallel_coordinator(
338336 "pip install msgpack"
339337 )
340338
341- if launch_inference_coordinator and torch .distributed .get_rank () == 0 :
339+ self .zmq_context = zmq .Context ().instance ()
340+ self .zmq_sockets = [] # keep track of all sockets created by this engine
341+
342+ # Get world info.
343+ dp_group = parallel_state .get_data_parallel_group ()
344+ dp_src = parallel_state .get_data_parallel_src_rank ()
345+ dp_size = parallel_state .get_data_parallel_world_size ()
346+ dp_rank = parallel_state .get_data_parallel_rank ()
347+
348+ mp_group = parallel_state .get_model_parallel_group ()
349+ mp_src = parallel_state .get_model_parallel_src_rank ()
350+ tp_rank = parallel_state .get_tensor_model_parallel_rank ()
351+ pp_rank = parallel_state .get_pipeline_model_parallel_rank ()
352+
353+ self .is_mp_coordinator = tp_rank == 0 and pp_rank == 0
354+ self .is_dp_coordinator = (dp_rank == 0 ) and self .is_mp_coordinator
355+
356+ # Get local IP.
357+ with socket .socket (socket .AF_INET , socket .SOCK_DGRAM ) as tmp_sock :
358+ tmp_sock .setsockopt (socket .SOL_SOCKET , socket .SO_BROADCAST , 1 )
359+ tmp_sock .connect (('<broadcast>' , 0 ))
360+ local_ip = tmp_sock .getsockname ()[0 ]
361+ del tmp_sock
362+
363+ # Spawn a DP coordinator process and get the connection info.
364+ if launch_inference_coordinator and self .is_dp_coordinator :
342365 spawn_context = multiprocessing .get_context ('spawn' )
343366 coordinator_ready_event = spawn_context .Event ()
344367 self .inference_coordinator_process = spawn_context .Process (
@@ -351,61 +374,65 @@ async def start_listening_to_data_parallel_coordinator(
351374 )
352375 self .inference_coordinator_process .start ()
353376
354- # Todo [Siddharth]: can we move this code to another file?
355- self .zmq_context = zmq .Context ()
356- self .zmq_sockets = [] # keep track of all sockets created by this engine
377+ # Find available ports for MP and bind to them.
378+ if self .is_mp_coordinator :
379+ mp_req_sock = self .zmq_context .socket (zmq .PUB )
380+ mp_req_sock .bind_to_random_port (f"tcp://{ local_ip } " )
381+ mp_req_addr = [mp_req_sock .getsockopt_string (zmq .LAST_ENDPOINT )]
382+
383+ mp_len_sock = self .zmq_context .socket (zmq .PUB )
384+ mp_len_sock .bind_to_random_port (f"tcp://{ local_ip } " )
385+ mp_len_addr = [mp_len_sock .getsockopt_string (zmq .LAST_ENDPOINT )]
386+ else :
387+ mp_req_addr = [None ]
388+ mp_len_addr = [None ]
389+
390+ # Broadcast addresses to respective ranks.
391+ torch .distributed .broadcast_object_list (mp_req_addr , src = mp_src , group = mp_group )
392+ torch .distributed .broadcast_object_list (mp_len_addr , src = mp_src , group = mp_group )
393+
357394 ip_address_of_dp_coordinator = os .getenv ('MASTER_ADDR' , '127.0.0.1' )
358- identity = f'tp-coord-{ parallel_state .get_data_parallel_rank ()} '
359- if (
360- parallel_state .get_tensor_model_parallel_rank () == 0
361- and parallel_state .get_pipeline_model_parallel_rank () == 0
362- ):
395+ dp_addr = [f"tcp://{ ip_address_of_dp_coordinator } :{ inference_coordinator_port } " ]
396+ identity = f'mp-coord-{ dp_rank } '
397+ if self .is_mp_coordinator :
363398 # 1. Create dealer sockets where tp_rank = 0 and pp_rank = 0
364399 # These will receive requests from an InferenceCoordinator.
365400 self .socket_for_receiving_requests = self .zmq_context .socket (zmq .DEALER )
366401
367402 self .socket_for_receiving_requests .setsockopt (zmq .IDENTITY , identity .encode ('utf-8' ))
368- self .socket_for_receiving_requests .connect (
369- f"tcp://{ ip_address_of_dp_coordinator } :{ inference_coordinator_port } "
370- )
403+ self .socket_for_receiving_requests .connect (dp_addr [0 ])
371404
372405 # send empty string. this is used to register with the coordinator.
373406 self .socket_for_receiving_requests .send (b"" )
374407
375408 # 2. Create a publisher socket. This is used to publish or broadcast
376- # requests within the tensor parallel group
377- self .tensor_parallel_publisher_socket = self .zmq_context .socket (zmq .PUB )
378- self .tensor_parallel_publisher_socket .bind (f"ipc:///tmp/{ identity } -tp-bcast-socket-req" )
409+ # requests within the model parallel group
410+ self .model_parallel_publisher_socket = mp_req_sock
379411
380412 # 3. Create another publisher socket to broadcast the number of messages to receive.
381- self .tensor_parallel_num_msgs_publisher_socket = self .zmq_context .socket (zmq .PUB )
382- self .tensor_parallel_num_msgs_publisher_socket .bind (
383- f"ipc:///tmp/{ identity } -tp-bcast-socket-len"
384- )
413+ self .model_parallel_num_msgs_publisher_socket = mp_len_sock
385414 self .zmq_sockets += [
386415 self .socket_for_receiving_requests ,
387- self .tensor_parallel_num_msgs_publisher_socket ,
388- self .tensor_parallel_publisher_socket ,
416+ self .model_parallel_num_msgs_publisher_socket ,
417+ self .model_parallel_publisher_socket ,
389418 ]
390- # All TP ranks subscribe to the two publisher sockets
391- self .tensor_parallel_subscriber_socket = self .zmq_context .socket (zmq .SUB )
392- self .tensor_parallel_subscriber_socket .connect (f"ipc:///tmp/{ identity } -tp-bcast-socket-req" )
393- self .tensor_parallel_subscriber_socket .setsockopt_string (zmq .SUBSCRIBE , "" )
394-
395- self .tensor_parallel_num_msgs_subscriber_socket = self .zmq_context .socket (zmq .SUB )
396- self .tensor_parallel_num_msgs_subscriber_socket .connect (
397- f"ipc:///tmp/{ identity } -tp-bcast-socket-len"
398- )
399- self .tensor_parallel_num_msgs_subscriber_socket .setsockopt_string (zmq .SUBSCRIBE , "" )
419+ # All MP ranks subscribe to the two publisher sockets
420+ self .model_parallel_subscriber_socket = self .zmq_context .socket (zmq .SUB )
421+ self .model_parallel_subscriber_socket .connect (mp_req_addr [0 ])
422+ self .model_parallel_subscriber_socket .setsockopt_string (zmq .SUBSCRIBE , "" )
423+
424+ self .model_parallel_num_msgs_subscriber_socket = self .zmq_context .socket (zmq .SUB )
425+ self .model_parallel_num_msgs_subscriber_socket .connect (mp_len_addr [0 ])
426+ self .model_parallel_num_msgs_subscriber_socket .setsockopt_string (zmq .SUBSCRIBE , "" )
400427
401428 self .zmq_sockets += [
402- self .tensor_parallel_subscriber_socket ,
403- self .tensor_parallel_num_msgs_subscriber_socket ,
429+ self .model_parallel_subscriber_socket ,
430+ self .model_parallel_num_msgs_subscriber_socket ,
404431 ]
405432
406- torch .distributed .barrier (parallel_state . get_tensor_model_parallel_group () )
433+ torch .distributed .barrier (mp_group )
407434
408- if launch_inference_coordinator and torch . distributed . get_rank () == 0 :
435+ if launch_inference_coordinator and self . is_dp_coordinator :
409436 await await_process_event (coordinator_ready_event , self .inference_coordinator_process )
410437 logging .info ("Inference co-ordinator is ready to receive requests!" )
411438
@@ -455,7 +482,7 @@ def _add_request(
455482 try :
456483 eod = self .controller .tokenizer .eod
457484 except AttributeError :
458- if torch . distributed . get_rank () == 0 :
485+ if self . rank == 0 :
459486 warnings .warn (
460487 "Termination ID not specified, and tokenizer does not define eod."
461488 "Defaulting to not using termination id."
@@ -932,16 +959,16 @@ def schedule_requests(self) -> int:
932959 """Drains the ZMQ socket for a batch of requests and adds them to the engine.
933960
934961 This method is a collective and synchronous operation that must be called
935- by all ranks in a Tensor Parallel (TP ) group at the same time. It ensures
962+ by all ranks in a Model Parallel (MP ) group at the same time. It ensures
936963 that all ranks process the exact same batch of incoming requests and
937964 control signals.
938965
939966 The synchronization works as follows:
940- 1. The TP rank 0 drains all pending messages from its subscriber socket
967+ 1. The MP rank 0 drains all pending messages from its subscriber socket
941968 in a non-blocking manner.
942- 2. TP rank 0 then broadcasts the number of messages it received to all other
943- ranks in its TP group using a dedicated publisher socket.
944- 3. The other TP ranks wait to receive this count, and then receive exactly
969+ 2. MP rank 0 then broadcasts the number of messages it received to all other
970+ ranks in its MP group using a dedicated publisher socket.
971+ 3. The other MP ranks wait to receive this count, and then receive exactly
945972 that many messages from their subscriber sockets.
946973
947974 Once all ranks have the same batch of messages, they are unpacked and
@@ -950,18 +977,17 @@ def schedule_requests(self) -> int:
950977
951978 Note:
952979 This function is synchronous and must be called collectively by all
953- ranks in a TP group. It should not be launched in a separate coroutine
980+ ranks in a MP group. It should not be launched in a separate coroutine
954981 to ensure all ranks execute it in lockstep before proceeding to the
955982 next engine step.
956983
957984 Returns:
958985 int: The number of messages that were received and processed in this batch.
959986 """
960987
961- rank = parallel_state .get_tensor_model_parallel_rank ()
962988 torch .cuda .nvtx .range_push ("drain_zmq_socket" )
963989 all_messages = []
964- if rank == 0 :
990+ if self . is_mp_coordinator :
965991 while True :
966992 try :
967993 # Receive messages in a non-blocking way.
@@ -973,22 +999,22 @@ def schedule_requests(self) -> int:
973999 # First publish the number of messages to dequeue.
9741000 # This is important because we want all tensor parallel ranks
9751001 # to dequeue the same number of messages.
976- self .tensor_parallel_num_msgs_publisher_socket .send (
1002+ self .model_parallel_num_msgs_publisher_socket .send (
9771003 struct .pack ('!i' , messages_to_dequeue )
9781004 )
979- # Now publish the actual messages to all tensor parallel ranks
1005+ # Now publish the actual messages to all model parallel ranks
9801006 for message in all_messages :
981- self .tensor_parallel_publisher_socket .send (message )
1007+ self .model_parallel_publisher_socket .send (message )
9821008 else :
983- # First, receive the number of messages to dequeue from tp -rank 0
1009+ # First, receive the number of messages to dequeue from mp -rank 0
9841010 messages_to_dequeue = struct .unpack (
985- '!i' , self .tensor_parallel_num_msgs_subscriber_socket .recv ()
1011+ '!i' , self .model_parallel_num_msgs_subscriber_socket .recv ()
9861012 )[0 ]
9871013 # Now, dequeue the same number of messages from the subscriber socket.
9881014 # Note that these receives are blocking, because the messages
9891015 # are guaranteed to be available after the tp-rank 0 has sent them.
9901016 for _ in range (messages_to_dequeue ):
991- all_messages .append (self .tensor_parallel_subscriber_socket .recv ())
1017+ all_messages .append (self .model_parallel_subscriber_socket .recv ())
9921018
9931019 torch .cuda .nvtx .range_pop ()
9941020 for message in all_messages :
@@ -1080,12 +1106,8 @@ async def run_engine_with_coordinator(
10801106
10811107 engine_output = await self .async_step (verbose = verbose )
10821108
1083- is_tp0_and_pp0 = (
1084- parallel_state .get_tensor_model_parallel_rank () == 0
1085- and parallel_state .get_pipeline_model_parallel_rank () == 0
1086- )
10871109 if (
1088- is_tp0_and_pp0
1110+ self . is_mp_coordinator
10891111 and engine_output is not None
10901112 and engine_output ["finished_requests" ]
10911113 ):
0 commit comments