|
23 | 23 | try: |
24 | 24 | import transformer_engine as te # pylint: disable=unused-import |
25 | 25 |
|
26 | | - from megatron.core.extensions.transformer_engine import te_checkpoint |
| 26 | + from megatron.core.extensions.transformer_engine import TELinear, te_checkpoint |
27 | 27 |
|
28 | 28 | HAVE_TE = True |
29 | 29 | except ImportError: |
@@ -123,6 +123,32 @@ def __init__( |
123 | 123 | # Initialize router |
124 | 124 | self.router = TopKRouter(config=self.config, pg_collection=pg_collection) |
125 | 125 |
|
| 126 | + # Initialize latent projections |
| 127 | + if self.config.moe_latent_size: |
| 128 | + assert HAVE_TE |
| 129 | + self.fc1_latent_proj = TELinear( |
| 130 | + self.config.hidden_size, |
| 131 | + self.config.moe_latent_size, |
| 132 | + parallel_mode="duplicated", |
| 133 | + config=self.config, |
| 134 | + init_method=self.config.init_method, |
| 135 | + bias=self.config.add_bias_linear, |
| 136 | + skip_bias_add=False, |
| 137 | + skip_weight_param_allocation=False, |
| 138 | + is_expert=False, |
| 139 | + ) |
| 140 | + self.fc2_latent_proj = TELinear( |
| 141 | + self.config.moe_latent_size, |
| 142 | + self.config.hidden_size, |
| 143 | + parallel_mode="duplicated", |
| 144 | + config=self.config, |
| 145 | + init_method=self.config.output_layer_init_method, |
| 146 | + bias=self.config.add_bias_linear, |
| 147 | + skip_bias_add=True, |
| 148 | + skip_weight_param_allocation=False, |
| 149 | + is_expert=False, |
| 150 | + ) |
| 151 | + |
126 | 152 | # Initialize token dispatcher |
127 | 153 | if config.moe_token_dispatcher_type == "allgather": |
128 | 154 | self.token_dispatcher = MoEAllGatherTokenDispatcher( |
@@ -272,8 +298,23 @@ def forward(self, hidden_states: torch.Tensor): |
272 | 298 | def custom_forward(hidden_states): |
273 | 299 | shared_expert_output = self.shared_experts_compute(hidden_states) |
274 | 300 | hidden_states, probs, residual = self.router_and_preprocess(hidden_states) |
| 301 | + |
| 302 | + # Project the hidden_states from hidden dimension down to latent dimenion. |
| 303 | + if self.config.moe_latent_size: |
| 304 | + assert ( |
| 305 | + not self.shared_expert_overlap |
| 306 | + ), "Shared expert overlap not supported when MoE latent projections are used." |
| 307 | + hidden_states, _ = self.fc1_latent_proj(hidden_states) |
| 308 | + |
275 | 309 | dispatched_input, probs = self.dispatch(hidden_states, probs) |
276 | 310 | output, mlp_bias = self.routed_experts_compute(dispatched_input, probs, residual) |
| 311 | + |
| 312 | + # Project the output back from latent dimension to hidden dimension |
| 313 | + if self.config.moe_latent_size: |
| 314 | + if mlp_bias is not None: |
| 315 | + output = output + mlp_bias |
| 316 | + output, mlp_bias = self.fc2_latent_proj(output) |
| 317 | + |
277 | 318 | output = self.combine(output, shared_expert_output) |
278 | 319 | return output, mlp_bias |
279 | 320 |
|
|
0 commit comments