Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/inference/gpt/gpt_dynamic_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
ContextOverflowError,
DynamicInferenceContext,
)
from megatron.core.inference.context.attention_context.mamba_metadata import (
from megatron.core.inference.contexts.attention_context.mamba_metadata import (
MambaInferenceStateConfig,
)
from megatron.core.inference.engines import DynamicInferenceEngine, EngineSuspendedError
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ async def main(
engine: DynamicInferenceEngine,
requests: List[Request],
port: int,
mp_port: int,
sampling_params: SamplingParams | None = None,
):
if sampling_params is not None:
Expand All @@ -58,7 +57,6 @@ async def main(

await engine.start_listening_to_data_parallel_coordinator(
inference_coordinator_port=port,
inference_mp_coordinator_port=mp_port,
launch_inference_coordinator=True,
verbose=True,
)
Expand Down Expand Up @@ -258,6 +256,5 @@ async def main(
engine,
requests,
args.inference_coordinator_port,
args.inference_mp_coordinator_port
)
)
107 changes: 51 additions & 56 deletions megatron/core/inference/engines/dynamic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,14 @@ def __init__(
self.track_paused_request_events = track_paused_request_events
self.enable_chunked_prefill = enable_chunked_prefill
self.unified_memory_level = context.unified_memory_level
self.rank = torch.distributed.get_rank()

self.inference_logging_step_interval = inference_logging_step_interval
# Configure wandb to use separate step counter for inference metrics (only once)
if self.inference_logging_step_interval > 0 and self.context.metrics_writer is not None:
logging.info(
f"\033[1;93m[INFERENCE]\033[0m "
f"\033[1;95mLogging inference metrics to wandb (rank {torch.distributed.get_rank()})\033[0m"
f"\033[1;95mLogging inference metrics to wandb (rank {self.rank})\033[0m"
)
if HAVE_WANDB and self.context.metrics_writer.__name__ == "wandb":
# Make all inference/* metrics use inference_step as their x-axis
Expand Down Expand Up @@ -259,7 +260,7 @@ def create_cuda_graphs(self, reset_context: bool = True):

if moe_pad_experts and context.non_decode_cuda_graphs:
context.non_decode_cuda_graphs = False
if torch.distributed.get_rank() == 0:
if self.rank == 0:
warnings.warn(
"MoE models do not support non-decode cuda graphs. "
"Forcing non_decode_cuda_graphs to False."
Expand Down Expand Up @@ -347,7 +348,6 @@ def create_cuda_graphs(self, reset_context: bool = True):
async def start_listening_to_data_parallel_coordinator(
self,
inference_coordinator_port: int,
inference_mp_coordinator_port: int = 20000,
launch_inference_coordinator: bool = True,
verbose: bool = False,
*,
Expand All @@ -360,6 +360,8 @@ async def start_listening_to_data_parallel_coordinator(
`InferenceCoordinator`. It configures different ZMQ socket patterns
based on the rank's role within the distributed topology.

Note that this method must be called on all ranks, as it uses blocking torch broadcasts.

The setup involves two primary roles within each data-parallel group:
1. **MP Coordinator (TP_rank=0, PP_rank=0)**: This rank connects directly
to the central coordinator via a ZMQ `DEALER` socket. It receives
Expand All @@ -378,9 +380,6 @@ async def start_listening_to_data_parallel_coordinator(
Args:
inference_coordinator_port (int): The network port where the central
`InferenceCoordinator` is or will be listening.
inference_mp_coordinator_port (int): The base network port where each model parallel
coordinator will broadcast messages from. Each MP group will compute an independent
port offset from this base port.
launch_inference_coordinator (bool, optional): If True, the global rank 0
process will spawn and manage the `InferenceCoordinator`
process. Defaults to True.
Expand All @@ -395,7 +394,25 @@ async def start_listening_to_data_parallel_coordinator(
"pip install msgpack"
)

if launch_inference_coordinator and torch.distributed.get_rank() == 0:
self.zmq_context = zmq.Context().instance()
self.zmq_sockets = [] # keep track of all sockets created by this engine

# Get world info.
dp_group = parallel_state.get_data_parallel_group()
dp_src = parallel_state.get_data_parallel_src_rank()
dp_size = parallel_state.get_data_parallel_world_size()
dp_rank = parallel_state.get_data_parallel_rank()

mp_group = parallel_state.get_model_parallel_group()
mp_src = parallel_state.get_model_parallel_src_rank()
tp_rank = parallel_state.get_tensor_model_parallel_rank()
pp_rank = parallel_state.get_pipeline_model_parallel_rank()

self.is_mp_coordinator = tp_rank == 0 and pp_rank == 0
self.is_dp_coordinator = (dp_rank == 0) and self.is_mp_coordinator

# Spawn a DP coordinator process and get the connection info.
if launch_inference_coordinator and self.is_dp_coordinator:
spawn_context = multiprocessing.get_context('spawn')
coordinator_ready_event = spawn_context.Event()
self.inference_coordinator_process = spawn_context.Process(
Expand All @@ -408,80 +425,67 @@ async def start_listening_to_data_parallel_coordinator(
)
self.inference_coordinator_process.start()

# Todo [Siddharth]: can we move this code to another file?
self.zmq_context = zmq.Context()
self.zmq_sockets = [] # keep track of all sockets created by this engine

# We need to broadcast the hostname of the (TP=0, PP=0) rank
# to all other ranks in the same model parallel group.
tp_rank = parallel_state.get_tensor_model_parallel_rank()
pp_rank = parallel_state.get_pipeline_model_parallel_rank()

hostname_list = [None]
if tp_rank == 0 and pp_rank == 0:
hostname_list[0] = socket.gethostname()
# Find available ports for MP and bind to them.
if self.is_mp_coordinator:
local_ip = socket.gethostname()
mp_req_sock = self.zmq_context.socket(zmq.PUB)
mp_req_sock.bind_to_random_port(f"tcp://{local_ip}")
mp_req_addr = mp_req_sock.getsockopt_string(zmq.LAST_ENDPOINT)

# Find the global rank of the (TP=0, PP=0) rank in our MP group
src_global_rank = parallel_state.get_model_parallel_src_rank()
mp_len_sock = self.zmq_context.socket(zmq.PUB)
mp_len_sock.bind_to_random_port(f"tcp://{local_ip}")
mp_len_addr = mp_len_sock.getsockopt_string(zmq.LAST_ENDPOINT)
else:
mp_req_addr = None
mp_len_addr = None

# Broadcast addresses to respective ranks.
torch.distributed.broadcast_object_list(
hostname_list, src=src_global_rank, group=parallel_state.get_model_parallel_group()
[mp_req_addr, mp_len_addr], src=mp_src, group=mp_group
)
bcast_hostname = hostname_list[0]

# We need unique ports for each MP group, so we compute an offset using the DP rank.
dp_rank = parallel_state.get_data_parallel_rank()
req_port = inference_mp_coordinator_port + (dp_rank * 2)
len_port = inference_mp_coordinator_port + (dp_rank * 2) + 1

ip_address_of_dp_coordinator = os.getenv('MASTER_ADDR', '127.0.0.1')
identity = f'mp-coord-{parallel_state.get_data_parallel_rank()}'
if (
parallel_state.get_tensor_model_parallel_rank() == 0
and parallel_state.get_pipeline_model_parallel_rank() == 0
):
dp_addr = f"tcp://{ip_address_of_dp_coordinator}:{inference_coordinator_port}"
identity = f'mp-coord-{dp_rank}'
if self.is_mp_coordinator:
# 1. Create dealer sockets where tp_rank = 0 and pp_rank = 0
# These will receive requests from an InferenceCoordinator.
self.socket_for_receiving_requests = self.zmq_context.socket(zmq.DEALER)

self.socket_for_receiving_requests.setsockopt(zmq.IDENTITY, identity.encode('utf-8'))
self.socket_for_receiving_requests.connect(
f"tcp://{ip_address_of_dp_coordinator}:{inference_coordinator_port}"
)
self.socket_for_receiving_requests.connect(dp_addr)

# send empty string. this is used to register with the coordinator.
self.socket_for_receiving_requests.send(b"")

# 2. Create a publisher socket. This is used to publish or broadcast
# requests within the model parallel group
self.model_parallel_publisher_socket = self.zmq_context.socket(zmq.PUB)
self.model_parallel_publisher_socket.bind(f"tcp://*:{req_port}")
self.model_parallel_publisher_socket = mp_req_sock

# 3. Create another publisher socket to broadcast the number of messages to receive.
self.model_parallel_num_msgs_publisher_socket = self.zmq_context.socket(zmq.PUB)
self.model_parallel_num_msgs_publisher_socket.bind(f"tcp://*:{len_port}")
self.model_parallel_num_msgs_publisher_socket = mp_len_sock
self.zmq_sockets += [
self.socket_for_receiving_requests,
self.model_parallel_num_msgs_publisher_socket,
self.model_parallel_publisher_socket,
]
# All MP ranks subscribe to the two publisher sockets
self.model_parallel_subscriber_socket = self.zmq_context.socket(zmq.SUB)
self.model_parallel_subscriber_socket.connect(f"tcp://{bcast_hostname}:{req_port}")
self.model_parallel_subscriber_socket.connect(mp_req_addr)
self.model_parallel_subscriber_socket.setsockopt_string(zmq.SUBSCRIBE, "")

self.model_parallel_num_msgs_subscriber_socket = self.zmq_context.socket(zmq.SUB)
self.model_parallel_num_msgs_subscriber_socket.connect(f"tcp://{bcast_hostname}:{len_port}")
self.model_parallel_num_msgs_subscriber_socket.connect(mp_len_addr)
self.model_parallel_num_msgs_subscriber_socket.setsockopt_string(zmq.SUBSCRIBE, "")

self.zmq_sockets += [
self.model_parallel_subscriber_socket,
self.model_parallel_num_msgs_subscriber_socket,
]

torch.distributed.barrier(parallel_state.get_model_parallel_group())
torch.distributed.barrier(mp_group)

if launch_inference_coordinator and torch.distributed.get_rank() == 0:
if launch_inference_coordinator and self.is_dp_coordinator:
await await_process_event(coordinator_ready_event, self.inference_coordinator_process)
logging.info("Inference co-ordinator is ready to receive requests!")

Expand Down Expand Up @@ -693,7 +697,7 @@ def _add_request(
try:
eod = self.controller.tokenizer.eod
except AttributeError:
if torch.distributed.get_rank() == 0:
if self.rank == 0:
warnings.warn(
"Termination ID not specified, and tokenizer does not define eod."
"Defaulting to not using termination id."
Expand Down Expand Up @@ -1204,11 +1208,9 @@ def schedule_requests(self) -> int:
int: The number of messages that were received and processed in this batch.
"""

tp_rank = parallel_state.get_tensor_model_parallel_rank()
pp_rank = parallel_state.get_pipeline_model_parallel_rank()
torch.cuda.nvtx.range_push("drain_zmq_socket")
all_messages = []
if tp_rank == 0 and pp_rank == 0:
if self.is_mp_coordinator:
while True:
try:
# Receive messages in a non-blocking way.
Expand Down Expand Up @@ -1274,7 +1276,6 @@ def stop(self):
for socket in self.zmq_sockets:
socket.close()
self.zmq_context.term()
parallel_state.destroy_model_parallel()

@trace_async_exceptions
async def run_engine(
Expand All @@ -1295,7 +1296,6 @@ async def run_engine(
)
)
)

await self.async_step(verbose=verbose)
except asyncio.CancelledError:
pass
Expand Down Expand Up @@ -1349,13 +1349,8 @@ async def run_engine_with_coordinator(
# Step.
engine_output = await self.async_step(verbose=verbose)

# Send finished requests.
is_tp0_and_pp0 = (
parallel_state.get_tensor_model_parallel_rank() == 0
and parallel_state.get_pipeline_model_parallel_rank() == 0
)
if (
is_tp0_and_pp0
self.is_mp_coordinator
and engine_output is not None
and engine_output["finished_request_records"]
):
Expand Down
2 changes: 0 additions & 2 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1496,8 +1496,6 @@ def _add_inference_args(parser):
'Default to 0 to disable inference wandb logging.')
group.add_argument("--inference-coordinator-port", type=int, default=12346,
help="This port will be used to setup the inference coordinator on node-0")
group.add_argument("--inference-mp-coordinator-port", type=int, default=20000,
help="This port will be used to setup the inference model parallel coordinators")
return parser


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ async def _run_test(cls, **test_config_kwargs):
env.timing_data["start_time"] = time.time()
await env.engine.start_listening_to_data_parallel_coordinator(
inference_coordinator_port=test_config.port,
inference_mp_coordinator_port=test_config.mp_port,
launch_inference_coordinator=test_config.launch_inference_coordinator,
)

Expand Down
Loading