@@ -102,22 +102,22 @@ def _get_qkv(
102102
103103 @override
104104 def _get_o (
105- self , input , infer_state : LlamaInferStateInfo , layer_weight : Qwen3NextTransformerLayerWeight
105+ self , input , gate , infer_state : LlamaInferStateInfo , layer_weight : Qwen3NextTransformerLayerWeight
106106 ) -> torch .Tensor :
107- input = input * layer_weight ._gate
108- layer_weight ._gate = None
107+ input = input * gate
109108 o_tensor = layer_weight .o_proj .mm (input )
110109 return o_tensor
111110
112111 def _context_full_attn (
113112 self , input , infer_state : LlamaInferStateInfo , layer_weight : Qwen3NextTransformerLayerWeight
114113 ):
114+ gate = torch .sigmoid (layer_weight .o_gate_proj .mm (input ))
115115 q , cache_kv = self ._get_qkv (input , infer_state , layer_weight )
116116 input = None
117117 self ._post_cache_kv (cache_kv , infer_state , layer_weight )
118118 o = self ._context_attention_kernel (q , cache_kv , infer_state , layer_weight )
119119 q = None
120- o = self ._get_o (o , infer_state , layer_weight )
120+ o = self ._get_o (o , gate , infer_state , layer_weight )
121121 if self .tp_world_size_ > 1 :
122122 all_reduce (o , op = dist .ReduceOp .SUM , group = infer_state .dist_group , async_op = False )
123123 return o
@@ -129,7 +129,6 @@ def context_forward(
129129 if self .is_gdn :
130130 o = self .gdn_infer .forward (input1 , infer_state , layer_weight .gdn_layer_weight )
131131 else :
132- layer_weight ._gate = torch .sigmoid (layer_weight .o_gate_proj .mm (input1 ))
133132 o = self ._context_full_attn (input1 , infer_state , layer_weight )
134133 input_embdings .add_ (o .view (- 1 , self .embed_dim_ ))
135134 o = None
@@ -143,12 +142,13 @@ def context_forward(
143142 return input_embdings
144143
145144 def _token_full_attn (self , input , infer_state : LlamaInferStateInfo , layer_weight : Qwen3NextTransformerLayerWeight ):
145+ gate = torch .sigmoid (layer_weight .o_gate_proj .mm (input ))
146146 q , cache_kv = self ._get_qkv (input , infer_state , layer_weight )
147147 input = None
148148 self ._post_cache_kv (cache_kv , infer_state , layer_weight )
149149 o = self ._token_attention_kernel (q , infer_state , layer_weight )
150150 q = None
151- o = self ._get_o (o , infer_state , layer_weight )
151+ o = self ._get_o (o , gate , infer_state , layer_weight )
152152 if self .tp_world_size_ > 1 :
153153 all_reduce (o , op = dist .ReduceOp .SUM , group = infer_state .dist_group , async_op = False )
154154 return o
@@ -160,7 +160,6 @@ def token_forward(
160160 if self .is_gdn :
161161 o = self .gdn_infer .forward (input1 , infer_state , layer_weight .gdn_layer_weight )
162162 else :
163- layer_weight ._gate = torch .sigmoid (layer_weight .o_gate_proj .mm (input1 ))
164163 o = self ._token_full_attn (input1 , infer_state , layer_weight )
165164 input_embdings .add_ (o .view (- 1 , self .embed_dim_ ))
166165 o = None
0 commit comments