-
Notifications
You must be signed in to change notification settings - Fork 284
add flashinfer fused_norm_add op #1105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -8,6 +8,8 @@ | |||||||||||||||||||||||||||||||||||
| from lightllm.distributed import all_reduce | ||||||||||||||||||||||||||||||||||||
| from typing import Tuple | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| from lightllm.utils.flashinfer_utils import FLASHINFER_AVAILABLE, flashinfer | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| class TransformerLayerInferTpl(TransformerLayerInfer): | ||||||||||||||||||||||||||||||||||||
| """ """ | ||||||||||||||||||||||||||||||||||||
|
|
@@ -73,10 +75,16 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei | |||||||||||||||||||||||||||||||||||
| o = self._get_o(o, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| if self.tp_world_size_ > 1: | ||||||||||||||||||||||||||||||||||||
| all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| o = None | ||||||||||||||||||||||||||||||||||||
| if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": | ||||||||||||||||||||||||||||||||||||
| input1 = o.view(-1, self.embed_dim_) | ||||||||||||||||||||||||||||||||||||
| flashinfer.norm.fused_add_rmsnorm( | ||||||||||||||||||||||||||||||||||||
| input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_ | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| o = None | ||||||||||||||||||||||||||||||||||||
| ffn_out = self._ffn(input1, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| input1 = None | ||||||||||||||||||||||||||||||||||||
| if self.tp_world_size_ > 1: | ||||||||||||||||||||||||||||||||||||
|
|
@@ -94,10 +102,16 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh | |||||||||||||||||||||||||||||||||||
| o = self._get_o(o, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| if self.tp_world_size_ > 1: | ||||||||||||||||||||||||||||||||||||
| all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| o = None | ||||||||||||||||||||||||||||||||||||
| if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": | ||||||||||||||||||||||||||||||||||||
| input1 = o.view(-1, self.embed_dim_) | ||||||||||||||||||||||||||||||||||||
| flashinfer.norm.fused_add_rmsnorm( | ||||||||||||||||||||||||||||||||||||
| input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_ | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+105
to
+112
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This block has the same correctness bug as in
Suggested change
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| o = None | ||||||||||||||||||||||||||||||||||||
| ffn_out = self._ffn(input1, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| input1 = None | ||||||||||||||||||||||||||||||||||||
| if self.tp_world_size_ > 1: | ||||||||||||||||||||||||||||||||||||
|
|
@@ -113,10 +127,16 @@ def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferS | |||||||||||||||||||||||||||||||||||
| o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| q = None | ||||||||||||||||||||||||||||||||||||
| o = self._tpsp_get_o(o, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| o = None | ||||||||||||||||||||||||||||||||||||
| if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": | ||||||||||||||||||||||||||||||||||||
| input1 = o.view(-1, self.embed_dim_) | ||||||||||||||||||||||||||||||||||||
| flashinfer.norm.fused_add_rmsnorm( | ||||||||||||||||||||||||||||||||||||
| input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_ | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+130
to
+137
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This block has the same correctness bug as in
Suggested change
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| o = None | ||||||||||||||||||||||||||||||||||||
| ffn_out = self._tpsp_ffn(input1, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| input1 = None | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
|
|
@@ -130,10 +150,16 @@ def tpsp_token_forward(self, input_embdings: torch.Tensor, infer_state: InferSta | |||||||||||||||||||||||||||||||||||
| o = self._token_attention_kernel(q, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| q = None | ||||||||||||||||||||||||||||||||||||
| o = self._tpsp_get_o(o, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| o = None | ||||||||||||||||||||||||||||||||||||
| if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": | ||||||||||||||||||||||||||||||||||||
| input1 = o.view(-1, self.embed_dim_) | ||||||||||||||||||||||||||||||||||||
| flashinfer.norm.fused_add_rmsnorm( | ||||||||||||||||||||||||||||||||||||
| input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_ | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+153
to
+160
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This block has the same correctness bug as in
Suggested change
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| o = None | ||||||||||||||||||||||||||||||||||||
| ffn_out = self._tpsp_ffn(input1, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| input1 = None | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -30,6 +30,7 @@ | |||||||||||||||||||||||||||||||||
| from lightllm.utils.dist_utils import get_global_world_size | ||||||||||||||||||||||||||||||||||
| from lightllm.utils.log_utils import init_logger | ||||||||||||||||||||||||||||||||||
| from lightllm.utils.sgl_utils import flash_attn_varlen_func, flash_attn_with_kvcache, merge_state_v2 | ||||||||||||||||||||||||||||||||||
| from lightllm.utils.flashinfer_utils import FLASHINFER_AVAILABLE, flashinfer | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| logger = init_logger(__name__) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
@@ -820,9 +821,16 @@ def overlap_tpsp_token_forward( | |||||||||||||||||||||||||||||||||
| _0_o = self._token_attention_kernel(_0_q, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||
| _0_q = None | ||||||||||||||||||||||||||||||||||
| _0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||
| input_embdings.add_(_0_o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||
| _0_input1 = _0_o.view(-1, self.embed_dim_) | ||||||||||||||||||||||||||||||||||
| if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": | ||||||||||||||||||||||||||||||||||
| flashinfer.norm.fused_add_rmsnorm( | ||||||||||||||||||||||||||||||||||
| _0_input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_ | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| input_embdings.add_(_0_o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||
| _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
| _0_input1 = _0_o.view(-1, self.embed_dim_) | |
| if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": | |
| flashinfer.norm.fused_add_rmsnorm( | |
| _0_input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_ | |
| ) | |
| else: | |
| input_embdings.add_(_0_o.view(-1, self.embed_dim_)) | |
| _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | |
| if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": | |
| flashinfer.norm.fused_add_rmsnorm( | |
| _0_input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_ | |
| ) | |
| _0_input1 = input_embdings | |
| else: | |
| input_embdings.add_(_0_o.view(-1, self.embed_dim_)) | |
| _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block has the same correctness bug as in overlap_tpsp_token_forward. The _0_input1 variable is not correctly updated in the if branch before being used in layer_weight.moe_gate.mm on line 964. It should be assigned the value of input_embdings after the fused_add_rmsnorm operation.
| _0_input1 = _0_o.view(-1, self.embed_dim_) | |
| if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": | |
| flashinfer.norm.fused_add_rmsnorm( | |
| _0_input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_ | |
| ) | |
| else: | |
| input_embdings.add_(_0_o.view(-1, self.embed_dim_)) | |
| _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | |
| if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": | |
| flashinfer.norm.fused_add_rmsnorm( | |
| _0_input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_ | |
| ) | |
| _0_input1 = input_embdings | |
| else: | |
| input_embdings.add_(_0_o.view(-1, self.embed_dim_)) | |
| _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) |
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -18,6 +18,8 @@ | |||||||||||||
| from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer | ||||||||||||||
| from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd | ||||||||||||||
|
|
||||||||||||||
| from lightllm.utils.flashinfer_utils import FLASHINFER_AVAILABLE, flashinfer | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| class Gemma3TransformerLayerInfer(LlamaTransformerLayerInfer): | ||||||||||||||
| """ """ | ||||||||||||||
|
|
@@ -138,10 +140,16 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei | |||||||||||||
| o = self._get_o(o, infer_state, layer_weight) | ||||||||||||||
| if self.tp_world_size_ > 1: | ||||||||||||||
| all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) | ||||||||||||||
| o = self._ffn_norm(o.float(), infer_state, layer_weight).to(torch.bfloat16) | ||||||||||||||
| input_embdings.add_(o.view(-1, self.embed_dim_)) | ||||||||||||||
| o = None | ||||||||||||||
| if FLASHINFER_AVAILABLE: | ||||||||||||||
| input1 = o.view(-1, self.embed_dim_) | ||||||||||||||
| flashinfer.norm.gemma_fused_add_rmsnorm( | ||||||||||||||
| input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_ | ||||||||||||||
| ) | ||||||||||||||
| else: | ||||||||||||||
| input_embdings.add_(o.view(-1, self.embed_dim_)) | ||||||||||||||
| input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||
|
Comment on lines
+148
to
+150
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic in the
Suggested change
|
||||||||||||||
|
|
||||||||||||||
| o = None | ||||||||||||||
| input1 = self._pre_feedforward_layernorm(input_embdings.float(), infer_state, layer_weight).to(torch.bfloat16) | ||||||||||||||
| ffn_out = self._ffn(input1, infer_state, layer_weight) | ||||||||||||||
| input1 = None | ||||||||||||||
|
|
@@ -164,8 +172,14 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh | |||||||||||||
| o = self._get_o(o, infer_state, layer_weight) | ||||||||||||||
| if self.tp_world_size_ > 1: | ||||||||||||||
| all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) | ||||||||||||||
| o = self._ffn_norm(o.float(), infer_state, layer_weight).to(torch.bfloat16) | ||||||||||||||
| input_embdings.add_(o.view(-1, self.embed_dim_)) | ||||||||||||||
| if FLASHINFER_AVAILABLE: | ||||||||||||||
| input1 = o.view(-1, self.embed_dim_) | ||||||||||||||
| flashinfer.norm.gemma_fused_add_rmsnorm( | ||||||||||||||
| input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_ | ||||||||||||||
| ) | ||||||||||||||
| else: | ||||||||||||||
| input_embdings.add_(o.view(-1, self.embed_dim_)) | ||||||||||||||
| input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||
|
Comment on lines
+180
to
+182
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This
Suggested change
|
||||||||||||||
| o = None | ||||||||||||||
|
|
||||||||||||||
| input1 = self._pre_feedforward_layernorm(input_embdings.float(), infer_state, layer_weight).to(torch.bfloat16) | ||||||||||||||
|
|
||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |||||||||||||||||||||||||||||||||||
| from lightllm.distributed.communication_op import all_gather_into_tensor, reduce_scatter_tensor | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| logger = init_logger(__name__) | ||||||||||||||||||||||||||||||||||||
| from lightllm.utils.flashinfer_utils import FLASHINFER_AVAILABLE, flashinfer | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| class Qwen3MOETransformerLayerInfer(LlamaTransformerLayerInfer): | ||||||||||||||||||||||||||||||||||||
|
|
@@ -226,9 +227,15 @@ def overlap_tpsp_token_forward( | |||||||||||||||||||||||||||||||||||
| _0_o = self._token_attention_kernel(_0_q, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| _0_q = None | ||||||||||||||||||||||||||||||||||||
| _0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(_0_o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| _0_input1 = _0_o.view(-1, self.embed_dim_) | ||||||||||||||||||||||||||||||||||||
| if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": | ||||||||||||||||||||||||||||||||||||
| flashinfer.norm.fused_add_rmsnorm( | ||||||||||||||||||||||||||||||||||||
| _0_input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_ | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(_0_o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
230
to
237
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a correctness bug here. The variable
Suggested change
|
||||||||||||||||||||||||||||||||||||
| _0_o = None | ||||||||||||||||||||||||||||||||||||
| _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| _0_router_logits = layer_weight.moe_gate.mm(_0_input1) | ||||||||||||||||||||||||||||||||||||
| # 1 hook | ||||||||||||||||||||||||||||||||||||
| if getattr(infer_state1, "hook", None) is not None: | ||||||||||||||||||||||||||||||||||||
|
|
@@ -254,9 +261,15 @@ def overlap_tpsp_token_forward( | |||||||||||||||||||||||||||||||||||
| _1_o = self._token_attention_kernel(_1_q, infer_state1, layer_weight) | ||||||||||||||||||||||||||||||||||||
| _1_q = None | ||||||||||||||||||||||||||||||||||||
| _1_o = self._tpsp_get_o(_1_o, infer_state1, layer_weight) | ||||||||||||||||||||||||||||||||||||
| input_embdings1.add_(_1_o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| _1_input1 = _1_o.view(-1, self.embed_dim_) | ||||||||||||||||||||||||||||||||||||
| if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": | ||||||||||||||||||||||||||||||||||||
| flashinfer.norm.fused_add_rmsnorm( | ||||||||||||||||||||||||||||||||||||
| _1_input1, input_embdings1, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_ | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(_1_o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| _1_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
264
to
271
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This block has two critical issues:
Suggested change
|
||||||||||||||||||||||||||||||||||||
| _1_o = None | ||||||||||||||||||||||||||||||||||||
| _1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight) | ||||||||||||||||||||||||||||||||||||
| # to do gate and disptatch | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| _1_router_logits = layer_weight.moe_gate.mm(_1_input1) | ||||||||||||||||||||||||||||||||||||
|
|
@@ -338,9 +351,15 @@ def overlap_tpsp_context_forward( | |||||||||||||||||||||||||||||||||||
| _0_o = self._context_attention_kernel(_0_q, _0_cache_kv, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| _0_q = None | ||||||||||||||||||||||||||||||||||||
| _0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(_0_o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| _0_input1 = _0_o.view(-1, self.embed_dim_) | ||||||||||||||||||||||||||||||||||||
| if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": | ||||||||||||||||||||||||||||||||||||
| flashinfer.norm.fused_add_rmsnorm( | ||||||||||||||||||||||||||||||||||||
| _0_input1, input_embdings, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_ | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(_0_o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
354
to
361
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a correctness bug here. The variable
Suggested change
|
||||||||||||||||||||||||||||||||||||
| _0_o = None | ||||||||||||||||||||||||||||||||||||
| _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
| _0_router_logits = layer_weight.moe_gate.mm(_0_input1) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # wait last 1 combine | ||||||||||||||||||||||||||||||||||||
|
|
@@ -363,9 +382,15 @@ def overlap_tpsp_context_forward( | |||||||||||||||||||||||||||||||||||
| _1_o = self._context_attention_kernel(_1_q, _1_cache_kv, infer_state1, layer_weight) | ||||||||||||||||||||||||||||||||||||
| _1_q = None | ||||||||||||||||||||||||||||||||||||
| _1_o = self._tpsp_get_o(_1_o, infer_state1, layer_weight) | ||||||||||||||||||||||||||||||||||||
| input_embdings1.add_(_1_o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": | ||||||||||||||||||||||||||||||||||||
| _1_input1 = _1_o.view(-1, self.embed_dim_) | ||||||||||||||||||||||||||||||||||||
| flashinfer.norm.fused_add_rmsnorm( | ||||||||||||||||||||||||||||||||||||
| _1_input1, input_embdings1, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_ | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||
| input_embdings.add_(_1_o.view(-1, self.embed_dim_)) | ||||||||||||||||||||||||||||||||||||
| _1_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+385
to
+392
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This block has two critical issues:
Suggested change
|
||||||||||||||||||||||||||||||||||||
| _1_o = None | ||||||||||||||||||||||||||||||||||||
| _1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight) | ||||||||||||||||||||||||||||||||||||
| # to do gate and disptatch | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| _1_router_logits = layer_weight.moe_gate.mm(_1_input1) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a correctness bug in the
ifbranch. The variableinput1is used as input toself._ffnon line 88. In theelsebranch,input1is correctly assigned the output ofself._ffn_norm. However, in theifbranch,input1holds the value ofo.view(-1, self.embed_dim_), which is the raw attention output, not the normalized tensor. Theflashinfer.norm.fused_add_rmsnormfunction is expected to updateinput_embdingsin-place with the normalized result ofinput_embdings + o. Therefore,input1should be updated toinput_embdingsafter the fused kernel call to ensure the correct tensor is passed to the FFN.