Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def __init__(
)

if is_te_min_version("0.8.0"):
if self.config.tp_comm_overlap:
if self.config.tp_comm_overlap and parallel_mode != "duplicated":
if is_te_min_version("1.5.0"):
# Use old overlap flags if they were supplied instead
extra_kwargs["ub_overlap_ag"] = (
Expand Down
7 changes: 5 additions & 2 deletions megatron/core/transformer/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,12 @@ def __init__(
if self.config.gated_linear_unit:
ffn_hidden_size *= 2

# Use moe_latent_size only for routed experts. 'is_expert' is false for shared_experts
use_latent_size = (self.config.moe_latent_size is not None) and is_expert

self.linear_fc1 = build_module(
submodules.linear_fc1,
self.input_size,
self.input_size if not use_latent_size else self.config.moe_latent_size,
ffn_hidden_size,
config=self.config,
init_method=self.config.init_method,
Expand All @@ -126,7 +129,7 @@ def __init__(
self.linear_fc2 = build_module(
submodules.linear_fc2,
self.config.ffn_hidden_size,
self.config.hidden_size,
self.config.hidden_size if not use_latent_size else self.config.moe_latent_size,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
Expand Down
11 changes: 9 additions & 2 deletions megatron/core/transformer/moe/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ def __init__(
assert (
config.add_bias_linear == False
), "bias not supported in Grouped GEMM yet, please set '--disable-bias-linear' instead."
assert (
config.moe_latent_size is None
), "MoE latent projection not supported in GroupedMLP yet."

self.expert_parallel = config.expert_model_parallel_size > 1
if self.config.gated_linear_unit:
Expand Down Expand Up @@ -778,7 +781,7 @@ def __init__(
self.linear_fc1 = build_module(
submodules.linear_fc1,
self.num_local_experts,
self.input_size,
self.input_size if self.config.moe_latent_size is None else self.config.moe_latent_size,
ffn_hidden_size,
config=self.config,
init_method=self.config.init_method,
Expand All @@ -799,7 +802,11 @@ def __init__(
submodules.linear_fc2,
self.num_local_experts,
self.config.moe_ffn_hidden_size,
self.config.hidden_size,
(
self.config.hidden_size
if self.config.moe_latent_size is None
else self.config.moe_latent_size
),
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
Expand Down
45 changes: 44 additions & 1 deletion megatron/core/transformer/moe/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
try:
import transformer_engine as te # pylint: disable=unused-import

from megatron.core.extensions.transformer_engine import te_checkpoint
from megatron.core.extensions.transformer_engine import TELinear, te_checkpoint

HAVE_TE = True
except ImportError:
Expand Down Expand Up @@ -123,6 +123,32 @@ def __init__(
# Initialize router
self.router = TopKRouter(config=self.config, pg_collection=pg_collection)

# Initialize latent projections
if self.config.moe_latent_size:
assert HAVE_TE
self.fc1_latent_proj = TELinear(
self.config.hidden_size,
self.config.moe_latent_size,
parallel_mode="duplicated",
config=self.config,
init_method=self.config.init_method,
bias=self.config.add_bias_linear,
skip_bias_add=False,
skip_weight_param_allocation=False,
is_expert=False,
)
self.fc2_latent_proj = TELinear(
self.config.moe_latent_size,
self.config.hidden_size,
parallel_mode="duplicated",
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
skip_bias_add=True,
skip_weight_param_allocation=False,
is_expert=False,
)

# Initialize token dispatcher
if config.moe_token_dispatcher_type == "allgather":
self.token_dispatcher = MoEAllGatherTokenDispatcher(
Expand Down Expand Up @@ -272,9 +298,26 @@ def forward(self, hidden_states: torch.Tensor):
def custom_forward(hidden_states):
shared_expert_output = self.shared_experts_compute(hidden_states)
hidden_states, probs, residual = self.router_and_preprocess(hidden_states)

# Project the hidden_states from hidden dimension down to latent dimenion.
if self.config.moe_latent_size:
assert (
not self.shared_expert_overlap
), "Shared expert overlap not supported when MoE latent projections are used."
hidden_states, _ = self.fc1_latent_proj(hidden_states)

dispatched_input, probs = self.dispatch(hidden_states, probs)
output, mlp_bias = self.routed_experts_compute(dispatched_input, probs, residual)

if self.config.moe_latent_size and mlp_bias is not None:
output = output + mlp_bias
mlp_bias = None
output = self.combine(output, shared_expert_output)
# Project the output back from latent dimension to hidden dimension after combine
# in latent dimension.
if self.config.moe_latent_size:
output, _ = self.fc2_latent_proj(output)

return output, mlp_bias

if self.moe_layer_recompute:
Expand Down
3 changes: 3 additions & 0 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ class TransformerConfig(ModelParallelConfig):
"""Number of SMs to use for HybridEP. In pure NVL scenarios,
16 SMs can generally achieve good bandwidth."""

moe_latent_size: Optional[int] = None
"""Latent projection dimension for MoE. If None, MoE latent projections are not used."""

####################
# initialization
####################
Expand Down
10 changes: 10 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,6 +1203,13 @@ def validate_args(args, defaults={}):
args.recompute_granularity != 'full'
), 'recompute_granularity must not be full when CUDA Graphs are enabled.'

# MoE latent projections
if args.moe_latent_size is not None:
assert args.moe_latent_size > 0, "MoE latent projection dimension has to be greater than zero."
assert args.num_experts is not None, "MoE latent projections are applicable only for MoE models."
assert not args.use_legacy_models, "MoE latent projections are only supported for mcore models."
assert not args.moe_use_legacy_grouped_gemm, "MoE latent projection is not supported yet with legacy grouped GEMM."

# Print arguments.
_print_args("arguments", args)

Expand Down Expand Up @@ -1302,6 +1309,7 @@ def core_transformer_config_from_args(args, config_class=None):
kw_args['use_kitchen'] = True
kw_args['quant_recipe'] = kitchen_quantization_recipe_config(args.kitchen_recipe_number)

kw_args['moe_latent_size'] = args.moe_latent_size

# Return config.
return config_class(**kw_args)
Expand Down Expand Up @@ -1671,6 +1679,8 @@ def _add_network_size_args(parser):
'We compute the average of the MTP losses across all depths, '
'and multiply it the scaling factor to obtain the overall MTP loss, '
'which serves as an additional training objective.')
group.add_argument('--moe-latent-size', type=int, default=None,
help='Latent projection dimension for MoE. If None, MoE latent projections are not used.')
return parser


Expand Down
3 changes: 3 additions & 0 deletions megatron/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1335,6 +1335,9 @@ def _set_arg(arg_name, old_arg_name=None, force=False):
_set_arg('heterogeneous_layers_config_path', force=True)
_set_arg('heterogeneous_layers_config_encoded_json', force=True)

# MoE latent projection
_set_arg('moe_latent_size', force=True)

# Tokenizer args.
_set_arg('tokenizer_type', force=True)
# Using checkpoint version might not always be safe (e.g., if running on different cluster).
Expand Down
19 changes: 15 additions & 4 deletions megatron/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,19 @@ def mlp_layer_flops(batch_size, seq_len, hidden_size, expansion=4.0, swiglu=Fals
return 4 * expansion * scale_factor * batch_size * seq_len * hidden_size**2

def moe_layer_flops(batch_size, seq_len, hidden_size, moe_ffn_hidden_size,
shared_expert_ffn_hidden_size, num_experts_routed_to, swiglu=False):
shared_expert_ffn_hidden_size, num_experts_routed_to,
moe_latent_size=None, swiglu=False):
"""Calculate FLOPs for an MoE layer."""
scale_factor = 3.0 / 2.0 if swiglu else 1.0
routed_flops = (4 * batch_size * seq_len * hidden_size *
moe_ffn_hidden_size * num_experts_routed_to * scale_factor)
if moe_latent_size is None:
routed_flops = (4 * batch_size * seq_len * hidden_size *
moe_ffn_hidden_size * num_experts_routed_to * scale_factor)
else:
# Routed experts run on moe_latent_size.
routed_flops = (4 * batch_size * seq_len * moe_latent_size *
moe_ffn_hidden_size * num_experts_routed_to * scale_factor)
# Up proj and down proj.
routed_flops += (4 * batch_size * seq_len * hidden_size * moe_latent_size)
shared_flops = 4 * batch_size * seq_len * hidden_size * shared_expert_ffn_hidden_size * scale_factor
return routed_flops + shared_flops

Expand Down Expand Up @@ -230,6 +238,7 @@ def hybrid_flops(batch_size, seq_len, hidden_size,
num_attn_heads=32, gqa=True,
gqa_groups=8, kv_channels=None,
mlp_expansion=4.0, swiglu=False,
moe_latent_size=None,
moe_ffn_hidden_size=2048, shared_expert_ffn_hidden_size=2048, num_experts_routed_to=1,
vocab_size=256000):
"""Calculate total FLOPs for the hybrid model."""
Expand All @@ -242,7 +251,8 @@ def hybrid_flops(batch_size, seq_len, hidden_size,
mamba_state_dim, mamba_head_dim,
mamba_num_groups, mamba_num_heads) +
num_moe_layers * moe_layer_flops(batch_size, seq_len, hidden_size, moe_ffn_hidden_size,
shared_expert_ffn_hidden_size, num_experts_routed_to, swiglu) +
shared_expert_ffn_hidden_size, num_experts_routed_to,
moe_latent_size, swiglu) +
(2 * batch_size * seq_len * hidden_size * vocab_size) # logits computation
)
return flops_fwd * 3
Expand Down Expand Up @@ -447,6 +457,7 @@ def transformer_flops():
kv_channels=args.kv_channels,
mlp_expansion=args.ffn_hidden_size / args.hidden_size,
swiglu=args.swiglu,
moe_latent_size=args.moe_latent_size,
moe_ffn_hidden_size=(args.moe_ffn_hidden_size if args.moe_ffn_hidden_size is not None
else args.ffn_hidden_size),
shared_expert_ffn_hidden_size=(0 if args.moe_shared_expert_intermediate_size is None
Expand Down
Loading