Skip to content

Commit 0574624

Browse files
committed
update to ToM, clean up a bit, move to cancel_request
Signed-off-by: raayandhar <[email protected]>
1 parent 2b8722b commit 0574624

File tree

9 files changed

+109
-18
lines changed

9 files changed

+109
-18
lines changed

cpp/include/tensorrt_llm/executor/executor.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,13 +1456,15 @@ class CacheTransceiverConfig
14561456
UCX = 2,
14571457
NIXL = 3
14581458
};
1459-
explicit CacheTransceiverConfig(
1460-
std::optional<BackendType> backendType = std::nullopt, std::optional<size_t> maxNumTokens = std::nullopt);
1459+
explicit CacheTransceiverConfig(std::optional<BackendType> backendType = std::nullopt,
1460+
std::optional<size_t> maxNumTokens = std::nullopt, std::optional<int> kvTransferTimeoutMs = std::nullopt);
14611461

14621462
bool operator==(CacheTransceiverConfig const& other) const;
14631463
void setBackendType(std::optional<BackendType> backendType);
14641464
void setMaxTokensInBuffer(std::optional<size_t> maxTokensInBuffer);
1465+
void setKvTransferTimeoutMs(std::optional<int> kvTransferTimeoutMs);
14651466

1467+
[[nodiscard]] std::optional<int> getKvTransferTimeoutMs() const;
14661468
[[nodiscard]] std::optional<size_t> getMaxTokensInBuffer() const;
14671469
[[nodiscard]] std::optional<BackendType> getBackendType() const;
14681470

@@ -1472,6 +1474,7 @@ class CacheTransceiverConfig
14721474
/// kvCache tokens to be transferred for a single request is greater than this value, the performance of the cache
14731475
/// transfer may be degraded.
14741476
std::optional<size_t> mMaxTokensInBuffer;
1477+
std::optional<int> mKvTransferTimeoutMs;
14751478
};
14761479

14771480
/// @brief Configuration class for the model executor

cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,17 @@ namespace tensorrt_llm::executor
2222
{
2323

2424
CacheTransceiverConfig::CacheTransceiverConfig(
25-
std::optional<BackendType> backendType, std::optional<size_t> maxNumTokens)
25+
std::optional<BackendType> backendType, std::optional<size_t> maxNumTokens, std::optional<int> kvTransferTimeoutMs)
2626
: mBackendType(backendType)
27-
, mMaxTokensInBuffer(maxNumTokens)
27+
: mMaxTokensInBuffer(maxNumTokens)
28+
, mKvTransferTimeoutMs(kvTransferTimeoutMs)
2829
{
2930
}
3031

3132
bool CacheTransceiverConfig::operator==(CacheTransceiverConfig const& other) const
3233
{
33-
return mMaxTokensInBuffer == other.mMaxTokensInBuffer && mBackendType == other.mBackendType;
34+
return mMaxTokensInBuffer == other.mMaxTokensInBuffer && mBackendType == other.mBackendType
35+
&& mKvTransferTimeoutMs == other.mKvTransferTimeoutMs;
3436
}
3537

3638
void CacheTransceiverConfig::setBackendType(std::optional<BackendType> backendType)
@@ -43,6 +45,11 @@ void CacheTransceiverConfig::setMaxTokensInBuffer(std::optional<size_t> maxToken
4345
mMaxTokensInBuffer = maxTokensInBuffer;
4446
}
4547

48+
void CacheTransceiverConfig::setKvTransferTimeoutMs(std::optional<int> kvTransferTimeoutMs)
49+
{
50+
mKvTransferTimeoutMs = kvTransferTimeoutMs;
51+
}
52+
4653
std::optional<CacheTransceiverConfig::BackendType> CacheTransceiverConfig::getBackendType() const
4754
{
4855
return mBackendType;
@@ -53,4 +60,9 @@ std::optional<size_t> CacheTransceiverConfig::getMaxTokensInBuffer() const
5360
return mMaxTokensInBuffer;
5461
}
5562

63+
std::optional<int> CacheTransceiverConfig::getKvTransferTimeoutMs() const
64+
{
65+
return mKvTransferTimeoutMs;
66+
}
67+
5668
} // namespace tensorrt_llm::executor

cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -433,15 +433,15 @@ void initConfigBindings(nb::module_& m)
433433
.def("__setstate__", guidedDecodingConfigSetstate);
434434

435435
auto cacheTransceiverConfigGetstate = [](tle::CacheTransceiverConfig const& self)
436-
{ return nb::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer()); };
436+
{ return nb::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer(), self.getKvTransferTimeoutMs()); };
437437
auto cacheTransceiverConfigSetstate = [](tle::CacheTransceiverConfig& self, nb::tuple const& state)
438438
{
439-
if (state.size() != 2)
439+
if (state.size() != 3)
440440
{
441441
throw std::runtime_error("Invalid CacheTransceiverConfig state!");
442442
}
443-
new (&self) tle::CacheTransceiverConfig(
444-
nb::cast<tle::CacheTransceiverConfig::BackendType>(state[0]), nb::cast<std::optional<size_t>>(state[1]));
443+
new (&self) tle::CacheTransceiverConfig(nb::cast<tle::CacheTransceiverConfig::BackendType>(state[0]),
444+
nb::cast<std::optional<size_t>>(state[1]), nb::cast<std::optional<int>>(state[2]));
445445
};
446446

447447
nb::enum_<tle::CacheTransceiverConfig::BackendType>(m, "CacheTransceiverBackendType")
@@ -464,12 +464,15 @@ void initConfigBindings(nb::module_& m)
464464
});
465465

466466
nb::class_<tle::CacheTransceiverConfig>(m, "CacheTransceiverConfig")
467-
.def(nb::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>(),
468-
nb::arg("backend") = std::nullopt, nb::arg("max_tokens_in_buffer") = std::nullopt)
467+
.def(nb::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>,
468+
std::optional < int >> (), nb::arg("backend") = std::nullopt,
469+
nb::arg("max_tokens_in_buffer") = std::nullopt, nb::arg("kv_transfer_timeout_ms") = std::nullopt)
469470
.def_prop_rw(
470471
"backend", &tle::CacheTransceiverConfig::getBackendType, &tle::CacheTransceiverConfig::setBackendType)
471472
.def_prop_rw("max_tokens_in_buffer", &tle::CacheTransceiverConfig::getMaxTokensInBuffer,
472473
&tle::CacheTransceiverConfig::setMaxTokensInBuffer)
474+
.def_prop_rw("kv_transfer_timeout_ms", &tle::CacheTransceiverConfig::getKvTransferTimeoutMs,
475+
&tle::CacheTransceiverConfig::setKvTransferTimeoutMs)
473476
.def("__getstate__", cacheTransceiverConfigGetstate)
474477
.def("__setstate__", cacheTransceiverConfigSetstate);
475478

cpp/tensorrt_llm/pybind/executor/executorConfig.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -415,15 +415,15 @@ void initConfigBindings(pybind11::module_& m)
415415
.def(py::pickle(guidedDecodingConfigGetstate, guidedDecodingConfigSetstate));
416416

417417
auto cacheTransceiverConfigGetstate = [](tle::CacheTransceiverConfig const& self)
418-
{ return py::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer()); };
418+
{ return py::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer(), self.getKvTransferTimeoutMs()); };
419419
auto cacheTransceiverConfigSetstate = [](py::tuple const& state)
420420
{
421-
if (state.size() != 2)
421+
if (state.size() != 3)
422422
{
423423
throw std::runtime_error("Invalid CacheTransceiverConfig state!");
424424
}
425-
return tle::CacheTransceiverConfig(
426-
state[0].cast<tle::CacheTransceiverConfig::BackendType>(), state[1].cast<std::optional<size_t>>());
425+
return tle::CacheTransceiverConfig(state[0].cast<tle::CacheTransceiverConfig::BackendType>(),
426+
state[1].cast<std::optional<size_t>>(), state[2].cast<std::optional<int>>());
427427
};
428428

429429
py::enum_<tle::CacheTransceiverConfig::BackendType>(m, "CacheTransceiverBackendType")
@@ -446,12 +446,15 @@ void initConfigBindings(pybind11::module_& m)
446446
});
447447

448448
py::class_<tle::CacheTransceiverConfig>(m, "CacheTransceiverConfig")
449-
.def(py::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>(),
450-
py::arg("backend") = std::nullopt, py::arg("max_tokens_in_buffer") = std::nullopt)
449+
.def(py::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>,
450+
std::optional < int >> (), py::arg("backend") = std::nullopt,
451+
py::arg("max_tokens_in_buffer") = std::nullopt, py::arg("kv_transfer_timeout_ms") = std::nullopt)
451452
.def_property(
452453
"backend", &tle::CacheTransceiverConfig::getBackendType, &tle::CacheTransceiverConfig::setBackendType)
453454
.def_property("max_tokens_in_buffer", &tle::CacheTransceiverConfig::getMaxTokensInBuffer,
454455
&tle::CacheTransceiverConfig::setMaxTokensInBuffer)
456+
.def_property("kv_transfer_timeout_ms", &tle::CacheTransceiverConfig::getKvTransferTimeoutMs,
457+
&tle::CacheTransceiverConfig::setKvTransferTimeoutMs)
455458
.def(py::pickle(cacheTransceiverConfigGetstate, cacheTransceiverConfigSetstate));
456459

457460
auto executorConfigGetState = [](py::object const& self)

examples/disaggregated/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ cache_transceiver_config:
1616
backend: <str>
1717
# KV cache buffer size. Set it ≥ the maximum ISL (Input Sequence Length) for best performance.
1818
max_tokens_in_buffer: <int>
19+
# KV cache transfer timeout in milliseconds
20+
# For requests, if they do not send/receive the KV cache in time they are cancelled and cleaned up
21+
kv_transfer_timeout_ms: <int>
1922
```
2023
2124
The following is an example, consisting of the `ctx_extra-llm-api-config.yaml` and `gen_extra-llm-api-config.yaml` files needed in the sections below.

tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ def __init__(self, mapping: Mapping, dist: Distributed,
109109
# get the layer num per pp rank, which is required by cache transceiver.
110110
pp_layer_num = len(kv_cache_manager.pp_layers)
111111
pp_layer_num_per_pp_rank = dist.pp_allgather(pp_layer_num)
112+
113+
self.kv_transfer_timeout_ms = cache_transceiver_config.kv_transfer_timeout_ms
112114
self.impl = CacheTransceiverCpp(kv_cache_manager.impl,
113115
total_num_kv_heads_per_layer, head_dim,
114116
tokens_per_block, world_config,

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,8 @@ def __init__(
442442
self.py_lora_task_layer_module_configs: list[
443443
tensorrt_llm.bindings.internal.runtime.
444444
TaskLayerModuleConfig] | None = None
445+
self.py_kv_transfer_start_time = None
446+
self.py_kv_transfer_timed_out = False
445447

446448
self.py_num_logprobs = num_logprobs
447449
self.py_return_log_probs = return_log_probs

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -978,6 +978,7 @@ def _executor_loop_pp(self):
978978
self.micro_batches[prev_microbatch_id] = None
979979

980980
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
981+
self._check_kv_transfer_timeout()
981982
self._terminate_ctx_finished_requests()
982983

983984
if self._disagg_pp_termination_handler is not None:
@@ -1006,6 +1007,7 @@ def _prepare_and_schedule_batch(self):
10061007

10071008
if self.kv_cache_transceiver:
10081009
self._check_disagg_gen_transfer_status()
1010+
self._check_kv_transfer_timeout()
10091011

10101012
iter_stats = None
10111013
if self.enable_iter_perf_stats:
@@ -1179,6 +1181,7 @@ def _executor_loop(self):
11791181
self._add_kv_cache_events()
11801182

11811183
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
1184+
self._check_kv_transfer_timeout()
11821185
self._terminate_ctx_finished_requests()
11831186

11841187
self._kv_connector_terminate_requests()
@@ -1364,6 +1367,7 @@ def _executor_loop_overlap(self):
13641367
ctx_transmission_reqs=ctx_transmission_reqs)
13651368

13661369
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
1370+
self._check_kv_transfer_timeout()
13671371
self._terminate_ctx_finished_requests()
13681372

13691373
self._kv_connector_terminate_requests()
@@ -1572,6 +1576,38 @@ def _check_disagg_gen_transfer_status(self):
15721576

15731577
return
15741578

1579+
@nvtx_range("_check_kv_transfer_timeout")
1580+
def _check_kv_transfer_timeout(self):
1581+
if not self.kv_cache_transceiver:
1582+
return
1583+
timeout_ms = self.kv_cache_transceiver.kv_transfer_timeout_ms
1584+
if timeout_ms is None or timeout_ms <= 0:
1585+
return
1586+
1587+
current_time = time.time()
1588+
1589+
for req in self.ctx_in_transmission_requests:
1590+
if req.py_kv_transfer_start_time is None:
1591+
continue
1592+
elapsed_time = (current_time - req.py_kv_transfer_start_time) * 1000
1593+
if elapsed_time > timeout_ms and not req.py_kv_transfer_timed_out:
1594+
logger.warning(
1595+
f"Terminating context request {req.py_request_id} due to KV cache transfer timeout"
1596+
)
1597+
req.py_kv_transfer_timed_out = True
1598+
1599+
for req in self.active_requests:
1600+
if req.is_disagg_generation_transmission_in_progress and req.py_kv_transfer_start_time is not None:
1601+
elapsed_time = (current_time -
1602+
req.py_kv_transfer_start_time) * 1000
1603+
if elapsed_time > timeout_ms and not req.py_kv_transfer_timed_out:
1604+
logger.warning(
1605+
f"Terminating generation request {req.py_request_id} due to KV cache transfer timeout"
1606+
)
1607+
req.py_kv_transfer_timed_out = True
1608+
1609+
return
1610+
15751611
@nvtx_range("_pad_attention_dp_dummy_request")
15761612
def _pad_attention_dp_dummy_request(self):
15771613
"""
@@ -1646,6 +1682,7 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch):
16461682
req.context_current_position = req.prompt_len
16471683
req.decoding_iter = 1
16481684
req.py_decoding_iter = 1
1685+
req.py_kv_transfer_start_time = None
16491686
first_gen_tokens = req.context_phase_params.first_gen_tokens
16501687
ctx_draft_tokens = req.context_phase_params.draft_tokens
16511688
req.py_draft_tokens = [] if ctx_draft_tokens is None else ctx_draft_tokens
@@ -1669,6 +1706,11 @@ def _recv_disagg_gen_cache(self, new_gen_reqs):
16691706
for req in new_gen_reqs:
16701707
self.kv_cache_transceiver.request_and_receive_async(req)
16711708

1709+
if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None:
1710+
for req in new_gen_reqs:
1711+
if req.state == LlmRequestState.DISAGG_GENERATION_TRANS_IN_PROGRESS:
1712+
req.py_kv_transfer_start_time = time.time()
1713+
16721714
block_transfer = all([
16731715
req.is_disagg_generation_transmission_in_progress
16741716
for req in self.active_requests
@@ -1701,6 +1743,11 @@ def _send_disagg_ctx_cache(self, scheduled_ctx_requests):
17011743
if req.state == LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS
17021744
]
17031745

1746+
if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None:
1747+
for req in ctx_in_transmission_requests:
1748+
if req.state == LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS:
1749+
req.py_kv_transfer_start_time = time.time()
1750+
17041751
return ctx_transmission_reqs
17051752

17061753
def _get_disagg_reqs_in_error_state(self):
@@ -2018,6 +2065,12 @@ def _handle_responses(self):
20182065
requests_to_terminate.append(request)
20192066
continue
20202067

2068+
# Check if generation request needs cleanup due to KV cache transfer timeout
2069+
if request.py_kv_transfer_timed_out:
2070+
# Previously, we were doing _handle_errors, which sends an error response.
2071+
# We should consider how we should be doing this now?
2072+
self.kv_cache_transceiver.cancel_request(request)
2073+
20212074
if request.is_generation_only_request():
20222075
# If request is in transmission, so we don't need to emit a response
20232076
# Also, for the first iteration with overlap, we should skip since first
@@ -2068,6 +2121,9 @@ def _handle_responses(self):
20682121
def _terminate_ctx_finished_requests(self):
20692122
for request, block_id in self.ctx_in_transmission_requests[:]:
20702123
if request.is_disagg_context_complete_state:
2124+
if request.py_kv_transfer_timed_out:
2125+
request.py_kv_transfer_start_time = None
2126+
self.kv_cache_transceiver.cancel_request(request)
20712127
if not self.block_reuse_enabled or self.kv_cache_manager.is_vswa:
20722128
self._terminate_request(request)
20732129
else:

tensorrt_llm/llmapi/llm_args.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1286,10 +1286,17 @@ class CacheTransceiverConfig(StrictBaseModel, PybindMirror):
12861286
default=None,
12871287
description="The max number of tokens the transfer buffer can fit.")
12881288

1289+
kv_transfer_timeout_ms: Optional[int] = Field(
1290+
default=None,
1291+
description=
1292+
"Timeout in milliseconds for KV cache transfer. Requests exceeding this timeout will be cancelled."
1293+
)
1294+
12891295
def _to_pybind(self):
12901296
return _CacheTransceiverConfig(
12911297
backend=_CacheTransceiverBackendType.from_string(self.backend),
1292-
max_tokens_in_buffer=self.max_tokens_in_buffer)
1298+
max_tokens_in_buffer=self.max_tokens_in_buffer,
1299+
kv_transfer_timeout_ms=self.kv_transfer_timeout_ms)
12931300

12941301

12951302
@dataclass

0 commit comments

Comments
 (0)