Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions megatron/core/inference/engines/dynamic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,6 +1042,30 @@ async def run_engine(
except asyncio.CancelledError:
pass

def should_run_dummy_forward_for_expert_model_parallelism(self):
"""Determines if a dummy forward pass should be run on this rank.
This is used to keep expert parallel ranks in sync when some ranks have no work to do.
"""
range_push("should_run_dummy_forward_for_expert_model_parallelism")
if parallel_state.get_expert_model_parallel_world_size() == 1:
return False
local_work = self.context.get_active_request_count() + len(self.waiting_request_ids)
expert_model_parallel_group = parallel_state.get_expert_model_parallel_group_gloo()
# all reduce local work across expert model parallel group

local_work_tensor = torch.tensor(
[local_work], device='cpu'
)
torch.distributed.all_reduce(
local_work_tensor,
op=torch.distributed.ReduceOp.SUM,
group=expert_model_parallel_group,
)
global_work = local_work_tensor.item()
range_pop()
return (local_work == 0 and global_work > 0)


@trace_async_exceptions
async def run_engine_with_coordinator(
self, *, loop: Optional[asyncio.AbstractEventLoop] = None, verbose: Optional[bool] = False
Expand All @@ -1054,6 +1078,19 @@ async def run_engine_with_coordinator(
if self.stopped:
self.stop()
return
local_work = self.context.get_active_request_count() + len(self.waiting_request_ids)
expert_model_parallel_group = parallel_state.get_expert_model_parallel_group()
# all reduce local work across expert model parallel group

# we need to run dummy forwards in the case of expert model parallelism
# when there's some work on other EP ranks but not on this one.
if self.should_run_dummy_forward_for_expert_model_parallelism():
self.controller.dummy_forward()
# the continue is extremely important to avoid premature pauses/stops
# example - say this rank has received a stop signal but others still have work
# if we don't continue here, this rank will process the stop signal
# and stop the engine prematurely
continue

# for the cases below (engine is paused or no active requests),
# do not use asyncio.sleep(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def _forward(self, inference_input):
tokens = inference_input["tokens"]
position_ids = inference_input["position_ids"]
attention_mask = inference_input["attention_mask"]
print(tokens.shape, position_ids.shape)
return self.model(
tokens,
position_ids,
Expand All @@ -166,6 +167,15 @@ def _forward(self, inference_input):
runtime_gather_output=True, # Inference should always gather the logits
)

def dummy_forward(self):
tokens = torch.zeros((1, 1), dtype=torch.long, device=torch.cuda.current_device())
position_ids = torch.zeros((1, 1), dtype=torch.long, device=torch.cuda.current_device())
attention_mask = None
return self.model(
tokens,
position_ids,
attention_mask)

def _get_batch_size_and_seq_len(
self, tokens: torch.Tensor, recv_buffer_seq_len: Optional[int] = None
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,12 @@ async def async_generate_output_tokens_dynamic_batch(
ret.update(request_bookkeeping)
return ret

def dummy_forward(self):
rank = torch.distributed.get_rank()
print(f"Rank {rank} dummy forward called")
return self.inference_wrapped_model.dummy_forward()


@torch.inference_mode()
def generate_output_tokens_dynamic_batch(
self, loop: Optional[asyncio.AbstractEventLoop] = None
Expand Down Expand Up @@ -1472,3 +1478,5 @@ def stream_token(
output_log_probs,
)
list(ret)


22 changes: 21 additions & 1 deletion megatron/core/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@

# Expert model parallel group that current rank belongs to.
_EXPERT_MODEL_PARALLEL_GROUP = None
_EXPERT_MODEL_PARALLEL_GROUP_GLOO = None
# Expert tensor parallel group that current rank belongs to.
_EXPERT_TENSOR_PARALLEL_GROUP = None
# Expert tensor and model combined parallel group
Expand Down Expand Up @@ -1117,16 +1118,27 @@ def initialize_model_parallel(

### Expert-related parallel groups initialization
# Build the expert model parallel group
global _EXPERT_MODEL_PARALLEL_GROUP
global _EXPERT_MODEL_PARALLEL_GROUP, _EXPERT_MODEL_PARALLEL_GROUP_GLOO
assert _EXPERT_MODEL_PARALLEL_GROUP is None, 'Expert parallel group is already initialized'
assert _EXPERT_MODEL_PARALLEL_GROUP_GLOO is None, "Expert parallel group-gloo is already initialized"
for ranks in expert_decoder_rank_generator.get_ranks('ep'):
group = create_group(
ranks,
pg_options=get_nccl_options("ep", nccl_comm_cfgs),
group_desc="EXPERT_MODEL_PARALLEL_GROUP",
)
if create_gloo_process_groups:
group_gloo = create_group(
ranks,
timeout=timeout,
backend="gloo",
group_desc="EXPERT_MODEL_PARALLEL_GROUP_GLOO"
)
else:
group_gloo = None
if rank in ranks:
_EXPERT_MODEL_PARALLEL_GROUP = group
_EXPERT_MODEL_PARALLEL_GROUP_GLOO = group_gloo

# Build the expert tensor parallel group
global _EXPERT_TENSOR_PARALLEL_GROUP
Expand Down Expand Up @@ -1721,6 +1733,14 @@ def get_expert_model_parallel_group(check_initialized=True):
), "expert model parallel group is not initialized"
return _EXPERT_MODEL_PARALLEL_GROUP

### Expert-related parallel states functions
def get_expert_model_parallel_group_gloo(check_initialized=True):
"""Get the expert-model-parallel group the caller rank belongs to."""
if check_initialized:
assert (
_EXPERT_MODEL_PARALLEL_GROUP_GLOO is not None
), "expert model parallel group gloo is not initialized"
return _EXPERT_MODEL_PARALLEL_GROUP_GLOO

def get_expert_model_parallel_world_size():
"""Return world size for the expert-model-parallel group."""
Expand Down
3 changes: 3 additions & 0 deletions megatron/core/tensor_parallel/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,13 +441,16 @@ def forward(ctx, group, input, output_split_sizes, input_split_sizes):
dtype=input.dtype,
device=torch.cuda.current_device(),
)
rank = torch.distributed.get_rank(group)
print(f"{rank}: starting all to all")
torch.distributed.all_to_all_single(
output,
input,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
)
print(f"{rank}: finished all to all")
return output

@staticmethod
Expand Down
66 changes: 64 additions & 2 deletions megatron/core/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.module import GraphableMegatronModule, MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.utils import (
deprecate_inference_params,
Expand Down Expand Up @@ -131,7 +131,7 @@ class CrossAttentionSubmodules:
linear_proj: Union[ModuleSpec, type] = None


class Attention(MegatronModule, ABC):
class Attention(GraphableMegatronModule, MegatronModule, ABC):
"""Attention layer abstract class.

This layer only contains common modules required for the "self attn" and
Expand Down Expand Up @@ -637,6 +637,67 @@ def flash_decode_and_prefill(
output_total = flash_attn_with_kvcache(**flash_attn_args)
return output_total

def _should_call_local_cudagraph(self, *args, **kwargs):
"""
Check if we should call the local cudagraph path.
"""
if not self.training and (
hasattr(self, 'cudagraph_manager')
and kwargs['attention_mask'] is None
and (
kwargs.get('inference_context') is not None
or kwargs.get('inference_params') is not None
)
and self.config.cuda_graph_scope == 'attn'
):
if kwargs['inference_context'].is_static_batching():
using_cuda_graph = kwargs['inference_context'].is_decode_only()
else:
using_cuda_graph = kwargs['inference_context'].using_cuda_graph_this_step()

if using_cuda_graph:
return True
return False

def __call__(self,
hidden_states: Tensor,
attention_mask: Tensor,
key_value_states: Optional[Tensor] = None,
inference_context: Optional[BaseInferenceContext] = None,
rotary_pos_emb: Optional[Union[Tensor, Tuple[Tensor, Tensor]]] = None,
rotary_pos_cos: Optional[Tensor] = None,
rotary_pos_sin: Optional[Tensor] = None,
rotary_pos_cos_sin: Optional[Tensor] = None,
attention_bias: Optional[Tensor] = None,
packed_seq_params: Optional[PackedSeqParams] = None,
sequence_len_offset: Optional[int] = None,
*,
inference_params: Optional[BaseInferenceContext] = None,):
inference_context = deprecate_inference_params(inference_context, inference_params)
kwargs = {
'hidden_states': hidden_states,
'attention_mask': attention_mask,
'key_value_states': key_value_states,
'inference_context': inference_context,
'rotary_pos_emb': rotary_pos_emb,
'rotary_pos_cos': rotary_pos_cos,
'rotary_pos_sin': rotary_pos_sin,
'rotary_pos_cos_sin': rotary_pos_cos_sin,
'attention_bias': attention_bias,
'packed_seq_params': packed_seq_params,
'sequence_len_offset': sequence_len_offset,
'inference_params': inference_context,
}
if self._should_call_local_cudagraph(**kwargs):
# dynamic_inference_decode_only is not a real argument to forward, it is only used
# to differentiate the cuda graph used for decode from the one used for non-decode
# inference.
dynamic_inference_decode_only = kwargs['inference_context'].is_decode_only()
return super().__call__(
dynamic_inference_decode_only=dynamic_inference_decode_only, **kwargs
)
return super().__call__(**kwargs)

def forward(
self,
hidden_states: Tensor,
Expand All @@ -652,6 +713,7 @@ def forward(
sequence_len_offset: Optional[int] = None,
*,
inference_params: Optional[BaseInferenceContext] = None,
dynamic_inference_decode_only: Optional[bool] = None,
) -> Tuple[Tensor, Tensor]:
"""
Perform a forward pass through the attention module.
Expand Down
14 changes: 13 additions & 1 deletion megatron/core/transformer/moe/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
self.shared_experts = None
self.token_dispatcher: Optional[MoETokenDispatcher] = None
self.layer_number = layer_number


@abstractmethod
def forward(self, hidden_states):
Expand Down Expand Up @@ -122,7 +123,7 @@ def __init__(

# Initialize router
self.router = TopKRouter(config=self.config, pg_collection=pg_collection)

# Initialize token dispatcher
if config.moe_token_dispatcher_type == "allgather":
self.token_dispatcher = MoEAllGatherTokenDispatcher(
Expand Down Expand Up @@ -174,11 +175,15 @@ def router_and_preprocess(self, hidden_states: torch.Tensor):
hidden states and probabilities for the token dispatcher. The original
hidden states are returned as a residual connection.
"""
rank = torch.distributed.get_rank()
residual = hidden_states
print(f"{rank}: about to run router...")
probs, routing_map = self.router(hidden_states)
print(f"{rank}: ran router...")
hidden_states, probs = self.token_dispatcher.dispatch_preprocess(
hidden_states, routing_map, probs
)
print(f"{rank}: ran dispatch preprocess...")
return hidden_states, probs, residual

def dispatch(self, hidden_states: torch.Tensor, probs: torch.Tensor):
Expand Down Expand Up @@ -262,6 +267,8 @@ def forward(self, hidden_states: torch.Tensor):
Returns:
A tuple containing the output tensor and the MLP bias, if any.
"""
rank = torch.distributed.get_rank()
print(f"{rank}: inside moe layer FW pass...")
if self.training and self.attn_tp_group.size() > 1 and not self.config.sequence_parallel:
raise ValueError(
"During training, performance may degrade if MoE and tensor parallelism"
Expand All @@ -271,10 +278,15 @@ def forward(self, hidden_states: torch.Tensor):
# MoE forward: route -> dispatch -> compute -> combine
def custom_forward(hidden_states):
shared_expert_output = self.shared_experts_compute(hidden_states)
print(f"{rank}: ran shared experts compute ...")
hidden_states, probs, residual = self.router_and_preprocess(hidden_states)
print(f"{rank}: ran router and preprocess ...")
dispatched_input, probs = self.dispatch(hidden_states, probs)
print(f"{rank}: ran dispatch ...")
output, mlp_bias = self.routed_experts_compute(dispatched_input, probs, residual)
print(f"{rank}: ran routed experts compute ...")
output = self.combine(output, shared_expert_output)
print(f"{rank}: ran combine ...")
return output, mlp_bias

if self.moe_layer_recompute:
Expand Down
Loading
Loading