Skip to content

Commit b99deda

Browse files
committed
remove _pre_cache_kv
1 parent 645baa1 commit b99deda

File tree

14 files changed

+18
-76
lines changed

14 files changed

+18
-76
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -290,10 +290,6 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0)
290290
infer_state.req_manager = self.req_manager
291291

292292
infer_state.mem_index = model_input.mem_indexes
293-
infer_state.kv_buffer_shapedtype = (
294-
(model_input.input_ids.shape[0], self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
295-
self.data_type,
296-
)
297293
infer_state.microbatch_index = microbatch_index
298294
infer_state.dist_group = dist_group_manager.get_group(microbatch_index)
299295

lightllm/common/basemodel/infer_struct.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def __init__(self):
4040
self.req_manager: ReqManager = None
4141

4242
self.mem_index: torch.Tensor = None
43-
self.kv_buffer_shapedtype: Tuple[Any, Any] = None
4443

4544
self.is_token_healing: bool = False
4645
self.return_all_prompt_logics: bool = False

lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,12 @@ def _bind_rotary_emb_fwd(self):
4444
def _get_qkv(
4545
self, input, infer_state: InferStateInfo, layer_weight
4646
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
47-
cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight)
4847
q = torch.mm(input.view(-1, self.embed_dim_), layer_weight.q_weight_)
49-
torch.mm(
48+
cache_kv = torch.mm(
5049
input.view(-1, self.embed_dim_),
5150
layer_weight.kv_weight_,
52-
out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_),
53-
)
51+
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
52+
5453
if self.use_qk_norm_:
5554
q = q.view(-1, self.tp_q_head_num_, self.head_dim_)
5655
k = cache_kv[:, 0 : self.tp_k_head_num_, :]

lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,22 +31,6 @@ def _att_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.T
3131
def _ffn_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor:
3232
raise Exception("need to impl")
3333

34-
def _pre_cache_kv(self, infer_state: InferStateInfo, layer_weight) -> torch.Tensor:
35-
if infer_state.is_prefill and get_env_start_args().enable_dp_prefill_balance:
36-
shape = infer_state.kv_buffer_shapedtype[0]
37-
shape = (len(infer_state.position_ids), *shape[1:])
38-
else:
39-
shape = infer_state.kv_buffer_shapedtype[0]
40-
41-
cache_kv = self.alloc_tensor(
42-
shape=shape,
43-
dtype=infer_state.kv_buffer_shapedtype[1],
44-
device="cuda",
45-
is_graph_out=False,
46-
microbatch_index=infer_state.microbatch_index,
47-
)
48-
return cache_kv
49-
5034
def _get_qkv(self, input, infer_state: InferStateInfo, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]:
5135
raise Exception("need to impl")
5236

lightllm/models/bloom/layer_infer/transformer_layer_infer.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,7 @@ def _get_qkv(
4747
self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight
4848
) -> Tuple[torch.Tensor, torch.Tensor]:
4949
q = layer_weight.q_proj.mm(input.view(-1, self.embed_dim_))
50-
cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight)
51-
cache_kv = layer_weight.kv_proj.mm(
52-
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
53-
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
50+
cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
5451
return q, cache_kv
5552

5653
def _context_attention_kernel(

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -143,14 +143,6 @@ def _bind_attention(self):
143143
Deepseek2TransformerLayerInfer._context_attention_kernel_origin, self
144144
)
145145

146-
def _pre_cache_kv(
147-
self, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
148-
) -> torch.Tensor:
149-
# q_lora_rank 不是None的时候,融合 q_a_proj 和 kv_a_proj_with_mqa
150-
if self.q_lora_rank is None:
151-
return super()._pre_cache_kv(infer_state, layer_weight)
152-
return None
153-
154146
def _get_qkv(
155147
self,
156148
input: torch.Tensor,
@@ -161,8 +153,7 @@ def _get_qkv(
161153

162154
if self.q_lora_rank is None:
163155
q = layer_weight.q_weight_.mm(input)
164-
cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight)
165-
layer_weight.kv_a_proj_with_mqa_.mm(input, out=cache_kv.view(-1, self.kv_lora_rank + self.qk_rope_head_dim))
156+
cache_kv = layer_weight.kv_a_proj_with_mqa_.mm(input).view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim)
166157
else:
167158
q, cache_kv = layer_weight.qkv_a_proj_with_mqa_.mm(input).split(
168159
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
@@ -203,8 +194,7 @@ def _tpsp_get_qkv(
203194

204195
input = input.view(-1, self.embed_dim_)
205196
q = layer_weight.q_weight_.mm(input)
206-
cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight)
207-
layer_weight.kv_a_proj_with_mqa_.mm(input, out=cache_kv.view(-1, self.kv_lora_rank + self.qk_rope_head_dim))
197+
cache_kv = layer_weight.kv_a_proj_with_mqa_.mm(input).view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim)
208198
else:
209199
input = input.view(-1, self.embed_dim_)
210200
qkv = layer_weight.qkv_a_proj_with_mqa_.mm(input)

lightllm/models/gemma3/layer_infer/transformer_layer_infer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,9 @@ def _get_qkv(
8787
# kv = kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
8888
k = layer_weight.k_proj.mm(input)
8989
v = layer_weight.v_proj.mm(input)
90-
cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight)
91-
cache_kv[:, 0 : self.tp_k_head_num_, :] = k.view(-1, self.tp_k_head_num_, self.head_dim_)
92-
cache_kv[:, self.tp_k_head_num_ :, :] = v.view(-1, self.tp_v_head_num_, self.head_dim_)
90+
cache_kv = torch.cat(
91+
[k.view(-1, self.tp_k_head_num_, self.head_dim_), v.view(-1, self.tp_v_head_num_, self.head_dim_)], dim=1
92+
)
9393

9494
# gemma3 use qk norm
9595
q = q.view(-1, self.tp_q_head_num_, self.head_dim_)

lightllm/models/llama/layer_infer/transformer_layer_infer.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,7 @@ def _get_qkv(
197197
self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight
198198
) -> torch.Tensor:
199199
q = layer_weight.q_proj.mm(input)
200-
cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight)
201-
cache_kv = layer_weight.kv_proj.mm(
202-
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
203-
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
200+
cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
204201

205202
rotary_emb_fwd(
206203
q.view(-1, self.tp_q_head_num_, self.head_dim_),
@@ -222,10 +219,7 @@ def _tpsp_get_qkv(
222219
input = gather_input[0 : len(infer_state.position_cos), :]
223220

224221
q = layer_weight.q_proj.mm(input)
225-
cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight)
226-
cache_kv = layer_weight.kv_proj.mm(
227-
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
228-
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
222+
cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
229223

230224
rotary_emb_fwd(
231225
q.view(-1, self.tp_q_head_num_, self.head_dim_),

lightllm/models/phi3/layer_infer/transformer_layer_infer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,8 @@ def _bind_attention(self):
2929

3030
def _get_qkv(self, input_emb, infer_state: LlamaInferStateInfo, layer_weight: Phi3TransformerLayerWeight):
3131
q = layer_weight.q_proj.mm(input_emb.view(-1, self.embed_dim_))
32-
cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight)
3332
cache_kv = layer_weight.kv_proj.mm(
3433
input_emb.view(-1, self.embed_dim_),
35-
out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_),
3634
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
3735
rotary_emb_fwd(
3836
q.view(-1, self.tp_q_head_num_, self.head_dim_),

lightllm/models/qwen/layer_infer/transformer_layer_infer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,9 @@ def __init__(self, layer_num, network_config, mode=[]):
1818

1919
def _get_qkv(self, input_emb, infer_state: QwenInferStateInfo, layer_weight: QwenTransformerLayerWeight):
2020
q = layer_weight.q_proj.mm(input_emb)
21-
cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight)
22-
cache_kv = layer_weight.kv_proj.mm(
23-
input_emb, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
24-
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
21+
cache_kv = layer_weight.kv_proj.mm(input_emb).view(
22+
-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_
23+
)
2524

2625
rotary_emb_fwd(
2726
q.view(-1, self.tp_q_head_num_, self.head_dim_),

0 commit comments

Comments
 (0)