diff --git a/examples/inference/gpt/benchmark.py b/examples/inference/gpt/benchmark.py new file mode 100644 index 0000000000..16ef18c290 --- /dev/null +++ b/examples/inference/gpt/benchmark.py @@ -0,0 +1,142 @@ +from megatron.core.inference.inference_client import InferenceClient +from examples.inference.gpt.utils import add_common_inference_args +import asyncio +import torch.distributed as dist +from examples.inference.gpt.gpt_dynamic_inference import get_model, get_inference_context, get_inference_controller, add_dynamic_inference_args +from megatron.core.inference.inference_request import DynamicInferenceRequest +from megatron.training import initialize_megatron +import torch +import os +from megatron.training import get_args, get_tokenizer +from megatron.core.inference.sampling_params import SamplingParams +from examples.inference.gpt.utils import build_requests, build_dynamic_engine_setup_prefix, Request +from megatron.core.inference.engines import DynamicInferenceEngine +import time +from tqdm import tqdm +from typing import List +import json +from megatron.training.arguments import parse_args +from megatron.core import parallel_state + +if __name__ == "__main__": + # enable inference mode in the very beginning as some fp-8 optimizations + # check for it. + with torch.inference_mode(): + initialize_megatron( + #parsed_args=args + extra_args_provider=add_dynamic_inference_args, + args_defaults={'no_load_rng': True, 'no_load_optim': True}, + ) + + # Start Nsight profiler. + if os.environ.get("NSIGHT_PREFIX"): + torch.cuda.cudart().cudaProfilerStart() + + args = get_args() + tokenizer = get_tokenizer() + + # Sampling params. + sampling_params = SamplingParams( + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + return_log_probs=args.return_log_probs, + num_tokens_to_generate=args.num_tokens_to_generate, + ) + + # Requests, context, conroller. + model = get_model() + requests = build_requests(args, tokenizer) if dist.get_rank() == 0 else None + + + context = get_inference_context(None, + None, + calculate_max_sequence_length_from_requests=False) + + controller = get_inference_controller(model, context) + + # Inference engine. + engine = DynamicInferenceEngine( + controller, + context, + termination_id=tokenizer.eod, + enable_cuda_graph=args.cuda_graph_impl == "local", + random_seed=args.seed, + enable_chunked_prefill=not args.disable_chunked_prefill + ) + + + if dist.get_rank() == 0: + setup_prefix = build_dynamic_engine_setup_prefix(args, model, context, requests) + print("~~~") + print(setup_prefix) + print("~~~") + + batch_size = args.inference_dynamic_batching_max_requests_override + + + # Warmup + for _ in range(5): + context.initialize_attention_state( + num_warmup_tokens=batch_size, + ) + input_ids, position_ids = context.current_input_and_position_ids( + num_warmup_tokens=batch_size + ) + + # Forward pass -> logits. + with torch.inference_mode(): + controller.inference_wrapped_model.run_one_forward_step( + { + "tokens": input_ids, + "position_ids": position_ids, + "attention_mask": None, + } + ) + context.reset() + + TIMED_ITERS = 10 + st_events = [torch.cuda.Event(enable_timing=True) for _ in range(TIMED_ITERS)] + en_events = [torch.cuda.Event(enable_timing=True) for _ in range(TIMED_ITERS)] + + for i in range(TIMED_ITERS): + context.initialize_attention_state( + num_warmup_tokens=batch_size, + ) + input_ids, position_ids = context.current_input_and_position_ids( + num_warmup_tokens=batch_size + ) + st_events[i].record() + # Forward pass -> logits. + with torch.inference_mode(): + controller.inference_wrapped_model.run_one_forward_step( + { + "tokens": input_ids, + "position_ids": position_ids, + "attention_mask": None, + } + ) + context.reset() + en_events[i].record() + torch.cuda.synchronize() + elapsed_times = [st_events[i].elapsed_time(en_events[i]) for i in range(TIMED_ITERS)] + elapsed_time = sum(elapsed_times) / TIMED_ITERS + torch.cuda.synchronize() + if dist.get_rank() == 0: + print(f"Overlapped GEMM: = {args.tp_comm_overlap}") + print(f"Inference Optimized Layers: {args.use_inference_optimized_layers}") + print(f"Avg latency per forward pass: {elapsed_time:.2f} ms") + with open("bench_tp.jsonl", "a") as f: + json.dump({ + "tp_comm_overlap": args.tp_comm_overlap, + "use_inference_optimized_layers": args.use_inference_optimized_layers, + "avg_latency_ms": elapsed_time, + "batch_size": batch_size, + }, f) + f.write("\n") + + if os.environ.get("NSIGHT_PREFIX"): + torch.cuda.cudart().cudaProfilerStop() + + + \ No newline at end of file diff --git a/examples/inference/gpt/gpt_dynamic_inference_with_coordinator.py b/examples/inference/gpt/gpt_dynamic_inference_with_coordinator.py index 9e2b6bfa98..74a7f51d64 100644 --- a/examples/inference/gpt/gpt_dynamic_inference_with_coordinator.py +++ b/examples/inference/gpt/gpt_dynamic_inference_with_coordinator.py @@ -163,4 +163,7 @@ async def main( asyncio.run(main(engine, requests, args.inference_coordinator_port)) + + if os.environ.get("NSIGHT_PREFIX"): + torch.cuda.cudart().cudaProfilerStop() diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index f325720b9b..49f04a8628 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -66,6 +66,7 @@ if TYPE_CHECKING: import wandb as WandbModule +RUN_ALL_IN_CUDA_GRAPHED_MODE = True class ContextOverflowError(Exception): """Base exception for when a new request does not fit. @@ -384,6 +385,8 @@ def __init__( self.max_total_requests = self.block_allocator.total_count - 1 # -1 for dummy block self.max_active_requests = self.block_allocator.active_count self.max_tokens = max_tokens or self.DEFAULT_MAX_TOKENS + if RUN_ALL_IN_CUDA_GRAPHED_MODE: + self.max_tokens = self.max_active_requests assert self.max_tokens >= self.max_active_requests, ( f"max_tokens ({self.max_tokens}) must be >= " diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 1a5f718fb0..377b6545ad 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -200,13 +200,13 @@ def create_cuda_graphs(self, reset_context: bool = True): config = controller.inference_wrapped_model.inference_wrapper_config moe_pad_experts = config.moe_pad_experts_for_cuda_graph_inference - if moe_pad_experts and context.non_decode_cuda_graphs: - context.non_decode_cuda_graphs = False - if torch.distributed.get_rank() == 0: - warnings.warn( - "MoE models do not support non-decode cuda graphs. " - "Forcing non_decode_cuda_graphs to False." - ) + # if moe_pad_experts and context.non_decode_cuda_graphs: + # context.non_decode_cuda_graphs = False + # if torch.distributed.get_rank() == 0: + # warnings.warn( + # "MoE models do not support non-decode cuda graphs. " + # "Forcing non_decode_cuda_graphs to False." + # ) time_start = time.time() mem_stats_start = torch.cuda.memory_stats() @@ -1042,6 +1042,45 @@ async def run_engine( except asyncio.CancelledError: pass + def get_local_and_global_work(self): + """Determines if a dummy forward pass should be run on this rank. + This is used to keep expert parallel ranks in sync when some ranks have no work to do. + """ + range_push("get_local_and_global_work") + local_work = self.context.get_active_request_count() + len(self.waiting_request_ids) + + if parallel_state.get_expert_model_parallel_world_size() > 1: + expert_model_parallel_group = parallel_state.get_expert_model_parallel_group_gloo() + # all reduce local work across expert model parallel group + + local_work_tensor = torch.tensor( + [local_work], device='cpu' + ) + torch.distributed.all_reduce( + local_work_tensor, + op=torch.distributed.ReduceOp.SUM, + group=expert_model_parallel_group, + ) + global_work = local_work_tensor.item() + else: + global_work = local_work + range_pop() + return local_work, global_work + + def dummy_forward(self): + """Performs a dummy forward pass to keep expert parallel ranks in sync.""" + range_push("dummy_forward") + input_ids, position_ids = self.controller._dynamic_step_context_init( + num_warmup_tokens=self.context.cuda_graph_token_counts[0], + warmup_engine_mode=WarmupEngineMode.NON_DECODE + ) + # Forward pass -> logits. + self.controller._dynamic_step_forward_logits(input_ids, position_ids) + + with torch.inference_mode(): + self.context.reset() # todo: @lmcafee, remove if unnecessary + range_pop() + @trace_async_exceptions async def run_engine_with_coordinator( self, *, loop: Optional[asyncio.AbstractEventLoop] = None, verbose: Optional[bool] = False @@ -1051,7 +1090,19 @@ async def run_engine_with_coordinator( try: while True: self.schedule_requests() - if self.stopped: + + # we need to run dummy forwards in the case of expert model parallelism + # when there's some work on other EP ranks but not on this one. + local_work, global_work = self.get_local_and_global_work() + if local_work == 0 and global_work > 0: + self.dummy_forward() + # the continue is extremely important to avoid premature pauses/stops + # example - say this rank has received a stop signal but others still have work + # if we don't continue here, this rank will process the stop signal + # and stop the engine prematurely + continue + + if self.stopped and global_work == 0: self.stop() return @@ -1067,7 +1118,7 @@ async def run_engine_with_coordinator( # todo [Siddharth]: Can this hardcoded sleep be avoided # with asyncio zmq sockets? - if self.paused: + if self.paused and global_work == 0: await asyncio.sleep(0.02) continue diff --git a/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py b/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py index 95d476a9f8..414f7ac7fe 100644 --- a/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +++ b/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py @@ -158,6 +158,7 @@ def _forward(self, inference_input): tokens = inference_input["tokens"] position_ids = inference_input["position_ids"] attention_mask = inference_input["attention_mask"] + print(tokens.shape, position_ids.shape) return self.model( tokens, position_ids, @@ -166,6 +167,15 @@ def _forward(self, inference_input): runtime_gather_output=True, # Inference should always gather the logits ) + def dummy_forward(self): + tokens = torch.zeros((1, 1), dtype=torch.long, device=torch.cuda.current_device()) + position_ids = torch.zeros((1, 1), dtype=torch.long, device=torch.cuda.current_device()) + attention_mask = None + return self.model( + tokens, + position_ids, + attention_mask) + def _get_batch_size_and_seq_len( self, tokens: torch.Tensor, recv_buffer_seq_len: Optional[int] = None ): diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 2b44b41874..8f54423b08 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -515,12 +515,12 @@ def _dynamic_step_context_init( inference_wrapper_config.moe_pad_experts_for_cuda_graph_inference ) if moe_pad_experts_for_cuda_graph_inference: - assert warmup_engine_mode is not WarmupEngineMode.NON_DECODE - if context.is_decode_only(): - capacity_factor = model_config.num_moe_experts / model_config.moe_router_topk - set_decode_expert_padding(unwrapped_model, True, capacity_factor=capacity_factor) - else: - set_decode_expert_padding(unwrapped_model, False) + #assert warmup_engine_mode is not WarmupEngineMode.NON_DECODE + #if context.is_decode_only(): + capacity_factor = model_config.num_moe_experts / model_config.moe_router_topk + set_decode_expert_padding(unwrapped_model, True, capacity_factor=capacity_factor) + #else: + # set_decode_expert_padding(unwrapped_model, False) if nccl_all_reduce_for_prefill and symmetric_ar_type is not None: if context.is_decode_only(): @@ -832,6 +832,12 @@ async def async_generate_output_tokens_dynamic_batch( ret.update(request_bookkeeping) return ret + def dummy_forward(self): + rank = torch.distributed.get_rank() + print(f"Rank {rank} dummy forward called") + return self.inference_wrapped_model.dummy_forward() + + @torch.inference_mode() def generate_output_tokens_dynamic_batch( self, loop: Optional[asyncio.AbstractEventLoop] = None @@ -1472,3 +1478,5 @@ def stream_token( output_log_probs, ) list(ret) + + diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index 1e41bf9d8c..e482910cf7 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -51,6 +51,7 @@ # Expert model parallel group that current rank belongs to. _EXPERT_MODEL_PARALLEL_GROUP = None +_EXPERT_MODEL_PARALLEL_GROUP_GLOO = None # Expert tensor parallel group that current rank belongs to. _EXPERT_TENSOR_PARALLEL_GROUP = None # Expert tensor and model combined parallel group @@ -1117,16 +1118,27 @@ def initialize_model_parallel( ### Expert-related parallel groups initialization # Build the expert model parallel group - global _EXPERT_MODEL_PARALLEL_GROUP + global _EXPERT_MODEL_PARALLEL_GROUP, _EXPERT_MODEL_PARALLEL_GROUP_GLOO assert _EXPERT_MODEL_PARALLEL_GROUP is None, 'Expert parallel group is already initialized' + assert _EXPERT_MODEL_PARALLEL_GROUP_GLOO is None, "Expert parallel group-gloo is already initialized" for ranks in expert_decoder_rank_generator.get_ranks('ep'): group = create_group( ranks, pg_options=get_nccl_options("ep", nccl_comm_cfgs), group_desc="EXPERT_MODEL_PARALLEL_GROUP", ) + if create_gloo_process_groups: + group_gloo = create_group( + ranks, + timeout=timeout, + backend="gloo", + group_desc="EXPERT_MODEL_PARALLEL_GROUP_GLOO" + ) + else: + group_gloo = None if rank in ranks: _EXPERT_MODEL_PARALLEL_GROUP = group + _EXPERT_MODEL_PARALLEL_GROUP_GLOO = group_gloo # Build the expert tensor parallel group global _EXPERT_TENSOR_PARALLEL_GROUP @@ -1721,6 +1733,14 @@ def get_expert_model_parallel_group(check_initialized=True): ), "expert model parallel group is not initialized" return _EXPERT_MODEL_PARALLEL_GROUP +### Expert-related parallel states functions +def get_expert_model_parallel_group_gloo(check_initialized=True): + """Get the expert-model-parallel group the caller rank belongs to.""" + if check_initialized: + assert ( + _EXPERT_MODEL_PARALLEL_GROUP_GLOO is not None + ), "expert model parallel group gloo is not initialized" + return _EXPERT_MODEL_PARALLEL_GROUP_GLOO def get_expert_model_parallel_world_size(): """Return world size for the expert-model-parallel group.""" diff --git a/megatron/core/tensor_parallel/mappings.py b/megatron/core/tensor_parallel/mappings.py index 9ff69c9dc3..e4f57e0e09 100644 --- a/megatron/core/tensor_parallel/mappings.py +++ b/megatron/core/tensor_parallel/mappings.py @@ -441,6 +441,8 @@ def forward(ctx, group, input, output_split_sizes, input_split_sizes): dtype=input.dtype, device=torch.cuda.current_device(), ) + rank = torch.distributed.get_rank(group) + print(f"{rank}: starting all to all") torch.distributed.all_to_all_single( output, input, @@ -448,6 +450,7 @@ def forward(ctx, group, input, output_split_sizes, input_split_sizes): input_split_sizes=input_split_sizes, group=group, ) + print(f"{rank}: finished all to all") return output @staticmethod diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index e221c0ea00..ebf5013d0f 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -24,7 +24,7 @@ ) from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.identity_op import IdentityOp -from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.module import GraphableMegatronModule, MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.utils import ( deprecate_inference_params, @@ -131,7 +131,7 @@ class CrossAttentionSubmodules: linear_proj: Union[ModuleSpec, type] = None -class Attention(MegatronModule, ABC): +class Attention(GraphableMegatronModule, MegatronModule, ABC): """Attention layer abstract class. This layer only contains common modules required for the "self attn" and @@ -637,6 +637,67 @@ def flash_decode_and_prefill( output_total = flash_attn_with_kvcache(**flash_attn_args) return output_total + def _should_call_local_cudagraph(self, *args, **kwargs): + """ + Check if we should call the local cudagraph path. + """ + if not self.training and ( + hasattr(self, 'cudagraph_manager') + and kwargs['attention_mask'] is None + and ( + kwargs.get('inference_context') is not None + or kwargs.get('inference_params') is not None + ) + and self.config.cuda_graph_scope == 'attn' + ): + if kwargs['inference_context'].is_static_batching(): + using_cuda_graph = kwargs['inference_context'].is_decode_only() + else: + using_cuda_graph = kwargs['inference_context'].using_cuda_graph_this_step() + + if using_cuda_graph: + return True + return False + + def __call__(self, + hidden_states: Tensor, + attention_mask: Tensor, + key_value_states: Optional[Tensor] = None, + inference_context: Optional[BaseInferenceContext] = None, + rotary_pos_emb: Optional[Union[Tensor, Tuple[Tensor, Tensor]]] = None, + rotary_pos_cos: Optional[Tensor] = None, + rotary_pos_sin: Optional[Tensor] = None, + rotary_pos_cos_sin: Optional[Tensor] = None, + attention_bias: Optional[Tensor] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[int] = None, + *, + inference_params: Optional[BaseInferenceContext] = None,): + inference_context = deprecate_inference_params(inference_context, inference_params) + kwargs = { + 'hidden_states': hidden_states, + 'attention_mask': attention_mask, + 'key_value_states': key_value_states, + 'inference_context': inference_context, + 'rotary_pos_emb': rotary_pos_emb, + 'rotary_pos_cos': rotary_pos_cos, + 'rotary_pos_sin': rotary_pos_sin, + 'rotary_pos_cos_sin': rotary_pos_cos_sin, + 'attention_bias': attention_bias, + 'packed_seq_params': packed_seq_params, + 'sequence_len_offset': sequence_len_offset, + 'inference_params': inference_context, + } + if self._should_call_local_cudagraph(**kwargs): + # dynamic_inference_decode_only is not a real argument to forward, it is only used + # to differentiate the cuda graph used for decode from the one used for non-decode + # inference. + dynamic_inference_decode_only = kwargs['inference_context'].is_decode_only() + return super().__call__( + dynamic_inference_decode_only=dynamic_inference_decode_only, **kwargs + ) + return super().__call__(**kwargs) + def forward( self, hidden_states: Tensor, @@ -652,6 +713,7 @@ def forward( sequence_len_offset: Optional[int] = None, *, inference_params: Optional[BaseInferenceContext] = None, + dynamic_inference_decode_only: Optional[bool] = None, ) -> Tuple[Tensor, Tensor]: """ Perform a forward pass through the attention module. diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index d5a6be9224..fb36a924f1 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -77,6 +77,7 @@ def __init__( self.shared_experts = None self.token_dispatcher: Optional[MoETokenDispatcher] = None self.layer_number = layer_number + @abstractmethod def forward(self, hidden_states): @@ -122,7 +123,7 @@ def __init__( # Initialize router self.router = TopKRouter(config=self.config, pg_collection=pg_collection) - + # Initialize token dispatcher if config.moe_token_dispatcher_type == "allgather": self.token_dispatcher = MoEAllGatherTokenDispatcher( @@ -174,11 +175,15 @@ def router_and_preprocess(self, hidden_states: torch.Tensor): hidden states and probabilities for the token dispatcher. The original hidden states are returned as a residual connection. """ + rank = torch.distributed.get_rank() residual = hidden_states + print(f"{rank}: about to run router...") probs, routing_map = self.router(hidden_states) + print(f"{rank}: ran router...") hidden_states, probs = self.token_dispatcher.dispatch_preprocess( hidden_states, routing_map, probs ) + print(f"{rank}: ran dispatch preprocess...") return hidden_states, probs, residual def dispatch(self, hidden_states: torch.Tensor, probs: torch.Tensor): @@ -262,6 +267,8 @@ def forward(self, hidden_states: torch.Tensor): Returns: A tuple containing the output tensor and the MLP bias, if any. """ + rank = torch.distributed.get_rank() + print(f"{rank}: inside moe layer FW pass...") if self.training and self.attn_tp_group.size() > 1 and not self.config.sequence_parallel: raise ValueError( "During training, performance may degrade if MoE and tensor parallelism" @@ -271,10 +278,15 @@ def forward(self, hidden_states: torch.Tensor): # MoE forward: route -> dispatch -> compute -> combine def custom_forward(hidden_states): shared_expert_output = self.shared_experts_compute(hidden_states) + print(f"{rank}: ran shared experts compute ...") hidden_states, probs, residual = self.router_and_preprocess(hidden_states) + print(f"{rank}: ran router and preprocess ...") dispatched_input, probs = self.dispatch(hidden_states, probs) + print(f"{rank}: ran dispatch ...") output, mlp_bias = self.routed_experts_compute(dispatched_input, probs, residual) + print(f"{rank}: ran routed experts compute ...") output = self.combine(output, shared_expert_output) + print(f"{rank}: ran combine ...") return output, mlp_bias if self.moe_layer_recompute: diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index 82fb7b0058..5ecfb6eb6f 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -442,6 +442,8 @@ def preprocess(self, routing_map: torch.Tensor) -> torch.Tensor: Returns: A tensor with the number of tokens for each local expert. """ + rank = torch.distributed.get_rank() + print(f"Rank {rank} start preprocess....") if self.drop_and_pad: # Drop and pad the input to capacity. num_tokens = routing_map.size(0) * self.config.moe_router_topk @@ -469,7 +471,7 @@ def preprocess(self, routing_map: torch.Tensor) -> torch.Tensor: # [num_experts], number of tokens assigned to each expert from the current rank's input. num_local_tokens_per_expert = routing_map.sum(dim=0).long() - + print(f"Rank {rank} routing map sum") if ( self.config.moe_expert_capacity_factor is not None or self.config.moe_router_padding_for_fp8 @@ -482,7 +484,9 @@ def preprocess(self, routing_map: torch.Tensor) -> torch.Tensor: # For dropless training, output size is static (num_tokens * topk) # No explicit sync needed self.num_out_tokens = routing_map.size(0) * self.config.moe_router_topk + print(f"Rank {rank} calculated num_out_tokens...") if self.ep_size > 1 or self.tp_size > 1: + print(f"Rank {rank} starting calculate splits...") # =================================================== # Calculate input_splits, output_splits for alltoall/allgather in variable size. # =================================================== @@ -495,6 +499,8 @@ def preprocess(self, routing_map: torch.Tensor) -> torch.Tensor: # num_global_tokens_per_expert represents the number of tokens sent to each # expert by all ranks. # [tp_size, ep_size, num_experts] + print(f"Rank {rank} starting gather from SP region") + print(f"Size of tp_ep group = {utils.get_pg_size(self.tp_ep_group)}") num_global_tokens_per_expert = ( gather_from_sequence_parallel_region( num_local_tokens_per_expert, group=self.tp_ep_group @@ -502,6 +508,7 @@ def preprocess(self, routing_map: torch.Tensor) -> torch.Tensor: .reshape(self.ep_size, self.tp_size, self.num_experts) .transpose(0, 1) ) + print(f"Rank {rank} completed gather from SP region") # [tp_size, ep_size, num_experts] -> [tp_size, ep_size, num_local_experts] num_global_tokens_per_local_expert = num_global_tokens_per_expert[ :, :, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1 @@ -521,7 +528,9 @@ def preprocess(self, routing_map: torch.Tensor) -> torch.Tensor: # A synchronization is needed before expert parallel AlltoAll communication # to get the `input_splits` and `output_splits` CPU values. + print(f"Rank {rank} starting update cuda sync point") self._maybe_update_cuda_sync_point("before_ep_alltoall") + print(f"Rank {rank} finished update cuda sync point") else: num_global_tokens_per_local_expert = num_local_tokens_per_expert.reshape( self.num_experts @@ -531,7 +540,7 @@ def preprocess(self, routing_map: torch.Tensor) -> torch.Tensor: # A synchronization is needed before the returns # to get the `num_tokens_per_local_expert` CPU value. self._maybe_update_cuda_sync_point("before_finish") - + print(f"Rank {rank} finished calculating splits...") if self.num_local_experts > 1: # [tp_size * ep_size, num_local_experts]. Represents the number of tokens sent # to each local expert by all ranks. @@ -542,7 +551,7 @@ def preprocess(self, routing_map: torch.Tensor) -> torch.Tensor: # A synchronization is needed before permutation 2 # to get the `num_global_tokens_per_local_expert` CPU value. self._maybe_update_cuda_sync_point("before_permutation_2") - + print(f"Rank {rank} finished preprocess....") assert ( self.cuda_sync_point_priority[self.cuda_dtoh_point] <= self.cuda_sync_point_priority[self.cuda_sync_point] @@ -566,6 +575,7 @@ def dispatch_preprocess( A tuple of permuted hidden states and probabilities. """ # Preprocess: Get the metadata for communication, permutation and computation operations. + rank = torch.distributed.get_rank() self.hidden_shape = hidden_states.shape self.probs = probs self.routing_map = routing_map @@ -580,15 +590,17 @@ def dispatch_preprocess( self.routing_map = fused_pad_routing_map(self.routing_map, pad_multiple) else: self.routing_map = pad_routing_map(self.routing_map, pad_multiple) + print(f"Rank {rank} start preprocess....") self.tokens_per_expert = self.preprocess(self.routing_map) - + print(f"Rank {rank} finished preprocess....") if self.shared_experts is not None: self.shared_experts.pre_forward_comm(hidden_states.view(self.hidden_shape)) - + print(f"Rank {rank} finished pre forward comm") # Permutation 1: input to AlltoAll input self.tokens_per_expert = self._maybe_dtoh_and_synchronize( "before_permutation_1", self.tokens_per_expert ) + print(f"Rank {rank} finished d2h synchronize....") self.hidden_shape_before_permute = hidden_states.shape ( permutated_local_input_tokens, @@ -602,6 +614,7 @@ def dispatch_preprocess( fused=self.config.moe_permute_fusion, drop_and_pad=self.drop_and_pad, ) + print(f"Rank {rank} finished permute...") return permutated_local_input_tokens, permuted_probs def token_dispatch(self, permutated_local_input_tokens, permuted_probs): @@ -619,17 +632,21 @@ def token_dispatch(self, permutated_local_input_tokens, permuted_probs): Returns: A tuple of tokens and probabilities after All-to-All. """ + rank = torch.distributed.get_rank() + print(f"Rank {rank} is using MoEAlltoAll dispatch....") # Perform expert parallel AlltoAll communication self.tokens_per_expert = self._maybe_dtoh_and_synchronize( "before_ep_alltoall", self.tokens_per_expert ) + print(f"Rank {rank} finished cuda sync before ep alltoall....") global_input_tokens = all_to_all( self.ep_group, permutated_local_input_tokens, self.output_splits, self.input_splits ) + print(f"Rank {rank} finished all to all on tokens...") global_probs = all_to_all( self.ep_group, permuted_probs, self.output_splits, self.input_splits ) - + print(f"Rank {rank} finished all to all on probs...") return global_input_tokens, global_probs def dispatch_postprocess(self, global_input_tokens, global_probs): @@ -1257,7 +1274,7 @@ def token_dispatch( Returns: A tuple of dispatched tokens and probabilities. - """ + """ return ( self._comm_manager.dispatch(hidden_states, async_finish, allocate_on_comm_stream), self._comm_manager.dispatched_probs, diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index a5babece9d..b0495d8138 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -826,7 +826,7 @@ def _should_call_local_cudagraph(self, *args, **kwargs): (kwargs.get('inference_context') is not None) or (kwargs.get('inference_params') is not None) ) - and self.config.cuda_graph_scope != 'full_iteration' + and self.config.cuda_graph_scope == 'full' ): if kwargs['inference_context'].is_static_batching(): using_cuda_graph = kwargs['inference_context'].is_decode_only()