-
Notifications
You must be signed in to change notification settings - Fork 283
add flashinfer fused_norm_add op #1105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -8,6 +8,8 @@ | |||||||||||||||||||||||||||||||||||
| from lightllm.distributed import all_reduce | ||||||||||||||||||||||||||||||||||||
| from typing import Tuple | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| from lightllm.utils.flashinfer_utils import FLASHINFER_AVAILABLE, flashinfer | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| class TransformerLayerInferTpl(TransformerLayerInfer): | ||||||||||||||||||||||||||||||||||||
| """ """ | ||||||||||||||||||||||||||||||||||||
|
|
@@ -73,10 +75,16 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei | |||||||||||||||||||||||||||||||||||
| o = self._get_o(o, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| if self.tp_world_size_ > 1: | ||||||||||||||||||||||||||||||||||||
| all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| o = None | ||||||||||||||||||||||||||||||||||||
| if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": | ||||||||||||||||||||||||||||||||||||
| input1 = o.view(-1, self.embed_dim_) | ||||||||||||||||||||||||||||||||||||
| flashinfer.norm.fused_add_rmsnorm( | ||||||||||||||||||||||||||||||||||||
| input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_ | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| o = None | ||||||||||||||||||||||||||||||||||||
| ffn_out = self._ffn(input1, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| input1 = None | ||||||||||||||||||||||||||||||||||||
| if self.tp_world_size_ > 1: | ||||||||||||||||||||||||||||||||||||
|
|
@@ -94,10 +102,16 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh | |||||||||||||||||||||||||||||||||||
| o = self._get_o(o, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| if self.tp_world_size_ > 1: | ||||||||||||||||||||||||||||||||||||
| all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| o = None | ||||||||||||||||||||||||||||||||||||
| if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": | ||||||||||||||||||||||||||||||||||||
| input1 = o.view(-1, self.embed_dim_) | ||||||||||||||||||||||||||||||||||||
| flashinfer.norm.fused_add_rmsnorm( | ||||||||||||||||||||||||||||||||||||
| input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_ | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+105
to
+112
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This block has the same correctness bug as in
Suggested change
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| o = None | ||||||||||||||||||||||||||||||||||||
| ffn_out = self._ffn(input1, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| input1 = None | ||||||||||||||||||||||||||||||||||||
| if self.tp_world_size_ > 1: | ||||||||||||||||||||||||||||||||||||
|
|
@@ -113,10 +127,16 @@ def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferS | |||||||||||||||||||||||||||||||||||
| o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| q = None | ||||||||||||||||||||||||||||||||||||
| o = self._tpsp_get_o(o, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| o = None | ||||||||||||||||||||||||||||||||||||
| if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": | ||||||||||||||||||||||||||||||||||||
| input1 = o.view(-1, self.embed_dim_) | ||||||||||||||||||||||||||||||||||||
| flashinfer.norm.fused_add_rmsnorm( | ||||||||||||||||||||||||||||||||||||
| input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_ | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+130
to
+137
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This block has the same correctness bug as in
Suggested change
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| o = None | ||||||||||||||||||||||||||||||||||||
| ffn_out = self._tpsp_ffn(input1, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| input1 = None | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
|
|
@@ -130,10 +150,16 @@ def tpsp_token_forward(self, input_embdings: torch.Tensor, infer_state: InferSta | |||||||||||||||||||||||||||||||||||
| o = self._token_attention_kernel(q, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| q = None | ||||||||||||||||||||||||||||||||||||
| o = self._tpsp_get_o(o, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| o = None | ||||||||||||||||||||||||||||||||||||
| if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": | ||||||||||||||||||||||||||||||||||||
| input1 = o.view(-1, self.embed_dim_) | ||||||||||||||||||||||||||||||||||||
| flashinfer.norm.fused_add_rmsnorm( | ||||||||||||||||||||||||||||||||||||
| input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_ | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+153
to
+160
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This block has the same correctness bug as in
Suggested change
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| o = None | ||||||||||||||||||||||||||||||||||||
| ffn_out = self._tpsp_ffn(input1, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| input1 = None | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -18,6 +18,8 @@ | |||||||||||||
| from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer | ||||||||||||||
| from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd | ||||||||||||||
|
|
||||||||||||||
| from lightllm.utils.flashinfer_utils import FLASHINFER_AVAILABLE, flashinfer | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| class Gemma3TransformerLayerInfer(LlamaTransformerLayerInfer): | ||||||||||||||
| """ """ | ||||||||||||||
|
|
@@ -138,10 +140,16 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei | |||||||||||||
| o = self._get_o(o, infer_state, layer_weight) | ||||||||||||||
| if self.tp_world_size_ > 1: | ||||||||||||||
| all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) | ||||||||||||||
| o = self._ffn_norm(o.float(), infer_state, layer_weight).to(torch.bfloat16) | ||||||||||||||
| input_embdings.add_(o.view(-1, self.embed_dim_)) | ||||||||||||||
| o = None | ||||||||||||||
| if FLASHINFER_AVAILABLE: | ||||||||||||||
| input1 = o.view(-1, self.embed_dim_) | ||||||||||||||
| flashinfer.norm.gemma_fused_add_rmsnorm( | ||||||||||||||
| input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_ | ||||||||||||||
| ) | ||||||||||||||
| else: | ||||||||||||||
| input_embdings.add_(o.view(-1, self.embed_dim_)) | ||||||||||||||
| input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||
|
Comment on lines
+148
to
+150
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic in the
Suggested change
|
||||||||||||||
|
|
||||||||||||||
| o = None | ||||||||||||||
| input1 = self._pre_feedforward_layernorm(input_embdings.float(), infer_state, layer_weight).to(torch.bfloat16) | ||||||||||||||
| ffn_out = self._ffn(input1, infer_state, layer_weight) | ||||||||||||||
| input1 = None | ||||||||||||||
|
|
@@ -164,8 +172,14 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh | |||||||||||||
| o = self._get_o(o, infer_state, layer_weight) | ||||||||||||||
| if self.tp_world_size_ > 1: | ||||||||||||||
| all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) | ||||||||||||||
| o = self._ffn_norm(o.float(), infer_state, layer_weight).to(torch.bfloat16) | ||||||||||||||
| input_embdings.add_(o.view(-1, self.embed_dim_)) | ||||||||||||||
| if FLASHINFER_AVAILABLE: | ||||||||||||||
| input1 = o.view(-1, self.embed_dim_) | ||||||||||||||
| flashinfer.norm.gemma_fused_add_rmsnorm( | ||||||||||||||
| input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_ | ||||||||||||||
| ) | ||||||||||||||
| else: | ||||||||||||||
| input_embdings.add_(o.view(-1, self.embed_dim_)) | ||||||||||||||
| input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||
|
Comment on lines
+180
to
+182
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This
Suggested change
|
||||||||||||||
| o = None | ||||||||||||||
|
|
||||||||||||||
| input1 = self._pre_feedforward_layernorm(input_embdings.float(), infer_state, layer_weight).to(torch.bfloat16) | ||||||||||||||
|
|
||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |||||||||||||||||||||||||||||||||||
| from lightllm.distributed.communication_op import all_gather_into_tensor, reduce_scatter_tensor | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| logger = init_logger(__name__) | ||||||||||||||||||||||||||||||||||||
| from lightllm.utils.flashinfer_utils import FLASHINFER_AVAILABLE, flashinfer | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| class Qwen3MOETransformerLayerInfer(LlamaTransformerLayerInfer): | ||||||||||||||||||||||||||||||||||||
|
|
@@ -226,9 +227,15 @@ def overlap_tpsp_token_forward( | |||||||||||||||||||||||||||||||||||
| _0_o = self._token_attention_kernel(_0_q, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| _0_q = None | ||||||||||||||||||||||||||||||||||||
| _0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(_0_o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": | ||||||||||||||||||||||||||||||||||||
| _0_input1 = _0_o.view(-1, self.embed_dim_) | ||||||||||||||||||||||||||||||||||||
| flashinfer.norm.fused_add_rmsnorm( | ||||||||||||||||||||||||||||||||||||
| _0_input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_ | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(_0_o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| _0_o = None | ||||||||||||||||||||||||||||||||||||
| _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| _0_router_logits = layer_weight.moe_gate.mm(_0_input1) | ||||||||||||||||||||||||||||||||||||
| # 1 hook | ||||||||||||||||||||||||||||||||||||
| if getattr(infer_state1, "hook", None) is not None: | ||||||||||||||||||||||||||||||||||||
|
|
@@ -254,9 +261,15 @@ def overlap_tpsp_token_forward( | |||||||||||||||||||||||||||||||||||
| _1_o = self._token_attention_kernel(_1_q, infer_state1, layer_weight) | ||||||||||||||||||||||||||||||||||||
| _1_q = None | ||||||||||||||||||||||||||||||||||||
| _1_o = self._tpsp_get_o(_1_o, infer_state1, layer_weight) | ||||||||||||||||||||||||||||||||||||
| input_embdings1.add_(_1_o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": | ||||||||||||||||||||||||||||||||||||
| _1_input1 = _1_o.view(-1, self.embed_dim_) | ||||||||||||||||||||||||||||||||||||
| flashinfer.norm.fused_add_rmsnorm( | ||||||||||||||||||||||||||||||||||||
| _1_input1, input_embdings1, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_ | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(_1_o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| _1_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| _1_o = None | ||||||||||||||||||||||||||||||||||||
| _1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight) | ||||||||||||||||||||||||||||||||||||
| # to do gate and disptatch | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| _1_router_logits = layer_weight.moe_gate.mm(_1_input1) | ||||||||||||||||||||||||||||||||||||
|
|
@@ -338,9 +351,15 @@ def overlap_tpsp_context_forward( | |||||||||||||||||||||||||||||||||||
| _0_o = self._context_attention_kernel(_0_q, _0_cache_kv, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| _0_q = None | ||||||||||||||||||||||||||||||||||||
| _0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(_0_o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": | ||||||||||||||||||||||||||||||||||||
| _0_input1 = _0_o.view(-1, self.embed_dim_) | ||||||||||||||||||||||||||||||||||||
| flashinfer.norm.fused_add_rmsnorm( | ||||||||||||||||||||||||||||||||||||
| _0_input1, input_embdings, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_ | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(_0_o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| _0_o = None | ||||||||||||||||||||||||||||||||||||
| _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| _0_router_logits = layer_weight.moe_gate.mm(_0_input1) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # wait last 1 combine | ||||||||||||||||||||||||||||||||||||
|
|
@@ -363,9 +382,15 @@ def overlap_tpsp_context_forward( | |||||||||||||||||||||||||||||||||||
| _1_o = self._context_attention_kernel(_1_q, _1_cache_kv, infer_state1, layer_weight) | ||||||||||||||||||||||||||||||||||||
| _1_q = None | ||||||||||||||||||||||||||||||||||||
| _1_o = self._tpsp_get_o(_1_o, infer_state1, layer_weight) | ||||||||||||||||||||||||||||||||||||
| input_embdings1.add_(_1_o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": | ||||||||||||||||||||||||||||||||||||
| _1_input1 = _1_o.view(-1, self.embed_dim_) | ||||||||||||||||||||||||||||||||||||
| flashinfer.norm.fused_add_rmsnorm( | ||||||||||||||||||||||||||||||||||||
| _1_input1, input_embdings1, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_ | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(_1_o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| _1_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+385
to
+392
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This block has two critical issues:
Suggested change
|
||||||||||||||||||||||||||||||||||||
| _1_o = None | ||||||||||||||||||||||||||||||||||||
| _1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight) | ||||||||||||||||||||||||||||||||||||
| # to do gate and disptatch | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| _1_router_logits = layer_weight.moe_gate.mm(_1_input1) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| from lightllm.utils.log_utils import init_logger | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
| FLASHINFER_AVAILABLE = False | ||
| flashinfer = None | ||
| try: | ||
| import flashinfer as _flashinfer | ||
|
|
||
| flashinfer = _flashinfer | ||
| FLASHINFER_AVAILABLE = False | ||
| except ImportError: | ||
| logger.warning("flashinfer is not available") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a correctness bug in the
ifbranch. The variableinput1is used as input toself._ffnon line 88. In theelsebranch,input1is correctly assigned the output ofself._ffn_norm. However, in theifbranch,input1holds the value ofo.view(-1, self.embed_dim_), which is the raw attention output, not the normalized tensor. Theflashinfer.norm.fused_add_rmsnormfunction is expected to updateinput_embdingsin-place with the normalized result ofinput_embdings + o. Therefore,input1should be updated toinput_embdingsafter the fused kernel call to ensure the correct tensor is passed to the FFN.