Skip to content

Commit d088236

Browse files
venmugildeepakn94
authored andcommitted
Changes to support latent MoEs
Signed-off-by: Deepak Narayanan <[email protected]>
1 parent 3b83c3f commit d088236

File tree

7 files changed

+73
-6
lines changed

7 files changed

+73
-6
lines changed

megatron/core/extensions/transformer_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def __init__(
312312
)
313313

314314
if is_te_min_version("0.8.0"):
315-
if self.config.tp_comm_overlap:
315+
if self.config.tp_comm_overlap and parallel_mode != "duplicated":
316316
if is_te_min_version("1.5.0"):
317317
# Use old overlap flags if they were supplied instead
318318
extra_kwargs["ub_overlap_ag"] = (

megatron/core/transformer/mlp.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,12 @@ def __init__(
104104
if self.config.gated_linear_unit:
105105
ffn_hidden_size *= 2
106106

107+
# Use moe_latent_size only for routed experts. 'is_expert' is false for shared_experts
108+
use_latent_size = (self.config.moe_latent_size is not None) and is_expert
109+
107110
self.linear_fc1 = build_module(
108111
submodules.linear_fc1,
109-
self.input_size,
112+
self.input_size if not use_latent_size else self.config.moe_latent_size,
110113
ffn_hidden_size,
111114
config=self.config,
112115
init_method=self.config.init_method,
@@ -126,7 +129,7 @@ def __init__(
126129
self.linear_fc2 = build_module(
127130
submodules.linear_fc2,
128131
self.config.ffn_hidden_size,
129-
self.config.hidden_size,
132+
self.config.hidden_size if not use_latent_size else self.config.moe_latent_size,
130133
config=self.config,
131134
init_method=self.config.output_layer_init_method,
132135
bias=self.config.add_bias_linear,

megatron/core/transformer/moe/experts.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@ def __init__(
118118
assert (
119119
config.add_bias_linear == False
120120
), "bias not supported in Grouped GEMM yet, please set '--disable-bias-linear' instead."
121+
assert (
122+
config.moe_latent_size is None
123+
), "MoE latent projection not supported in GroupedMLP yet."
121124

122125
self.expert_parallel = config.expert_model_parallel_size > 1
123126
if self.config.gated_linear_unit:
@@ -778,7 +781,7 @@ def __init__(
778781
self.linear_fc1 = build_module(
779782
submodules.linear_fc1,
780783
self.num_local_experts,
781-
self.input_size,
784+
self.input_size if self.config.moe_latent_size is None else self.config.moe_latent_size,
782785
ffn_hidden_size,
783786
config=self.config,
784787
init_method=self.config.init_method,
@@ -799,7 +802,11 @@ def __init__(
799802
submodules.linear_fc2,
800803
self.num_local_experts,
801804
self.config.moe_ffn_hidden_size,
802-
self.config.hidden_size,
805+
(
806+
self.config.hidden_size
807+
if self.config.moe_latent_size is None
808+
else self.config.moe_latent_size
809+
),
803810
config=self.config,
804811
init_method=self.config.output_layer_init_method,
805812
bias=self.config.add_bias_linear,

megatron/core/transformer/moe/moe_layer.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
try:
2424
import transformer_engine as te # pylint: disable=unused-import
2525

26-
from megatron.core.extensions.transformer_engine import te_checkpoint
26+
from megatron.core.extensions.transformer_engine import TELinear, te_checkpoint
2727

2828
HAVE_TE = True
2929
except ImportError:
@@ -123,6 +123,32 @@ def __init__(
123123
# Initialize router
124124
self.router = TopKRouter(config=self.config, pg_collection=pg_collection)
125125

126+
# Initialize latent projections
127+
if self.config.moe_latent_size:
128+
assert HAVE_TE
129+
self.fc1_latent_proj = TELinear(
130+
self.config.hidden_size,
131+
self.config.moe_latent_size,
132+
parallel_mode="duplicated",
133+
config=self.config,
134+
init_method=self.config.init_method,
135+
bias=self.config.add_bias_linear,
136+
skip_bias_add=False,
137+
skip_weight_param_allocation=False,
138+
is_expert=False,
139+
)
140+
self.fc2_latent_proj = TELinear(
141+
self.config.moe_latent_size,
142+
self.config.hidden_size,
143+
parallel_mode="duplicated",
144+
config=self.config,
145+
init_method=self.config.output_layer_init_method,
146+
bias=self.config.add_bias_linear,
147+
skip_bias_add=True,
148+
skip_weight_param_allocation=False,
149+
is_expert=False,
150+
)
151+
126152
# Initialize token dispatcher
127153
if config.moe_token_dispatcher_type == "allgather":
128154
self.token_dispatcher = MoEAllGatherTokenDispatcher(
@@ -272,8 +298,23 @@ def forward(self, hidden_states: torch.Tensor):
272298
def custom_forward(hidden_states):
273299
shared_expert_output = self.shared_experts_compute(hidden_states)
274300
hidden_states, probs, residual = self.router_and_preprocess(hidden_states)
301+
302+
# Project the hidden_states from hidden dimension down to latent dimenion.
303+
if self.config.moe_latent_size:
304+
assert (
305+
not self.shared_expert_overlap
306+
), "Shared expert overlap not supported when MoE latent projections are used."
307+
hidden_states, _ = self.fc1_latent_proj(hidden_states)
308+
275309
dispatched_input, probs = self.dispatch(hidden_states, probs)
276310
output, mlp_bias = self.routed_experts_compute(dispatched_input, probs, residual)
311+
312+
# Project the output back from latent dimension to hidden dimension
313+
if self.config.moe_latent_size:
314+
if mlp_bias is not None:
315+
output = output + mlp_bias
316+
output, mlp_bias = self.fc2_latent_proj(output)
317+
277318
output = self.combine(output, shared_expert_output)
278319
return output, mlp_bias
279320

megatron/core/transformer/transformer_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,9 @@ class TransformerConfig(ModelParallelConfig):
216216
"""Number of SMs to use for HybridEP. In pure NVL scenarios,
217217
16 SMs can generally achieve good bandwidth."""
218218

219+
moe_latent_size: Optional[int] = None
220+
"""Latent projection dimension for MoE. If None, MoE latent projections are not used."""
221+
219222
####################
220223
# initialization
221224
####################

megatron/training/arguments.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,6 +1203,13 @@ def validate_args(args, defaults={}):
12031203
args.recompute_granularity != 'full'
12041204
), 'recompute_granularity must not be full when CUDA Graphs are enabled.'
12051205

1206+
# MoE latent projections
1207+
if args.moe_latent_size is not None:
1208+
assert args.moe_latent_size > 0, "MoE latent projection dimension has to be greater than zero."
1209+
assert args.num_experts is not None, "MoE latent projections are applicable only for MoE models."
1210+
assert not args.use_legacy_models, "MoE latent projections are only supported for mcore models."
1211+
assert not args.moe_use_legacy_grouped_gemm, "MoE latent projection is not supported yet with legacy grouped GEMM."
1212+
12061213
# Print arguments.
12071214
_print_args("arguments", args)
12081215

@@ -1302,6 +1309,7 @@ def core_transformer_config_from_args(args, config_class=None):
13021309
kw_args['use_kitchen'] = True
13031310
kw_args['quant_recipe'] = kitchen_quantization_recipe_config(args.kitchen_recipe_number)
13041311

1312+
kw_args['moe_latent_size'] = args.moe_latent_size
13051313

13061314
# Return config.
13071315
return config_class(**kw_args)
@@ -1671,6 +1679,8 @@ def _add_network_size_args(parser):
16711679
'We compute the average of the MTP losses across all depths, '
16721680
'and multiply it the scaling factor to obtain the overall MTP loss, '
16731681
'which serves as an additional training objective.')
1682+
group.add_argument('--moe-latent-size', type=int, default=None,
1683+
help='Latent projection dimension for MoE. If None, MoE latent projections are not used.')
16741684
return parser
16751685

16761686

megatron/training/checkpointing.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,6 +1335,9 @@ def _set_arg(arg_name, old_arg_name=None, force=False):
13351335
_set_arg('heterogeneous_layers_config_path', force=True)
13361336
_set_arg('heterogeneous_layers_config_encoded_json', force=True)
13371337

1338+
# MoE latent projection
1339+
_set_arg('moe_latent_size', force=True)
1340+
13381341
# Tokenizer args.
13391342
_set_arg('tokenizer_type', force=True)
13401343
# Using checkpoint version might not always be safe (e.g., if running on different cluster).

0 commit comments

Comments
 (0)