@@ -92,14 +92,22 @@ def _nemotron_h_block_forward(
9292def _nemotron_h_moe_forward (self , hidden_states : torch .Tensor ):
9393 """
9494 Uses NemotronH router (returns indices, weights) and dispatches through auto_deploy::torch_moe
95- with act_fn='relu2'. Falls back to original forward if any expert has bias .
95+ with act_fn='relu2'. Handles both latent MOE and direct MOE architectures .
9696 """
9797
9898 residuals = hidden_states
9999 orig_shape = hidden_states .shape
100100 topk_indices , topk_weights = self .gate (hidden_states )
101101 x_flat = hidden_states .view (- 1 , hidden_states .shape [- 1 ])
102102
103+ # Check if this is a latent MOE (has fc1_latent_proj and fc2_latent_proj)
104+ has_latent_proj = hasattr (self , "fc1_latent_proj" ) and hasattr (self , "fc2_latent_proj" )
105+
106+ if has_latent_proj :
107+ # Latent MOE: project to latent space before routing
108+ x_flat = self .fc1_latent_proj (x_flat )
109+
110+ # Route through experts (operates in latent space if latent MOE, full space otherwise)
103111 out_flat = torch .ops .auto_deploy .torch_moe (
104112 x_flat ,
105113 topk_indices ,
@@ -111,6 +119,10 @@ def _nemotron_h_moe_forward(self, hidden_states: torch.Tensor):
111119 mlp_style = "mlp" ,
112120 )
113121
122+ if has_latent_proj :
123+ # Latent MOE: project back from latent space
124+ out_flat = self .fc2_latent_proj (out_flat )
125+
114126 out = out_flat .view (* orig_shape )
115127 out = out + self .shared_experts (residuals )
116128 return out
0 commit comments