diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py index 7567bc644..7617f9515 100755 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py @@ -8,6 +8,8 @@ from lightllm.distributed import all_reduce from typing import Tuple +from lightllm.utils.flashinfer_utils import FLASHINFER_AVAILABLE, flashinfer + class TransformerLayerInferTpl(TransformerLayerInfer): """ """ @@ -73,10 +75,16 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei o = self._get_o(o, infer_state, layer_weight) if self.tp_world_size_ > 1: all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) - input_embdings.add_(o.view(-1, self.embed_dim_)) - o = None + if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": + input1 = o.view(-1, self.embed_dim_) + flashinfer.norm.fused_add_rmsnorm( + input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_ + ) + else: + input_embdings.add_(o.view(-1, self.embed_dim_)) + input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) - input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) + o = None ffn_out = self._ffn(input1, infer_state, layer_weight) input1 = None if self.tp_world_size_ > 1: @@ -94,10 +102,16 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh o = self._get_o(o, infer_state, layer_weight) if self.tp_world_size_ > 1: all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) - input_embdings.add_(o.view(-1, self.embed_dim_)) - o = None + if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": + input1 = o.view(-1, self.embed_dim_) + flashinfer.norm.fused_add_rmsnorm( + input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_ + ) + else: + input_embdings.add_(o.view(-1, self.embed_dim_)) + input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) - input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) + o = None ffn_out = self._ffn(input1, infer_state, layer_weight) input1 = None if self.tp_world_size_ > 1: @@ -113,10 +127,16 @@ def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferS o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) q = None o = self._tpsp_get_o(o, infer_state, layer_weight) - input_embdings.add_(o.view(-1, self.embed_dim_)) - o = None + if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": + input1 = o.view(-1, self.embed_dim_) + flashinfer.norm.fused_add_rmsnorm( + input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_ + ) + else: + input_embdings.add_(o.view(-1, self.embed_dim_)) + input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) - input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) + o = None ffn_out = self._tpsp_ffn(input1, infer_state, layer_weight) input1 = None input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) @@ -130,10 +150,16 @@ def tpsp_token_forward(self, input_embdings: torch.Tensor, infer_state: InferSta o = self._token_attention_kernel(q, infer_state, layer_weight) q = None o = self._tpsp_get_o(o, infer_state, layer_weight) - input_embdings.add_(o.view(-1, self.embed_dim_)) - o = None + if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": + input1 = o.view(-1, self.embed_dim_) + flashinfer.norm.fused_add_rmsnorm( + input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_ + ) + else: + input_embdings.add_(o.view(-1, self.embed_dim_)) + input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) - input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) + o = None ffn_out = self._tpsp_ffn(input1, infer_state, layer_weight) input1 = None input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) diff --git a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py index 97bc76237..9fb319c43 100644 --- a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py @@ -14,6 +14,7 @@ def __init__(self, layer_num, data_type, network_config, mode, quant_cfg): self.layer_num_ = layer_num self.data_type_ = data_type self.network_config_ = network_config + self.norm_type = "rms_norm" if "rms_norm_eps" in self.network_config_ else "layer_norm" self.mode = mode self.quant_cfg = quant_cfg self._parse_config() diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 30d37d1df..3f0e1203f 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -30,6 +30,7 @@ from lightllm.utils.dist_utils import get_global_world_size from lightllm.utils.log_utils import init_logger from lightllm.utils.sgl_utils import flash_attn_varlen_func, flash_attn_with_kvcache, merge_state_v2 +from lightllm.utils.flashinfer_utils import FLASHINFER_AVAILABLE, flashinfer logger = init_logger(__name__) @@ -820,9 +821,16 @@ def overlap_tpsp_token_forward( _0_o = self._token_attention_kernel(_0_q, infer_state, layer_weight) _0_q = None _0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight) - input_embdings.add_(_0_o.view(-1, self.embed_dim_)) + if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": + _0_input1 = _0_o.view(-1, self.embed_dim_) + flashinfer.norm.fused_add_rmsnorm( + _0_input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_ + ) + else: + input_embdings.add_(_0_o.view(-1, self.embed_dim_)) + _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) + _0_o = None - _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) _0_router_logits = layer_weight.moe_gate.mm(_0_input1) # 1 hook if getattr(infer_state1, "hook", None) is not None: @@ -944,9 +952,15 @@ def overlap_tpsp_context_forward( _0_o = self._context_attention_kernel(_0_q, _0_cache_kv, infer_state, layer_weight) _0_q = None _0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight) - input_embdings.add_(_0_o.view(-1, self.embed_dim_)) + if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": + _0_input1 = _0_o.view(-1, self.embed_dim_) + flashinfer.norm.fused_add_rmsnorm( + _0_input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_ + ) + else: + input_embdings.add_(_0_o.view(-1, self.embed_dim_)) + _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) _0_o = None - _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) _0_router_logits = layer_weight.moe_gate.mm(_0_input1) # wait last 1 combine diff --git a/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py index 09efe9a36..86b56d92f 100644 --- a/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py @@ -18,6 +18,8 @@ from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.utils.flashinfer_utils import FLASHINFER_AVAILABLE, flashinfer + class Gemma3TransformerLayerInfer(LlamaTransformerLayerInfer): """ """ @@ -138,10 +140,16 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei o = self._get_o(o, infer_state, layer_weight) if self.tp_world_size_ > 1: all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) - o = self._ffn_norm(o.float(), infer_state, layer_weight).to(torch.bfloat16) - input_embdings.add_(o.view(-1, self.embed_dim_)) - o = None + if FLASHINFER_AVAILABLE: + input1 = o.view(-1, self.embed_dim_) + flashinfer.norm.gemma_fused_add_rmsnorm( + input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_ + ) + else: + input_embdings.add_(o.view(-1, self.embed_dim_)) + input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) + o = None input1 = self._pre_feedforward_layernorm(input_embdings.float(), infer_state, layer_weight).to(torch.bfloat16) ffn_out = self._ffn(input1, infer_state, layer_weight) input1 = None @@ -164,8 +172,14 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh o = self._get_o(o, infer_state, layer_weight) if self.tp_world_size_ > 1: all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) - o = self._ffn_norm(o.float(), infer_state, layer_weight).to(torch.bfloat16) - input_embdings.add_(o.view(-1, self.embed_dim_)) + if FLASHINFER_AVAILABLE: + input1 = o.view(-1, self.embed_dim_) + flashinfer.norm.gemma_fused_add_rmsnorm( + input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_ + ) + else: + input_embdings.add_(o.view(-1, self.embed_dim_)) + input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) o = None input1 = self._pre_feedforward_layernorm(input_embdings.float(), infer_state, layer_weight).to(torch.bfloat16) diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 45f1f59d7..aecde4309 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -17,6 +17,7 @@ from lightllm.distributed.communication_op import all_gather_into_tensor, reduce_scatter_tensor logger = init_logger(__name__) +from lightllm.utils.flashinfer_utils import FLASHINFER_AVAILABLE, flashinfer class Qwen3MOETransformerLayerInfer(LlamaTransformerLayerInfer): @@ -226,9 +227,15 @@ def overlap_tpsp_token_forward( _0_o = self._token_attention_kernel(_0_q, infer_state, layer_weight) _0_q = None _0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight) - input_embdings.add_(_0_o.view(-1, self.embed_dim_)) + if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": + _0_input1 = _0_o.view(-1, self.embed_dim_) + flashinfer.norm.fused_add_rmsnorm( + _0_input1, input_embdings, layer_weight.ffn_norm_weight_.weight, eps=self.eps_ + ) + else: + input_embdings.add_(_0_o.view(-1, self.embed_dim_)) + _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) _0_o = None - _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) _0_router_logits = layer_weight.moe_gate.mm(_0_input1) # 1 hook if getattr(infer_state1, "hook", None) is not None: @@ -254,9 +261,15 @@ def overlap_tpsp_token_forward( _1_o = self._token_attention_kernel(_1_q, infer_state1, layer_weight) _1_q = None _1_o = self._tpsp_get_o(_1_o, infer_state1, layer_weight) - input_embdings1.add_(_1_o.view(-1, self.embed_dim_)) + if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": + _1_input1 = _1_o.view(-1, self.embed_dim_) + flashinfer.norm.fused_add_rmsnorm( + _1_input1, input_embdings1, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_ + ) + else: + input_embdings.add_(_1_o.view(-1, self.embed_dim_)) + _1_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) _1_o = None - _1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight) # to do gate and disptatch _1_router_logits = layer_weight.moe_gate.mm(_1_input1) @@ -338,9 +351,15 @@ def overlap_tpsp_context_forward( _0_o = self._context_attention_kernel(_0_q, _0_cache_kv, infer_state, layer_weight) _0_q = None _0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight) - input_embdings.add_(_0_o.view(-1, self.embed_dim_)) + if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": + _0_input1 = _0_o.view(-1, self.embed_dim_) + flashinfer.norm.fused_add_rmsnorm( + _0_input1, input_embdings, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_ + ) + else: + input_embdings.add_(_0_o.view(-1, self.embed_dim_)) + _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) _0_o = None - _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) _0_router_logits = layer_weight.moe_gate.mm(_0_input1) # wait last 1 combine @@ -363,9 +382,15 @@ def overlap_tpsp_context_forward( _1_o = self._context_attention_kernel(_1_q, _1_cache_kv, infer_state1, layer_weight) _1_q = None _1_o = self._tpsp_get_o(_1_o, infer_state1, layer_weight) - input_embdings1.add_(_1_o.view(-1, self.embed_dim_)) + if FLASHINFER_AVAILABLE and layer_weight.norm_type == "rms_norm": + _1_input1 = _1_o.view(-1, self.embed_dim_) + flashinfer.norm.fused_add_rmsnorm( + _1_input1, input_embdings1, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_ + ) + else: + input_embdings.add_(_1_o.view(-1, self.embed_dim_)) + _1_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) _1_o = None - _1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight) # to do gate and disptatch _1_router_logits = layer_weight.moe_gate.mm(_1_input1) diff --git a/lightllm/utils/flashinfer_utils.py b/lightllm/utils/flashinfer_utils.py new file mode 100644 index 000000000..4bb35a66a --- /dev/null +++ b/lightllm/utils/flashinfer_utils.py @@ -0,0 +1,13 @@ +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + +FLASHINFER_AVAILABLE = False +flashinfer = None +try: + import flashinfer as _flashinfer + + flashinfer = _flashinfer + FLASHINFER_AVAILABLE = False +except ImportError: + logger.warning("flashinfer is not available")