4242)
4343from megatron .core .inference .utils import Counter , await_process_event
4444from megatron .core .transformer .cuda_graphs import delete_cuda_graphs
45- from megatron .core .utils import get_asyncio_loop , trace_async_exceptions
45+ from megatron .core .utils import get_asyncio_loop , internal_api , trace_async_exceptions
4646
4747try :
4848 from tqdm import tqdm
@@ -237,10 +237,6 @@ def reset(self) -> None:
237237
238238 # Coordinator state.
239239 self .use_coordinator = False
240- self .is_tp0_and_pp0 = (
241- parallel_state .get_tensor_model_parallel_rank () == 0
242- and parallel_state .get_pipeline_model_parallel_rank () == 0
243- )
244240
245241 def create_cuda_graphs (self , reset_context : bool = True ):
246242 """Create cuda graphs.
@@ -263,7 +259,7 @@ def create_cuda_graphs(self, reset_context: bool = True):
263259
264260 if moe_pad_experts and context .non_decode_cuda_graphs :
265261 context .non_decode_cuda_graphs = False
266- if torch . distributed . get_rank () == 0 :
262+ if self . rank == 0 :
267263 warnings .warn (
268264 "MoE models do not support non-decode cuda graphs. "
269265 "Forcing non_decode_cuda_graphs to False."
@@ -348,10 +344,10 @@ def create_cuda_graphs(self, reset_context: bool = True):
348344
349345 self .capture_stats = capture_stats
350346
347+ @internal_api
351348 async def start_listening_to_data_parallel_coordinator (
352349 self ,
353350 inference_coordinator_port : int ,
354- inference_mp_coordinator_port : int = 20000 ,
355351 launch_inference_coordinator : bool = True ,
356352 verbose : bool = False ,
357353 * ,
@@ -364,6 +360,8 @@ async def start_listening_to_data_parallel_coordinator(
364360 `InferenceCoordinator`. It configures different ZMQ socket patterns
365361 based on the rank's role within the distributed topology.
366362
363+ Note that this method must be called on all ranks, as it uses blocking torch broadcasts.
364+
367365 The setup involves two primary roles within each data-parallel group:
368366 1. **MP Coordinator (TP_rank=0, PP_rank=0)**: This rank connects directly
369367 to the central coordinator via a ZMQ `DEALER` socket. It receives
@@ -382,9 +380,6 @@ async def start_listening_to_data_parallel_coordinator(
382380 Args:
383381 inference_coordinator_port (int): The network port where the central
384382 `InferenceCoordinator` is or will be listening.
385- inference_mp_coordinator_port (int): The base network port where each model parallel
386- coordinator will broadcast messages from. Each MP group will compute an independent
387- port offset from this base port.
388383 launch_inference_coordinator (bool, optional): If True, the global rank 0
389384 process will spawn and manage the `InferenceCoordinator`
390385 process. Defaults to True.
@@ -399,7 +394,25 @@ async def start_listening_to_data_parallel_coordinator(
399394 "pip install msgpack"
400395 )
401396
402- if launch_inference_coordinator and torch .distributed .get_rank () == 0 :
397+ self .zmq_context = zmq .Context ().instance ()
398+ self .zmq_sockets = [] # keep track of all sockets created by this engine
399+
400+ # Get world info.
401+ dp_group = parallel_state .get_data_parallel_group ()
402+ dp_src = parallel_state .get_data_parallel_src_rank ()
403+ dp_size = parallel_state .get_data_parallel_world_size ()
404+ dp_rank = parallel_state .get_data_parallel_rank ()
405+
406+ mp_group = parallel_state .get_model_parallel_group ()
407+ mp_src = parallel_state .get_model_parallel_src_rank ()
408+ tp_rank = parallel_state .get_tensor_model_parallel_rank ()
409+ pp_rank = parallel_state .get_pipeline_model_parallel_rank ()
410+
411+ self .is_mp_coordinator = tp_rank == 0 and pp_rank == 0
412+ self .is_dp_coordinator = (dp_rank == 0 ) and self .is_mp_coordinator
413+
414+ # Spawn a DP coordinator process and get the connection info.
415+ if launch_inference_coordinator and self .is_dp_coordinator :
403416 spawn_context = multiprocessing .get_context ('spawn' )
404417 coordinator_ready_event = spawn_context .Event ()
405418 self .inference_coordinator_process = spawn_context .Process (
@@ -412,80 +425,67 @@ async def start_listening_to_data_parallel_coordinator(
412425 )
413426 self .inference_coordinator_process .start ()
414427
415- # Todo [Siddharth]: can we move this code to another file?
416- self .zmq_context = zmq .Context ()
417- self .zmq_sockets = [] # keep track of all sockets created by this engine
418-
419- # We need to broadcast the hostname of the (TP=0, PP=0) rank
420- # to all other ranks in the same model parallel group.
421- tp_rank = parallel_state .get_tensor_model_parallel_rank ()
422- pp_rank = parallel_state .get_pipeline_model_parallel_rank ()
423-
424- hostname_list = [None ]
425- if tp_rank == 0 and pp_rank == 0 :
426- hostname_list [0 ] = socket .gethostname ()
428+ # Find available ports for MP and bind to them.
429+ if self .is_mp_coordinator :
430+ local_ip = socket .gethostname ()
431+ mp_req_sock = self .zmq_context .socket (zmq .PUB )
432+ mp_req_sock .bind_to_random_port (f"tcp://{ local_ip } " )
433+ mp_req_addr = mp_req_sock .getsockopt_string (zmq .LAST_ENDPOINT )
427434
428- # Find the global rank of the (TP=0, PP=0) rank in our MP group
429- src_global_rank = parallel_state .get_model_parallel_src_rank ()
430-
431- torch .distributed .broadcast_object_list (
432- hostname_list , src = src_global_rank , group = parallel_state .get_model_parallel_group ()
433- )
434- bcast_hostname = hostname_list [0 ]
435+ mp_len_sock = self .zmq_context .socket (zmq .PUB )
436+ mp_len_sock .bind_to_random_port (f"tcp://{ local_ip } " )
437+ mp_len_addr = mp_len_sock .getsockopt_string (zmq .LAST_ENDPOINT )
438+ else :
439+ mp_req_addr = None
440+ mp_len_addr = None
435441
436- # We need unique ports for each MP group, so we compute an offset using the DP rank .
437- dp_rank = parallel_state . get_data_parallel_rank ()
438- req_port = inference_mp_coordinator_port + ( dp_rank * 2 )
439- len_port = inference_mp_coordinator_port + ( dp_rank * 2 ) + 1
442+ # Broadcast addresses to respective ranks .
443+ bcast = [ mp_req_addr , mp_len_addr ]
444+ torch . distributed . broadcast_object_list ( bcast , src = mp_src , group = mp_group )
445+ [ mp_req_addr , mp_len_addr ] = bcast
440446
441447 ip_address_of_dp_coordinator = os .getenv ('MASTER_ADDR' , '127.0.0.1' )
442- identity = f'mp-coord-{ parallel_state .get_data_parallel_rank ()} '
443- if (
444- parallel_state .get_tensor_model_parallel_rank () == 0
445- and parallel_state .get_pipeline_model_parallel_rank () == 0
446- ):
448+ dp_addr = f"tcp://{ ip_address_of_dp_coordinator } :{ inference_coordinator_port } "
449+ identity = f'mp-coord-{ dp_rank } '
450+ if self .is_mp_coordinator :
447451 # 1. Create dealer sockets where tp_rank = 0 and pp_rank = 0
448452 # These will receive requests from an InferenceCoordinator.
449453 self .socket_for_receiving_requests = self .zmq_context .socket (zmq .DEALER )
450454
451455 self .socket_for_receiving_requests .setsockopt (zmq .IDENTITY , identity .encode ('utf-8' ))
452- self .socket_for_receiving_requests .connect (
453- f"tcp://{ ip_address_of_dp_coordinator } :{ inference_coordinator_port } "
454- )
456+ self .socket_for_receiving_requests .connect (dp_addr )
455457
456458 # send empty string. this is used to register with the coordinator.
457459 self .socket_for_receiving_requests .send (b"" )
458460
459461 # 2. Create a publisher socket. This is used to publish or broadcast
460462 # requests within the model parallel group
461- self .model_parallel_publisher_socket = self .zmq_context .socket (zmq .PUB )
462- self .model_parallel_publisher_socket .bind (f"tcp://*:{ req_port } " )
463+ self .model_parallel_publisher_socket = mp_req_sock
463464
464465 # 3. Create another publisher socket to broadcast the number of messages to receive.
465- self .model_parallel_num_msgs_publisher_socket = self .zmq_context .socket (zmq .PUB )
466- self .model_parallel_num_msgs_publisher_socket .bind (f"tcp://*:{ len_port } " )
466+ self .model_parallel_num_msgs_publisher_socket = mp_len_sock
467467 self .zmq_sockets += [
468468 self .socket_for_receiving_requests ,
469469 self .model_parallel_num_msgs_publisher_socket ,
470470 self .model_parallel_publisher_socket ,
471471 ]
472472 # All MP ranks subscribe to the two publisher sockets
473473 self .model_parallel_subscriber_socket = self .zmq_context .socket (zmq .SUB )
474- self .model_parallel_subscriber_socket .connect (f"tcp:// { bcast_hostname } : { req_port } " )
474+ self .model_parallel_subscriber_socket .connect (mp_req_addr )
475475 self .model_parallel_subscriber_socket .setsockopt_string (zmq .SUBSCRIBE , "" )
476476
477477 self .model_parallel_num_msgs_subscriber_socket = self .zmq_context .socket (zmq .SUB )
478- self .model_parallel_num_msgs_subscriber_socket .connect (f"tcp:// { bcast_hostname } : { len_port } " )
478+ self .model_parallel_num_msgs_subscriber_socket .connect (mp_len_addr )
479479 self .model_parallel_num_msgs_subscriber_socket .setsockopt_string (zmq .SUBSCRIBE , "" )
480480
481481 self .zmq_sockets += [
482482 self .model_parallel_subscriber_socket ,
483483 self .model_parallel_num_msgs_subscriber_socket ,
484484 ]
485485
486- torch .distributed .barrier (parallel_state . get_model_parallel_group () )
486+ torch .distributed .barrier (mp_group )
487487
488- if launch_inference_coordinator and torch . distributed . get_rank () == 0 :
488+ if launch_inference_coordinator and self . is_dp_coordinator :
489489 await await_process_event (coordinator_ready_event , self .inference_coordinator_process )
490490 logging .info ("Inference co-ordinator is ready to receive requests!" )
491491
@@ -697,7 +697,7 @@ def _add_request(
697697 try :
698698 eod = self .controller .tokenizer .eod
699699 except AttributeError :
700- if torch . distributed . get_rank () == 0 :
700+ if self . rank == 0 :
701701 warnings .warn (
702702 "Termination ID not specified, and tokenizer does not define eod."
703703 "Defaulting to not using termination id."
@@ -1093,7 +1093,7 @@ async def async_bookkeep(
10931093 self .failed_request_ids .clear ()
10941094
10951095 # Handle necessary ZMQ DP coordinator communication.
1096- if self .use_coordinator and self .is_tp0_and_pp0 and finished_request_records :
1096+ if self .use_coordinator and self .is_mp_coordinator and finished_request_records :
10971097 payload = msgpack .packb (
10981098 [Headers .ENGINE_REPLY .value , [r .serialize () for r in finished_request_records ]],
10991099 use_bin_type = True ,
@@ -1277,11 +1277,9 @@ def schedule_requests(self) -> int:
12771277 int: The number of messages that were received and processed in this batch.
12781278 """
12791279
1280- tp_rank = parallel_state .get_tensor_model_parallel_rank ()
1281- pp_rank = parallel_state .get_pipeline_model_parallel_rank ()
12821280 torch .cuda .nvtx .range_push ("drain_zmq_socket" )
12831281 all_messages = []
1284- if tp_rank == 0 and pp_rank == 0 :
1282+ if self . is_mp_coordinator :
12851283 while True :
12861284 try :
12871285 # Receive messages in a non-blocking way.
@@ -1297,8 +1295,8 @@ def schedule_requests(self) -> int:
12971295 struct .pack ('!i' , messages_to_dequeue )
12981296 )
12991297 # Now publish the actual messages to all model parallel ranks
1300- for message in all_messages :
1301- self .model_parallel_publisher_socket .send ( message )
1298+ if messages_to_dequeue > 0 :
1299+ self .model_parallel_publisher_socket .send_multipart ( all_messages )
13021300 else :
13031301 # First, receive the number of messages to dequeue from mp-rank 0
13041302 messages_to_dequeue = struct .unpack (
@@ -1307,8 +1305,10 @@ def schedule_requests(self) -> int:
13071305 # Now, dequeue the same number of messages from the subscriber socket.
13081306 # Note that these receives are blocking, because the messages
13091307 # are guaranteed to be available after the tp-rank 0 has sent them.
1310- for _ in range (messages_to_dequeue ):
1311- all_messages .append (self .model_parallel_subscriber_socket .recv ())
1308+ if messages_to_dequeue > 0 :
1309+ all_messages = self .model_parallel_subscriber_socket .recv_multipart ()
1310+ else :
1311+ all_messages = []
13121312
13131313 torch .cuda .nvtx .range_pop ()
13141314 for message in all_messages :
@@ -1347,7 +1347,6 @@ def stop(self):
13471347 for socket in self .zmq_sockets :
13481348 socket .close ()
13491349 self .zmq_context .term ()
1350- parallel_state .destroy_model_parallel ()
13511350
13521351 @trace_async_exceptions
13531352 async def run_engine (
@@ -1369,7 +1368,6 @@ async def run_engine(
13691368 )
13701369 )
13711370 )
1372-
13731371 await self .async_step (verbose = verbose )
13741372 except asyncio .CancelledError :
13751373 pass
0 commit comments