Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from lightllm.distributed import all_reduce
from typing import Tuple

from lightllm.utils.flashinfer_utils import FLASHINFER_AVAILABLE, flashinfer


class TransformerLayerInferTpl(TransformerLayerInfer):
""" """
Expand Down Expand Up @@ -73,10 +75,16 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei
o = self._get_o(o, infer_state, layer_weight)
if self.tp_world_size_ > 1:
all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
input_embdings.add_(o.view(-1, self.embed_dim_))
o = None
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
input1 = o.view(-1, self.embed_dim_)
flashinfer.norm.fused_add_rmsnorm(
input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
else:
input_embdings.add_(o.view(-1, self.embed_dim_))
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
Comment on lines +78 to +85

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a correctness bug in the if branch. The variable input1 is used as input to self._ffn on line 88. In the else branch, input1 is correctly assigned the output of self._ffn_norm. However, in the if branch, input1 holds the value of o.view(-1, self.embed_dim_), which is the raw attention output, not the normalized tensor. The flashinfer.norm.fused_add_rmsnorm function is expected to update input_embdings in-place with the normalized result of input_embdings + o. Therefore, input1 should be updated to input_embdings after the fused kernel call to ensure the correct tensor is passed to the FFN.

Suggested change
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
input1 = o.view(-1, self.embed_dim_)
flashinfer.norm.fused_add_rmsnorm(
input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
else:
input_embdings.add_(o.view(-1, self.embed_dim_))
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
residual = o.view(-1, self.embed_dim_)
flashinfer.norm.fused_add_rmsnorm(
residual, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
input1 = input_embdings
else:
input_embdings.add_(o.view(-1, self.embed_dim_))
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)


input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
o = None
ffn_out = self._ffn(input1, infer_state, layer_weight)
input1 = None
if self.tp_world_size_ > 1:
Expand All @@ -94,10 +102,16 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh
o = self._get_o(o, infer_state, layer_weight)
if self.tp_world_size_ > 1:
all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
input_embdings.add_(o.view(-1, self.embed_dim_))
o = None
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
input1 = o.view(-1, self.embed_dim_)
flashinfer.norm.fused_add_rmsnorm(
input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
else:
input_embdings.add_(o.view(-1, self.embed_dim_))
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
Comment on lines +105 to +112

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This block has the same correctness bug as in context_forward. The input1 variable is not correctly updated in the if branch before being used in self._ffn on line 115. It should be assigned the value of input_embdings after the fused_add_rmsnorm operation.

Suggested change
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
input1 = o.view(-1, self.embed_dim_)
flashinfer.norm.fused_add_rmsnorm(
input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
else:
input_embdings.add_(o.view(-1, self.embed_dim_))
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
residual = o.view(-1, self.embed_dim_)
flashinfer.norm.fused_add_rmsnorm(
residual, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
input1 = input_embdings
else:
input_embdings.add_(o.view(-1, self.embed_dim_))
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)


input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
o = None
ffn_out = self._ffn(input1, infer_state, layer_weight)
input1 = None
if self.tp_world_size_ > 1:
Expand All @@ -113,10 +127,16 @@ def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferS
o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight)
q = None
o = self._tpsp_get_o(o, infer_state, layer_weight)
input_embdings.add_(o.view(-1, self.embed_dim_))
o = None
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
input1 = o.view(-1, self.embed_dim_)
flashinfer.norm.fused_add_rmsnorm(
input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
else:
input_embdings.add_(o.view(-1, self.embed_dim_))
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
Comment on lines +130 to +137

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This block has the same correctness bug as in context_forward. The input1 variable is not correctly updated in the if branch before being used in self._tpsp_ffn on line 140. It should be assigned the value of input_embdings after the fused_add_rmsnorm operation.

Suggested change
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
input1 = o.view(-1, self.embed_dim_)
flashinfer.norm.fused_add_rmsnorm(
input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
else:
input_embdings.add_(o.view(-1, self.embed_dim_))
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
residual = o.view(-1, self.embed_dim_)
flashinfer.norm.fused_add_rmsnorm(
residual, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
input1 = input_embdings
else:
input_embdings.add_(o.view(-1, self.embed_dim_))
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)


input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
o = None
ffn_out = self._tpsp_ffn(input1, infer_state, layer_weight)
input1 = None
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
Expand All @@ -130,10 +150,16 @@ def tpsp_token_forward(self, input_embdings: torch.Tensor, infer_state: InferSta
o = self._token_attention_kernel(q, infer_state, layer_weight)
q = None
o = self._tpsp_get_o(o, infer_state, layer_weight)
input_embdings.add_(o.view(-1, self.embed_dim_))
o = None
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
input1 = o.view(-1, self.embed_dim_)
flashinfer.norm.fused_add_rmsnorm(
input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
else:
input_embdings.add_(o.view(-1, self.embed_dim_))
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
Comment on lines +153 to +160

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This block has the same correctness bug as in context_forward. The input1 variable is not correctly updated in the if branch before being used in self._tpsp_ffn on line 163. It should be assigned the value of input_embdings after the fused_add_rmsnorm operation.

Suggested change
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
input1 = o.view(-1, self.embed_dim_)
flashinfer.norm.fused_add_rmsnorm(
input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
else:
input_embdings.add_(o.view(-1, self.embed_dim_))
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
residual = o.view(-1, self.embed_dim_)
flashinfer.norm.fused_add_rmsnorm(
residual, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
input1 = input_embdings
else:
input_embdings.add_(o.view(-1, self.embed_dim_))
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)


input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
o = None
ffn_out = self._tpsp_ffn(input1, infer_state, layer_weight)
input1 = None
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self, layer_num, data_type, network_config, mode, quant_cfg):
self.layer_num_ = layer_num
self.data_type_ = data_type
self.network_config_ = network_config
self.norm_type = "rms_norm" if "rms_norm_eps" in self.network_config_ else "layer_norm"
self.mode = mode
self.quant_cfg = quant_cfg
self._parse_config()
Expand Down
22 changes: 18 additions & 4 deletions lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from lightllm.utils.dist_utils import get_global_world_size
from lightllm.utils.log_utils import init_logger
from lightllm.utils.sgl_utils import flash_attn_varlen_func, flash_attn_with_kvcache, merge_state_v2
from lightllm.utils.flashinfer_utils import FLASHINFER_AVAILABLE, flashinfer

logger = init_logger(__name__)

Expand Down Expand Up @@ -820,9 +821,16 @@ def overlap_tpsp_token_forward(
_0_o = self._token_attention_kernel(_0_q, infer_state, layer_weight)
_0_q = None
_0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight)
input_embdings.add_(_0_o.view(-1, self.embed_dim_))
_0_input1 = _0_o.view(-1, self.embed_dim_)
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
flashinfer.norm.fused_add_rmsnorm(
_0_input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
else:
input_embdings.add_(_0_o.view(-1, self.embed_dim_))
_0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a correctness bug here. The variable _0_input1 is used as input to layer_weight.moe_gate.mm on line 834. In the else branch, _0_input1 is correctly assigned the output of self._ffn_norm. However, in the if branch, _0_input1 is not updated after being initialized to _0_o.view(...). The fused_add_rmsnorm function updates input_embdings in-place. _0_input1 should be set to input_embdings after the fused kernel call to pass the correct normalized tensor to the MoE gate.

Suggested change
_0_input1 = _0_o.view(-1, self.embed_dim_)
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
flashinfer.norm.fused_add_rmsnorm(
_0_input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
else:
input_embdings.add_(_0_o.view(-1, self.embed_dim_))
_0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
flashinfer.norm.fused_add_rmsnorm(
_0_input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
_0_input1 = input_embdings
else:
input_embdings.add_(_0_o.view(-1, self.embed_dim_))
_0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)


_0_o = None
_0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
_0_router_logits = layer_weight.moe_gate.mm(_0_input1)
# 1 hook
if getattr(infer_state1, "hook", None) is not None:
Expand Down Expand Up @@ -944,9 +952,15 @@ def overlap_tpsp_context_forward(
_0_o = self._context_attention_kernel(_0_q, _0_cache_kv, infer_state, layer_weight)
_0_q = None
_0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight)
input_embdings.add_(_0_o.view(-1, self.embed_dim_))
_0_input1 = _0_o.view(-1, self.embed_dim_)
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
flashinfer.norm.fused_add_rmsnorm(
_0_input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
else:
input_embdings.add_(_0_o.view(-1, self.embed_dim_))
_0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This block has the same correctness bug as in overlap_tpsp_token_forward. The _0_input1 variable is not correctly updated in the if branch before being used in layer_weight.moe_gate.mm on line 964. It should be assigned the value of input_embdings after the fused_add_rmsnorm operation.

Suggested change
_0_input1 = _0_o.view(-1, self.embed_dim_)
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
flashinfer.norm.fused_add_rmsnorm(
_0_input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
else:
input_embdings.add_(_0_o.view(-1, self.embed_dim_))
_0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
flashinfer.norm.fused_add_rmsnorm(
_0_input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
_0_input1 = input_embdings
else:
input_embdings.add_(_0_o.view(-1, self.embed_dim_))
_0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)

_0_o = None
_0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
_0_router_logits = layer_weight.moe_gate.mm(_0_input1)

# wait last 1 combine
Expand Down
24 changes: 19 additions & 5 deletions lightllm/models/gemma3/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd

from lightllm.utils.flashinfer_utils import FLASHINFER_AVAILABLE, flashinfer


class Gemma3TransformerLayerInfer(LlamaTransformerLayerInfer):
""" """
Expand Down Expand Up @@ -138,10 +140,16 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei
o = self._get_o(o, infer_state, layer_weight)
if self.tp_world_size_ > 1:
all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
o = self._ffn_norm(o.float(), infer_state, layer_weight).to(torch.bfloat16)
input_embdings.add_(o.view(-1, self.embed_dim_))
o = None
if FLASHINFER_AVAILABLE:
input1 = o.view(-1, self.embed_dim_)
flashinfer.norm.gemma_fused_add_rmsnorm(
input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
else:
input_embdings.add_(o.view(-1, self.embed_dim_))
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
Comment on lines +148 to +150

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic in the else branch is inconsistent with the if branch and contains dead code. The if branch uses gemma_fused_add_rmsnorm, which presumably updates input_embdings in-place to norm(input_embdings + o). However, the else branch only performs input_embdings.add_(...) and then calculates self._ffn_norm(...) into input1, which is immediately overwritten on line 153 and never used. This means the normalization step is effectively skipped in the else path, leading to inconsistent behavior. The else branch should also apply normalization to input_embdings.

Suggested change
else:
input_embdings.add_(o.view(-1, self.embed_dim_))
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
else:
input_embdings.add_(o.view(-1, self.embed_dim_))
input_embdings.copy_(self._ffn_norm(input_embdings.float(), infer_state, layer_weight).to(input_embdings.dtype))


o = None
input1 = self._pre_feedforward_layernorm(input_embdings.float(), infer_state, layer_weight).to(torch.bfloat16)
ffn_out = self._ffn(input1, infer_state, layer_weight)
input1 = None
Expand All @@ -164,8 +172,14 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh
o = self._get_o(o, infer_state, layer_weight)
if self.tp_world_size_ > 1:
all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
o = self._ffn_norm(o.float(), infer_state, layer_weight).to(torch.bfloat16)
input_embdings.add_(o.view(-1, self.embed_dim_))
if FLASHINFER_AVAILABLE:
input1 = o.view(-1, self.embed_dim_)
flashinfer.norm.gemma_fused_add_rmsnorm(
input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
else:
input_embdings.add_(o.view(-1, self.embed_dim_))
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
Comment on lines +180 to +182

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This else branch has the same issue as in context_forward. The _ffn_norm result is calculated into input1 which is then discarded, and the normalization is not applied to input_embdings. This makes the logic inconsistent with the if branch which uses a fused add-and-norm operation.

Suggested change
else:
input_embdings.add_(o.view(-1, self.embed_dim_))
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
else:
input_embdings.add_(o.view(-1, self.embed_dim_))
input_embdings.copy_(self._ffn_norm(input_embdings.float(), infer_state, layer_weight).to(input_embdings.dtype))

o = None

input1 = self._pre_feedforward_layernorm(input_embdings.float(), infer_state, layer_weight).to(torch.bfloat16)
Expand Down
41 changes: 33 additions & 8 deletions lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from lightllm.distributed.communication_op import all_gather_into_tensor, reduce_scatter_tensor

logger = init_logger(__name__)
from lightllm.utils.flashinfer_utils import FLASHINFER_AVAILABLE, flashinfer


class Qwen3MOETransformerLayerInfer(LlamaTransformerLayerInfer):
Expand Down Expand Up @@ -226,9 +227,15 @@ def overlap_tpsp_token_forward(
_0_o = self._token_attention_kernel(_0_q, infer_state, layer_weight)
_0_q = None
_0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight)
input_embdings.add_(_0_o.view(-1, self.embed_dim_))
_0_input1 = _0_o.view(-1, self.embed_dim_)
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
flashinfer.norm.fused_add_rmsnorm(
_0_input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
else:
input_embdings.add_(_0_o.view(-1, self.embed_dim_))
_0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
Comment on lines 230 to 237

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a correctness bug here. The variable _0_input1 is used as input to layer_weight.moe_gate.mm on line 239. In the else branch, _0_input1 is correctly assigned the output of self._ffn_norm. However, in the if branch, _0_input1 is not updated after being initialized to _0_o.view(...). The fused_add_rmsnorm function updates input_embdings in-place. _0_input1 should be set to input_embdings after the fused kernel call to pass the correct normalized tensor to the MoE gate.

Suggested change
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
flashinfer.norm.fused_add_rmsnorm(
_0_input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
else:
input_embdings.add_(_0_o.view(-1, self.embed_dim_))
_0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
flashinfer.norm.fused_add_rmsnorm(
_0_input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
_0_input1 = input_embdings
else:
input_embdings.add_(_0_o.view(-1, self.embed_dim_))
_0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)

_0_o = None
_0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
_0_router_logits = layer_weight.moe_gate.mm(_0_input1)
# 1 hook
if getattr(infer_state1, "hook", None) is not None:
Expand All @@ -254,9 +261,15 @@ def overlap_tpsp_token_forward(
_1_o = self._token_attention_kernel(_1_q, infer_state1, layer_weight)
_1_q = None
_1_o = self._tpsp_get_o(_1_o, infer_state1, layer_weight)
input_embdings1.add_(_1_o.view(-1, self.embed_dim_))
_1_input1 = _1_o.view(-1, self.embed_dim_)
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
flashinfer.norm.fused_add_rmsnorm(
_1_input1, input_embdings1, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
else:
input_embdings.add_(_1_o.view(-1, self.embed_dim_))
_1_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
Comment on lines 264 to 271

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This block has two critical issues:

  1. Similar to other parts of this PR, the if branch has a correctness bug. _1_input1 is not updated after the fused_add_rmsnorm call, causing the wrong tensor to be passed to layer_weight.moe_gate.mm on line 275. It should be set to input_embdings1 after the fused kernel call.
  2. The else branch has a copy-paste error, using input_embdings and infer_state instead of input_embdings1 and infer_state1.
Suggested change
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
flashinfer.norm.fused_add_rmsnorm(
_1_input1, input_embdings1, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
else:
input_embdings.add_(_1_o.view(-1, self.embed_dim_))
_1_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
flashinfer.norm.fused_add_rmsnorm(
_1_input1, input_embdings1, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
_1_input1 = input_embdings1
else:
input_embdings1.add_(_1_o.view(-1, self.embed_dim_))
_1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight)

_1_o = None
_1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight)
# to do gate and disptatch

_1_router_logits = layer_weight.moe_gate.mm(_1_input1)
Expand Down Expand Up @@ -338,9 +351,15 @@ def overlap_tpsp_context_forward(
_0_o = self._context_attention_kernel(_0_q, _0_cache_kv, infer_state, layer_weight)
_0_q = None
_0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight)
input_embdings.add_(_0_o.view(-1, self.embed_dim_))
_0_input1 = _0_o.view(-1, self.embed_dim_)
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
flashinfer.norm.fused_add_rmsnorm(
_0_input1, input_embdings, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
else:
input_embdings.add_(_0_o.view(-1, self.embed_dim_))
_0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
Comment on lines 354 to 361

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a correctness bug here. The variable _0_input1 is used as input to layer_weight.moe_gate.mm on line 363. In the else branch, _0_input1 is correctly assigned the output of self._ffn_norm. However, in the if branch, _0_input1 is not updated after being initialized to _0_o.view(...). The fused_add_rmsnorm function updates input_embdings in-place. _0_input1 should be set to input_embdings after the fused kernel call to pass the correct normalized tensor to the MoE gate.

Suggested change
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
flashinfer.norm.fused_add_rmsnorm(
_0_input1, input_embdings, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
else:
input_embdings.add_(_0_o.view(-1, self.embed_dim_))
_0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
flashinfer.norm.fused_add_rmsnorm(
_0_input1, input_embdings, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
_0_input1 = input_embdings
else:
input_embdings.add_(_0_o.view(-1, self.embed_dim_))
_0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)

_0_o = None
_0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
_0_router_logits = layer_weight.moe_gate.mm(_0_input1)

# wait last 1 combine
Expand All @@ -363,9 +382,15 @@ def overlap_tpsp_context_forward(
_1_o = self._context_attention_kernel(_1_q, _1_cache_kv, infer_state1, layer_weight)
_1_q = None
_1_o = self._tpsp_get_o(_1_o, infer_state1, layer_weight)
input_embdings1.add_(_1_o.view(-1, self.embed_dim_))
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
_1_input1 = _1_o.view(-1, self.embed_dim_)
flashinfer.norm.fused_add_rmsnorm(
_1_input1, input_embdings1, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
else:
input_embdings.add_(_1_o.view(-1, self.embed_dim_))
_1_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
Comment on lines +385 to +392

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This block has two critical issues:

  1. Similar to other parts of this PR, the if branch has a correctness bug. _1_input1 is not updated after the fused_add_rmsnorm call, causing the wrong tensor to be passed to layer_weight.moe_gate.mm on line 396. It should be set to input_embdings1 after the fused kernel call.
  2. The else branch has a copy-paste error, using input_embdings and infer_state instead of input_embdings1 and infer_state1.
Suggested change
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
_1_input1 = _1_o.view(-1, self.embed_dim_)
flashinfer.norm.fused_add_rmsnorm(
_1_input1, input_embdings1, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
else:
input_embdings.add_(_1_o.view(-1, self.embed_dim_))
_1_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm":
_1_input1 = _1_o.view(-1, self.embed_dim_)
flashinfer.norm.fused_add_rmsnorm(
_1_input1, input_embdings1, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_
)
_1_input1 = input_embdings1
else:
input_embdings1.add_(_1_o.view(-1, self.embed_dim_))
_1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight)

_1_o = None
_1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight)
# to do gate and disptatch

_1_router_logits = layer_weight.moe_gate.mm(_1_input1)
Expand Down