Skip to content

Commit 82df7a1

Browse files
hiworldwzjwangzaijunniushengxiao
authored
Cpu KV Cache feature (#997)
Co-authored-by: wangzaijun <[email protected]> Co-authored-by: niushengxiao <[email protected]>
1 parent 3291109 commit 82df7a1

File tree

46 files changed

+2795
-385
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+2795
-385
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def __init__(self, kvargs):
6161
self.finetune_config = kvargs.get("finetune_config", None)
6262
self.max_req_num = kvargs.get("max_req_num", 1000)
6363
self.max_seq_length = kvargs.get("max_seq_length", 1024 * 5)
64+
# 用于等待外围的一些模块的初始化完成(如 CPU KV Cache 注册完成)
65+
self.wait_events = kvargs.get("wait_events", [])
6466
# is_token_healing 和 return_all_prompt_logics 是有排斥关系的两个模式,只能单独有一个生效
6567
# 主要是在prefill阶段返回多少个token的用于后续处理相关。
6668
self.is_token_healing = kvargs.get("is_token_healing", False)
@@ -110,12 +112,19 @@ def __init__(self, kvargs):
110112
self._init_inferstate_cls()
111113
self._autotune_warmup()
112114
self._init_padded_req()
115+
# wait必须在init cudagraph 之前,避免错误捕获
116+
self._wait_other_modules_ready()
113117
self._init_cudagraph()
114118
self._check_max_len_infer()
115119
torch.cuda.empty_cache()
116120
set_model_init_status(True)
117121
return
118122

123+
def _wait_other_modules_ready(self):
124+
for event in self.wait_events:
125+
event.wait()
126+
return
127+
119128
def _init_config(self):
120129
with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
121130
self.config = json.load(json_file)
@@ -352,8 +361,13 @@ def _prefill(
352361
alloc_mem_index=infer_state.mem_index,
353362
max_q_seq_len=infer_state.max_q_seq_len,
354363
)
364+
prefill_mem_indexes_ready_event = torch.cuda.Event()
365+
prefill_mem_indexes_ready_event.record()
366+
355367
infer_state.init_some_extra_state(self, model_input.input_ids)
356-
return self._context_forward(model_input.input_ids, infer_state)
368+
model_output = self._context_forward(model_input.input_ids, infer_state)
369+
model_output.prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event
370+
return model_output
357371

358372
def _decode(
359373
self,
@@ -505,13 +519,18 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
505519
)
506520
infer_state1.init_some_extra_state(self, input_ids1)
507521

522+
prefill_mem_indexes_ready_event = torch.cuda.Event()
523+
prefill_mem_indexes_ready_event.record()
524+
508525
model_output0, model_output1 = self._overlap_tpsp_context_forward(
509526
input_ids0, infer_state0, input_ids1=input_ids1, infer_state1=infer_state1
510527
)
511528

512529
# 在开启使用deepep的时候,需要调用clear_deepep_buffer做资源清理,没有启用的时候
513530
# 该调用没有实际意义
514531
dist_group_manager.clear_deepep_buffer()
532+
model_output0.prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event
533+
model_output1.prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event
515534
return model_output0, model_output1
516535

517536
@torch.no_grad()

lightllm/common/basemodel/batch_objs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ def to_cuda(self):
5858
class ModelOutput:
5959
# 通用变量
6060
logits: torch.Tensor
61+
# 用于判断 mem_indexes 是否成功写入 req manager 中的事件对象。
62+
prefill_mem_indexes_ready_event: torch.Event = None
6163

6264
# 专有变量,用于一些特殊的模型,特殊的模式下, 传递一些特殊
6365
# 的输出变量。只在特殊的模型模式下才会具体使用和生效。

0 commit comments

Comments
 (0)