@@ -61,6 +61,8 @@ def __init__(self, kvargs):
6161 self .finetune_config = kvargs .get ("finetune_config" , None )
6262 self .max_req_num = kvargs .get ("max_req_num" , 1000 )
6363 self .max_seq_length = kvargs .get ("max_seq_length" , 1024 * 5 )
64+ # 用于等待外围的一些模块的初始化完成(如 CPU KV Cache 注册完成)
65+ self .wait_events = kvargs .get ("wait_events" , [])
6466 # is_token_healing 和 return_all_prompt_logics 是有排斥关系的两个模式,只能单独有一个生效
6567 # 主要是在prefill阶段返回多少个token的用于后续处理相关。
6668 self .is_token_healing = kvargs .get ("is_token_healing" , False )
@@ -110,12 +112,19 @@ def __init__(self, kvargs):
110112 self ._init_inferstate_cls ()
111113 self ._autotune_warmup ()
112114 self ._init_padded_req ()
115+ # wait必须在init cudagraph 之前,避免错误捕获
116+ self ._wait_other_modules_ready ()
113117 self ._init_cudagraph ()
114118 self ._check_max_len_infer ()
115119 torch .cuda .empty_cache ()
116120 set_model_init_status (True )
117121 return
118122
123+ def _wait_other_modules_ready (self ):
124+ for event in self .wait_events :
125+ event .wait ()
126+ return
127+
119128 def _init_config (self ):
120129 with open (os .path .join (self .weight_dir_ , "config.json" ), "r" ) as json_file :
121130 self .config = json .load (json_file )
@@ -352,8 +361,13 @@ def _prefill(
352361 alloc_mem_index = infer_state .mem_index ,
353362 max_q_seq_len = infer_state .max_q_seq_len ,
354363 )
364+ prefill_mem_indexes_ready_event = torch .cuda .Event ()
365+ prefill_mem_indexes_ready_event .record ()
366+
355367 infer_state .init_some_extra_state (self , model_input .input_ids )
356- return self ._context_forward (model_input .input_ids , infer_state )
368+ model_output = self ._context_forward (model_input .input_ids , infer_state )
369+ model_output .prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event
370+ return model_output
357371
358372 def _decode (
359373 self ,
@@ -505,13 +519,18 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
505519 )
506520 infer_state1 .init_some_extra_state (self , input_ids1 )
507521
522+ prefill_mem_indexes_ready_event = torch .cuda .Event ()
523+ prefill_mem_indexes_ready_event .record ()
524+
508525 model_output0 , model_output1 = self ._overlap_tpsp_context_forward (
509526 input_ids0 , infer_state0 , input_ids1 = input_ids1 , infer_state1 = infer_state1
510527 )
511528
512529 # 在开启使用deepep的时候,需要调用clear_deepep_buffer做资源清理,没有启用的时候
513530 # 该调用没有实际意义
514531 dist_group_manager .clear_deepep_buffer ()
532+ model_output0 .prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event
533+ model_output1 .prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event
515534 return model_output0 , model_output1
516535
517536 @torch .no_grad ()
0 commit comments