Skip to content

Commit a9e9054

Browse files
committed
remove gate
1 parent 48882f0 commit a9e9054

File tree

2 files changed

+6
-8
lines changed

2 files changed

+6
-8
lines changed

lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -102,22 +102,22 @@ def _get_qkv(
102102

103103
@override
104104
def _get_o(
105-
self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight
105+
self, input, gate, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight
106106
) -> torch.Tensor:
107-
input = input * layer_weight._gate
108-
layer_weight._gate = None
107+
input = input * gate
109108
o_tensor = layer_weight.o_proj.mm(input)
110109
return o_tensor
111110

112111
def _context_full_attn(
113112
self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight
114113
):
114+
gate = torch.sigmoid(layer_weight.o_gate_proj.mm(input))
115115
q, cache_kv = self._get_qkv(input, infer_state, layer_weight)
116116
input = None
117117
self._post_cache_kv(cache_kv, infer_state, layer_weight)
118118
o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight)
119119
q = None
120-
o = self._get_o(o, infer_state, layer_weight)
120+
o = self._get_o(o, gate, infer_state, layer_weight)
121121
if self.tp_world_size_ > 1:
122122
all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
123123
return o
@@ -129,7 +129,6 @@ def context_forward(
129129
if self.is_gdn:
130130
o = self.gdn_infer.forward(input1, infer_state, layer_weight.gdn_layer_weight)
131131
else:
132-
layer_weight._gate = torch.sigmoid(layer_weight.o_gate_proj.mm(input1))
133132
o = self._context_full_attn(input1, infer_state, layer_weight)
134133
input_embdings.add_(o.view(-1, self.embed_dim_))
135134
o = None
@@ -143,12 +142,13 @@ def context_forward(
143142
return input_embdings
144143

145144
def _token_full_attn(self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight):
145+
gate = torch.sigmoid(layer_weight.o_gate_proj.mm(input))
146146
q, cache_kv = self._get_qkv(input, infer_state, layer_weight)
147147
input = None
148148
self._post_cache_kv(cache_kv, infer_state, layer_weight)
149149
o = self._token_attention_kernel(q, infer_state, layer_weight)
150150
q = None
151-
o = self._get_o(o, infer_state, layer_weight)
151+
o = self._get_o(o, gate, infer_state, layer_weight)
152152
if self.tp_world_size_ > 1:
153153
all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
154154
return o
@@ -160,7 +160,6 @@ def token_forward(
160160
if self.is_gdn:
161161
o = self.gdn_infer.forward(input1, infer_state, layer_weight.gdn_layer_weight)
162162
else:
163-
layer_weight._gate = torch.sigmoid(layer_weight.o_gate_proj.mm(input1))
164163
o = self._token_full_attn(input1, infer_state, layer_weight)
165164
input_embdings.add_(o.view(-1, self.embed_dim_))
166165
o = None

lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def _init_weight(self):
6161
layer_num=self.layer_num_,
6262
name="o_gate_proj",
6363
)
64-
self._gate = None
6564
return
6665

6766
@override

0 commit comments

Comments
 (0)