Skip to content

Commit 7169860

Browse files
author
sangchengmeng
committed
add flashinfer fused_norm_add op
1 parent 735dace commit 7169860

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -821,8 +821,8 @@ def overlap_tpsp_token_forward(
821821
_0_o = self._token_attention_kernel(_0_q, infer_state, layer_weight)
822822
_0_q = None
823823
_0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight)
824-
_0_input1 = _0_o.view(-1, self.embed_dim_)
825824
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
825+
_0_input1 = _0_o.view(-1, self.embed_dim_)
826826
flashinfer.norm.fused_add_rmsnorm(
827827
_0_input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_
828828
)
@@ -952,8 +952,8 @@ def overlap_tpsp_context_forward(
952952
_0_o = self._context_attention_kernel(_0_q, _0_cache_kv, infer_state, layer_weight)
953953
_0_q = None
954954
_0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight)
955-
_0_input1 = _0_o.view(-1, self.embed_dim_)
956955
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
956+
_0_input1 = _0_o.view(-1, self.embed_dim_)
957957
flashinfer.norm.fused_add_rmsnorm(
958958
_0_input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_
959959
)

lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,8 @@ def overlap_tpsp_token_forward(
227227
_0_o = self._token_attention_kernel(_0_q, infer_state, layer_weight)
228228
_0_q = None
229229
_0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight)
230-
_0_input1 = _0_o.view(-1, self.embed_dim_)
231230
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
231+
_0_input1 = _0_o.view(-1, self.embed_dim_)
232232
flashinfer.norm.fused_add_rmsnorm(
233233
_0_input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_
234234
)
@@ -261,8 +261,8 @@ def overlap_tpsp_token_forward(
261261
_1_o = self._token_attention_kernel(_1_q, infer_state1, layer_weight)
262262
_1_q = None
263263
_1_o = self._tpsp_get_o(_1_o, infer_state1, layer_weight)
264-
_1_input1 = _1_o.view(-1, self.embed_dim_)
265264
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
265+
_1_input1 = _1_o.view(-1, self.embed_dim_)
266266
flashinfer.norm.fused_add_rmsnorm(
267267
_1_input1, input_embdings1, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_
268268
)
@@ -351,8 +351,8 @@ def overlap_tpsp_context_forward(
351351
_0_o = self._context_attention_kernel(_0_q, _0_cache_kv, infer_state, layer_weight)
352352
_0_q = None
353353
_0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight)
354-
_0_input1 = _0_o.view(-1, self.embed_dim_)
355354
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
355+
_0_input1 = _0_o.view(-1, self.embed_dim_)
356356
flashinfer.norm.fused_add_rmsnorm(
357357
_0_input1, input_embdings, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_
358358
)

0 commit comments

Comments
 (0)