@@ -227,8 +227,8 @@ def overlap_tpsp_token_forward(
227227 _0_o = self ._token_attention_kernel (_0_q , infer_state , layer_weight )
228228 _0_q = None
229229 _0_o = self ._tpsp_get_o (_0_o , infer_state , layer_weight )
230- _0_input1 = _0_o .view (- 1 , self .embed_dim_ )
231230 if FLASHINFER_AVAILABLE and layer_weight .norm_type == "rms_norm" :
231+ _0_input1 = _0_o .view (- 1 , self .embed_dim_ )
232232 flashinfer .norm .fused_add_rmsnorm (
233233 _0_input1 , input_embdings , layer_weight .ffn_norm_weight_ .weight , eps = self .eps_
234234 )
@@ -261,8 +261,8 @@ def overlap_tpsp_token_forward(
261261 _1_o = self ._token_attention_kernel (_1_q , infer_state1 , layer_weight )
262262 _1_q = None
263263 _1_o = self ._tpsp_get_o (_1_o , infer_state1 , layer_weight )
264- _1_input1 = _1_o .view (- 1 , self .embed_dim_ )
265264 if FLASHINFER_AVAILABLE and layer_weight .norm_type == "rms_norm" :
265+ _1_input1 = _1_o .view (- 1 , self .embed_dim_ )
266266 flashinfer .norm .fused_add_rmsnorm (
267267 _1_input1 , input_embdings1 , weight = layer_weight .ffn_norm_weight_ .weight , eps = self .eps_
268268 )
@@ -351,8 +351,8 @@ def overlap_tpsp_context_forward(
351351 _0_o = self ._context_attention_kernel (_0_q , _0_cache_kv , infer_state , layer_weight )
352352 _0_q = None
353353 _0_o = self ._tpsp_get_o (_0_o , infer_state , layer_weight )
354- _0_input1 = _0_o .view (- 1 , self .embed_dim_ )
355354 if FLASHINFER_AVAILABLE and layer_weight .norm_type == "rms_norm" :
355+ _0_input1 = _0_o .view (- 1 , self .embed_dim_ )
356356 flashinfer .norm .fused_add_rmsnorm (
357357 _0_input1 , input_embdings , weight = layer_weight .ffn_norm_weight_ .weight , eps = self .eps_
358358 )
0 commit comments