Skip to content

Commit 1a78e7a

Browse files
authored
[None][feat] AutoDeploy: Support Latent MOE for Nemotron (#8955)
Signed-off-by: Chenghao Zhang <[email protected]>
1 parent ada93f1 commit 1a78e7a

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,22 @@ def _nemotron_h_block_forward(
9292
def _nemotron_h_moe_forward(self, hidden_states: torch.Tensor):
9393
"""
9494
Uses NemotronH router (returns indices, weights) and dispatches through auto_deploy::torch_moe
95-
with act_fn='relu2'. Falls back to original forward if any expert has bias.
95+
with act_fn='relu2'. Handles both latent MOE and direct MOE architectures.
9696
"""
9797

9898
residuals = hidden_states
9999
orig_shape = hidden_states.shape
100100
topk_indices, topk_weights = self.gate(hidden_states)
101101
x_flat = hidden_states.view(-1, hidden_states.shape[-1])
102102

103+
# Check if this is a latent MOE (has fc1_latent_proj and fc2_latent_proj)
104+
has_latent_proj = hasattr(self, "fc1_latent_proj") and hasattr(self, "fc2_latent_proj")
105+
106+
if has_latent_proj:
107+
# Latent MOE: project to latent space before routing
108+
x_flat = self.fc1_latent_proj(x_flat)
109+
110+
# Route through experts (operates in latent space if latent MOE, full space otherwise)
103111
out_flat = torch.ops.auto_deploy.torch_moe(
104112
x_flat,
105113
topk_indices,
@@ -111,6 +119,10 @@ def _nemotron_h_moe_forward(self, hidden_states: torch.Tensor):
111119
mlp_style="mlp",
112120
)
113121

122+
if has_latent_proj:
123+
# Latent MOE: project back from latent space
124+
out_flat = self.fc2_latent_proj(out_flat)
125+
114126
out = out_flat.view(*orig_shape)
115127
out = out + self.shared_experts(residuals)
116128
return out

0 commit comments

Comments
 (0)