diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index e95409e08e..cf0de64788 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -312,7 +312,7 @@ def __init__( ) if is_te_min_version("0.8.0"): - if self.config.tp_comm_overlap: + if self.config.tp_comm_overlap and parallel_mode != "duplicated": if is_te_min_version("1.5.0"): # Use old overlap flags if they were supplied instead extra_kwargs["ub_overlap_ag"] = ( diff --git a/megatron/core/transformer/mlp.py b/megatron/core/transformer/mlp.py index 9602beb2f7..4a95134b8e 100644 --- a/megatron/core/transformer/mlp.py +++ b/megatron/core/transformer/mlp.py @@ -104,9 +104,12 @@ def __init__( if self.config.gated_linear_unit: ffn_hidden_size *= 2 + # Use moe_latent_size only for routed experts. 'is_expert' is false for shared_experts + use_latent_size = (self.config.moe_latent_size is not None) and is_expert + self.linear_fc1 = build_module( submodules.linear_fc1, - self.input_size, + self.input_size if not use_latent_size else self.config.moe_latent_size, ffn_hidden_size, config=self.config, init_method=self.config.init_method, @@ -126,7 +129,7 @@ def __init__( self.linear_fc2 = build_module( submodules.linear_fc2, self.config.ffn_hidden_size, - self.config.hidden_size, + self.config.hidden_size if not use_latent_size else self.config.moe_latent_size, config=self.config, init_method=self.config.output_layer_init_method, bias=self.config.add_bias_linear, diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 68a3d53d2b..9ea26e3e2e 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -118,6 +118,9 @@ def __init__( assert ( config.add_bias_linear == False ), "bias not supported in Grouped GEMM yet, please set '--disable-bias-linear' instead." + assert ( + config.moe_latent_size is None + ), "MoE latent projection not supported in GroupedMLP yet." self.expert_parallel = config.expert_model_parallel_size > 1 if self.config.gated_linear_unit: @@ -778,7 +781,7 @@ def __init__( self.linear_fc1 = build_module( submodules.linear_fc1, self.num_local_experts, - self.input_size, + self.input_size if self.config.moe_latent_size is None else self.config.moe_latent_size, ffn_hidden_size, config=self.config, init_method=self.config.init_method, @@ -799,7 +802,11 @@ def __init__( submodules.linear_fc2, self.num_local_experts, self.config.moe_ffn_hidden_size, - self.config.hidden_size, + ( + self.config.hidden_size + if self.config.moe_latent_size is None + else self.config.moe_latent_size + ), config=self.config, init_method=self.config.output_layer_init_method, bias=self.config.add_bias_linear, diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index e3de8220a5..d460dec5a9 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -23,7 +23,7 @@ try: import transformer_engine as te # pylint: disable=unused-import - from megatron.core.extensions.transformer_engine import te_checkpoint + from megatron.core.extensions.transformer_engine import TELinear, te_checkpoint HAVE_TE = True except ImportError: @@ -123,6 +123,32 @@ def __init__( # Initialize router self.router = TopKRouter(config=self.config, pg_collection=pg_collection) + # Initialize latent projections + if self.config.moe_latent_size: + assert HAVE_TE + self.fc1_latent_proj = TELinear( + self.config.hidden_size, + self.config.moe_latent_size, + parallel_mode="duplicated", + config=self.config, + init_method=self.config.init_method, + bias=self.config.add_bias_linear, + skip_bias_add=False, + skip_weight_param_allocation=False, + is_expert=False, + ) + self.fc2_latent_proj = TELinear( + self.config.moe_latent_size, + self.config.hidden_size, + parallel_mode="duplicated", + config=self.config, + init_method=self.config.output_layer_init_method, + bias=self.config.add_bias_linear, + skip_bias_add=True, + skip_weight_param_allocation=False, + is_expert=False, + ) + # Initialize token dispatcher if config.moe_token_dispatcher_type == "allgather": self.token_dispatcher = MoEAllGatherTokenDispatcher( @@ -272,9 +298,26 @@ def forward(self, hidden_states: torch.Tensor): def custom_forward(hidden_states): shared_expert_output = self.shared_experts_compute(hidden_states) hidden_states, probs, residual = self.router_and_preprocess(hidden_states) + + # Project the hidden_states from hidden dimension down to latent dimenion. + if self.config.moe_latent_size: + assert ( + not self.shared_expert_overlap + ), "Shared expert overlap not supported when MoE latent projections are used." + hidden_states, _ = self.fc1_latent_proj(hidden_states) + dispatched_input, probs = self.dispatch(hidden_states, probs) output, mlp_bias = self.routed_experts_compute(dispatched_input, probs, residual) + + if self.config.moe_latent_size and mlp_bias is not None: + output = output + mlp_bias + mlp_bias = None output = self.combine(output, shared_expert_output) + # Project the output back from latent dimension to hidden dimension after combine + # in latent dimension. + if self.config.moe_latent_size: + output, _ = self.fc2_latent_proj(output) + return output, mlp_bias if self.moe_layer_recompute: diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index c0582a3a87..02b3d4c9a9 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -216,6 +216,9 @@ class TransformerConfig(ModelParallelConfig): """Number of SMs to use for HybridEP. In pure NVL scenarios, 16 SMs can generally achieve good bandwidth.""" + moe_latent_size: Optional[int] = None + """Latent projection dimension for MoE. If None, MoE latent projections are not used.""" + #################### # initialization #################### diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 8c533e36f7..616e4f7f7a 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1203,6 +1203,13 @@ def validate_args(args, defaults={}): args.recompute_granularity != 'full' ), 'recompute_granularity must not be full when CUDA Graphs are enabled.' + # MoE latent projections + if args.moe_latent_size is not None: + assert args.moe_latent_size > 0, "MoE latent projection dimension has to be greater than zero." + assert args.num_experts is not None, "MoE latent projections are applicable only for MoE models." + assert not args.use_legacy_models, "MoE latent projections are only supported for mcore models." + assert not args.moe_use_legacy_grouped_gemm, "MoE latent projection is not supported yet with legacy grouped GEMM." + # Print arguments. _print_args("arguments", args) @@ -1302,6 +1309,7 @@ def core_transformer_config_from_args(args, config_class=None): kw_args['use_kitchen'] = True kw_args['quant_recipe'] = kitchen_quantization_recipe_config(args.kitchen_recipe_number) + kw_args['moe_latent_size'] = args.moe_latent_size # Return config. return config_class(**kw_args) @@ -1671,6 +1679,8 @@ def _add_network_size_args(parser): 'We compute the average of the MTP losses across all depths, ' 'and multiply it the scaling factor to obtain the overall MTP loss, ' 'which serves as an additional training objective.') + group.add_argument('--moe-latent-size', type=int, default=None, + help='Latent projection dimension for MoE. If None, MoE latent projections are not used.') return parser diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py index 95bd015a90..0ccfe4188a 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py @@ -1335,6 +1335,9 @@ def _set_arg(arg_name, old_arg_name=None, force=False): _set_arg('heterogeneous_layers_config_path', force=True) _set_arg('heterogeneous_layers_config_encoded_json', force=True) + # MoE latent projection + _set_arg('moe_latent_size', force=True) + # Tokenizer args. _set_arg('tokenizer_type', force=True) # Using checkpoint version might not always be safe (e.g., if running on different cluster). diff --git a/megatron/training/training.py b/megatron/training/training.py index b162aa87ac..c80bc5cb6e 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -178,11 +178,19 @@ def mlp_layer_flops(batch_size, seq_len, hidden_size, expansion=4.0, swiglu=Fals return 4 * expansion * scale_factor * batch_size * seq_len * hidden_size**2 def moe_layer_flops(batch_size, seq_len, hidden_size, moe_ffn_hidden_size, - shared_expert_ffn_hidden_size, num_experts_routed_to, swiglu=False): + shared_expert_ffn_hidden_size, num_experts_routed_to, + moe_latent_size=None, swiglu=False): """Calculate FLOPs for an MoE layer.""" scale_factor = 3.0 / 2.0 if swiglu else 1.0 - routed_flops = (4 * batch_size * seq_len * hidden_size * - moe_ffn_hidden_size * num_experts_routed_to * scale_factor) + if moe_latent_size is None: + routed_flops = (4 * batch_size * seq_len * hidden_size * + moe_ffn_hidden_size * num_experts_routed_to * scale_factor) + else: + # Routed experts run on moe_latent_size. + routed_flops = (4 * batch_size * seq_len * moe_latent_size * + moe_ffn_hidden_size * num_experts_routed_to * scale_factor) + # Up proj and down proj. + routed_flops += (4 * batch_size * seq_len * hidden_size * moe_latent_size) shared_flops = 4 * batch_size * seq_len * hidden_size * shared_expert_ffn_hidden_size * scale_factor return routed_flops + shared_flops @@ -230,6 +238,7 @@ def hybrid_flops(batch_size, seq_len, hidden_size, num_attn_heads=32, gqa=True, gqa_groups=8, kv_channels=None, mlp_expansion=4.0, swiglu=False, + moe_latent_size=None, moe_ffn_hidden_size=2048, shared_expert_ffn_hidden_size=2048, num_experts_routed_to=1, vocab_size=256000): """Calculate total FLOPs for the hybrid model.""" @@ -242,7 +251,8 @@ def hybrid_flops(batch_size, seq_len, hidden_size, mamba_state_dim, mamba_head_dim, mamba_num_groups, mamba_num_heads) + num_moe_layers * moe_layer_flops(batch_size, seq_len, hidden_size, moe_ffn_hidden_size, - shared_expert_ffn_hidden_size, num_experts_routed_to, swiglu) + + shared_expert_ffn_hidden_size, num_experts_routed_to, + moe_latent_size, swiglu) + (2 * batch_size * seq_len * hidden_size * vocab_size) # logits computation ) return flops_fwd * 3 @@ -447,6 +457,7 @@ def transformer_flops(): kv_channels=args.kv_channels, mlp_expansion=args.ffn_hidden_size / args.hidden_size, swiglu=args.swiglu, + moe_latent_size=args.moe_latent_size, moe_ffn_hidden_size=(args.moe_ffn_hidden_size if args.moe_ffn_hidden_size is not None else args.ffn_hidden_size), shared_expert_ffn_hidden_size=(0 if args.moe_shared_expert_intermediate_size is None