@@ -978,6 +978,7 @@ def _executor_loop_pp(self):
978978 self .micro_batches [prev_microbatch_id ] = None
979979
980980 if self .kv_cache_transceiver and self .ctx_in_transmission_requests :
981+ self ._check_kv_transfer_timeout ()
981982 self ._terminate_ctx_finished_requests ()
982983
983984 if self ._disagg_pp_termination_handler is not None :
@@ -1006,6 +1007,7 @@ def _prepare_and_schedule_batch(self):
10061007
10071008 if self .kv_cache_transceiver :
10081009 self ._check_disagg_gen_transfer_status ()
1010+ self ._check_kv_transfer_timeout ()
10091011
10101012 iter_stats = None
10111013 if self .enable_iter_perf_stats :
@@ -1179,6 +1181,7 @@ def _executor_loop(self):
11791181 self ._add_kv_cache_events ()
11801182
11811183 if self .kv_cache_transceiver and self .ctx_in_transmission_requests :
1184+ self ._check_kv_transfer_timeout ()
11821185 self ._terminate_ctx_finished_requests ()
11831186
11841187 self ._kv_connector_terminate_requests ()
@@ -1364,6 +1367,7 @@ def _executor_loop_overlap(self):
13641367 ctx_transmission_reqs = ctx_transmission_reqs )
13651368
13661369 if self .kv_cache_transceiver and self .ctx_in_transmission_requests :
1370+ self ._check_kv_transfer_timeout ()
13671371 self ._terminate_ctx_finished_requests ()
13681372
13691373 self ._kv_connector_terminate_requests ()
@@ -1572,6 +1576,38 @@ def _check_disagg_gen_transfer_status(self):
15721576
15731577 return
15741578
1579+ @nvtx_range ("_check_kv_transfer_timeout" )
1580+ def _check_kv_transfer_timeout (self ):
1581+ if not self .kv_cache_transceiver :
1582+ return
1583+ timeout_ms = self .kv_cache_transceiver .kv_transfer_timeout_ms
1584+ if timeout_ms is None or timeout_ms <= 0 :
1585+ return
1586+
1587+ current_time = time .time ()
1588+
1589+ for req in self .ctx_in_transmission_requests :
1590+ if req .py_kv_transfer_start_time is None :
1591+ continue
1592+ elapsed_time = (current_time - req .py_kv_transfer_start_time ) * 1000
1593+ if elapsed_time > timeout_ms and not req .py_kv_transfer_timed_out :
1594+ logger .warning (
1595+ f"Terminating context request { req .py_request_id } due to KV cache transfer timeout"
1596+ )
1597+ req .py_kv_transfer_timed_out = True
1598+
1599+ for req in self .active_requests :
1600+ if req .is_disagg_generation_transmission_in_progress and req .py_kv_transfer_start_time is not None :
1601+ elapsed_time = (current_time -
1602+ req .py_kv_transfer_start_time ) * 1000
1603+ if elapsed_time > timeout_ms and not req .py_kv_transfer_timed_out :
1604+ logger .warning (
1605+ f"Terminating generation request { req .py_request_id } due to KV cache transfer timeout"
1606+ )
1607+ req .py_kv_transfer_timed_out = True
1608+
1609+ return
1610+
15751611 @nvtx_range ("_pad_attention_dp_dummy_request" )
15761612 def _pad_attention_dp_dummy_request (self ):
15771613 """
@@ -1646,6 +1682,7 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch):
16461682 req .context_current_position = req .prompt_len
16471683 req .decoding_iter = 1
16481684 req .py_decoding_iter = 1
1685+ req .py_kv_transfer_start_time = None
16491686 first_gen_tokens = req .context_phase_params .first_gen_tokens
16501687 ctx_draft_tokens = req .context_phase_params .draft_tokens
16511688 req .py_draft_tokens = [] if ctx_draft_tokens is None else ctx_draft_tokens
@@ -1669,6 +1706,11 @@ def _recv_disagg_gen_cache(self, new_gen_reqs):
16691706 for req in new_gen_reqs :
16701707 self .kv_cache_transceiver .request_and_receive_async (req )
16711708
1709+ if self .kv_cache_transceiver .kv_transfer_timeout_ms is not None :
1710+ for req in new_gen_reqs :
1711+ if req .state == LlmRequestState .DISAGG_GENERATION_TRANS_IN_PROGRESS :
1712+ req .py_kv_transfer_start_time = time .time ()
1713+
16721714 block_transfer = all ([
16731715 req .is_disagg_generation_transmission_in_progress
16741716 for req in self .active_requests
@@ -1701,6 +1743,11 @@ def _send_disagg_ctx_cache(self, scheduled_ctx_requests):
17011743 if req .state == LlmRequestState .DISAGG_CONTEXT_TRANS_IN_PROGRESS
17021744 ]
17031745
1746+ if self .kv_cache_transceiver .kv_transfer_timeout_ms is not None :
1747+ for req in ctx_in_transmission_requests :
1748+ if req .state == LlmRequestState .DISAGG_CONTEXT_TRANS_IN_PROGRESS :
1749+ req .py_kv_transfer_start_time = time .time ()
1750+
17041751 return ctx_transmission_reqs
17051752
17061753 def _get_disagg_reqs_in_error_state (self ):
@@ -2018,6 +2065,12 @@ def _handle_responses(self):
20182065 requests_to_terminate .append (request )
20192066 continue
20202067
2068+ # Check if generation request needs cleanup due to KV cache transfer timeout
2069+ if request .py_kv_transfer_timed_out :
2070+ # Previously, we were doing _handle_errors, which sends an error response.
2071+ # We should consider how we should be doing this now?
2072+ self .kv_cache_transceiver .cancel_request (request )
2073+
20212074 if request .is_generation_only_request ():
20222075 # If request is in transmission, so we don't need to emit a response
20232076 # Also, for the first iteration with overlap, we should skip since first
@@ -2068,6 +2121,9 @@ def _handle_responses(self):
20682121 def _terminate_ctx_finished_requests (self ):
20692122 for request , block_id in self .ctx_in_transmission_requests [:]:
20702123 if request .is_disagg_context_complete_state :
2124+ if request .py_kv_transfer_timed_out :
2125+ request .py_kv_transfer_start_time = None
2126+ self .kv_cache_transceiver .cancel_request (request )
20712127 if not self .block_reuse_enabled or self .kv_cache_manager .is_vswa :
20722128 self ._terminate_request (request )
20732129 else :
0 commit comments