From da1cbe4e8e9ab9f9723975a8bedd27ac7b207531 Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Tue, 22 Jul 2025 09:56:58 -0700 Subject: [PATCH 01/50] Initial cpp binding stuff Signed-off-by: jthomson04 --- .../batch_manager/kvCacheConnector.h | 81 +++++++++++++++++++ cpp/tensorrt_llm/batch_manager/CMakeLists.txt | 1 + .../batch_manager/kvCacheConnector.cpp | 31 +++++++ .../pybind/batch_manager/bindings.cpp | 22 +++++ .../pybind/batch_manager/kvCacheConnector.h | 77 ++++++++++++++++++ 5 files changed, 212 insertions(+) create mode 100644 cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h create mode 100644 cpp/tensorrt_llm/batch_manager/kvCacheConnector.cpp create mode 100644 cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.h diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h new file mode 100644 index 00000000000..3ad9f21db93 --- /dev/null +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/batch_manager/llmRequest.h" +#include "tensorrt_llm/runtime/common.h" + +#include +#include + +using SizeType32 = tensorrt_llm::runtime::SizeType32; +using RequestIdType = tensorrt_llm::batch_manager::LlmRequest::RequestIdType; + +using namespace tensorrt_llm::batch_manager; + +namespace tensorrt_llm::batch_manager::kv_connector +{ +enum KvCacheConnectorRole : std::int8_t +{ + Scheduler, + Worker +}; + +class KvCacheConnector +{ +public: + explicit KvCacheConnector(KvCacheConnectorRole role); + virtual ~KvCacheConnector() = default; + + [[nodiscard]] KvCacheConnectorRole role() const; + + // + // WORKER SIDE METHODS + // + + // TODO(jothomson): Need arguments here. + virtual void registerKvCaches(); + + // TODO(jothomson): Need arguments here. + virtual void startLoadKv() = 0; + + virtual void waitForLayerLoad(SizeType32 layer_idx) = 0; + + // TODO(jothomson): Need arguments here. + virtual void saveKvLayer(SizeType32 layer_idx) = 0; + + virtual void waitForSave() = 0; + + virtual std::tuple, std::vector> getFinished( + std::vector const& finishedReqIds); + + // + // SCHEDULER SIDE METHODS + // + + virtual std::tuple getNumNewMatchedTokens(LlmRequest const& request, SizeType32 numComputedTokens) + = 0; + + // TODO(jothomson): Need arguments here. Also, is this even needed? + virtual void updateStateAfterAlloc(); + + virtual bool requestFinished(LlmRequest const& request); + +private: + KvCacheConnectorRole mRole; +}; +} // namespace tensorrt_llm::batch_manager::kv_connector diff --git a/cpp/tensorrt_llm/batch_manager/CMakeLists.txt b/cpp/tensorrt_llm/batch_manager/CMakeLists.txt index 5f7d774c0b0..75f1e0fa20b 100644 --- a/cpp/tensorrt_llm/batch_manager/CMakeLists.txt +++ b/cpp/tensorrt_llm/batch_manager/CMakeLists.txt @@ -30,6 +30,7 @@ set(SRCS guidedDecoder.cpp handleContextLogits.cpp handleGenerationLogits.cpp + kvCacheConnector.cpp kvCacheManager.cpp kvCacheEventManager.cpp kvCacheTransferManager.cpp diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheConnector.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheConnector.cpp new file mode 100644 index 00000000000..e564716e76b --- /dev/null +++ b/cpp/tensorrt_llm/batch_manager/kvCacheConnector.cpp @@ -0,0 +1,31 @@ +#include "tensorrt_llm/batch_manager/kvCacheConnector.h" + +namespace tensorrt_llm::batch_manager::kv_connector +{ + +KvCacheConnector::KvCacheConnector(KvCacheConnectorRole role) + : mRole(role) +{ +} + +KvCacheConnectorRole KvCacheConnector::role() const +{ + return mRole; +} + +void KvCacheConnector::registerKvCaches() {} + +std::tuple, std::vector> KvCacheConnector::getFinished( + std::vector const& finishedReqIds) +{ + return std::make_tuple(std::vector(), std::vector()); +} + +void KvCacheConnector::updateStateAfterAlloc() {} + +bool KvCacheConnector::requestFinished(LlmRequest const& request) +{ + return false; +} + +} // namespace tensorrt_llm::batch_manager::kv_connector diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp index 04faa90c2ff..ac6201eb659 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp @@ -19,10 +19,13 @@ #include "tensorrt_llm/batch_manager/common.h" #include "tensorrt_llm/batch_manager/decoderBuffers.h" +#include "tensorrt_llm/batch_manager/kvCacheConnector.h" +#include "tensorrt_llm/batch_manager/medusaBuffers.h" #include "tensorrt_llm/batch_manager/microBatchScheduler.h" #include "tensorrt_llm/batch_manager/peftCacheManager.h" #include "tensorrt_llm/batch_manager/rnnStateManager.h" #include "tensorrt_llm/batch_manager/sequenceSlotManager.h" +#include "tensorrt_llm/pybind/batch_manager/kvCacheConnector.h" #include "tensorrt_llm/pybind/common/bindTypes.h" #include "tensorrt_llm/runtime/gptDecoderBatched.h" #include "tensorrt_llm/runtime/runtimeKernels.h" @@ -523,6 +526,25 @@ void initBindings(pybind11::module_& m) py::arg("context_requests"), py::arg("generation_requests"), py::arg("logits"), py::arg("beam_width"), py::arg("num_context_logits_prefix_sum"), py::arg("decoder_input_buffers"), py::arg("decoder_state"), py::arg("buffer_manager"), "Make decoding batch input."); + + py::enum_(m, "KvCacheConnectorRole") + .value("Scheduler", tb::kv_connector::KvCacheConnectorRole::Scheduler) + .value("Worker", tb::kv_connector::KvCacheConnectorRole::Worker); + + py::class_( + m, "KvCacheConnector") + .def(py::init(), py::arg("role")) + .def("register_kv_caches", &tb::kv_connector::KvCacheConnector::registerKvCaches) + .def("start_load_kv", &tb::kv_connector::KvCacheConnector::startLoadKv) + .def("wait_for_layer_load", &tb::kv_connector::KvCacheConnector::waitForLayerLoad, py::arg("layer_idx")) + .def("save_kv_layer", &tb::kv_connector::KvCacheConnector::saveKvLayer, py::arg("layer_idx")) + .def("wait_for_save", &tb::kv_connector::KvCacheConnector::waitForSave) + .def("get_finished", &tb::kv_connector::KvCacheConnector::getFinished, py::arg("finished_req_ids")) + .def("get_num_new_matched_tokens", &tb::kv_connector::KvCacheConnector::getNumNewMatchedTokens, + py::arg("request"), py::arg("num_computed_tokens")) + .def("update_state_after_alloc", &tb::kv_connector::KvCacheConnector::updateStateAfterAlloc) + .def("request_finished", &tb::kv_connector::KvCacheConnector::requestFinished, py::arg("request")) + .def_property_readonly("role", &tb::kv_connector::KvCacheConnector::role); } } // namespace tensorrt_llm::pybind::batch_manager diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.h b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.h new file mode 100644 index 00000000000..2ba5693c26b --- /dev/null +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.h @@ -0,0 +1,77 @@ +#pragma once + +#include "tensorrt_llm/batch_manager/kvCacheConnector.h" +#include + +namespace py = pybind11; + +namespace tensorrt_llm::pybind::batch_manager::kv_connector +{ + +using namespace tensorrt_llm::batch_manager::kv_connector; + +class PyKvCacheConnector : public KvCacheConnector +{ +public: + using KvCacheConnector::KvCacheConnector; + + // + // WORKER SIDE METHODS + // + + void registerKvCaches() override + { + PYBIND11_OVERRIDE_PURE(void, KvCacheConnector, registerKvCaches); + } + + void startLoadKv() override + { + PYBIND11_OVERRIDE_PURE(void, KvCacheConnector, startLoadKv); + } + + void waitForLayerLoad(SizeType32 layer_idx) override + { + PYBIND11_OVERRIDE_PURE(void, KvCacheConnector, waitForLayerLoad, layer_idx); + } + + void saveKvLayer(SizeType32 layer_idx) override + { + PYBIND11_OVERRIDE_PURE(void, KvCacheConnector, saveKvLayer, layer_idx); + } + + void waitForSave() override + { + PYBIND11_OVERRIDE_PURE(void, KvCacheConnector, waitForSave); + } + + using FinishedReqs = std::tuple, std::vector>; + + FinishedReqs getFinished(std::vector const& finishedReqIds) override + { + PYBIND11_OVERRIDE_PURE(FinishedReqs, KvCacheConnector, getFinished, finishedReqIds); + } + + // + // SCHEDULER SIDE METHODS + // + + using NumNewMatchedTokens = std::tuple; + + NumNewMatchedTokens getNumNewMatchedTokens(LlmRequest const& request, SizeType32 numComputedTokens) override + { + PYBIND11_OVERRIDE_PURE( + NumNewMatchedTokens, KvCacheConnector, getNumNewMatchedTokens, request, numComputedTokens); + } + + void updateStateAfterAlloc() override + { + PYBIND11_OVERRIDE_PURE(void, KvCacheConnector, updateStateAfterAlloc); + } + + bool requestFinished(LlmRequest const& request) override + { + PYBIND11_OVERRIDE_PURE(bool, KvCacheConnector, requestFinished, request); + } +}; + +} // namespace tensorrt_llm::pybind::batch_manager::kv_connector From 4bd4df48687e8f7b35ed97907a8b4c0f97ee6b4d Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Tue, 22 Jul 2025 16:43:17 -0700 Subject: [PATCH 02/50] Basic connector tests Signed-off-by: jthomson04 --- tensorrt_llm/_torch/pyexecutor/connector.py | 21 +++++++++ tensorrt_llm/llmapi/llm_args.py | 8 +++- tensorrt_llm/models/modeling_utils.py | 6 +++ .../bindings/test_connector_bindings.py | 47 +++++++++++++++++++ 4 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 tensorrt_llm/_torch/pyexecutor/connector.py create mode 100644 tests/unittest/bindings/test_connector_bindings.py diff --git a/tensorrt_llm/_torch/pyexecutor/connector.py b/tensorrt_llm/_torch/pyexecutor/connector.py new file mode 100644 index 00000000000..49bef935630 --- /dev/null +++ b/tensorrt_llm/_torch/pyexecutor/connector.py @@ -0,0 +1,21 @@ +from typing import Optional + +from tensorrt_llm.bindings.internal.batch_manager import \ + KvCacheConnector as KvCacheConnectorCpp +from tensorrt_llm.bindings.internal.batch_manager import KvCacheConnectorRole + + +class KvCacheConnector(KvCacheConnectorCpp): + + def __init__(self, role: KvCacheConnectorRole): + super().__init__(role) + self.connector_metadata = None + + def bind_connector_metadata(self, metadata: object): + self.connector_metadata = metadata + + def _get_connector_metadata(self) -> object: + return self.connector_metadata + + def build_connector_metadata(self) -> Optional[object]: + return None diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 1169a779be6..95a0ca66df7 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -52,7 +52,8 @@ from ..logger import logger from ..mapping import Mapping from ..models.automodel import AutoConfig -from ..models.modeling_utils import (PretrainedConfig, QuantAlgo, QuantConfig, +from ..models.modeling_utils import (KvCacheConnectorConfig, PretrainedConfig, + QuantAlgo, QuantConfig, SpeculativeDecodingMode) from ..sampling_params import BatchedLogitsProcessor from .build_cache import BuildCacheConfig @@ -2120,6 +2121,11 @@ class TorchLlmArgs(BaseLlmArgs): status="prototype", ) + connector_config: Optional[KvCacheConnectorConfig] = Field( + default=None, + description="The config for KV cache connector.", + ) + # PrivateVars _quant_config: Optional[QuantConfig] = PrivateAttr(default=None) diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py index b2fdc393a02..cc99566b46a 100644 --- a/tensorrt_llm/models/modeling_utils.py +++ b/tensorrt_llm/models/modeling_utils.py @@ -124,6 +124,12 @@ def from_arguments(args: argparse.Namespace): assert False, "Unknown speculative_decoding_mode " + args.speculative_decoding_mode +@dataclasses.dataclass +class KvCacheConnectorConfig: + connector_module: str + connector_class: str + + @dataclasses.dataclass class QuantConfig: """ diff --git a/tests/unittest/bindings/test_connector_bindings.py b/tests/unittest/bindings/test_connector_bindings.py new file mode 100644 index 00000000000..6c40014ad11 --- /dev/null +++ b/tests/unittest/bindings/test_connector_bindings.py @@ -0,0 +1,47 @@ +from typing import Optional + +from tensorrt_llm._torch.pyexecutor.connector import KvCacheConnector +from tensorrt_llm.bindings.internal.batch_manager import (KvCacheConnectorRole, + LlmRequest) + + +class BasicConnector(KvCacheConnector): + + def __init__(self, role: KvCacheConnectorRole): + super().__init__(role) + + def build_connector_metadata(self) -> Optional[object]: + return {"test": "test"} + + def start_load_kv(self): + pass + + def wait_for_layer_load(self, layer_idx: int): + pass + + def save_kv_layer(self, layer_idx: int): + pass + + def wait_for_save(self): + pass + + def get_num_new_matched_tokens( + self, request: LlmRequest, + num_computed_tokens: int) -> tuple[int, bool]: + return 16, True + + +def test_basic_init(): + connector = BasicConnector(KvCacheConnectorRole.Scheduler) + + assert connector.role == KvCacheConnectorRole.Scheduler + + assert connector.build_connector_metadata() == {"test": "test"} + + # Try calling some of the other virtual methods. + connector.save_kv_layer(0) + connector.wait_for_save() + + connector_worker = BasicConnector(KvCacheConnectorRole.Worker) + + assert connector_worker.role == KvCacheConnectorRole.Worker From f83f0edb36f29c70e916946f2d550e0703306bba Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Tue, 22 Jul 2025 22:04:07 -0700 Subject: [PATCH 03/50] Hook into torch runtime and py executor Signed-off-by: jthomson04 --- tensorrt_llm/_torch/pyexecutor/_util.py | 8 ++++-- tensorrt_llm/_torch/pyexecutor/py_executor.py | 26 +++++++++++++++++-- .../_torch/pyexecutor/py_executor_creator.py | 7 ++++- tensorrt_llm/executor/executor.py | 21 ++++++++++----- tensorrt_llm/executor/proxy.py | 5 +++- tensorrt_llm/executor/worker.py | 11 +++++++- tensorrt_llm/llmapi/llm.py | 4 ++- .../bindings/test_connector_bindings.py | 9 +++++-- 8 files changed, 74 insertions(+), 17 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 52bd7089d74..f4f2ac6e037 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -17,6 +17,7 @@ get_default_trtllm_modules_to_hf_modules, load_torch_lora) from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models.modeling_utils import KvCacheConnectorConfig from ..model_config import ModelConfig from ..speculative import get_num_extra_kv_tokens, get_spec_decoder @@ -418,7 +419,9 @@ def create_py_executor_instance( drafter, guided_decoder: Optional[GuidedDecoder] = None, lora_config: Optional[LoraConfig] = None, - garbage_collection_gen0_threshold: Optional[int] = None) -> PyExecutor: + garbage_collection_gen0_threshold: Optional[int] = None, + kv_connector_config: Optional[KvCacheConnectorConfig] = None +) -> PyExecutor: kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None) spec_config = model_engine.spec_config @@ -558,7 +561,8 @@ def create_py_executor_instance( kv_cache_transceiver=kv_cache_transceiver, guided_decoder=guided_decoder, start_worker=start_worker, - garbage_collection_gen0_threshold=garbage_collection_gen0_threshold) + garbage_collection_gen0_threshold=garbage_collection_gen0_threshold, + kv_connector_config=kv_connector_config) def create_torch_sampler_args(executor_config: ExecutorConfig, mapping: Mapping, diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 0dad7ba7817..f6c6a4e1287 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -2,6 +2,7 @@ import datetime import functools import gc +import importlib import os import threading import time @@ -27,9 +28,11 @@ RequestStage, RequestStats, SpecDecodingStats, StaticBatchingStats) -from tensorrt_llm.bindings.internal.batch_manager import (LlmRequestType, +from tensorrt_llm.bindings.internal.batch_manager import (KvCacheConnectorRole, + LlmRequestType, ReqIdsSet) from tensorrt_llm.logger import logger +from tensorrt_llm.models.modeling_utils import KvCacheConnectorConfig from tensorrt_llm.runtime.generation import CUASSERT from ..distributed import Distributed @@ -132,6 +135,13 @@ class BatchStatePP(BatchState): scheduled_ctx_reqs: list[LlmRequest] = None +def load_connector_module(kv_connector_config: KvCacheConnectorConfig): + module_name = kv_connector_config.connector_module + class_name = kv_connector_config.connector_class + module = importlib.import_module(module_name) + return getattr(module, class_name) + + class PyExecutor: def __init__(self, @@ -150,7 +160,8 @@ def __init__(self, kv_cache_transceiver: Optional[KvCacheTransceiver] = None, guided_decoder: Optional[GuidedDecoder] = None, garbage_collection_gen0_threshold: Optional[int] = None, - start_worker: bool = True): + start_worker: bool = True, + kv_connector_config: Optional[KvCacheConnectorConfig] = None): super(PyExecutor, self).__init__() self.device_id = torch.cuda.current_device() self.global_rank = global_mpi_rank() @@ -263,6 +274,17 @@ def __init__(self, self.worker_started = False self.worker_lock = threading.Lock() + + print("LOADING WITH KV CONNECTOR CONFIG", kv_connector_config) + if kv_connector_config is not None: + connector_cls = load_connector_module(kv_connector_config) + + self.connector_worker = connector_cls(KvCacheConnectorRole.Worker) + + if global_mpi_rank() == 0: + self.connector_scheduler = connector_cls( + KvCacheConnectorRole.Scheduler) + if start_worker: self.start_worker() diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 8fe7d8a1aa3..54c796964cd 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -15,6 +15,7 @@ from tensorrt_llm.logger import logger from tensorrt_llm.lora_manager import LoraConfig from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models.modeling_utils import KvCacheConnectorConfig from tensorrt_llm.quantization import QuantAlgo from ..attention_backend.interface import AttentionRuntimeFeatures @@ -195,7 +196,9 @@ def create_py_executor( executor_config: ExecutorConfig, checkpoint_dir: str = None, lora_config: Optional[LoraConfig] = None, - garbage_collection_gen0_threshold: Optional[int] = None) -> PyExecutor: + garbage_collection_gen0_threshold: Optional[int] = None, + kv_connector_config: Optional[KvCacheConnectorConfig] = None +) -> PyExecutor: _mangle_executor_config(executor_config) pytorch_backend_config = executor_config.pytorch_backend_config @@ -408,6 +411,7 @@ def create_py_executor( guided_decoder=guided_decoder, lora_config=lora_config, garbage_collection_gen0_threshold=garbage_collection_gen0_threshold, + kv_connector_config=kv_connector_config, ) if estimating_kv_cache: @@ -451,6 +455,7 @@ def create_py_executor( lora_config=lora_config, garbage_collection_gen0_threshold= garbage_collection_gen0_threshold, + kv_connector_config=kv_connector_config, ) _adjust_torch_mem_fraction(executor_config.pytorch_backend_config) diff --git a/tensorrt_llm/executor/executor.py b/tensorrt_llm/executor/executor.py index 9ce4ad0d85c..5592132119c 100644 --- a/tensorrt_llm/executor/executor.py +++ b/tensorrt_llm/executor/executor.py @@ -27,6 +27,7 @@ from ..llmapi.utils import (AsyncQueue, enable_llm_debug, enable_worker_single_process_for_tp1, print_colored, print_colored_debug) +from ..models.modeling_utils import KvCacheConnectorConfig from ..sampling_params import (BatchedLogitsProcessor, LogprobParams, SamplingParams) from ..scheduling_params import SchedulingParams @@ -355,6 +356,7 @@ def create( is_llm_executor: Optional[bool] = None, lora_config: Optional[LoraConfig] = None, garbage_collection_gen0_threshold: Optional[int] = None, + kv_connector_config: Optional[KvCacheConnectorConfig] = None, ) -> Union["GenerationExecutorProxy", "GenerationExecutorWorker"]: # local imports to avoid cyclic importing from .proxy import GenerationExecutorProxy @@ -400,7 +402,8 @@ def create( postproc_worker_config=postproc_worker_config, is_llm_executor=is_llm_executor, garbage_collection_gen0_threshold= - garbage_collection_gen0_threshold) + garbage_collection_gen0_threshold, + kv_connector_config=kv_connector_config) # WAR: For the performance of gathering logits, we use single process worker # for TP1 to avoid the large overhead of IPC. @@ -410,10 +413,12 @@ def create( logger.warning( "Using single process worker for TP1, this may hurt streaming generation performance." ) - return GenerationExecutorWorker(**worker_kwargs, - is_llm_executor=is_llm_executor, - garbage_collection_gen0_threshold= - garbage_collection_gen0_threshold) + return GenerationExecutorWorker( + **worker_kwargs, + is_llm_executor=is_llm_executor, + garbage_collection_gen0_threshold= + garbage_collection_gen0_threshold, + kv_connector_config=kv_connector_config) # For single-gpu case: # Partition the workload to multiple process for streaming performance. @@ -427,7 +432,8 @@ def create( postproc_worker_config=postproc_worker_config, is_llm_executor=is_llm_executor, garbage_collection_gen0_threshold= - garbage_collection_gen0_threshold) + garbage_collection_gen0_threshold, + kv_connector_config=kv_connector_config) else: ctx = multiprocessing.get_context("spawn") # The ProcessPoolExecutorSession is used to support Windows, as mpi4py cannot. @@ -440,7 +446,8 @@ def create( postproc_worker_config=postproc_worker_config, is_llm_executor=is_llm_executor, garbage_collection_gen0_threshold= - garbage_collection_gen0_threshold) + garbage_collection_gen0_threshold, + kv_connector_config=kv_connector_config) def wait_first_completed( self, futures: List[GenerationResult] diff --git a/tensorrt_llm/executor/proxy.py b/tensorrt_llm/executor/proxy.py index 1cb86dfdff7..f4bccd966ae 100644 --- a/tensorrt_llm/executor/proxy.py +++ b/tensorrt_llm/executor/proxy.py @@ -10,6 +10,7 @@ import zmq.asyncio from tensorrt_llm.logger import logger +from tensorrt_llm.models.modeling_utils import KvCacheConnectorConfig from .._utils import customized_gc_thresholds, mpi_rank, nvtx_range_debug from ..llmapi.mpi_session import (MpiCommSession, MpiPoolSession, MpiSession, @@ -46,6 +47,7 @@ def __init__( postproc_worker_config: Optional[PostprocWorkerConfig] = None, is_llm_executor: Optional[bool] = None, garbage_collection_gen0_threshold: Optional[int] = None, + kv_connector_config: Optional[KvCacheConnectorConfig] = None, ) -> None: postproc_worker_config = postproc_worker_config or PostprocWorkerConfig( ) @@ -94,7 +96,8 @@ def __init__( postproc_worker_config=postproc_worker_config, is_llm_executor=False, garbage_collection_gen0_threshold=self. - garbage_collection_gen0_threshold) + garbage_collection_gen0_threshold, + kv_connector_config=kv_connector_config) if "log_level" not in worker_kwargs: worker_kwargs["log_level"] = logger.level diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index db8d84fcc89..0fa91d8cc72 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -13,6 +13,7 @@ import torch from tensorrt_llm.logger import logger +from tensorrt_llm.models.modeling_utils import KvCacheConnectorConfig from .._utils import (KVCacheEventSerializer, global_mpi_rank, global_mpi_size, mpi_comm, mpi_rank, nvtx_range_debug) @@ -59,6 +60,7 @@ def __init__( is_llm_executor: Optional[bool] = None, lora_config: Optional[LoraConfig] = None, garbage_collection_gen0_threshold: Optional[int] = None, + kv_connector_config: Optional[KvCacheConnectorConfig] = None, ) -> None: postproc_config = postproc_worker_config or PostprocWorkerConfig() super().__init__( @@ -82,6 +84,10 @@ def __init__( self._is_pytorch_backend = getattr(self._executor_config, "backend", None) == "pytorch" + if not self._is_pytorch_backend and kv_connector_config is not None: + raise ValueError( + "KV connector config is only supported for PyTorch backend") + if global_mpi_size() > 1: logger.set_rank(self.global_rank) @@ -127,6 +133,7 @@ def _create_engine(): args["lora_config"] = lora_config args[ "garbage_collection_gen0_threshold"] = garbage_collection_gen0_threshold + args["kv_connector_config"] = kv_connector_config elif executor_config.backend == "_autodeploy": from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \ create_autodeploy_executor @@ -649,6 +656,7 @@ def worker_main( bool] = True, # whether it's the main executor instance lora_config: Optional[LoraConfig] = None, garbage_collection_gen0_threshold: Optional[int] = None, + kv_connector_config: Optional[KvCacheConnectorConfig] = None, ) -> None: mpi_comm().barrier() print_colored_debug(f"Worker {mpi_rank()} entering worker_main...\n", @@ -775,7 +783,8 @@ def notify_proxy_threads_to_quit(): postproc_worker_config=postproc_worker_config, is_llm_executor=is_llm_executor, lora_config=lora_config, - garbage_collection_gen0_threshold=garbage_collection_gen0_threshold) + garbage_collection_gen0_threshold=garbage_collection_gen0_threshold, + kv_connector_config=kv_connector_config) except Exception as e: logger.error(f"Failed to initialize executor on rank {mpi_rank()}: {e}") logger.error(traceback.format_exc()) diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 12bb079eaf5..a9f5344a9af 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -1054,7 +1054,9 @@ def _build_model(self): is_llm_executor=True, lora_config=self.args.lora_config, garbage_collection_gen0_threshold=self.args. - garbage_collection_gen0_threshold) + garbage_collection_gen0_threshold, + kv_connector_config=self.args.connector_config, + ) def _validate_args_for_torch_backend(self, kwargs: dict) -> None: """Validate that users don't pass TrtLlmArgs-specific arguments when using PyTorch backend. diff --git a/tests/unittest/bindings/test_connector_bindings.py b/tests/unittest/bindings/test_connector_bindings.py index 6c40014ad11..0fc5fb8c929 100644 --- a/tests/unittest/bindings/test_connector_bindings.py +++ b/tests/unittest/bindings/test_connector_bindings.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import List, Optional from tensorrt_llm._torch.pyexecutor.connector import KvCacheConnector from tensorrt_llm.bindings.internal.batch_manager import (KvCacheConnectorRole, @@ -25,6 +25,10 @@ def save_kv_layer(self, layer_idx: int): def wait_for_save(self): pass + def get_finished( + self, finished_req_ids: List[int]) -> tuple[List[int], List[int]]: + return [42], [7] + def get_num_new_matched_tokens( self, request: LlmRequest, num_computed_tokens: int) -> tuple[int, bool]: @@ -38,7 +42,8 @@ def test_basic_init(): assert connector.build_connector_metadata() == {"test": "test"} - # Try calling some of the other virtual methods. + assert connector.get_finished([]) == ([42], [7]) + connector.save_kv_layer(0) connector.wait_for_save() From d4e5178e5f55d40e7270179eed5b9395c93bc723 Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Thu, 24 Jul 2025 10:59:29 -0700 Subject: [PATCH 04/50] Expose block pools as torch tensor Signed-off-by: jthomson04 --- .../batch_manager/kvCacheConnector.h | 49 +++++++ .../batch_manager/kvCacheManager.h | 9 ++ .../batch_manager/kvCacheManager.cpp | 23 +++ cpp/tensorrt_llm/pybind/CMakeLists.txt | 1 + .../pybind/batch_manager/bindings.cpp | 19 --- .../pybind/batch_manager/kvCacheConnector.cpp | 134 ++++++++++++++++++ .../pybind/batch_manager/kvCacheConnector.h | 71 ++-------- .../pybind/batch_manager/kvCacheManager.cpp | 9 +- cpp/tensorrt_llm/pybind/bindings.cpp | 2 + tensorrt_llm/_torch/pyexecutor/py_executor.py | 1 - 10 files changed, 234 insertions(+), 84 deletions(-) create mode 100644 cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h index 3ad9f21db93..5af392558ae 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h @@ -35,6 +35,55 @@ enum KvCacheConnectorRole : std::int8_t Worker }; +class KvCacheConnectorPoolData +{ +public: + KvCacheConnectorPoolData(runtime::ITensor::SharedPtr const& poolTensor, SizeType32 numBlocks) + : mPoolTensor(poolTensor) + , mNumBlocks(numBlocks) + { + } + + runtime::ITensor::SharedPtr const& getPoolTensor() const + { + return mPoolTensor; + } + + SizeType32 getNumBlocks() const + { + return mNumBlocks; + } + +private: + runtime::ITensor::SharedPtr mPoolTensor; + SizeType32 mNumBlocks; +}; + +class KvCacheConnectorPoolsData +{ +public: + explicit KvCacheConnectorPoolsData( + std::vector& poolsData, runtime::ITensor::SharedPtr const& layerToPoolMapping) + : mPoolsData(poolsData) + , mLayerToPoolMapping(layerToPoolMapping) + { + } + + std::vector& getPoolsData() + { + return mPoolsData; + } + + runtime::ITensor::SharedPtr& getLayerToPoolMapping() + { + return mLayerToPoolMapping; + } + +private: + std::vector mPoolsData; + runtime::ITensor::SharedPtr mLayerToPoolMapping; +}; + class KvCacheConnector { public: diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index a0234cbbe49..a9d7ab1143d 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -16,6 +16,7 @@ #pragma once +#include "tensorrt_llm/batch_manager/kvCacheConnector.h" #include "tensorrt_llm/batch_manager/kvCacheEventManager.h" #include "tensorrt_llm/batch_manager/kvCacheType.h" #include "tensorrt_llm/batch_manager/llmRequest.h" // TODO forward declare @@ -747,6 +748,8 @@ class WindowBlockManager return 0; } + [[nodiscard]] kv_connector::KvCacheConnectorPoolData getKvCacheConnectorPoolData() const; + private: //! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq. void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx); @@ -1135,6 +1138,8 @@ class BlockManager return mWindowBlockManagers.at(windowSize).getPool(relativePoolIndex); } + [[nodiscard]] std::vector getKvCacheConnectorPoolsData() const; + private: [[nodiscard]] WindowBlockManager const& windowManagerByLayer(SizeType32 layerIdx) const { @@ -1367,6 +1372,8 @@ class BaseKVCacheManager [[nodiscard]] virtual SizeType32 getMaxCapacityBatchSize(SizeType32 inputLength, SizeType32 outputLength) const = 0; [[nodiscard]] virtual CacheType getCacheType() const = 0; + + [[nodiscard]] virtual kv_connector::KvCacheConnectorPoolsData getKvCacheConnectorPoolsData() const = 0; }; class KVCacheManager : public BaseKVCacheManager @@ -1666,6 +1673,8 @@ class KVCacheManager : public BaseKVCacheManager [[nodiscard]] static SizeType32 calculateMaxAttentionWindow(SizeType32 inputLength, SizeType32 outputLength, SizeType32 sinkTokenLength, SizeType32 blockCapacity, SizeType32 beamWidth, SizeType32 tokensPerBlock); + [[nodiscard]] kv_connector::KvCacheConnectorPoolsData getKvCacheConnectorPoolsData() const override; + private: void cacheBlockOffsets(GenerationRequest& seq, SizeType32 windowSize); void cacheNewBlockOffsets(GenerationRequest& seq, SizeType32 windowSize); diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 4202ba348ac..558e0a7a481 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -1535,6 +1535,22 @@ void BlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef BlockManager::getKvCacheConnectorPoolsData() const +{ + std::vector poolsData; + poolsData.reserve(mWindowBlockManagers.size()); + for (auto const& [_, manager] : mWindowBlockManagers) + { + poolsData.emplace_back(manager.getKvCacheConnectorPoolData()); + } + return poolsData; +} + +[[nodiscard]] kv_connector::KvCacheConnectorPoolData WindowBlockManager::getKvCacheConnectorPoolData() const +{ + return kv_connector::KvCacheConnectorPoolData(mPools.at(mWindowSize).primaryPtr, mNumPrimaryBlocks); +} + void WindowBlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef llmRequest) { auto constexpr beamIdx = 0; @@ -2576,4 +2592,11 @@ SizeType32 KVCacheManager::calculateMaxBlockRequirements(SizeType32 inputLength, auto const leftoverBlockCapacity = blockCapacity - outputBlockRequirements; return std::min(outputLength + leftoverBlockCapacity * tokensPerBlock, inputLength + outputLength); } + +[[nodiscard]] kv_connector::KvCacheConnectorPoolsData KVCacheManager::getKvCacheConnectorPoolsData() const +{ + auto poolsData = mBlockManager.getKvCacheConnectorPoolsData(); + return kv_connector::KvCacheConnectorPoolsData(poolsData, mLayerToPoolMapping); +} + } // namespace tensorrt_llm::batch_manager::kv_cache_manager diff --git a/cpp/tensorrt_llm/pybind/CMakeLists.txt b/cpp/tensorrt_llm/pybind/CMakeLists.txt index b4809d5135e..644375e3013 100755 --- a/cpp/tensorrt_llm/pybind/CMakeLists.txt +++ b/cpp/tensorrt_llm/pybind/CMakeLists.txt @@ -7,6 +7,7 @@ set(SRCS batch_manager/algorithms.cpp batch_manager/bindings.cpp batch_manager/cacheTransceiver.cpp + batch_manager/kvCacheConnector.cpp batch_manager/kvCacheManager.cpp batch_manager/llmRequest.cpp executor/bindings.cpp diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp index ac6201eb659..0ba4fd94c20 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp @@ -526,25 +526,6 @@ void initBindings(pybind11::module_& m) py::arg("context_requests"), py::arg("generation_requests"), py::arg("logits"), py::arg("beam_width"), py::arg("num_context_logits_prefix_sum"), py::arg("decoder_input_buffers"), py::arg("decoder_state"), py::arg("buffer_manager"), "Make decoding batch input."); - - py::enum_(m, "KvCacheConnectorRole") - .value("Scheduler", tb::kv_connector::KvCacheConnectorRole::Scheduler) - .value("Worker", tb::kv_connector::KvCacheConnectorRole::Worker); - - py::class_( - m, "KvCacheConnector") - .def(py::init(), py::arg("role")) - .def("register_kv_caches", &tb::kv_connector::KvCacheConnector::registerKvCaches) - .def("start_load_kv", &tb::kv_connector::KvCacheConnector::startLoadKv) - .def("wait_for_layer_load", &tb::kv_connector::KvCacheConnector::waitForLayerLoad, py::arg("layer_idx")) - .def("save_kv_layer", &tb::kv_connector::KvCacheConnector::saveKvLayer, py::arg("layer_idx")) - .def("wait_for_save", &tb::kv_connector::KvCacheConnector::waitForSave) - .def("get_finished", &tb::kv_connector::KvCacheConnector::getFinished, py::arg("finished_req_ids")) - .def("get_num_new_matched_tokens", &tb::kv_connector::KvCacheConnector::getNumNewMatchedTokens, - py::arg("request"), py::arg("num_computed_tokens")) - .def("update_state_after_alloc", &tb::kv_connector::KvCacheConnector::updateStateAfterAlloc) - .def("request_finished", &tb::kv_connector::KvCacheConnector::requestFinished, py::arg("request")) - .def_property_readonly("role", &tb::kv_connector::KvCacheConnector::role); } } // namespace tensorrt_llm::pybind::batch_manager diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp new file mode 100644 index 00000000000..ff0892154f6 --- /dev/null +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp @@ -0,0 +1,134 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/pybind/batch_manager/kvCacheConnector.h" +#include "tensorrt_llm/runtime/torch.h" + +#include + +namespace +{ + +using KvCacheConnector = tensorrt_llm::batch_manager::kv_connector::KvCacheConnector; +namespace tb = tensorrt_llm::batch_manager; + +class PyKvCacheConnector : public KvCacheConnector +{ +public: + using KvCacheConnector::KvCacheConnector; + + // + // WORKER SIDE METHODS + // + + void registerKvCaches() override + { + PYBIND11_OVERRIDE_PURE(void, KvCacheConnector, registerKvCaches); + } + + void startLoadKv() override + { + PYBIND11_OVERRIDE_PURE(void, KvCacheConnector, startLoadKv); + } + + void waitForLayerLoad(SizeType32 layer_idx) override + { + PYBIND11_OVERRIDE_PURE(void, KvCacheConnector, waitForLayerLoad, layer_idx); + } + + void saveKvLayer(SizeType32 layer_idx) override + { + PYBIND11_OVERRIDE_PURE(void, KvCacheConnector, saveKvLayer, layer_idx); + } + + void waitForSave() override + { + PYBIND11_OVERRIDE_PURE(void, KvCacheConnector, waitForSave); + } + + using FinishedReqs = std::tuple, std::vector>; + + FinishedReqs getFinished(std::vector const& finishedReqIds) override + { + PYBIND11_OVERRIDE_PURE(FinishedReqs, KvCacheConnector, getFinished, finishedReqIds); + } + + // + // SCHEDULER SIDE METHODS + // + + using NumNewMatchedTokens = std::tuple; + + NumNewMatchedTokens getNumNewMatchedTokens(LlmRequest const& request, SizeType32 numComputedTokens) override + { + PYBIND11_OVERRIDE_PURE( + NumNewMatchedTokens, KvCacheConnector, getNumNewMatchedTokens, request, numComputedTokens); + } + + void updateStateAfterAlloc() override + { + PYBIND11_OVERRIDE_PURE(void, KvCacheConnector, updateStateAfterAlloc); + } + + bool requestFinished(LlmRequest const& request) override + { + PYBIND11_OVERRIDE_PURE(bool, KvCacheConnector, requestFinished, request); + } +}; + +} // namespace + +void tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManagerConnectorBindings::initBindings(py::module_& m) +{ + py::enum_(m, "KvCacheConnectorRole") + .value("Scheduler", tb::kv_connector::KvCacheConnectorRole::Scheduler) + .value("Worker", tb::kv_connector::KvCacheConnectorRole::Worker); + + py::class_(m, "KvCacheConnectorPoolData") + .def_property_readonly("tensor", + [](tb::kv_connector::KvCacheConnectorPoolData& self) + { + auto const& poolTensor = self.getPoolTensor(); + + return tensorrt_llm::runtime::Torch::tensor(poolTensor); + }) + .def_property_readonly("num_blocks", &tb::kv_connector::KvCacheConnectorPoolData::getNumBlocks); + + py::class_(m, "KvCacheConnectorPoolsData") + .def_property_readonly("pools", &tb::kv_connector::KvCacheConnectorPoolsData::getPoolsData) + .def_property_readonly("layer_to_pool_mapping", + [](tb::kv_connector::KvCacheConnectorPoolsData& self) + { + auto const& layerToPoolMapping = self.getLayerToPoolMapping(); + + return tensorrt_llm::runtime::Torch::tensor(layerToPoolMapping); + }); + + py::class_(m, "KvCacheConnector") + .def(py::init(), py::arg("role")) + .def("register_kv_caches", &tb::kv_connector::KvCacheConnector::registerKvCaches) + .def("start_load_kv", &tb::kv_connector::KvCacheConnector::startLoadKv) + .def("wait_for_layer_load", &tb::kv_connector::KvCacheConnector::waitForLayerLoad, py::arg("layer_idx")) + .def("save_kv_layer", &tb::kv_connector::KvCacheConnector::saveKvLayer, py::arg("layer_idx")) + .def("wait_for_save", &tb::kv_connector::KvCacheConnector::waitForSave) + .def("get_finished", &tb::kv_connector::KvCacheConnector::getFinished, py::arg("finished_req_ids")) + .def("get_num_new_matched_tokens", &tb::kv_connector::KvCacheConnector::getNumNewMatchedTokens, + py::arg("request"), py::arg("num_computed_tokens")) + .def("update_state_after_alloc", &tb::kv_connector::KvCacheConnector::updateStateAfterAlloc) + .def("request_finished", &tb::kv_connector::KvCacheConnector::requestFinished, py::arg("request")) + .def_property_readonly("role", &tb::kv_connector::KvCacheConnector::role); +} diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.h b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.h index 2ba5693c26b..4b1568a3abe 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.h +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.h @@ -5,73 +5,18 @@ namespace py = pybind11; -namespace tensorrt_llm::pybind::batch_manager::kv_connector +namespace tensorrt_llm::batch_manager::kv_cache_manager { - -using namespace tensorrt_llm::batch_manager::kv_connector; - -class PyKvCacheConnector : public KvCacheConnector +class KVCacheManagerConnectorBindings { public: - using KvCacheConnector::KvCacheConnector; - - // - // WORKER SIDE METHODS - // - - void registerKvCaches() override - { - PYBIND11_OVERRIDE_PURE(void, KvCacheConnector, registerKvCaches); - } - - void startLoadKv() override - { - PYBIND11_OVERRIDE_PURE(void, KvCacheConnector, startLoadKv); - } - - void waitForLayerLoad(SizeType32 layer_idx) override - { - PYBIND11_OVERRIDE_PURE(void, KvCacheConnector, waitForLayerLoad, layer_idx); - } - - void saveKvLayer(SizeType32 layer_idx) override - { - PYBIND11_OVERRIDE_PURE(void, KvCacheConnector, saveKvLayer, layer_idx); - } - - void waitForSave() override - { - PYBIND11_OVERRIDE_PURE(void, KvCacheConnector, waitForSave); - } - - using FinishedReqs = std::tuple, std::vector>; - - FinishedReqs getFinished(std::vector const& finishedReqIds) override - { - PYBIND11_OVERRIDE_PURE(FinishedReqs, KvCacheConnector, getFinished, finishedReqIds); - } - - // - // SCHEDULER SIDE METHODS - // - - using NumNewMatchedTokens = std::tuple; - - NumNewMatchedTokens getNumNewMatchedTokens(LlmRequest const& request, SizeType32 numComputedTokens) override - { - PYBIND11_OVERRIDE_PURE( - NumNewMatchedTokens, KvCacheConnector, getNumNewMatchedTokens, request, numComputedTokens); - } + static void initBindings(pybind11::module_& m); +}; +} // namespace tensorrt_llm::batch_manager::kv_cache_manager - void updateStateAfterAlloc() override - { - PYBIND11_OVERRIDE_PURE(void, KvCacheConnector, updateStateAfterAlloc); - } +namespace tensorrt_llm::pybind::batch_manager::kv_connector +{ - bool requestFinished(LlmRequest const& request) override - { - PYBIND11_OVERRIDE_PURE(bool, KvCacheConnector, requestFinished, request); - } -}; +using namespace tensorrt_llm::batch_manager::kv_connector; } // namespace tensorrt_llm::pybind::batch_manager::kv_connector diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index 54835e81d7f..f9b02b78fca 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -234,6 +234,12 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager { PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, flushIterationEvents); } + + kv_connector::KvCacheConnectorPoolsData getKvCacheConnectorPoolsData() const override + { + PYBIND11_OVERLOAD_PURE( + kv_connector::KvCacheConnectorPoolsData, tbk::BaseKVCacheManager, getKvCacheConnectorPoolsData); + } }; // TODO: Deduplicate executor bindings KvCacheStats @@ -423,7 +429,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) .def("get_cache_block_ids", &BaseKVCacheManager::getCacheBlockIds) .def("get_batch_cache_block_ids", &BaseKVCacheManager::getBatchCacheBlockIds) .def("get_newly_allocated_block_ids", &BaseKVCacheManager::getNewlyAllocatedBlockIds) - .def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents); + .def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents) + .def("get_kv_cache_connector_pools_data", &BaseKVCacheManager::getKvCacheConnectorPoolsData); py::enum_(m, "CacheType") .value("SELF", tbk::CacheType::kSELF) diff --git a/cpp/tensorrt_llm/pybind/bindings.cpp b/cpp/tensorrt_llm/pybind/bindings.cpp index 1d2e5b0a951..7a695e86568 100644 --- a/cpp/tensorrt_llm/pybind/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/bindings.cpp @@ -28,6 +28,7 @@ #include "tensorrt_llm/pybind/batch_manager/algorithms.h" #include "tensorrt_llm/pybind/batch_manager/bindings.h" #include "tensorrt_llm/pybind/batch_manager/cacheTransceiver.h" +#include "tensorrt_llm/pybind/batch_manager/kvCacheConnector.h" #include "tensorrt_llm/pybind/batch_manager/kvCacheManager.h" #include "tensorrt_llm/pybind/batch_manager/llmRequest.h" #include "tensorrt_llm/pybind/executor/bindings.h" @@ -465,6 +466,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) tensorrt_llm::pybind::runtime::initBindings(mInternalRuntime); tensorrt_llm::pybind::testing::initBindings(mInternalTesting); tpb::initBindings(mInternalBatchManager); + tb::kv_cache_manager::KVCacheManagerConnectorBindings::initBindings(mInternalBatchManager); tb::kv_cache_manager::KVCacheManagerBindings::initBindings(mInternalBatchManager); tb::BasePeftCacheManagerBindings::initBindings(mInternalBatchManager); tb::CacheTransceiverBindings::initBindings(mInternalBatchManager); diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index f6c6a4e1287..17ad136b886 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -275,7 +275,6 @@ def __init__(self, self.worker_started = False self.worker_lock = threading.Lock() - print("LOADING WITH KV CONNECTOR CONFIG", kv_connector_config) if kv_connector_config is not None: connector_cls = load_connector_module(kv_connector_config) From 0c9fa7a1d4d88f889b2b799a75cff27e64144621 Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Thu, 24 Jul 2025 14:56:33 -0700 Subject: [PATCH 05/50] Little fixes Signed-off-by: jthomson04 --- .../batch_manager/kvCacheConnector.h | 6 ++-- .../batch_manager/kvCacheManager.cpp | 16 +++++++-- .../pybind/batch_manager/kvCacheConnector.cpp | 9 ++--- tensorrt_llm/_torch/pyexecutor/py_executor.py | 33 +++++++++++-------- .../_torch/pyexecutor/resource_manager.py | 5 +++ 5 files changed, 44 insertions(+), 25 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h index 5af392558ae..8a2021140fa 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h @@ -63,7 +63,7 @@ class KvCacheConnectorPoolsData { public: explicit KvCacheConnectorPoolsData( - std::vector& poolsData, runtime::ITensor::SharedPtr const& layerToPoolMapping) + std::vector& poolsData, std::vector& layerToPoolMapping) : mPoolsData(poolsData) , mLayerToPoolMapping(layerToPoolMapping) { @@ -74,14 +74,14 @@ class KvCacheConnectorPoolsData return mPoolsData; } - runtime::ITensor::SharedPtr& getLayerToPoolMapping() + std::vector& getLayerToPoolMapping() { return mLayerToPoolMapping; } private: std::vector mPoolsData; - runtime::ITensor::SharedPtr mLayerToPoolMapping; + std::vector mLayerToPoolMapping; }; class KvCacheConnector diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 558e0a7a481..3de1cfa8643 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -1548,7 +1548,7 @@ void BlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef llmRequest) @@ -2596,7 +2596,19 @@ SizeType32 KVCacheManager::calculateMaxBlockRequirements(SizeType32 inputLength, [[nodiscard]] kv_connector::KvCacheConnectorPoolsData KVCacheManager::getKvCacheConnectorPoolsData() const { auto poolsData = mBlockManager.getKvCacheConnectorPoolsData(); - return kv_connector::KvCacheConnectorPoolsData(poolsData, mLayerToPoolMapping); + + auto layerToPoolView = BufferRange(*mLayerToPoolMapping); + + auto numLayers = mBlockManager.getNumLayers(); + + auto layerToPool = std::vector(numLayers); + + for (size_t layer = 0; layer < static_cast(numLayers); layer++) + { + layerToPool[layer] = layerToPoolView[layer * 2]; + } + + return kv_connector::KvCacheConnectorPoolsData(poolsData, layerToPool); } } // namespace tensorrt_llm::batch_manager::kv_cache_manager diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp index ff0892154f6..7fe0a73db3e 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp @@ -110,13 +110,8 @@ void tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManagerConnectorBindi py::class_(m, "KvCacheConnectorPoolsData") .def_property_readonly("pools", &tb::kv_connector::KvCacheConnectorPoolsData::getPoolsData) - .def_property_readonly("layer_to_pool_mapping", - [](tb::kv_connector::KvCacheConnectorPoolsData& self) - { - auto const& layerToPoolMapping = self.getLayerToPoolMapping(); - - return tensorrt_llm::runtime::Torch::tensor(layerToPoolMapping); - }); + .def_property_readonly( + "layer_to_pool_mapping", &tb::kv_connector::KvCacheConnectorPoolsData::getLayerToPoolMapping); py::class_(m, "KvCacheConnector") .def(py::init(), py::arg("role")) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 17ad136b886..8da85260c4f 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -135,13 +135,6 @@ class BatchStatePP(BatchState): scheduled_ctx_reqs: list[LlmRequest] = None -def load_connector_module(kv_connector_config: KvCacheConnectorConfig): - module_name = kv_connector_config.connector_module - class_name = kv_connector_config.connector_class - module = importlib.import_module(module_name) - return getattr(module, class_name) - - class PyExecutor: def __init__(self, @@ -276,13 +269,27 @@ def __init__(self, self.worker_lock = threading.Lock() if kv_connector_config is not None: - connector_cls = load_connector_module(kv_connector_config) - - self.connector_worker = connector_cls(KvCacheConnectorRole.Worker) + logger.info( + f"Initializing kv connector with config: {kv_connector_config}") + module_name = kv_connector_config.connector_module + class_name = kv_connector_config.connector_class - if global_mpi_rank() == 0: - self.connector_scheduler = connector_cls( - KvCacheConnectorRole.Scheduler) + try: + module = importlib.import_module(module_name) + connector_cls = getattr(module, class_name) + self.connector_worker = connector_cls( + KvCacheConnectorRole.Worker) + + # Only initialize the scheduler on rank 0. + if global_mpi_rank() == 0: + self.connector_scheduler = connector_cls( + KvCacheConnectorRole.Scheduler) + except Exception as e: + logger.error(f"Error instantiating connector: {e}") + raise e + kv_cache_data = self.kv_cache_manager.get_kv_cache_connector_pools_data( + ) + self.connector_worker.register_kv_caches(kv_cache_data) if start_worker: self.start_worker() diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 56c4871542e..2ddf14a924f 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -10,6 +10,8 @@ import tensorrt_llm import tensorrt_llm.bindings from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE +from tensorrt_llm.bindings.internal.batch_manager import \ + KvCacheConnectorPoolsData from tensorrt_llm.lora_manager import LoraConfig, LoraManager, LoraModelConfig from tensorrt_llm.sampling_params import SamplingParams @@ -894,6 +896,9 @@ def _set_temp_attention_window_inputs( else: return None + def get_kv_cache_connector_pools_data(self) -> KvCacheConnectorPoolsData: + return self.impl.get_kv_cache_connector_pools_data() + class MambaCacheManager(BaseResourceManager): From 229e6c5ecaa0b57a037a3f904fbf456c428a750a Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Fri, 25 Jul 2025 08:45:07 -0700 Subject: [PATCH 06/50] Scheduler Output bindings Signed-off-by: jthomson04 --- .../batch_manager/kvCacheConnector.h | 42 +++++++++++++++++++ .../pybind/batch_manager/kvCacheConnector.cpp | 16 +++++++ 2 files changed, 58 insertions(+) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h index 8a2021140fa..7f9c9930454 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h @@ -16,6 +16,7 @@ #pragma once +#include "tensorrt_llm/batch_manager/common.h" #include "tensorrt_llm/batch_manager/llmRequest.h" #include "tensorrt_llm/runtime/common.h" @@ -35,6 +36,47 @@ enum KvCacheConnectorRole : std::int8_t Worker }; +struct NewRequestData +{ + NewRequestData(RequestIdType requestId, std::vector const& newTokens, + std::vector const& blockIds, SizeType32 numComputedTokens) + : requestId(requestId) + , newTokens(newTokens) + , blockIds(blockIds) + , numComputedTokens(numComputedTokens) + { + } + + RequestIdType requestId; + std::vector newTokens; + std::vector blockIds; + SizeType32 numComputedTokens; +}; + +struct CachedRequestData +{ + CachedRequestData(RequestIdType requestId, std::vector const& newTokens, + std::vector const& newBlockIds, SizeType32 numComputedTokens) + : requestId(requestId) + , newTokens(newTokens) + , newBlockIds(newBlockIds) + , numComputedTokens(numComputedTokens) + { + } + + RequestIdType requestId; + std::vector newTokens; + std::vector newBlockIds; + SizeType32 numComputedTokens; +}; + +class SchedulerOutput +{ +public: + std::vector newRequests; + std::vector cachedRequests; +}; + class KvCacheConnectorPoolData { public: diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp index 7fe0a73db3e..7bc9d30edcc 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp @@ -126,4 +126,20 @@ void tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManagerConnectorBindi .def("update_state_after_alloc", &tb::kv_connector::KvCacheConnector::updateStateAfterAlloc) .def("request_finished", &tb::kv_connector::KvCacheConnector::requestFinished, py::arg("request")) .def_property_readonly("role", &tb::kv_connector::KvCacheConnector::role); + + py::class_(m, "NewRequestData") + .def_readonly("request_id", &tb::kv_connector::NewRequestData::requestId) + .def_readonly("new_tokens", &tb::kv_connector::NewRequestData::newTokens) + .def_readonly("block_ids", &tb::kv_connector::NewRequestData::blockIds) + .def_readonly("num_computed_tokens", &tb::kv_connector::NewRequestData::numComputedTokens); + + py::class_(m, "CachedRequestData") + .def_readonly("request_id", &tb::kv_connector::CachedRequestData::requestId) + .def_readonly("new_tokens", &tb::kv_connector::CachedRequestData::newTokens) + .def_readonly("new_block_ids", &tb::kv_connector::CachedRequestData::newBlockIds) + .def_readonly("num_computed_tokens", &tb::kv_connector::CachedRequestData::numComputedTokens); + + py::class_(m, "SchedulerOutput") + .def_readonly("new_requests", &tb::kv_connector::SchedulerOutput::newRequests) + .def_readonly("cached_requests", &tb::kv_connector::SchedulerOutput::cachedRequests); } From 614cb01c5b62a44181950b759b64d9854b4523e5 Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Fri, 25 Jul 2025 15:12:49 -0700 Subject: [PATCH 07/50] more little fixes - dont instantiate twice Signed-off-by: jthomson04 --- cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp | 5 +++++ tensorrt_llm/_torch/pyexecutor/connector.py | 6 ++++-- tensorrt_llm/_torch/pyexecutor/py_executor_creator.py | 3 ++- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 3de1cfa8643..6f173165638 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -1537,6 +1537,11 @@ void BlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef BlockManager::getKvCacheConnectorPoolsData() const { + if (mWindowBlockManagers.size() > 1) + { + throw std::runtime_error("KV Cache connector is not supported with multiple window sizes"); + } + std::vector poolsData; poolsData.reserve(mWindowBlockManagers.size()); for (auto const& [_, manager] : mWindowBlockManagers) diff --git a/tensorrt_llm/_torch/pyexecutor/connector.py b/tensorrt_llm/_torch/pyexecutor/connector.py index 49bef935630..b3b78848dd6 100644 --- a/tensorrt_llm/_torch/pyexecutor/connector.py +++ b/tensorrt_llm/_torch/pyexecutor/connector.py @@ -2,7 +2,8 @@ from tensorrt_llm.bindings.internal.batch_manager import \ KvCacheConnector as KvCacheConnectorCpp -from tensorrt_llm.bindings.internal.batch_manager import KvCacheConnectorRole +from tensorrt_llm.bindings.internal.batch_manager import (KvCacheConnectorRole, + SchedulerOutput) class KvCacheConnector(KvCacheConnectorCpp): @@ -17,5 +18,6 @@ def bind_connector_metadata(self, metadata: object): def _get_connector_metadata(self) -> object: return self.connector_metadata - def build_connector_metadata(self) -> Optional[object]: + def build_connector_metadata( + self, scheduler_output: SchedulerOutput) -> Optional[object]: return None diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 54c796964cd..b21f5d4ab0f 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -411,7 +411,8 @@ def create_py_executor( guided_decoder=guided_decoder, lora_config=lora_config, garbage_collection_gen0_threshold=garbage_collection_gen0_threshold, - kv_connector_config=kv_connector_config, + kv_connector_config=kv_connector_config + if not estimating_kv_cache else None, ) if estimating_kv_cache: From 6c26369e697824c116c2987ecde013e538c58033 Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Sun, 27 Jul 2025 11:35:53 -0700 Subject: [PATCH 08/50] MEGA REFACTOR, move scheduler and worker into their own class, do init inside py_executor_creator, pass connector manager to kv cache manager Signed-off-by: jthomson04 --- .../batch_manager/kvCacheConnector.h | 58 ++++++--- .../batch_manager/kvCacheConnector.cpp | 18 +-- .../pybind/batch_manager/kvCacheConnector.cpp | 114 +++++++++++------- tensorrt_llm/_torch/pyexecutor/_util.py | 24 +++- tensorrt_llm/_torch/pyexecutor/connector.py | 34 +++--- tensorrt_llm/_torch/pyexecutor/py_executor.py | 30 +---- .../_torch/pyexecutor/py_executor_creator.py | 50 ++++++-- .../_torch/pyexecutor/resource_manager.py | 4 + tensorrt_llm/models/modeling_utils.py | 3 +- .../bindings/test_connector_bindings.py | 45 ++++--- 10 files changed, 226 insertions(+), 154 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h index 7f9c9930454..052cd20cb18 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h @@ -30,11 +30,6 @@ using namespace tensorrt_llm::batch_manager; namespace tensorrt_llm::batch_manager::kv_connector { -enum KvCacheConnectorRole : std::int8_t -{ - Scheduler, - Worker -}; struct NewRequestData { @@ -126,20 +121,29 @@ class KvCacheConnectorPoolsData std::vector mLayerToPoolMapping; }; -class KvCacheConnector +class KvCacheConnectorScheduler { public: - explicit KvCacheConnector(KvCacheConnectorRole role); - virtual ~KvCacheConnector() = default; + explicit KvCacheConnectorScheduler() = default; + virtual ~KvCacheConnectorScheduler() = default; - [[nodiscard]] KvCacheConnectorRole role() const; + virtual std::tuple getNumNewMatchedTokens(LlmRequest const& request, SizeType32 numComputedTokens) + = 0; - // - // WORKER SIDE METHODS - // + // TODO(jothomson): Need arguments here. Also, is this even needed? + virtual void updateStateAfterAlloc(); + + virtual bool requestFinished(LlmRequest const& request); +}; + +class KvCacheConnectorWorker +{ +public: + explicit KvCacheConnectorWorker() = default; + virtual ~KvCacheConnectorWorker() = default; // TODO(jothomson): Need arguments here. - virtual void registerKvCaches(); + virtual void registerKvCaches(KvCacheConnectorPoolsData const& kvCacheConnectorPoolsData); // TODO(jothomson): Need arguments here. virtual void startLoadKv() = 0; @@ -153,20 +157,34 @@ class KvCacheConnector virtual std::tuple, std::vector> getFinished( std::vector const& finishedReqIds); +}; - // - // SCHEDULER SIDE METHODS - // +class KvCacheConnectorManager +{ +public: + KvCacheConnectorManager(std::shared_ptr const& worker, + std::optional> const& scheduler) + : mWorker(worker) + , mScheduler(scheduler) + { + } virtual std::tuple getNumNewMatchedTokens(LlmRequest const& request, SizeType32 numComputedTokens) = 0; - // TODO(jothomson): Need arguments here. Also, is this even needed? - virtual void updateStateAfterAlloc(); + std::optional> getScheduler() const + { + return mScheduler; + } - virtual bool requestFinished(LlmRequest const& request); + std::shared_ptr getWorker() const + { + return mWorker; + } private: - KvCacheConnectorRole mRole; + std::shared_ptr mWorker; + std::optional> mScheduler; }; + } // namespace tensorrt_llm::batch_manager::kv_connector diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheConnector.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheConnector.cpp index e564716e76b..d296d69171a 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheConnector.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheConnector.cpp @@ -3,27 +3,17 @@ namespace tensorrt_llm::batch_manager::kv_connector { -KvCacheConnector::KvCacheConnector(KvCacheConnectorRole role) - : mRole(role) -{ -} - -KvCacheConnectorRole KvCacheConnector::role() const -{ - return mRole; -} - -void KvCacheConnector::registerKvCaches() {} +void KvCacheConnectorWorker::registerKvCaches(KvCacheConnectorPoolsData const& kvCacheConnectorPoolsData) {} -std::tuple, std::vector> KvCacheConnector::getFinished( +std::tuple, std::vector> KvCacheConnectorWorker::getFinished( std::vector const& finishedReqIds) { return std::make_tuple(std::vector(), std::vector()); } -void KvCacheConnector::updateStateAfterAlloc() {} +void KvCacheConnectorScheduler::updateStateAfterAlloc() {} -bool KvCacheConnector::requestFinished(LlmRequest const& request) +bool KvCacheConnectorScheduler::requestFinished(LlmRequest const& request) { return false; } diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp index 7bc9d30edcc..96b41f237da 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp @@ -23,70 +23,83 @@ namespace { -using KvCacheConnector = tensorrt_llm::batch_manager::kv_connector::KvCacheConnector; +using KvCacheConnectorScheduler = tensorrt_llm::batch_manager::kv_connector::KvCacheConnectorScheduler; +using KvCacheConnectorWorker = tensorrt_llm::batch_manager::kv_connector::KvCacheConnectorWorker; +using KvCacheConnectorManager = tensorrt_llm::batch_manager::kv_connector::KvCacheConnectorManager; + +using NumNewMatchedTokens = std::tuple; + namespace tb = tensorrt_llm::batch_manager; -class PyKvCacheConnector : public KvCacheConnector +class PyKvCacheConnectorScheduler : public KvCacheConnectorScheduler { public: - using KvCacheConnector::KvCacheConnector; + using KvCacheConnectorScheduler::KvCacheConnectorScheduler; + + NumNewMatchedTokens getNumNewMatchedTokens(LlmRequest const& request, SizeType32 numComputedTokens) override + { + PYBIND11_OVERRIDE_PURE( + NumNewMatchedTokens, KvCacheConnectorScheduler, getNumNewMatchedTokens, request, numComputedTokens); + } - // - // WORKER SIDE METHODS - // + void updateStateAfterAlloc() override + { + PYBIND11_OVERRIDE_PURE(void, KvCacheConnectorScheduler, updateStateAfterAlloc); + } - void registerKvCaches() override + bool requestFinished(LlmRequest const& request) override { - PYBIND11_OVERRIDE_PURE(void, KvCacheConnector, registerKvCaches); + PYBIND11_OVERRIDE_PURE(bool, KvCacheConnectorScheduler, requestFinished, request); + } +}; + +class PyKvCacheConnectorWorker : public KvCacheConnectorWorker +{ +public: + using KvCacheConnectorWorker::KvCacheConnectorWorker; + + void registerKvCaches(kv_connector::KvCacheConnectorPoolsData const& kvCacheConnectorPoolsData) override + { + PYBIND11_OVERRIDE_PURE(void, KvCacheConnectorWorker, registerKvCaches, kvCacheConnectorPoolsData); } void startLoadKv() override { - PYBIND11_OVERRIDE_PURE(void, KvCacheConnector, startLoadKv); + PYBIND11_OVERRIDE_PURE(void, KvCacheConnectorWorker, startLoadKv); } void waitForLayerLoad(SizeType32 layer_idx) override { - PYBIND11_OVERRIDE_PURE(void, KvCacheConnector, waitForLayerLoad, layer_idx); + PYBIND11_OVERRIDE_PURE(void, KvCacheConnectorWorker, waitForLayerLoad, layer_idx); } void saveKvLayer(SizeType32 layer_idx) override { - PYBIND11_OVERRIDE_PURE(void, KvCacheConnector, saveKvLayer, layer_idx); + PYBIND11_OVERRIDE_PURE(void, KvCacheConnectorWorker, saveKvLayer, layer_idx); } void waitForSave() override { - PYBIND11_OVERRIDE_PURE(void, KvCacheConnector, waitForSave); + PYBIND11_OVERRIDE_PURE(void, KvCacheConnectorWorker, waitForSave); } using FinishedReqs = std::tuple, std::vector>; FinishedReqs getFinished(std::vector const& finishedReqIds) override { - PYBIND11_OVERRIDE_PURE(FinishedReqs, KvCacheConnector, getFinished, finishedReqIds); + PYBIND11_OVERRIDE_PURE(FinishedReqs, KvCacheConnectorWorker, getFinished, finishedReqIds); } +}; - // - // SCHEDULER SIDE METHODS - // - - using NumNewMatchedTokens = std::tuple; +class PyKvCacheConnectorManager : public KvCacheConnectorManager +{ +public: + using KvCacheConnectorManager::KvCacheConnectorManager; NumNewMatchedTokens getNumNewMatchedTokens(LlmRequest const& request, SizeType32 numComputedTokens) override { PYBIND11_OVERRIDE_PURE( - NumNewMatchedTokens, KvCacheConnector, getNumNewMatchedTokens, request, numComputedTokens); - } - - void updateStateAfterAlloc() override - { - PYBIND11_OVERRIDE_PURE(void, KvCacheConnector, updateStateAfterAlloc); - } - - bool requestFinished(LlmRequest const& request) override - { - PYBIND11_OVERRIDE_PURE(bool, KvCacheConnector, requestFinished, request); + NumNewMatchedTokens, KvCacheConnectorManager, getNumNewMatchedTokens, request, numComputedTokens); } }; @@ -94,10 +107,6 @@ class PyKvCacheConnector : public KvCacheConnector void tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManagerConnectorBindings::initBindings(py::module_& m) { - py::enum_(m, "KvCacheConnectorRole") - .value("Scheduler", tb::kv_connector::KvCacheConnectorRole::Scheduler) - .value("Worker", tb::kv_connector::KvCacheConnectorRole::Worker); - py::class_(m, "KvCacheConnectorPoolData") .def_property_readonly("tensor", [](tb::kv_connector::KvCacheConnectorPoolData& self) @@ -113,19 +122,34 @@ void tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManagerConnectorBindi .def_property_readonly( "layer_to_pool_mapping", &tb::kv_connector::KvCacheConnectorPoolsData::getLayerToPoolMapping); - py::class_(m, "KvCacheConnector") - .def(py::init(), py::arg("role")) - .def("register_kv_caches", &tb::kv_connector::KvCacheConnector::registerKvCaches) - .def("start_load_kv", &tb::kv_connector::KvCacheConnector::startLoadKv) - .def("wait_for_layer_load", &tb::kv_connector::KvCacheConnector::waitForLayerLoad, py::arg("layer_idx")) - .def("save_kv_layer", &tb::kv_connector::KvCacheConnector::saveKvLayer, py::arg("layer_idx")) - .def("wait_for_save", &tb::kv_connector::KvCacheConnector::waitForSave) - .def("get_finished", &tb::kv_connector::KvCacheConnector::getFinished, py::arg("finished_req_ids")) - .def("get_num_new_matched_tokens", &tb::kv_connector::KvCacheConnector::getNumNewMatchedTokens, + py::class_( + m, "KvCacheConnectorWorker") + .def(py::init<>()) + .def( + "register_kv_caches", &tb::kv_connector::KvCacheConnectorWorker::registerKvCaches, py::arg("kv_cache_data")) + .def("start_load_kv", &tb::kv_connector::KvCacheConnectorWorker::startLoadKv) + .def("wait_for_layer_load", &tb::kv_connector::KvCacheConnectorWorker::waitForLayerLoad, py::arg("layer_idx")) + .def("save_kv_layer", &tb::kv_connector::KvCacheConnectorWorker::saveKvLayer, py::arg("layer_idx")) + .def("wait_for_save", &tb::kv_connector::KvCacheConnectorWorker::waitForSave) + .def("get_finished", &tb::kv_connector::KvCacheConnectorWorker::getFinished, py::arg("finished_req_ids")); + + py::class_( + m, "KvCacheConnectorScheduler") + .def(py::init<>()) + .def("get_num_new_matched_tokens", &tb::kv_connector::KvCacheConnectorScheduler::getNumNewMatchedTokens, py::arg("request"), py::arg("num_computed_tokens")) - .def("update_state_after_alloc", &tb::kv_connector::KvCacheConnector::updateStateAfterAlloc) - .def("request_finished", &tb::kv_connector::KvCacheConnector::requestFinished, py::arg("request")) - .def_property_readonly("role", &tb::kv_connector::KvCacheConnector::role); + .def("update_state_after_alloc", &tb::kv_connector::KvCacheConnectorScheduler::updateStateAfterAlloc) + .def("request_finished", &tb::kv_connector::KvCacheConnectorScheduler::requestFinished, py::arg("request")); + + py::class_( + m, "KvCacheConnectorManager") + .def(py::init, + std::optional>>(), + py::arg("worker"), py::arg("scheduler")) + .def_property_readonly("scheduler", &tb::kv_connector::KvCacheConnectorManager::getScheduler) + .def_property_readonly("worker", &tb::kv_connector::KvCacheConnectorManager::getWorker) + .def("get_num_new_matched_tokens", &tb::kv_connector::KvCacheConnectorManager::getNumNewMatchedTokens, + py::arg("request"), py::arg("num_computed_tokens")); py::class_(m, "NewRequestData") .def_readonly("request_id", &tb::kv_connector::NewRequestData::requestId) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index f4f2ac6e037..b06897f42be 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -17,12 +17,12 @@ get_default_trtllm_modules_to_hf_modules, load_torch_lora) from tensorrt_llm.mapping import Mapping -from tensorrt_llm.models.modeling_utils import KvCacheConnectorConfig from ..model_config import ModelConfig from ..speculative import get_num_extra_kv_tokens, get_spec_decoder from .config import PyTorchConfig from .config_utils import is_mla, is_nemotron_hybrid +from .connector import KvCacheConnectorManager from .guided_decoder import GuidedDecoder from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver from .llm_request import ExecutorResponse @@ -45,7 +45,8 @@ class KvCacheCreator: def __init__(self, *, executor_config: ExecutorConfig, model_engine: PyTorchModelEngine, draft_model_engine: Optional[PyTorchModelEngine], - mapping: Mapping, net_max_seq_len: int): + mapping: Mapping, net_max_seq_len: int, + kv_connector_manager: Optional[KvCacheConnectorManager]): self._executor_config = executor_config self._model_engine = model_engine self._draft_model_engine = draft_model_engine @@ -53,6 +54,7 @@ def __init__(self, *, executor_config: ExecutorConfig, self._max_kv_tokens_in = self._executor_config.kv_cache_config.max_tokens self._dummy_reqs = self._create_dummy_context_requests(net_max_seq_len - 1) + self._kv_connector_manager = kv_connector_manager @staticmethod def _get_cache_size_per_token(model_config: ModelConfig, @@ -315,12 +317,19 @@ def _create_kv_cache_manager( dtype=kv_cache_dtype, spec_config=spec_config, max_beam_width=executor_config.max_beam_width, + kv_connector_manager=self._kv_connector_manager, ) elif is_nemotron_hybrid(config): if executor_config.max_beam_width > 1: raise ValueError( "MambaHybridCacheManager + beam search is not supported yet." ) + + if self._kv_connector_manager is not None: + raise ValueError( + "Connector manager is not supported for MambaHybridCacheManager." + ) + config = model_engine.model.model_config.pretrained_config num_layers = config.hybrid_override_pattern.count("*") layer_mask = [ @@ -377,6 +386,7 @@ def _create_kv_cache_manager( max_num_tokens=executor_config.max_num_tokens, model_config=binding_model_config, max_beam_width=executor_config.max_beam_width, + kv_connector_manager=self._kv_connector_manager, ) # KVCacheManager (Non-draft) modifies the max_seq_len field, update it to executor_config if model_engine.kv_cache_manager_key == ResourceManagerType.KV_CACHE_MANAGER: @@ -387,9 +397,15 @@ def _create_kv_cache_manager( def build_managers(self, resources: Dict) -> None: """Construct KV caches for model and draft model (if applicable).""" kv_cache_manager = self._create_kv_cache_manager(self._model_engine) + + if self._kv_connector_manager is not None and self._draft_model_engine is not None: + raise ValueError( + "Connector manager is not supported for draft model.") + draft_kv_cache_manager = self._create_kv_cache_manager( self._draft_model_engine ) if self._draft_model_engine is not None else None + resources[ResourceManagerType.KV_CACHE_MANAGER] = kv_cache_manager resources[ ResourceManagerType.DRAFT_KV_CACHE_MANAGER] = draft_kv_cache_manager @@ -420,7 +436,7 @@ def create_py_executor_instance( guided_decoder: Optional[GuidedDecoder] = None, lora_config: Optional[LoraConfig] = None, garbage_collection_gen0_threshold: Optional[int] = None, - kv_connector_config: Optional[KvCacheConnectorConfig] = None + kv_connector_manager: Optional[KvCacheConnectorManager] = None ) -> PyExecutor: kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None) @@ -562,7 +578,7 @@ def create_py_executor_instance( guided_decoder=guided_decoder, start_worker=start_worker, garbage_collection_gen0_threshold=garbage_collection_gen0_threshold, - kv_connector_config=kv_connector_config) + kv_connector_manager=kv_connector_manager) def create_torch_sampler_args(executor_config: ExecutorConfig, mapping: Mapping, diff --git a/tensorrt_llm/_torch/pyexecutor/connector.py b/tensorrt_llm/_torch/pyexecutor/connector.py index b3b78848dd6..d466c55a7f1 100644 --- a/tensorrt_llm/_torch/pyexecutor/connector.py +++ b/tensorrt_llm/_torch/pyexecutor/connector.py @@ -1,23 +1,27 @@ from typing import Optional +from tensorrt_llm._utils import mpi_broadcast, mpi_rank from tensorrt_llm.bindings.internal.batch_manager import \ - KvCacheConnector as KvCacheConnectorCpp -from tensorrt_llm.bindings.internal.batch_manager import (KvCacheConnectorRole, - SchedulerOutput) + KvCacheConnectorManager as KvCacheConnectorManagerCpp +from tensorrt_llm.bindings.internal.batch_manager import ( + KvCacheConnectorScheduler, KvCacheConnectorWorker, LlmRequest) -class KvCacheConnector(KvCacheConnectorCpp): +class KvCacheConnectorManager(KvCacheConnectorManagerCpp): - def __init__(self, role: KvCacheConnectorRole): - super().__init__(role) - self.connector_metadata = None + def __init__(self, worker: KvCacheConnectorWorker, + scheduler: Optional[KvCacheConnectorScheduler]): + assert (scheduler is not None) == ( + mpi_rank() == 0), "The scheduler may only exist on rank 0!" + super().__init__(worker, scheduler) - def bind_connector_metadata(self, metadata: object): - self.connector_metadata = metadata + def get_num_new_matched_tokens( + self, request: LlmRequest, + num_computed_tokens: int) -> tuple[int, bool]: + if self.scheduler is not None: + result = self.scheduler.getNumNewMatchedTokens( + request, num_computed_tokens) + else: + result = None - def _get_connector_metadata(self) -> object: - return self.connector_metadata - - def build_connector_metadata( - self, scheduler_output: SchedulerOutput) -> Optional[object]: - return None + return mpi_broadcast(result, root=0) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 8da85260c4f..e3e058c57e1 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -2,7 +2,6 @@ import datetime import functools import gc -import importlib import os import threading import time @@ -28,16 +27,15 @@ RequestStage, RequestStats, SpecDecodingStats, StaticBatchingStats) -from tensorrt_llm.bindings.internal.batch_manager import (KvCacheConnectorRole, - LlmRequestType, +from tensorrt_llm.bindings.internal.batch_manager import (LlmRequestType, ReqIdsSet) from tensorrt_llm.logger import logger -from tensorrt_llm.models.modeling_utils import KvCacheConnectorConfig from tensorrt_llm.runtime.generation import CUASSERT from ..distributed import Distributed from ..models.modeling_utils import DecoderModelForCausalLM from ..speculative.drafter import Drafter +from .connector import KvCacheConnectorManager from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem from .guided_decoder import GuidedDecoder from .kv_cache_transceiver import KvCacheTransceiver @@ -154,7 +152,7 @@ def __init__(self, guided_decoder: Optional[GuidedDecoder] = None, garbage_collection_gen0_threshold: Optional[int] = None, start_worker: bool = True, - kv_connector_config: Optional[KvCacheConnectorConfig] = None): + kv_connector_manager: Optional[KvCacheConnectorManager] = None): super(PyExecutor, self).__init__() self.device_id = torch.cuda.current_device() self.global_rank = global_mpi_rank() @@ -268,28 +266,12 @@ def __init__(self, self.worker_started = False self.worker_lock = threading.Lock() - if kv_connector_config is not None: - logger.info( - f"Initializing kv connector with config: {kv_connector_config}") - module_name = kv_connector_config.connector_module - class_name = kv_connector_config.connector_class + self.kv_connector_manager = kv_connector_manager - try: - module = importlib.import_module(module_name) - connector_cls = getattr(module, class_name) - self.connector_worker = connector_cls( - KvCacheConnectorRole.Worker) - - # Only initialize the scheduler on rank 0. - if global_mpi_rank() == 0: - self.connector_scheduler = connector_cls( - KvCacheConnectorRole.Scheduler) - except Exception as e: - logger.error(f"Error instantiating connector: {e}") - raise e + if self.kv_connector_manager is not None: kv_cache_data = self.kv_cache_manager.get_kv_cache_connector_pools_data( ) - self.connector_worker.register_kv_caches(kv_cache_data) + self.kv_connector_manager.worker.register_kv_caches(kv_cache_data) if start_worker: self.start_worker() diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index b21f5d4ab0f..6fb8fbd0b3a 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -1,5 +1,6 @@ import copy import enum +import importlib from contextlib import contextmanager from dataclasses import dataclass from itertools import chain @@ -11,7 +12,8 @@ from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType from tensorrt_llm._utils import get_sm_version from tensorrt_llm.bindings.executor import ContextChunkingPolicy, ExecutorConfig -from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig +from tensorrt_llm.bindings.internal.batch_manager import ( + ContextChunkingConfig, KvCacheConnectorManager) from tensorrt_llm.logger import logger from tensorrt_llm.lora_manager import LoraConfig from tensorrt_llm.mapping import Mapping @@ -361,16 +363,48 @@ def create_py_executor( pytorch_backend_config, mapping) logger.info(f"Using Sampler: {type(sampler).__name__}") + if kv_connector_config is not None: + logger.info( + f"Initializing kv connector with config: {kv_connector_config}") + + try: + module = importlib.import_module( + kv_connector_config.connector_module) + worker_cls = getattr(module, + kv_connector_config.connector_worker_class) + scheduler_cls = getattr( + module, kv_connector_config.connector_scheduler_class) + + connector_worker = worker_cls() + + # Only initialize the scheduler on rank 0. + rank = tensorrt_llm.mpi_rank() + if rank == 0: + connector_scheduler = scheduler_cls() + else: + connector_scheduler = None + + kv_connector_manager = KvCacheConnectorManager( + connector_worker, connector_scheduler) + + except Exception as e: + logger.error(f"Error instantiating connector: {e}") + raise e + else: + kv_connector_manager = None + resources = {} estimating_kv_cache = False kv_cache_creator = None if model_engine.model.model_config.is_generation: #NOTE: non-generation models do not have kv cache - kv_cache_creator = KvCacheCreator(executor_config=executor_config, - model_engine=model_engine, - draft_model_engine=draft_model_engine, - mapping=mapping, - net_max_seq_len=net_max_seq_len) + kv_cache_creator = KvCacheCreator( + executor_config=executor_config, + model_engine=model_engine, + draft_model_engine=draft_model_engine, + mapping=mapping, + net_max_seq_len=net_max_seq_len, + kv_connector_manager=kv_connector_manager) estimating_kv_cache = kv_cache_creator.try_prepare_estimation() with mem_monitor.observe_creation_stage( _ExecutorCreationStage.INIT_KV_CACHE @@ -411,7 +445,7 @@ def create_py_executor( guided_decoder=guided_decoder, lora_config=lora_config, garbage_collection_gen0_threshold=garbage_collection_gen0_threshold, - kv_connector_config=kv_connector_config + kv_connector_manager=kv_connector_manager if not estimating_kv_cache else None, ) @@ -456,7 +490,7 @@ def create_py_executor( lora_config=lora_config, garbage_collection_gen0_threshold= garbage_collection_gen0_threshold, - kv_connector_config=kv_connector_config, + kv_connector_manager=kv_connector_manager, ) _adjust_torch_mem_fraction(executor_config.pytorch_backend_config) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 2ddf14a924f..3e6eb1c43a6 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -18,6 +18,7 @@ from ..._utils import binding_dtype_size, binding_to_str_dtype, nvtx_range from ...logger import logger from ...mapping import Mapping +from .connector import KvCacheConnectorManager from .llm_request import (LlmRequest, LlmRequestState, SamplingConfig, get_draft_token_length) from .scheduler import ScheduledRequests @@ -134,6 +135,7 @@ def __init__( max_num_tokens: int = 8192, model_config: Optional[ModelConfig] = None, max_beam_width: int = 1, + kv_connector_manager: Optional[KvCacheConnectorManager] = None, ) -> None: self.mapping = mapping self.dtype = dtype @@ -150,6 +152,8 @@ def __init__( for offset, idx in enumerate(self.pp_layers) } + self.kv_connector_manager = kv_connector_manager + tp_size = mapping.tp_size if mapping.enable_attention_dp: tp_size = 1 diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py index cc99566b46a..6fdeb0163b2 100644 --- a/tensorrt_llm/models/modeling_utils.py +++ b/tensorrt_llm/models/modeling_utils.py @@ -127,7 +127,8 @@ def from_arguments(args: argparse.Namespace): @dataclasses.dataclass class KvCacheConnectorConfig: connector_module: str - connector_class: str + connector_scheduler_class: str + connector_worker_class: str @dataclasses.dataclass diff --git a/tests/unittest/bindings/test_connector_bindings.py b/tests/unittest/bindings/test_connector_bindings.py index 0fc5fb8c929..531f74c10f4 100644 --- a/tests/unittest/bindings/test_connector_bindings.py +++ b/tests/unittest/bindings/test_connector_bindings.py @@ -1,52 +1,51 @@ -from typing import List, Optional +from typing import List -from tensorrt_llm._torch.pyexecutor.connector import KvCacheConnector -from tensorrt_llm.bindings.internal.batch_manager import (KvCacheConnectorRole, - LlmRequest) +from tensorrt_llm.bindings.internal.batch_manager import ( + KvCacheConnectorPoolsData, KvCacheConnectorScheduler, + KvCacheConnectorWorker, LlmRequest) -class BasicConnector(KvCacheConnector): +class BasicConnectorWorker(KvCacheConnectorWorker): - def __init__(self, role: KvCacheConnectorRole): - super().__init__(role) - - def build_connector_metadata(self) -> Optional[object]: - return {"test": "test"} + def register_kv_caches(self, kv_cache_data: KvCacheConnectorPoolsData): + pass def start_load_kv(self): pass - def wait_for_layer_load(self, layer_idx: int): + def wait_for_save(self): pass - def save_kv_layer(self, layer_idx: int): + def wait_for_layer_load(self, layer_idx: int): pass - def wait_for_save(self): + def save_kv_layer(self, layer_idx: int): pass def get_finished( self, finished_req_ids: List[int]) -> tuple[List[int], List[int]]: return [42], [7] + +class BasicConnectorScheduler(KvCacheConnectorScheduler): + def get_num_new_matched_tokens( self, request: LlmRequest, num_computed_tokens: int) -> tuple[int, bool]: return 16, True + def update_state_after_alloc(self): + pass -def test_basic_init(): - connector = BasicConnector(KvCacheConnectorRole.Scheduler) - - assert connector.role == KvCacheConnectorRole.Scheduler - assert connector.build_connector_metadata() == {"test": "test"} +def test_basic_init(): + connector_scheduler = BasicConnectorScheduler() - assert connector.get_finished([]) == ([42], [7]) + connector_scheduler.update_state_after_alloc() - connector.save_kv_layer(0) - connector.wait_for_save() + connector_worker = BasicConnectorWorker() - connector_worker = BasicConnector(KvCacheConnectorRole.Worker) + assert connector_worker.get_finished([]) == ([42], [7]) - assert connector_worker.role == KvCacheConnectorRole.Worker + connector_worker.save_kv_layer(0) + connector_worker.wait_for_save() From d545a5d3480ee033a42b1472fc91f6b890c56ccf Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Mon, 28 Jul 2025 13:33:57 -0700 Subject: [PATCH 09/50] Get num new matched tokens Signed-off-by: jthomson04 --- .../batch_manager/kvCacheConnector.h | 3 +- .../batch_manager/kvCacheManager.h | 16 ++++-- .../batch_manager/kvCacheManager.cpp | 51 +++++++++++++------ .../pybind/batch_manager/kvCacheConnector.cpp | 20 ++++---- .../pybind/batch_manager/kvCacheManager.cpp | 12 +++-- tensorrt_llm/_torch/pyexecutor/connector.py | 16 +++--- .../_torch/pyexecutor/py_executor_creator.py | 4 +- .../_torch/pyexecutor/resource_manager.py | 5 +- 8 files changed, 80 insertions(+), 47 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h index 052cd20cb18..c176e025cb1 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h @@ -169,8 +169,7 @@ class KvCacheConnectorManager { } - virtual std::tuple getNumNewMatchedTokens(LlmRequest const& request, SizeType32 numComputedTokens) - = 0; + virtual SizeType32 getNumNewMatchedTokens(LlmRequest const& request, SizeType32 numComputedTokens) = 0; std::optional> getScheduler() const { diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index a9d7ab1143d..c0fedf058fb 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -549,8 +549,9 @@ class WindowBlockManager void startScheduling(); //! \brief Assign blocks for new sequence. Try to reuse blocks. - void addSequence( - GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest); + void addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, + LlmRequest& llmRequest, + std::optional> kvCacheConnectorManager); //! \brief Assign blocks for new sequence. Does not try to reuse blocks. void addSequence(GenerationRequest& sequence, SizeType32 numBlocks, SizeType32 unsharedBlockIdx); @@ -884,7 +885,9 @@ class BlockManager void allocatePools(bool useUvm); void addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, - LlmRequest& llmRequest, SizeType32 windowSize); + LlmRequest& llmRequest, + std::optional> kvCacheConnectorManager, + SizeType32 windowSize); void addSequence( GenerationRequest& sequence, SizeType32 numBlocks, SizeType32 unsharedBlockIdx, SizeType32 windowSize); @@ -1245,7 +1248,8 @@ class BaseKVCacheManager /// @details If llmRequest is supplied and KV cache reuse is enabled, try to recover KV cache blocks for /// inputLength - 1 tokens and populate prepopulatedPromptLen. virtual void addSequence(LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, - OptionalRef llmRequest = std::nullopt) + OptionalRef llmRequest = std::nullopt, + std::optional> kvCacheConnectorManager = std::nullopt) = 0; virtual void removeSequence( @@ -1545,7 +1549,9 @@ class KVCacheManager : public BaseKVCacheManager /// @details If llmRequest is supplied and KV cache reuse is enabled, try to recover KV cache blocks for /// inputLength - 1 tokens and populate prepopulatedPromptLen. void addSequence(LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, - OptionalRef llmRequest = std::nullopt) override; + OptionalRef llmRequest = std::nullopt, + std::optional> kvCacheConnectorManager + = std::nullopt) override; void removeSequence( LlmRequest::RequestIdType requestId, OptionalRef llmRequest = std::nullopt) override; diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 6f173165638..6bf15ebf144 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -1198,13 +1198,17 @@ void WindowBlockManager::refreshBlocks() } void BlockManager::addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, - LlmRequest& llmRequest, SizeType32 windowSize) + LlmRequest& llmRequest, + std::optional> kvCacheConnectorManager, + SizeType32 windowSize) { - mWindowBlockManagers.at(windowSize).addSequence(sequence, inputLength, numContextBlocks, llmRequest); + mWindowBlockManagers.at(windowSize) + .addSequence(sequence, inputLength, numContextBlocks, llmRequest, kvCacheConnectorManager); } -void WindowBlockManager::addSequence( - GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest) +void WindowBlockManager::addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, + LlmRequest& llmRequest, + std::optional> kvCacheConnectorManager) { auto const requestId = sequence.getRequestId(); auto const [seqIt, emplaceDone] = mAllocatedBlocksPerSeq.emplace(requestId, std::vector{}); @@ -1235,9 +1239,19 @@ void WindowBlockManager::addSequence( auto const prepopulatedPromptLen = loadOrAllocateBlocks(blockKeys, numContextBlocks, sequence, perBlockRetentions); mReusedTokens += static_cast(prepopulatedPromptLen); mTotalInputTokens += static_cast(uniqueTokens.size()); - llmRequest.setPrepopulatedPromptLen(prepopulatedPromptLen, getTokensPerBlock()); - TLLM_LOG_DEBUG("addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d", llmRequest.mRequestId, - inputLength, prepopulatedPromptLen); + + SizeType32 numNewMatchedTokens = 0; + + if (kvCacheConnectorManager.has_value()) + { + numNewMatchedTokens = kvCacheConnectorManager->get()->getNumNewMatchedTokens(llmRequest, prepopulatedPromptLen); + TLLM_LOG_DEBUG("addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d, numNewMatchedTokens %d", + llmRequest.mRequestId, inputLength, prepopulatedPromptLen, numNewMatchedTokens); + } + + llmRequest.setPrepopulatedPromptLen(prepopulatedPromptLen + numNewMatchedTokens, getTokensPerBlock()); + TLLM_LOG_DEBUG("addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d, numNewMatchedTokens %d", + llmRequest.mRequestId, inputLength, prepopulatedPromptLen, numNewMatchedTokens); } void BlockManager::addSequence( @@ -1537,13 +1551,10 @@ void BlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef BlockManager::getKvCacheConnectorPoolsData() const { - if (mWindowBlockManagers.size() > 1) - { - throw std::runtime_error("KV Cache connector is not supported with multiple window sizes"); - } - + TLLM_CHECK_WITH_INFO( + mWindowBlockManagers.size() == 1, "KV Cache connector is not supported with multiple window sizes"); std::vector poolsData; - poolsData.reserve(mWindowBlockManagers.size()); + poolsData.reserve(1); for (auto const& [_, manager] : mWindowBlockManagers) { poolsData.emplace_back(manager.getKvCacheConnectorPoolData()); @@ -2043,8 +2054,9 @@ std::optional KVCacheManager::findNewContextBlock( return newContextBlockOpt; } -void KVCacheManager::addSequence( - RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, OptionalRef llmRequest) +void KVCacheManager::addSequence(RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, + OptionalRef llmRequest, + std::optional> kvCacheConnectorManager) { // Need to add the bubble after the sink tokens to use even block size inputLength += mSinkBubbleLength; @@ -2092,7 +2104,8 @@ void KVCacheManager::addSequence( auto const numContextBlocks = tc::ceilDiv(effectiveInputLength, getTokensPerBlock()); if (!sequence.isCyclic() && mEnableBlockReuse) { - mBlockManager.addSequence(sequence, effectiveInputLength, numContextBlocks, *llmRequest, windowSize); + mBlockManager.addSequence( + sequence, effectiveInputLength, numContextBlocks, *llmRequest, kvCacheConnectorManager, windowSize); } else { @@ -2104,6 +2117,12 @@ void KVCacheManager::addSequence( "will " "have no effect.", llmRequest->mRequestId); + if (kvCacheConnectorManager.has_value()) + { + TLLM_LOG_WARNING( + "KV Cache Connector specified when block reuse is disabled. The KV Cache Connector will be " + "ignored."); + } } mBlockManager.addSequence(sequence, numContextBlocks, unsharedBlockIdx, windowSize); if (mEnableHashKey && llmRequest.has_value() && beamWidth == 1) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp index 96b41f237da..60c21c2eb9c 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp @@ -31,7 +31,7 @@ using NumNewMatchedTokens = std::tuple; namespace tb = tensorrt_llm::batch_manager; -class PyKvCacheConnectorScheduler : public KvCacheConnectorScheduler +class PyKvCacheConnectorScheduler : public KvCacheConnectorScheduler, py::trampoline_self_life_support { public: using KvCacheConnectorScheduler::KvCacheConnectorScheduler; @@ -44,23 +44,23 @@ class PyKvCacheConnectorScheduler : public KvCacheConnectorScheduler void updateStateAfterAlloc() override { - PYBIND11_OVERRIDE_PURE(void, KvCacheConnectorScheduler, updateStateAfterAlloc); + PYBIND11_OVERRIDE(void, KvCacheConnectorScheduler, updateStateAfterAlloc); } bool requestFinished(LlmRequest const& request) override { - PYBIND11_OVERRIDE_PURE(bool, KvCacheConnectorScheduler, requestFinished, request); + PYBIND11_OVERRIDE(bool, KvCacheConnectorScheduler, requestFinished, request); } }; -class PyKvCacheConnectorWorker : public KvCacheConnectorWorker +class PyKvCacheConnectorWorker : public KvCacheConnectorWorker, py::trampoline_self_life_support { public: using KvCacheConnectorWorker::KvCacheConnectorWorker; void registerKvCaches(kv_connector::KvCacheConnectorPoolsData const& kvCacheConnectorPoolsData) override { - PYBIND11_OVERRIDE_PURE(void, KvCacheConnectorWorker, registerKvCaches, kvCacheConnectorPoolsData); + PYBIND11_OVERRIDE(void, KvCacheConnectorWorker, registerKvCaches, kvCacheConnectorPoolsData); } void startLoadKv() override @@ -87,19 +87,19 @@ class PyKvCacheConnectorWorker : public KvCacheConnectorWorker FinishedReqs getFinished(std::vector const& finishedReqIds) override { - PYBIND11_OVERRIDE_PURE(FinishedReqs, KvCacheConnectorWorker, getFinished, finishedReqIds); + PYBIND11_OVERRIDE(FinishedReqs, KvCacheConnectorWorker, getFinished, finishedReqIds); } }; -class PyKvCacheConnectorManager : public KvCacheConnectorManager +class PyKvCacheConnectorManager : public KvCacheConnectorManager, py::trampoline_self_life_support { public: using KvCacheConnectorManager::KvCacheConnectorManager; - NumNewMatchedTokens getNumNewMatchedTokens(LlmRequest const& request, SizeType32 numComputedTokens) override + SizeType32 getNumNewMatchedTokens(LlmRequest const& request, SizeType32 numComputedTokens) override { - PYBIND11_OVERRIDE_PURE( - NumNewMatchedTokens, KvCacheConnectorManager, getNumNewMatchedTokens, request, numComputedTokens); + PYBIND11_OVERRIDE_PURE_NAME(SizeType32, KvCacheConnectorManager, "get_num_new_matched_tokens", + getNumNewMatchedTokens, request, numComputedTokens); } }; diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index f9b02b78fca..c9fd2f1fae3 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -96,10 +96,12 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager } void addSequence(tb::LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, - tensorrt_llm::common::OptionalRef llmRequest = std::nullopt) override + tensorrt_llm::common::OptionalRef llmRequest = std::nullopt, + std::optional> kvCacheConnectorManager + = std::nullopt) override { - PYBIND11_OVERLOAD_PURE( - void, tbk::BaseKVCacheManager, addSequence, requestId, inputLength, beamWidth, llmRequest); + PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, addSequence, requestId, inputLength, beamWidth, + llmRequest, kvCacheConnectorManager); } void removeSequence(tb::LlmRequest::RequestIdType requestId, @@ -348,7 +350,9 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) .def("get_needed_blocks_one_step", &BaseKVCacheManager::getNeededBlocksOneStep) .def("get_remaining_blocks_to_completion", &BaseKVCacheManager::getRemainingBlocksToCompletion) .def("add_token", &BaseKVCacheManager::addToken) - .def("add_sequence", &BaseKVCacheManager::addSequence) + .def("add_sequence", &BaseKVCacheManager::addSequence, py::arg("request_id"), py::arg("input_length"), + py::arg("beam_width"), py::arg("llm_request") = std::nullopt, + py::arg("kv_cache_connector_manager") = std::nullopt) .def("remove_sequence", &BaseKVCacheManager::removeSequence) .def("scheduling_remove_sequence", &BaseKVCacheManager::schedulingRemoveSequence) .def("get_block_pool_pointers", diff --git a/tensorrt_llm/_torch/pyexecutor/connector.py b/tensorrt_llm/_torch/pyexecutor/connector.py index d466c55a7f1..8c54ac9044c 100644 --- a/tensorrt_llm/_torch/pyexecutor/connector.py +++ b/tensorrt_llm/_torch/pyexecutor/connector.py @@ -15,13 +15,17 @@ def __init__(self, worker: KvCacheConnectorWorker, mpi_rank() == 0), "The scheduler may only exist on rank 0!" super().__init__(worker, scheduler) - def get_num_new_matched_tokens( - self, request: LlmRequest, - num_computed_tokens: int) -> tuple[int, bool]: + def get_num_new_matched_tokens(self, + request: LlmRequest = None, + num_computed_tokens: int = None) -> int: if self.scheduler is not None: - result = self.scheduler.getNumNewMatchedTokens( + num_tokens, load_kv_async = self.scheduler.get_num_new_matched_tokens( request, num_computed_tokens) + + mpi_broadcast((num_tokens, load_kv_async), root=0) else: - result = None + num_tokens, load_kv_async = mpi_broadcast(None, root=0) + + # TODO: Do some stuff in the future to handle load_kv_async. - return mpi_broadcast(result, root=0) + return num_tokens diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 6fb8fbd0b3a..c562ad57961 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -12,8 +12,7 @@ from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType from tensorrt_llm._utils import get_sm_version from tensorrt_llm.bindings.executor import ContextChunkingPolicy, ExecutorConfig -from tensorrt_llm.bindings.internal.batch_manager import ( - ContextChunkingConfig, KvCacheConnectorManager) +from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig from tensorrt_llm.logger import logger from tensorrt_llm.lora_manager import LoraConfig from tensorrt_llm.mapping import Mapping @@ -28,6 +27,7 @@ create_py_executor_instance, instantiate_sampler, is_mla) from .config import PyTorchConfig from .config_utils import is_mla +from .connector import KvCacheConnectorManager from .guided_decoder import GuidedDecoder from .model_engine import PyTorchModelEngine from .py_executor import PyExecutor diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 3e6eb1c43a6..87ce0688221 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -386,11 +386,12 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): req.py_request_id, seq_len + (len(req.query_id) if self.mapping.cp_rank == self.mapping.cp_size - 1 else 0), - req_beam_width, req) + req_beam_width, req, self.kv_connector_manager) else: if req.is_first_context_chunk: self.impl.add_sequence(req.py_request_id, req.prompt_len, - req_beam_width, req) + req_beam_width, req, + self.kv_connector_manager) for _ in range(self.num_extra_kv_tokens): self.impl.add_token(req.py_request_id) for _ in range(get_draft_token_length(req)): From 9a1ba686484e01b77826b18013c7b10dcf9e33f9 Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Mon, 28 Jul 2025 18:11:13 -0700 Subject: [PATCH 10/50] Suspend requests for async onboard Signed-off-by: jthomson04 --- .../batch_manager/kvCacheConnector.h | 22 +---- .../batch_manager/kvCacheManager.cpp | 7 ++ .../pybind/batch_manager/kvCacheConnector.cpp | 6 +- tensorrt_llm/_torch/pyexecutor/connector.py | 84 +++++++++++++++++-- tensorrt_llm/_torch/pyexecutor/py_executor.py | 19 +++++ 5 files changed, 106 insertions(+), 32 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h index c176e025cb1..fedba8bb0cc 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h @@ -162,28 +162,10 @@ class KvCacheConnectorWorker class KvCacheConnectorManager { public: - KvCacheConnectorManager(std::shared_ptr const& worker, - std::optional> const& scheduler) - : mWorker(worker) - , mScheduler(scheduler) - { - } + KvCacheConnectorManager() = default; + virtual ~KvCacheConnectorManager() = default; virtual SizeType32 getNumNewMatchedTokens(LlmRequest const& request, SizeType32 numComputedTokens) = 0; - - std::optional> getScheduler() const - { - return mScheduler; - } - - std::shared_ptr getWorker() const - { - return mWorker; - } - -private: - std::shared_ptr mWorker; - std::optional> mScheduler; }; } // namespace tensorrt_llm::batch_manager::kv_connector diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 6bf15ebf144..b10492f2e2d 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -1247,6 +1247,8 @@ void WindowBlockManager::addSequence(GenerationRequest& sequence, SizeType32 inp numNewMatchedTokens = kvCacheConnectorManager->get()->getNumNewMatchedTokens(llmRequest, prepopulatedPromptLen); TLLM_LOG_DEBUG("addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d, numNewMatchedTokens %d", llmRequest.mRequestId, inputLength, prepopulatedPromptLen, numNewMatchedTokens); + TLLM_CHECK_WITH_INFO(prepopulatedPromptLen + numNewMatchedTokens < llmRequest.getPromptLen(), + "There must be at least one uncomputed token in the prompt!"); } llmRequest.setPrepopulatedPromptLen(prepopulatedPromptLen + numNewMatchedTokens, getTokensPerBlock()); @@ -2061,6 +2063,11 @@ void KVCacheManager::addSequence(RequestIdType requestId, SizeType32 inputLength // Need to add the bubble after the sink tokens to use even block size inputLength += mSinkBubbleLength; + if (kvCacheConnectorManager.has_value()) + { + TLLM_CHECK_WITH_INFO(beamWidth == 1, "KV Cache Connector is not supported with beam search"); + } + auto kvCacheRetentionConfig = llmRequest ? llmRequest->getKvCacheRetentionConfig().value_or(executor::KvCacheRetentionConfig()) : executor::KvCacheRetentionConfig(); diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp index 60c21c2eb9c..5f97f701156 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp @@ -143,11 +143,7 @@ void tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManagerConnectorBindi py::class_( m, "KvCacheConnectorManager") - .def(py::init, - std::optional>>(), - py::arg("worker"), py::arg("scheduler")) - .def_property_readonly("scheduler", &tb::kv_connector::KvCacheConnectorManager::getScheduler) - .def_property_readonly("worker", &tb::kv_connector::KvCacheConnectorManager::getWorker) + .def(py::init<>()) .def("get_num_new_matched_tokens", &tb::kv_connector::KvCacheConnectorManager::getNumNewMatchedTokens, py::arg("request"), py::arg("num_computed_tokens")); diff --git a/tensorrt_llm/_torch/pyexecutor/connector.py b/tensorrt_llm/_torch/pyexecutor/connector.py index 8c54ac9044c..18d80dffda8 100644 --- a/tensorrt_llm/_torch/pyexecutor/connector.py +++ b/tensorrt_llm/_torch/pyexecutor/connector.py @@ -1,10 +1,40 @@ from typing import Optional from tensorrt_llm._utils import mpi_broadcast, mpi_rank +from tensorrt_llm.bindings import LlmRequestState from tensorrt_llm.bindings.internal.batch_manager import \ KvCacheConnectorManager as KvCacheConnectorManagerCpp -from tensorrt_llm.bindings.internal.batch_manager import ( - KvCacheConnectorScheduler, KvCacheConnectorWorker, LlmRequest) +from tensorrt_llm.bindings.internal.batch_manager import \ + KvCacheConnectorScheduler as KvCacheConnectorSchedulerCpp +from tensorrt_llm.bindings.internal.batch_manager import \ + KvCacheConnectorWorker as KvCacheConnectorWorkerCpp +from tensorrt_llm.bindings.internal.batch_manager import LlmRequest + +from .scheduler import ScheduledRequests + + +class KvCacheConnectorWorker(KvCacheConnectorWorkerCpp): + + def __init__(self): + super().__init__() + + def bind_connector_meta(self, metadata: object): + self._metadata = metadata + + def get_connector_meta(self) -> object: + return self._metadata + + def _clear_connector_meta(self): + self._metadata = None + + +class KvCacheConnectorScheduler(KvCacheConnectorSchedulerCpp): + + def __init__(self): + super().__init__() + + def build_connector_metadata(self, metadata: object): + return None class KvCacheConnectorManager(KvCacheConnectorManagerCpp): @@ -13,11 +43,18 @@ def __init__(self, worker: KvCacheConnectorWorker, scheduler: Optional[KvCacheConnectorScheduler]): assert (scheduler is not None) == ( mpi_rank() == 0), "The scheduler may only exist on rank 0!" - super().__init__(worker, scheduler) - def get_num_new_matched_tokens(self, - request: LlmRequest = None, - num_computed_tokens: int = None) -> int: + super().__init__() + + self.worker = worker + self.scheduler = scheduler + + self.requests_awaiting_async_load_init = set() + + self.requests_awaiting_async_load_complete = [] + + def get_num_new_matched_tokens(self, request: LlmRequest, + num_computed_tokens: int) -> int: if self.scheduler is not None: num_tokens, load_kv_async = self.scheduler.get_num_new_matched_tokens( request, num_computed_tokens) @@ -26,6 +63,39 @@ def get_num_new_matched_tokens(self, else: num_tokens, load_kv_async = mpi_broadcast(None, root=0) - # TODO: Do some stuff in the future to handle load_kv_async. + if num_tokens == 0 and load_kv_async: + raise RuntimeError( + "load_kv_async must be False when num_tokens is 0!") + + if load_kv_async: + self.requests_awaiting_async_load_init.add(request.request_id) return num_tokens + + def build_connector_metadata(self) -> object: + if self.scheduler is not None: + metadata = self.scheduler.build_connector_metadata() + else: + metadata = None + + metadata = mpi_broadcast(metadata, root=0) + + self.worker.bind_connector_meta(metadata) + + def take_scheduled_requests_pending_transfer( + self, scheduled_requests: ScheduledRequests) -> ScheduledRequests: + allowed_context_requests = [] + async_load_requests = [] + + for req in scheduled_requests.context_requests: + if req.request_id in self.requests_awaiting_async_load_init: + async_load_requests.append(req) + req.state = LlmRequestState.DISAGG_GENERATION_TRANS_IN_PROGRESS + self.requests_awaiting_async_load_init.remove(req.request_id) + self.requests_awaiting_async_load_complete.append(req) + else: + allowed_context_requests.append(req) + + scheduled_requests.context_requests = allowed_context_requests + + return scheduled_requests diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index e3e058c57e1..fc414724a80 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -269,6 +269,21 @@ def __init__(self, self.kv_connector_manager = kv_connector_manager if self.kv_connector_manager is not None: + if kv_cache_transceiver is not None: + raise NotImplementedError( + "KV Cache Connector is not supported with KvCacheTransceiver." + ) + + if self.dist.pp_size > 1: + raise NotImplementedError( + "KV Cache Connector is not supported with pipeline parallelism." + ) + + if not disable_overlap_scheduler: + raise NotImplementedError( + "KV Cache Connector is not supported with overlap scheduler." + ) + kv_cache_data = self.kv_cache_manager.get_kv_cache_connector_pools_data( ) self.kv_connector_manager.worker.register_kv_caches(kv_cache_data) @@ -896,6 +911,10 @@ def _prepare_and_schedule_batch(self): "num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache" ) self.kv_cache_transceiver.check_context_transfer_status(1) + elif self.kv_connector_manager is None: + assert scheduled_batch.batch_size > 0, ( + "fail to schedule any pending request, " + "probably run out of resource.") self.num_scheduled_requests = scheduled_batch.batch_size logger.debug( From 1f0a35b005c1acdd1c3ee6dd09057942054a49f9 Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Tue, 29 Jul 2025 09:37:15 -0700 Subject: [PATCH 11/50] async load and resume Signed-off-by: jthomson04 --- .../batch_manager/kvCacheConnector.h | 4 +- .../batch_manager/kvCacheConnector.cpp | 4 +- .../batch_manager/kvCacheManager.cpp | 6 +- .../pybind/batch_manager/kvCacheConnector.cpp | 8 +- tensorrt_llm/_torch/pyexecutor/connector.py | 133 ++++++++++++++++-- tensorrt_llm/_torch/pyexecutor/py_executor.py | 9 +- 6 files changed, 140 insertions(+), 24 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h index fedba8bb0cc..4efa67865f7 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h @@ -142,7 +142,6 @@ class KvCacheConnectorWorker explicit KvCacheConnectorWorker() = default; virtual ~KvCacheConnectorWorker() = default; - // TODO(jothomson): Need arguments here. virtual void registerKvCaches(KvCacheConnectorPoolsData const& kvCacheConnectorPoolsData); // TODO(jothomson): Need arguments here. @@ -150,13 +149,12 @@ class KvCacheConnectorWorker virtual void waitForLayerLoad(SizeType32 layer_idx) = 0; - // TODO(jothomson): Need arguments here. virtual void saveKvLayer(SizeType32 layer_idx) = 0; virtual void waitForSave() = 0; virtual std::tuple, std::vector> getFinished( - std::vector const& finishedReqIds); + std::vector const& finishedGenReqIds, std::vector const& startedLoadingReqIds); }; class KvCacheConnectorManager diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheConnector.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheConnector.cpp index d296d69171a..a1a559416e1 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheConnector.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheConnector.cpp @@ -6,9 +6,9 @@ namespace tensorrt_llm::batch_manager::kv_connector void KvCacheConnectorWorker::registerKvCaches(KvCacheConnectorPoolsData const& kvCacheConnectorPoolsData) {} std::tuple, std::vector> KvCacheConnectorWorker::getFinished( - std::vector const& finishedReqIds) + std::vector const& finishedGenReqIds, std::vector const& startedLoadingReqIds) { - return std::make_tuple(std::vector(), std::vector()); + return std::make_tuple(finishedGenReqIds, startedLoadingReqIds); } void KvCacheConnectorScheduler::updateStateAfterAlloc() {} diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index b10492f2e2d..c61bb440c1e 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -2078,7 +2078,11 @@ void KVCacheManager::addSequence(RequestIdType requestId, SizeType32 inputLength return mSequences.try_emplace(requestId, requestId, inputLength, beamWidth, mBlockManager.getWindowSizesMetadata(), kvCacheRetentionConfig); }(); - TLLM_CHECK(emplaceDone); + TLLM_CHECK(emplaceDone || kvCacheConnectorManager.has_value()); + if (!emplaceDone && kvCacheConnectorManager.has_value()) + { + return; + } auto& sequence = seqIt->second; // Get statistics for block allocations/reuse pre request. diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp index 5f97f701156..dbf592334f8 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp @@ -85,9 +85,10 @@ class PyKvCacheConnectorWorker : public KvCacheConnectorWorker, py::trampoline_s using FinishedReqs = std::tuple, std::vector>; - FinishedReqs getFinished(std::vector const& finishedReqIds) override + FinishedReqs getFinished(std::vector const& finishedGenReqIds, + std::vector const& startedLoadingReqIds) override { - PYBIND11_OVERRIDE(FinishedReqs, KvCacheConnectorWorker, getFinished, finishedReqIds); + PYBIND11_OVERRIDE(FinishedReqs, KvCacheConnectorWorker, getFinished, finishedGenReqIds, startedLoadingReqIds); } }; @@ -131,7 +132,8 @@ void tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManagerConnectorBindi .def("wait_for_layer_load", &tb::kv_connector::KvCacheConnectorWorker::waitForLayerLoad, py::arg("layer_idx")) .def("save_kv_layer", &tb::kv_connector::KvCacheConnectorWorker::saveKvLayer, py::arg("layer_idx")) .def("wait_for_save", &tb::kv_connector::KvCacheConnectorWorker::waitForSave) - .def("get_finished", &tb::kv_connector::KvCacheConnectorWorker::getFinished, py::arg("finished_req_ids")); + .def("get_finished", &tb::kv_connector::KvCacheConnectorWorker::getFinished, py::arg("started_loading_req_ids"), + py::arg("finished_gen_req_ids")); py::class_( m, "KvCacheConnectorScheduler") diff --git a/tensorrt_llm/_torch/pyexecutor/connector.py b/tensorrt_llm/_torch/pyexecutor/connector.py index 18d80dffda8..7ff02dc2a01 100644 --- a/tensorrt_llm/_torch/pyexecutor/connector.py +++ b/tensorrt_llm/_torch/pyexecutor/connector.py @@ -1,6 +1,7 @@ +from dataclasses import dataclass from typing import Optional -from tensorrt_llm._utils import mpi_broadcast, mpi_rank +from tensorrt_llm._utils import mpi_allgather, mpi_broadcast, mpi_rank from tensorrt_llm.bindings import LlmRequestState from tensorrt_llm.bindings.internal.batch_manager import \ KvCacheConnectorManager as KvCacheConnectorManagerCpp @@ -37,6 +38,57 @@ def build_connector_metadata(self, metadata: object): return None +@dataclass +class Finished: + saving: dict[int, LlmRequest] + loading: dict[int, LlmRequest] + + def add_from(self, other: 'Finished'): + self.saving.update(other.saving) + self.loading.update(other.loading) + + other.saving = dict() + other.loading = dict() + + def extract_by_id(self, saving_ids: list[int], loading_ids: list[int]): + + new_finished = Finished(dict(), dict()) + + for req_id in saving_ids: + new_finished.saving[req_id] = self.saving[req_id] + del self.saving[req_id] + for req_id in loading_ids: + new_finished.loading[req_id] = self.loading[req_id] + del self.loading[req_id] + + return new_finished + + def saving_ids(self) -> set[int]: + return set(self.saving.keys()) + + def loading_ids(self) -> set[int]: + return set(self.loading.keys()) + + @staticmethod + def intersection(*all_finished: 'Finished') -> 'Finished': + if len(all_finished) == 0: + return Finished(dict(), dict()) + + saving_ids = set.intersection( + *[finished.saving_ids() for finished in all_finished]) + loading_ids = set.intersection( + *[finished.loading_ids() for finished in all_finished]) + return Finished( + dict([(req_id, all_finished[0].saving[req_id]) + for req_id in saving_ids]), + dict([(req_id, all_finished[0].loading[req_id]) + for req_id in loading_ids])) + + def __sub__(self, other: 'Finished') -> 'Finished': + return Finished(self.saving - other.saving, + self.loading - other.loading) + + class KvCacheConnectorManager(KvCacheConnectorManagerCpp): def __init__(self, worker: KvCacheConnectorWorker, @@ -49,31 +101,38 @@ def __init__(self, worker: KvCacheConnectorWorker, self.worker = worker self.scheduler = scheduler - self.requests_awaiting_async_load_init = set() + # Requests that haven't yet been passed into get_finished. + self.new_finished = Finished(dict(), dict()) + + # Requests that have been passed into get_finished, but haven't yet been returned. + self.pending_finished = Finished(dict(), dict()) - self.requests_awaiting_async_load_complete = [] + # Requests that have been returned from get_finished locally, but haven't yet been returned by all workers. + self.local_finished = Finished(dict(), dict()) def get_num_new_matched_tokens(self, request: LlmRequest, num_computed_tokens: int) -> int: if self.scheduler is not None: - num_tokens, load_kv_async = self.scheduler.get_num_new_matched_tokens( + assert mpi_rank() == 0, "The scheduler may only exist on rank 0!" + res = self.scheduler.get_num_new_matched_tokens( request, num_computed_tokens) - - mpi_broadcast((num_tokens, load_kv_async), root=0) else: - num_tokens, load_kv_async = mpi_broadcast(None, root=0) + res = None + + (num_tokens, load_kv_async) = mpi_broadcast(res, root=0) if num_tokens == 0 and load_kv_async: raise RuntimeError( "load_kv_async must be False when num_tokens is 0!") if load_kv_async: - self.requests_awaiting_async_load_init.add(request.request_id) + self.new_finished.loading[request.request_id] = request return num_tokens def build_connector_metadata(self) -> object: if self.scheduler is not None: + assert mpi_rank() == 0, "The scheduler may only exist on rank 0!" metadata = self.scheduler.build_connector_metadata() else: metadata = None @@ -82,20 +141,66 @@ def build_connector_metadata(self) -> object: self.worker.bind_connector_meta(metadata) - def take_scheduled_requests_pending_transfer( + def request_finished(self, req: LlmRequest) -> bool: + if self.scheduler is not None: + assert mpi_rank() == 0, "The scheduler may only exist on rank 0!" + saving_async = self.scheduler.request_finished(req) + else: + saving_async = None + + saving_async = mpi_broadcast(saving_async, root=0) + + if saving_async: + self.new_finished.saving[req.request_id] = req + + return saving_async + + def take_scheduled_requests_pending_load( self, scheduled_requests: ScheduledRequests) -> ScheduledRequests: allowed_context_requests = [] - async_load_requests = [] for req in scheduled_requests.context_requests: - if req.request_id in self.requests_awaiting_async_load_init: - async_load_requests.append(req) + if req.request_id in self.new_finished.loading.keys(): req.state = LlmRequestState.DISAGG_GENERATION_TRANS_IN_PROGRESS - self.requests_awaiting_async_load_init.remove(req.request_id) - self.requests_awaiting_async_load_complete.append(req) else: allowed_context_requests.append(req) scheduled_requests.context_requests = allowed_context_requests return scheduled_requests + + def get_finished(self) -> list[LlmRequest]: + started_loading_req_ids = list(self.new_finished.loading_ids()) + finished_gen_req_ids = list(self.new_finished.saving_ids()) + + self.pending_finished.add_from(self.new_finished) + (finished_saving, + finished_loading) = self.worker.get_finished(finished_gen_req_ids, + started_loading_req_ids) + + new_local_finished = self.pending_finished.extract_by_id( + finished_saving, finished_loading) + + # Get all pending finished requests for this worker. + self.local_finished.add_from(new_local_finished) + + # Broadcast this to all other workers. + finished_saving = list(self.local_finished.saving_ids()) + finished_loading = list(self.local_finished.loading_ids()) + + all_results = mpi_allgather((finished_saving, finished_loading)) + + # Find only the requests that have been reported complete by all workers. + intersect_finished_saving = set.intersection( + *[set(res[0]) for res in all_results]) + intersect_finished_loading = set.intersection( + *[set(res[1]) for res in all_results]) + + all_finished = self.local_finished.extract_by_id( + intersect_finished_saving, intersect_finished_loading) + + # For requests that have finished loading, move them back to the context state. + for req in all_finished.loading.values(): + req.state = LlmRequestState.CONTEXT_INIT + + return list(all_finished.saving.values()) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index fc414724a80..ca2fa77d0b3 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -998,6 +998,11 @@ def _executor_loop(self): if self.kv_cache_transceiver and self.ctx_in_transmission_requests: self._terminate_ctx_finished_requests() + if self.kv_connector_manager: + reqs_to_terminate = self.kv_connector_manager.get_finished() + for req in reqs_to_terminate: + self._terminate_request(req) + if self.enable_iter_perf_stats: iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[ 'num_ctx_tokens'] @@ -1548,7 +1553,9 @@ def _handle_errors(self, self._enqueue_responses(error_responses.items()) def _terminate_request(self, request: LlmRequest): - self.resource_manager.free_resources(request) + if self.kv_connector_manager is None or not self.kv_connector_manager.request_finished( + request): + self.resource_manager.free_resources(request) @nvtx_range("_handle_canceled_requests") def _handle_canceled_requests(self): From e75e6c412d01dfa1b064cf1598395729ce6e49dc Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Tue, 29 Jul 2025 11:25:41 -0700 Subject: [PATCH 12/50] Little cleanup Signed-off-by: jthomson04 --- .../batch_manager/kvCacheConnector.h | 41 ------------------- .../tensorrt_llm/batch_manager/llmRequest.h | 12 ++++++ .../batch_manager/kvCacheManager.cpp | 6 +-- .../pybind/batch_manager/bindings.cpp | 4 +- .../pybind/batch_manager/kvCacheConnector.cpp | 16 -------- tensorrt_llm/_torch/pyexecutor/connector.py | 2 + .../_torch/pyexecutor/resource_manager.py | 2 +- 7 files changed, 19 insertions(+), 64 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h index 4efa67865f7..20468954928 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h @@ -31,47 +31,6 @@ using namespace tensorrt_llm::batch_manager; namespace tensorrt_llm::batch_manager::kv_connector { -struct NewRequestData -{ - NewRequestData(RequestIdType requestId, std::vector const& newTokens, - std::vector const& blockIds, SizeType32 numComputedTokens) - : requestId(requestId) - , newTokens(newTokens) - , blockIds(blockIds) - , numComputedTokens(numComputedTokens) - { - } - - RequestIdType requestId; - std::vector newTokens; - std::vector blockIds; - SizeType32 numComputedTokens; -}; - -struct CachedRequestData -{ - CachedRequestData(RequestIdType requestId, std::vector const& newTokens, - std::vector const& newBlockIds, SizeType32 numComputedTokens) - : requestId(requestId) - , newTokens(newTokens) - , newBlockIds(newBlockIds) - , numComputedTokens(numComputedTokens) - { - } - - RequestIdType requestId; - std::vector newTokens; - std::vector newBlockIds; - SizeType32 numComputedTokens; -}; - -class SchedulerOutput -{ -public: - std::vector newRequests; - std::vector cachedRequests; -}; - class KvCacheConnectorPoolData { public: diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index 3320c6b0929..f1c9d52763e 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -1843,6 +1843,16 @@ class GenericLlmRequest return mIsDummyRequest; } + void setIsKvCacheConnectorAsyncOnboard(bool isKvCacheConnectorAsyncOnboard) + { + mIsKvCacheConnectorAsyncOnboard = isKvCacheConnectorAsyncOnboard; + } + + [[nodiscard]] bool isKvCacheConnectorAsyncOnboard() const + { + return mIsKvCacheConnectorAsyncOnboard; + } + RequestIdType mRequestId; SizeType32 mPromptLen; SizeType32 mMaxNewTokens; @@ -2017,6 +2027,8 @@ class GenericLlmRequest bool mIsDummyRequest{false}; + bool mIsKvCacheConnectorAsyncOnboard{false}; + private: void initialize(VecTokens const& inputTokens, bool outputLogProbs) { diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index c61bb440c1e..b10492f2e2d 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -2078,11 +2078,7 @@ void KVCacheManager::addSequence(RequestIdType requestId, SizeType32 inputLength return mSequences.try_emplace(requestId, requestId, inputLength, beamWidth, mBlockManager.getWindowSizesMetadata(), kvCacheRetentionConfig); }(); - TLLM_CHECK(emplaceDone || kvCacheConnectorManager.has_value()); - if (!emplaceDone && kvCacheConnectorManager.has_value()) - { - return; - } + TLLM_CHECK(emplaceDone); auto& sequence = seqIt->second; // Get statistics for block allocations/reuse pre request. diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp index 0ba4fd94c20..42829f45621 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp @@ -256,7 +256,9 @@ void initBindings(pybind11::module_& m) } }) .def_property("is_dummy_request", &GenLlmReq::isDummyRequest, &GenLlmReq::setIsDummyRequest) - .def_property_readonly("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics); + .def_property_readonly("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics) + .def_property("is_kv_cache_connector_async_onboard", &GenLlmReq::isKvCacheConnectorAsyncOnboard, + &GenLlmReq::setIsKvCacheConnectorAsyncOnboard); py::classh(m, "LlmRequest", pybind11::dynamic_attr()) .def(py::init<>( diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp index dbf592334f8..1037dd23932 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp @@ -148,20 +148,4 @@ void tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManagerConnectorBindi .def(py::init<>()) .def("get_num_new_matched_tokens", &tb::kv_connector::KvCacheConnectorManager::getNumNewMatchedTokens, py::arg("request"), py::arg("num_computed_tokens")); - - py::class_(m, "NewRequestData") - .def_readonly("request_id", &tb::kv_connector::NewRequestData::requestId) - .def_readonly("new_tokens", &tb::kv_connector::NewRequestData::newTokens) - .def_readonly("block_ids", &tb::kv_connector::NewRequestData::blockIds) - .def_readonly("num_computed_tokens", &tb::kv_connector::NewRequestData::numComputedTokens); - - py::class_(m, "CachedRequestData") - .def_readonly("request_id", &tb::kv_connector::CachedRequestData::requestId) - .def_readonly("new_tokens", &tb::kv_connector::CachedRequestData::newTokens) - .def_readonly("new_block_ids", &tb::kv_connector::CachedRequestData::newBlockIds) - .def_readonly("num_computed_tokens", &tb::kv_connector::CachedRequestData::numComputedTokens); - - py::class_(m, "SchedulerOutput") - .def_readonly("new_requests", &tb::kv_connector::SchedulerOutput::newRequests) - .def_readonly("cached_requests", &tb::kv_connector::SchedulerOutput::cachedRequests); } diff --git a/tensorrt_llm/_torch/pyexecutor/connector.py b/tensorrt_llm/_torch/pyexecutor/connector.py index 7ff02dc2a01..a2fc280e997 100644 --- a/tensorrt_llm/_torch/pyexecutor/connector.py +++ b/tensorrt_llm/_torch/pyexecutor/connector.py @@ -127,6 +127,7 @@ def get_num_new_matched_tokens(self, request: LlmRequest, if load_kv_async: self.new_finished.loading[request.request_id] = request + request.is_kv_cache_connector_async_onboard = True return num_tokens @@ -152,6 +153,7 @@ def request_finished(self, req: LlmRequest) -> bool: if saving_async: self.new_finished.saving[req.request_id] = req + req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS return saving_async diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 87ce0688221..2dd1d4648d7 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -388,7 +388,7 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): == self.mapping.cp_size - 1 else 0), req_beam_width, req, self.kv_connector_manager) else: - if req.is_first_context_chunk: + if req.is_first_context_chunk and not req.is_kv_cache_connector_async_onboard: self.impl.add_sequence(req.py_request_id, req.prompt_len, req_beam_width, req, self.kv_connector_manager) From 66aa5a7fec431dfa5d515fdf7cc7169f207f0d6f Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Tue, 29 Jul 2025 14:19:49 -0700 Subject: [PATCH 13/50] scheduler output for build_connector_meta Signed-off-by: jthomson04 --- .../batch_manager/kvCacheManager.h | 6 +-- .../batch_manager/kvCacheManager.cpp | 17 +++++++-- .../pybind/batch_manager/kvCacheManager.cpp | 7 ++-- tensorrt_llm/_torch/pyexecutor/connector.py | 33 +++++++++++++++- tensorrt_llm/_torch/pyexecutor/py_executor.py | 17 +++++++-- .../_torch/pyexecutor/resource_manager.py | 38 ++++++++++++++++++- 6 files changed, 100 insertions(+), 18 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index c0fedf058fb..2978880054a 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -1239,7 +1239,7 @@ class BaseKVCacheManager = 0; /// @brief Increase size for request at seqSlotIdx. Allocate new KV cache block(s) if needed. - virtual void addToken(LlmRequest::RequestIdType requestId) = 0; + virtual std::optional addToken(LlmRequest::RequestIdType requestId, bool returnNewBlockId = false) = 0; /// @brief Add new request to the KV cache manager. /// @param inputLength Input length for which KV cache need to be allocated. @@ -1540,7 +1540,7 @@ class KVCacheManager : public BaseKVCacheManager LlmRequest const& req, SizeType32 windowSize) const override; /// @brief Increase size for request with requestId. Allocate new KV cache block(s) if needed. - void addToken(LlmRequest::RequestIdType requestId) override; + std::optional addToken(LlmRequest::RequestIdType requestId, bool returnNewBlockId = false) override; /// @brief Add new request to the KV cache manager. /// @param inputLength Input length for which KV cache need to be allocated. @@ -1685,7 +1685,7 @@ class KVCacheManager : public BaseKVCacheManager void cacheBlockOffsets(GenerationRequest& seq, SizeType32 windowSize); void cacheNewBlockOffsets(GenerationRequest& seq, SizeType32 windowSize); void updateNewBlockPointer(GenerationRequest& seq, SizeType32 windowSize, SizeType32 blockIdx); - void updateToken(GenerationRequest& sequence, bool addToken); + std::optional updateToken(GenerationRequest& sequence, bool addToken, bool returnNewBlockId); private: // Maximum number of sequences diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index b10492f2e2d..b87a75a9405 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -1983,8 +1983,12 @@ void KVCacheManager::updateNewBlockPointer(GenerationRequest& sequence, SizeType } } -void KVCacheManager::updateToken(GenerationRequest& sequence, bool addToken) +std::optional KVCacheManager::updateToken(GenerationRequest& sequence, bool addToken, bool returnNewBlockId) { + TLLM_CHECK_WITH_INFO( + !returnNewBlockId || (mBlockManager.getWindowSizesMetadata().size() == 1 && sequence.getBeamWidth() == 1), + "KV Cache Connector is not supported with beam search"); + auto currNumTokens = sequence.getNumTokens(); if (addToken) @@ -2024,6 +2028,9 @@ void KVCacheManager::updateToken(GenerationRequest& sequence, bool addToken) { mBlockManager.allocateBlock(sequence, windowSize); cacheNewBlockOffsets(sequence, windowSize); + + return returnNewBlockId ? std::make_optional(sequence.getCacheBlockIds(windowSize).at(0).back()) + : std::nullopt; } else { @@ -2041,12 +2048,14 @@ void KVCacheManager::updateToken(GenerationRequest& sequence, bool addToken) } } } + + return std::nullopt; } -void KVCacheManager::addToken(RequestIdType requestId) +std::optional KVCacheManager::addToken(RequestIdType requestId, bool returnNewBlockId) { auto& sequence = getSequence(requestId); - updateToken(sequence, true); + return updateToken(sequence, true, returnNewBlockId); } std::optional KVCacheManager::findNewContextBlock( @@ -2489,7 +2498,7 @@ void KVCacheManager::removeToken(RequestIdType requestId) { return; } - updateToken(sequence, false); + updateToken(sequence, false, false); } void KVCacheManager::rewindKVCache(RequestIdType requestId, SizeType32 rewindLengths) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index c9fd2f1fae3..e6ef6d76f1b 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -90,9 +90,10 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager PYBIND11_OVERLOAD_PURE(tbk::KvCacheStats, tbk::BaseKVCacheManager, getKvCacheStats); } - void addToken(tb::LlmRequest::RequestIdType requestId) override + std::optional addToken(tb::LlmRequest::RequestIdType requestId, bool returnNewBlockId = false) override { - PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, addToken, requestId); + PYBIND11_OVERLOAD_PURE( + std::optional, tbk::BaseKVCacheManager, addToken, requestId, returnNewBlockId); } void addSequence(tb::LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, @@ -349,7 +350,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) [](tbk::BaseKVCacheManager& self) { return self.getOffsetTableDimensions().maxBlocksPerSeq; }) .def("get_needed_blocks_one_step", &BaseKVCacheManager::getNeededBlocksOneStep) .def("get_remaining_blocks_to_completion", &BaseKVCacheManager::getRemainingBlocksToCompletion) - .def("add_token", &BaseKVCacheManager::addToken) + .def("add_token", &BaseKVCacheManager::addToken, py::arg("request_id"), py::arg("return_new_block_id") = false) .def("add_sequence", &BaseKVCacheManager::addSequence, py::arg("request_id"), py::arg("input_length"), py::arg("beam_width"), py::arg("llm_request") = std::nullopt, py::arg("kv_cache_connector_manager") = std::nullopt) diff --git a/tensorrt_llm/_torch/pyexecutor/connector.py b/tensorrt_llm/_torch/pyexecutor/connector.py index a2fc280e997..84ed31d22d5 100644 --- a/tensorrt_llm/_torch/pyexecutor/connector.py +++ b/tensorrt_llm/_torch/pyexecutor/connector.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Optional from tensorrt_llm._utils import mpi_allgather, mpi_broadcast, mpi_rank @@ -89,6 +89,25 @@ def __sub__(self, other: 'Finished') -> 'Finished': self.loading - other.loading) +@dataclass +class RequestData: + request_id: int + new_tokens: list[int] + new_block_ids: list[int] + computed_position: int + + +@dataclass +class SchedulerOutput: + requests: list[RequestData] = field(default_factory=list) + + def add_request(self, request_id: int, new_tokens: list[int], + new_block_ids: list[int], computed_position: int): + self.requests.append( + RequestData(request_id, new_tokens, new_block_ids, + computed_position)) + + class KvCacheConnectorManager(KvCacheConnectorManagerCpp): def __init__(self, worker: KvCacheConnectorWorker, @@ -110,6 +129,8 @@ def __init__(self, worker: KvCacheConnectorWorker, # Requests that have been returned from get_finished locally, but haven't yet been returned by all workers. self.local_finished = Finished(dict(), dict()) + self._scheduler_output = None + def get_num_new_matched_tokens(self, request: LlmRequest, num_computed_tokens: int) -> int: if self.scheduler is not None: @@ -134,10 +155,15 @@ def get_num_new_matched_tokens(self, request: LlmRequest, def build_connector_metadata(self) -> object: if self.scheduler is not None: assert mpi_rank() == 0, "The scheduler may only exist on rank 0!" - metadata = self.scheduler.build_connector_metadata() + if self._scheduler_output is None: + raise RuntimeError("Scheduler output not set!") + metadata = self.scheduler.build_connector_metadata( + self._scheduler_output) else: metadata = None + self._scheduler_output = None + metadata = mpi_broadcast(metadata, root=0) self.worker.bind_connector_meta(metadata) @@ -206,3 +232,6 @@ def get_finished(self) -> list[LlmRequest]: req.state = LlmRequestState.CONTEXT_INIT return list(all_finished.saving.values()) + + def set_scheduler_output(self, scheduler_output: SchedulerOutput): + self._scheduler_output = scheduler_output diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index ca2fa77d0b3..f93252eac92 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -34,6 +34,7 @@ from ..distributed import Distributed from ..models.modeling_utils import DecoderModelForCausalLM +from ..modules.decoder_layer import DecoderLayer from ..speculative.drafter import Drafter from .connector import KvCacheConnectorManager from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem @@ -268,8 +269,14 @@ def __init__(self, self.kv_connector_manager = kv_connector_manager + self._maybe_init_kv_connector_manager() + + if start_worker: + self.start_worker() + + def _maybe_init_kv_connector_manager(self): if self.kv_connector_manager is not None: - if kv_cache_transceiver is not None: + if self.kv_cache_transceiver is not None: raise NotImplementedError( "KV Cache Connector is not supported with KvCacheTransceiver." ) @@ -279,17 +286,19 @@ def __init__(self, "KV Cache Connector is not supported with pipeline parallelism." ) - if not disable_overlap_scheduler: + if not self.disable_overlap_scheduler: raise NotImplementedError( "KV Cache Connector is not supported with overlap scheduler." ) kv_cache_data = self.kv_cache_manager.get_kv_cache_connector_pools_data( ) + self.kv_connector_manager.worker.register_kv_caches(kv_cache_data) - if start_worker: - self.start_worker() + for name, module in self.model_engine.model.named_modules(): + if isinstance(module, DecoderLayer): + print("LAYER", name, module) def _event_loop_wrapper(self): try: diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 2dd1d4648d7..dd8d83145d1 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -18,7 +18,7 @@ from ..._utils import binding_dtype_size, binding_to_str_dtype, nvtx_range from ...logger import logger from ...mapping import Mapping -from .connector import KvCacheConnectorManager +from .connector import KvCacheConnectorManager, SchedulerOutput from .llm_request import (LlmRequest, LlmRequestState, SamplingConfig, get_draft_token_length) from .scheduler import ScheduledRequests @@ -374,6 +374,9 @@ def get_needed_resource_to_completion(self, request: LlmRequest) -> int: def prepare_resources(self, scheduled_batch: ScheduledRequests): context_batch = scheduled_batch.context_requests generation_batch = scheduled_batch.generation_requests + + scheduler_output = SchedulerOutput() + # allocate KV Cache for req in context_batch: req_beam_width = req.sampling_config.beam_width @@ -397,11 +400,42 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): for _ in range(get_draft_token_length(req)): self.impl.add_token(req.py_request_id) + if not req.is_kv_cache_connector_async_onboard: + scheduler_output.add_request( + req.request_id, req.get_tokens(0), + self.impl.get_cache_block_ids( + req.request_id, + self.max_attention_window_vec[0]), + req.context_current_position) + else: + if req.is_first_context_chunk or req.is_kv_cache_connector_async_onboard: + req.is_kv_cache_connector_async_onboard = False + scheduler_output.add_request( + req.request_id, req.get_tokens(0), + self.impl.get_cache_block_ids( + req.request_id, + self.max_attention_window_vec[0]), + req.context_current_position) + else: + scheduler_output.add_request( + req.request_id, [], [], + req.context_current_position) + for req in generation_batch: - self.impl.add_token(req.py_request_id) + new_block_id = self.impl.add_token(req.py_request_id, True) + for _ in range(get_draft_token_length(req)): self.impl.add_token(req.py_request_id) + tokens = req.get_tokens(0) + + scheduler_output.add_request( + req.request_id, tokens[-1:], + [new_block_id] if new_block_id is not None else [], len(tokens)) + + if self.kv_connector_manager is not None: + self.kv_connector_manager.set_scheduler_output(scheduler_output) + def add_dummy_requests( self, request_ids: List[int], From 2f80a23d028ca0a4b502b5fc7ea6ed88467c16ad Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Tue, 29 Jul 2025 14:48:51 -0700 Subject: [PATCH 14/50] Worker-side hooks Signed-off-by: jthomson04 --- tensorrt_llm/_torch/pyexecutor/connector.py | 9 +++++++++ tensorrt_llm/_torch/pyexecutor/py_executor.py | 10 ++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/connector.py b/tensorrt_llm/_torch/pyexecutor/connector.py index 84ed31d22d5..297e91fdf52 100644 --- a/tensorrt_llm/_torch/pyexecutor/connector.py +++ b/tensorrt_llm/_torch/pyexecutor/connector.py @@ -235,3 +235,12 @@ def get_finished(self) -> list[LlmRequest]: def set_scheduler_output(self, scheduler_output: SchedulerOutput): self._scheduler_output = scheduler_output + + def layer_pre_hook(self, module, *args): + self.worker.wait_for_layer_load(module.layer_idx) + + def layer_post_hook(self, module, *args): + self.worker.save_kv_layer(module.layer_idx) + + def model_post_hook(self, module, *args): + self.worker.wait_for_save() diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index f93252eac92..422c20c075f 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -296,9 +296,15 @@ def _maybe_init_kv_connector_manager(self): self.kv_connector_manager.worker.register_kv_caches(kv_cache_data) - for name, module in self.model_engine.model.named_modules(): + for _name, module in self.model_engine.model.named_modules(): if isinstance(module, DecoderLayer): - print("LAYER", name, module) + module.register_forward_pre_hook( + self.kv_connector_manager.layer_pre_hook) + module.register_forward_hook( + self.kv_connector_manager.layer_post_hook) + + self.model_engine.model.register_forward_hook( + self.kv_connector_manager.model_post_hook) def _event_loop_wrapper(self): try: From a521b6acdeaab41f452923f32349bf4a726bd5bc Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Tue, 29 Jul 2025 15:25:07 -0700 Subject: [PATCH 15/50] Move a ton of stuff out of c++ into python Signed-off-by: jthomson04 --- .../batch_manager/kvCacheConnector.h | 38 +------- cpp/tensorrt_llm/batch_manager/CMakeLists.txt | 1 - .../batch_manager/kvCacheConnector.cpp | 21 ----- .../pybind/batch_manager/kvCacheConnector.cpp | 86 ------------------ tensorrt_llm/_torch/pyexecutor/connector.py | 88 ++++++++++++------- 5 files changed, 60 insertions(+), 174 deletions(-) delete mode 100644 cpp/tensorrt_llm/batch_manager/kvCacheConnector.cpp diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h index 20468954928..a4c0649d0d8 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h @@ -31,6 +31,7 @@ using namespace tensorrt_llm::batch_manager; namespace tensorrt_llm::batch_manager::kv_connector { +// @brief Data used to provide the KV cache tensors to the connector worker for a single pool. class KvCacheConnectorPoolData { public: @@ -80,42 +81,7 @@ class KvCacheConnectorPoolsData std::vector mLayerToPoolMapping; }; -class KvCacheConnectorScheduler -{ -public: - explicit KvCacheConnectorScheduler() = default; - virtual ~KvCacheConnectorScheduler() = default; - - virtual std::tuple getNumNewMatchedTokens(LlmRequest const& request, SizeType32 numComputedTokens) - = 0; - - // TODO(jothomson): Need arguments here. Also, is this even needed? - virtual void updateStateAfterAlloc(); - - virtual bool requestFinished(LlmRequest const& request); -}; - -class KvCacheConnectorWorker -{ -public: - explicit KvCacheConnectorWorker() = default; - virtual ~KvCacheConnectorWorker() = default; - - virtual void registerKvCaches(KvCacheConnectorPoolsData const& kvCacheConnectorPoolsData); - - // TODO(jothomson): Need arguments here. - virtual void startLoadKv() = 0; - - virtual void waitForLayerLoad(SizeType32 layer_idx) = 0; - - virtual void saveKvLayer(SizeType32 layer_idx) = 0; - - virtual void waitForSave() = 0; - - virtual std::tuple, std::vector> getFinished( - std::vector const& finishedGenReqIds, std::vector const& startedLoadingReqIds); -}; - +// @brief The KV connector manager. This is passed into the C++ KV Cache Manager when adding sequences. class KvCacheConnectorManager { public: diff --git a/cpp/tensorrt_llm/batch_manager/CMakeLists.txt b/cpp/tensorrt_llm/batch_manager/CMakeLists.txt index 75f1e0fa20b..5f7d774c0b0 100644 --- a/cpp/tensorrt_llm/batch_manager/CMakeLists.txt +++ b/cpp/tensorrt_llm/batch_manager/CMakeLists.txt @@ -30,7 +30,6 @@ set(SRCS guidedDecoder.cpp handleContextLogits.cpp handleGenerationLogits.cpp - kvCacheConnector.cpp kvCacheManager.cpp kvCacheEventManager.cpp kvCacheTransferManager.cpp diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheConnector.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheConnector.cpp deleted file mode 100644 index a1a559416e1..00000000000 --- a/cpp/tensorrt_llm/batch_manager/kvCacheConnector.cpp +++ /dev/null @@ -1,21 +0,0 @@ -#include "tensorrt_llm/batch_manager/kvCacheConnector.h" - -namespace tensorrt_llm::batch_manager::kv_connector -{ - -void KvCacheConnectorWorker::registerKvCaches(KvCacheConnectorPoolsData const& kvCacheConnectorPoolsData) {} - -std::tuple, std::vector> KvCacheConnectorWorker::getFinished( - std::vector const& finishedGenReqIds, std::vector const& startedLoadingReqIds) -{ - return std::make_tuple(finishedGenReqIds, startedLoadingReqIds); -} - -void KvCacheConnectorScheduler::updateStateAfterAlloc() {} - -bool KvCacheConnectorScheduler::requestFinished(LlmRequest const& request) -{ - return false; -} - -} // namespace tensorrt_llm::batch_manager::kv_connector diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp index 1037dd23932..5e3648fdff2 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp @@ -22,76 +22,10 @@ namespace { - -using KvCacheConnectorScheduler = tensorrt_llm::batch_manager::kv_connector::KvCacheConnectorScheduler; -using KvCacheConnectorWorker = tensorrt_llm::batch_manager::kv_connector::KvCacheConnectorWorker; using KvCacheConnectorManager = tensorrt_llm::batch_manager::kv_connector::KvCacheConnectorManager; -using NumNewMatchedTokens = std::tuple; - namespace tb = tensorrt_llm::batch_manager; -class PyKvCacheConnectorScheduler : public KvCacheConnectorScheduler, py::trampoline_self_life_support -{ -public: - using KvCacheConnectorScheduler::KvCacheConnectorScheduler; - - NumNewMatchedTokens getNumNewMatchedTokens(LlmRequest const& request, SizeType32 numComputedTokens) override - { - PYBIND11_OVERRIDE_PURE( - NumNewMatchedTokens, KvCacheConnectorScheduler, getNumNewMatchedTokens, request, numComputedTokens); - } - - void updateStateAfterAlloc() override - { - PYBIND11_OVERRIDE(void, KvCacheConnectorScheduler, updateStateAfterAlloc); - } - - bool requestFinished(LlmRequest const& request) override - { - PYBIND11_OVERRIDE(bool, KvCacheConnectorScheduler, requestFinished, request); - } -}; - -class PyKvCacheConnectorWorker : public KvCacheConnectorWorker, py::trampoline_self_life_support -{ -public: - using KvCacheConnectorWorker::KvCacheConnectorWorker; - - void registerKvCaches(kv_connector::KvCacheConnectorPoolsData const& kvCacheConnectorPoolsData) override - { - PYBIND11_OVERRIDE(void, KvCacheConnectorWorker, registerKvCaches, kvCacheConnectorPoolsData); - } - - void startLoadKv() override - { - PYBIND11_OVERRIDE_PURE(void, KvCacheConnectorWorker, startLoadKv); - } - - void waitForLayerLoad(SizeType32 layer_idx) override - { - PYBIND11_OVERRIDE_PURE(void, KvCacheConnectorWorker, waitForLayerLoad, layer_idx); - } - - void saveKvLayer(SizeType32 layer_idx) override - { - PYBIND11_OVERRIDE_PURE(void, KvCacheConnectorWorker, saveKvLayer, layer_idx); - } - - void waitForSave() override - { - PYBIND11_OVERRIDE_PURE(void, KvCacheConnectorWorker, waitForSave); - } - - using FinishedReqs = std::tuple, std::vector>; - - FinishedReqs getFinished(std::vector const& finishedGenReqIds, - std::vector const& startedLoadingReqIds) override - { - PYBIND11_OVERRIDE(FinishedReqs, KvCacheConnectorWorker, getFinished, finishedGenReqIds, startedLoadingReqIds); - } -}; - class PyKvCacheConnectorManager : public KvCacheConnectorManager, py::trampoline_self_life_support { public: @@ -123,26 +57,6 @@ void tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManagerConnectorBindi .def_property_readonly( "layer_to_pool_mapping", &tb::kv_connector::KvCacheConnectorPoolsData::getLayerToPoolMapping); - py::class_( - m, "KvCacheConnectorWorker") - .def(py::init<>()) - .def( - "register_kv_caches", &tb::kv_connector::KvCacheConnectorWorker::registerKvCaches, py::arg("kv_cache_data")) - .def("start_load_kv", &tb::kv_connector::KvCacheConnectorWorker::startLoadKv) - .def("wait_for_layer_load", &tb::kv_connector::KvCacheConnectorWorker::waitForLayerLoad, py::arg("layer_idx")) - .def("save_kv_layer", &tb::kv_connector::KvCacheConnectorWorker::saveKvLayer, py::arg("layer_idx")) - .def("wait_for_save", &tb::kv_connector::KvCacheConnectorWorker::waitForSave) - .def("get_finished", &tb::kv_connector::KvCacheConnectorWorker::getFinished, py::arg("started_loading_req_ids"), - py::arg("finished_gen_req_ids")); - - py::class_( - m, "KvCacheConnectorScheduler") - .def(py::init<>()) - .def("get_num_new_matched_tokens", &tb::kv_connector::KvCacheConnectorScheduler::getNumNewMatchedTokens, - py::arg("request"), py::arg("num_computed_tokens")) - .def("update_state_after_alloc", &tb::kv_connector::KvCacheConnectorScheduler::updateStateAfterAlloc) - .def("request_finished", &tb::kv_connector::KvCacheConnectorScheduler::requestFinished, py::arg("request")); - py::class_( m, "KvCacheConnectorManager") .def(py::init<>()) diff --git a/tensorrt_llm/_torch/pyexecutor/connector.py b/tensorrt_llm/_torch/pyexecutor/connector.py index 297e91fdf52..09e57cd3c55 100644 --- a/tensorrt_llm/_torch/pyexecutor/connector.py +++ b/tensorrt_llm/_torch/pyexecutor/connector.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Optional @@ -5,16 +6,32 @@ from tensorrt_llm.bindings import LlmRequestState from tensorrt_llm.bindings.internal.batch_manager import \ KvCacheConnectorManager as KvCacheConnectorManagerCpp -from tensorrt_llm.bindings.internal.batch_manager import \ - KvCacheConnectorScheduler as KvCacheConnectorSchedulerCpp -from tensorrt_llm.bindings.internal.batch_manager import \ - KvCacheConnectorWorker as KvCacheConnectorWorkerCpp -from tensorrt_llm.bindings.internal.batch_manager import LlmRequest +from tensorrt_llm.bindings.internal.batch_manager import ( + KvCacheConnectorPoolsData, LlmRequest) from .scheduler import ScheduledRequests -class KvCacheConnectorWorker(KvCacheConnectorWorkerCpp): +@dataclass +class RequestData: + request_id: int + new_tokens: list[int] + new_block_ids: list[int] + computed_position: int + + +@dataclass +class SchedulerOutput: + requests: list[RequestData] = field(default_factory=list) + + def add_request(self, request_id: int, new_tokens: list[int], + new_block_ids: list[int], computed_position: int): + self.requests.append( + RequestData(request_id, new_tokens, new_block_ids, + computed_position)) + + +class KvCacheConnectorWorker(ABC): def __init__(self): super().__init__() @@ -28,14 +45,44 @@ def get_connector_meta(self) -> object: def _clear_connector_meta(self): self._metadata = None + @abstractmethod + def register_kv_caches(self, kv_cache_data: KvCacheConnectorPoolsData): + pass + + @abstractmethod + def start_load_kv(self): + pass + + @abstractmethod + def wait_for_layer_load(self, layer_idx: int): + pass + + @abstractmethod + def save_kv_layer(self, layer_idx: int): + pass -class KvCacheConnectorScheduler(KvCacheConnectorSchedulerCpp): + @abstractmethod + def wait_for_save(self): + pass + + +class KvCacheConnectorScheduler(ABC): def __init__(self): super().__init__() - def build_connector_metadata(self, metadata: object): - return None + @abstractmethod + def build_connector_meta(self, scheduler_output: SchedulerOutput): + pass + + def get_num_new_matched_tokens( + self, request: LlmRequest, + num_computed_tokens: int) -> tuple[int, bool]: + pass + + @abstractmethod + def request_finished(self, request: LlmRequest) -> bool: + pass @dataclass @@ -89,25 +136,6 @@ def __sub__(self, other: 'Finished') -> 'Finished': self.loading - other.loading) -@dataclass -class RequestData: - request_id: int - new_tokens: list[int] - new_block_ids: list[int] - computed_position: int - - -@dataclass -class SchedulerOutput: - requests: list[RequestData] = field(default_factory=list) - - def add_request(self, request_id: int, new_tokens: list[int], - new_block_ids: list[int], computed_position: int): - self.requests.append( - RequestData(request_id, new_tokens, new_block_ids, - computed_position)) - - class KvCacheConnectorManager(KvCacheConnectorManagerCpp): def __init__(self, worker: KvCacheConnectorWorker, @@ -152,12 +180,12 @@ def get_num_new_matched_tokens(self, request: LlmRequest, return num_tokens - def build_connector_metadata(self) -> object: + def build_connector_meta(self) -> object: if self.scheduler is not None: assert mpi_rank() == 0, "The scheduler may only exist on rank 0!" if self._scheduler_output is None: raise RuntimeError("Scheduler output not set!") - metadata = self.scheduler.build_connector_metadata( + metadata = self.scheduler.build_connector_meta( self._scheduler_output) else: metadata = None From b80ed13c5fd5db2771abaeb91c04962bbde2b49e Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Tue, 29 Jul 2025 20:19:01 -0700 Subject: [PATCH 16/50] small refactorings and docs Signed-off-by: jthomson04 --- .../batch_manager/kvCacheConnector.h | 12 +- .../batch_manager/kvCacheManager.h | 15 +- .../batch_manager/kvCacheManager.cpp | 28 +- .../pybind/batch_manager/bindings.cpp | 1 - .../pybind/batch_manager/kvCacheManager.cpp | 2 +- tensorrt_llm/_torch/pyexecutor/connector.py | 298 +++++++++++++----- tensorrt_llm/_torch/pyexecutor/py_executor.py | 11 +- .../_torch/pyexecutor/resource_manager.py | 14 + 8 files changed, 264 insertions(+), 117 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h index a4c0649d0d8..8724eb0a109 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h @@ -20,7 +20,7 @@ #include "tensorrt_llm/batch_manager/llmRequest.h" #include "tensorrt_llm/runtime/common.h" -#include +#include #include using SizeType32 = tensorrt_llm::runtime::SizeType32; @@ -31,12 +31,11 @@ using namespace tensorrt_llm::batch_manager; namespace tensorrt_llm::batch_manager::kv_connector { -// @brief Data used to provide the KV cache tensors to the connector worker for a single pool. class KvCacheConnectorPoolData { public: - KvCacheConnectorPoolData(runtime::ITensor::SharedPtr const& poolTensor, SizeType32 numBlocks) - : mPoolTensor(poolTensor) + KvCacheConnectorPoolData(runtime::ITensor::SharedPtr poolTensor, SizeType32 numBlocks) + : mPoolTensor(std::move(poolTensor)) , mNumBlocks(numBlocks) { } @@ -56,6 +55,7 @@ class KvCacheConnectorPoolData SizeType32 mNumBlocks; }; +/// @brief Data used to provide the KV cache tensors to the connector worker for all the pools. class KvCacheConnectorPoolsData { public: @@ -81,13 +81,15 @@ class KvCacheConnectorPoolsData std::vector mLayerToPoolMapping; }; -// @brief The KV connector manager. This is passed into the C++ KV Cache Manager when adding sequences. +/// @brief The KV connector manager. This is passed into the C++ KV Cache Manager when adding sequences. class KvCacheConnectorManager { public: KvCacheConnectorManager() = default; virtual ~KvCacheConnectorManager() = default; + /// @brief Handle the getNumNewMatchedTokens call inside the C++ KV Cache Manager. + /// @return The number of tokens that can be loaded from remote KV cache. virtual SizeType32 getNumNewMatchedTokens(LlmRequest const& request, SizeType32 numComputedTokens) = 0; }; diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index 2978880054a..651f09c65a1 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -550,8 +550,7 @@ class WindowBlockManager //! \brief Assign blocks for new sequence. Try to reuse blocks. void addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, - LlmRequest& llmRequest, - std::optional> kvCacheConnectorManager); + LlmRequest& llmRequest, OptionalRef kvCacheConnectorManager); //! \brief Assign blocks for new sequence. Does not try to reuse blocks. void addSequence(GenerationRequest& sequence, SizeType32 numBlocks, SizeType32 unsharedBlockIdx); @@ -885,8 +884,7 @@ class BlockManager void allocatePools(bool useUvm); void addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, - LlmRequest& llmRequest, - std::optional> kvCacheConnectorManager, + LlmRequest& llmRequest, OptionalRef kvCacheConnectorManager, SizeType32 windowSize); void addSequence( @@ -1239,6 +1237,8 @@ class BaseKVCacheManager = 0; /// @brief Increase size for request at seqSlotIdx. Allocate new KV cache block(s) if needed. + /// @param returnNewBlockId If true, return the id of the newly allocated block (if any). Only supported when VSWA + /// and beam search are disabled. virtual std::optional addToken(LlmRequest::RequestIdType requestId, bool returnNewBlockId = false) = 0; /// @brief Add new request to the KV cache manager. @@ -1249,7 +1249,7 @@ class BaseKVCacheManager /// inputLength - 1 tokens and populate prepopulatedPromptLen. virtual void addSequence(LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, OptionalRef llmRequest = std::nullopt, - std::optional> kvCacheConnectorManager = std::nullopt) + OptionalRef kvCacheConnectorManager = std::nullopt) = 0; virtual void removeSequence( @@ -1540,6 +1540,8 @@ class KVCacheManager : public BaseKVCacheManager LlmRequest const& req, SizeType32 windowSize) const override; /// @brief Increase size for request with requestId. Allocate new KV cache block(s) if needed. + /// @param returnNewBlockId If true, return the id of the newly allocated block (if any). Only supported when VSWA + /// and beam search are disabled. std::optional addToken(LlmRequest::RequestIdType requestId, bool returnNewBlockId = false) override; /// @brief Add new request to the KV cache manager. @@ -1550,8 +1552,7 @@ class KVCacheManager : public BaseKVCacheManager /// inputLength - 1 tokens and populate prepopulatedPromptLen. void addSequence(LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, OptionalRef llmRequest = std::nullopt, - std::optional> kvCacheConnectorManager - = std::nullopt) override; + OptionalRef kvCacheConnectorManager = std::nullopt) override; void removeSequence( LlmRequest::RequestIdType requestId, OptionalRef llmRequest = std::nullopt) override; diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index b87a75a9405..744a9e1106a 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -1198,8 +1198,7 @@ void WindowBlockManager::refreshBlocks() } void BlockManager::addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, - LlmRequest& llmRequest, - std::optional> kvCacheConnectorManager, + LlmRequest& llmRequest, OptionalRef kvCacheConnectorManager, SizeType32 windowSize) { mWindowBlockManagers.at(windowSize) @@ -1207,8 +1206,7 @@ void BlockManager::addSequence(GenerationRequest& sequence, SizeType32 inputLeng } void WindowBlockManager::addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, - LlmRequest& llmRequest, - std::optional> kvCacheConnectorManager) + LlmRequest& llmRequest, OptionalRef kvCacheConnectorManager) { auto const requestId = sequence.getRequestId(); auto const [seqIt, emplaceDone] = mAllocatedBlocksPerSeq.emplace(requestId, std::vector{}); @@ -1240,20 +1238,19 @@ void WindowBlockManager::addSequence(GenerationRequest& sequence, SizeType32 inp mReusedTokens += static_cast(prepopulatedPromptLen); mTotalInputTokens += static_cast(uniqueTokens.size()); - SizeType32 numNewMatchedTokens = 0; + SizeType32 numConnectorMatchedTokens = 0; - if (kvCacheConnectorManager.has_value()) + // If we're using a KV cache connector, check if any additional blocks can be loaded. + if (kvCacheConnectorManager) { - numNewMatchedTokens = kvCacheConnectorManager->get()->getNumNewMatchedTokens(llmRequest, prepopulatedPromptLen); - TLLM_LOG_DEBUG("addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d, numNewMatchedTokens %d", - llmRequest.mRequestId, inputLength, prepopulatedPromptLen, numNewMatchedTokens); - TLLM_CHECK_WITH_INFO(prepopulatedPromptLen + numNewMatchedTokens < llmRequest.getPromptLen(), + numConnectorMatchedTokens = kvCacheConnectorManager->getNumNewMatchedTokens(llmRequest, prepopulatedPromptLen); + TLLM_CHECK_WITH_INFO(prepopulatedPromptLen + numConnectorMatchedTokens < llmRequest.getPromptLen(), "There must be at least one uncomputed token in the prompt!"); } - llmRequest.setPrepopulatedPromptLen(prepopulatedPromptLen + numNewMatchedTokens, getTokensPerBlock()); - TLLM_LOG_DEBUG("addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d, numNewMatchedTokens %d", - llmRequest.mRequestId, inputLength, prepopulatedPromptLen, numNewMatchedTokens); + llmRequest.setPrepopulatedPromptLen(prepopulatedPromptLen + numConnectorMatchedTokens, getTokensPerBlock()); + TLLM_LOG_DEBUG("addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d, numConnectorMatchedTokens %d", + llmRequest.mRequestId, inputLength, prepopulatedPromptLen, numConnectorMatchedTokens); } void BlockManager::addSequence( @@ -2066,13 +2063,12 @@ std::optional KVCacheManager::findNewContextBlock( } void KVCacheManager::addSequence(RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, - OptionalRef llmRequest, - std::optional> kvCacheConnectorManager) + OptionalRef llmRequest, OptionalRef kvCacheConnectorManager) { // Need to add the bubble after the sink tokens to use even block size inputLength += mSinkBubbleLength; - if (kvCacheConnectorManager.has_value()) + if (kvCacheConnectorManager) { TLLM_CHECK_WITH_INFO(beamWidth == 1, "KV Cache Connector is not supported with beam search"); } diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp index 42829f45621..6c30a350594 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp @@ -25,7 +25,6 @@ #include "tensorrt_llm/batch_manager/peftCacheManager.h" #include "tensorrt_llm/batch_manager/rnnStateManager.h" #include "tensorrt_llm/batch_manager/sequenceSlotManager.h" -#include "tensorrt_llm/pybind/batch_manager/kvCacheConnector.h" #include "tensorrt_llm/pybind/common/bindTypes.h" #include "tensorrt_llm/runtime/gptDecoderBatched.h" #include "tensorrt_llm/runtime/runtimeKernels.h" diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index e6ef6d76f1b..94aef0197b4 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -98,7 +98,7 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager void addSequence(tb::LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, tensorrt_llm::common::OptionalRef llmRequest = std::nullopt, - std::optional> kvCacheConnectorManager + tensorrt_llm::common::OptionalRef kvCacheConnectorManager = std::nullopt) override { PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, addSequence, requestId, inputLength, beamWidth, diff --git a/tensorrt_llm/_torch/pyexecutor/connector.py b/tensorrt_llm/_torch/pyexecutor/connector.py index 09e57cd3c55..2607e20eaaf 100644 --- a/tensorrt_llm/_torch/pyexecutor/connector.py +++ b/tensorrt_llm/_torch/pyexecutor/connector.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Optional +from typing import Any, Callable, Optional from tensorrt_llm._utils import mpi_allgather, mpi_broadcast, mpi_rank from tensorrt_llm.bindings import LlmRequestState @@ -10,16 +10,42 @@ KvCacheConnectorPoolsData, LlmRequest) from .scheduler import ScheduledRequests +""" +This file contains the primary interface for the KV Cache Connector. +The KV Cache Connector is a component that allows for remote KV cache access. +It is responsible for: +- Orchestrating the loading and saving of KV cache blocks. +- Managing asynchronous block tx/rx. +It can be used to provide functionalities such as: +1. Disagg +2. KV offload/onboard +3. KV cache sharing +4. P2P KV cache transfer +etc. + +The Connector API is split into two parts: +1. The scheduler, which is responsible for orchestration, and building metadata for the workers. +2. The worker, which performs and monitors transfers indicated by the scheduler's metadata. +""" + + +# Used to store data for a single inflight request. @dataclass class RequestData: + # The request ID. request_id: int + # The new tokens that were generated in the prior forward pass. new_tokens: list[int] + # The new block IDs allocated in the prior forward pass. new_block_ids: list[int] + # The position of the latest token with computed (valid) kv cache values. computed_position: int +# A class to store some basic data regarding all inflight requests. +# This is used when calling `build_connector_meta` on the scheduler. @dataclass class SchedulerOutput: requests: list[RequestData] = field(default_factory=list) @@ -47,23 +73,64 @@ def _clear_connector_meta(self): @abstractmethod def register_kv_caches(self, kv_cache_data: KvCacheConnectorPoolsData): - pass + """ + Register the KV cache tensors to the worker. + This can be used for something like NIXL registration. + + Args: + kv_cache_data: The data for all the KV cache pools. + """ @abstractmethod def start_load_kv(self): - pass + """ + Begin loading the KV cache in preparation for the next forward pass. + Specific blocks to transfer are indicated by the scheduler's metadata. + """ @abstractmethod def wait_for_layer_load(self, layer_idx: int): - pass + """ + Wait for a layer to finish being loaded before proceeding with the forward pass on the layer. + + Args: + layer_idx: The index of the layer to wait for. + """ @abstractmethod def save_kv_layer(self, layer_idx: int): - pass + """ + Begin saving the KV cache for a layer. + This is called after the forward pass on the layer has completed. + + Args: + layer_idx: The index of the layer to save. + """ @abstractmethod def wait_for_save(self): - pass + """ + Block until all synchronous saving operations are complete. Called at the end of the forward pass. + """ + + @abstractmethod + def get_finished( + self, finished_gen_req_ids: list[int], + started_loading_req_ids: list[int]) -> tuple[list[int], list[int]]: + """ + Get the requests that have finished loading and saving. + + Args: + finished_gen_req_ids: The IDs of the requests that have finished generating tokens, and are now asynchronously saving. + started_loading_req_ids: The IDs of the requests that have started asynchronously loading. + + Returns: + The IDs of the requests that have finished saving. + The IDs of the requests that have finished loading. + + Note: IDs may only be returned from this call after they've been provided in the `finished_gen_req_ids` and `started_loading_req_ids` arguments. + Additionally, the runtime will only take action based on these returned IDs once they've been returned by ALL workers. This allows some workers to take longer than others to complete the operations. + """ class KvCacheConnectorScheduler(ABC): @@ -73,68 +140,94 @@ def __init__(self): @abstractmethod def build_connector_meta(self, scheduler_output: SchedulerOutput): - pass + """ + Build the metadata for the worker. + This is called by the KV Cache Manager when adding a sequence. + Args: + scheduler_output: The data for all inflight requests. + + Returns: + The metadata for the workers. + """ def get_num_new_matched_tokens( self, request: LlmRequest, num_computed_tokens: int) -> tuple[int, bool]: - pass + """ + Get the number of tokens that can be loaded from remote KV cache. + This does not include the tokens already matched on device (indicated by `num_computed_tokens`). + + Args: + request: The request to get the number of tokens for. + num_computed_tokens: The number of tokens already matched on device. + + Returns: + The number of tokens that can be loaded from remote KV cache. + Whether the tokens will be loaded asynchronously. + """ @abstractmethod def request_finished(self, request: LlmRequest) -> bool: - pass + """ + Called when a request is finished generating tokens. + + Args: + request: The request that finished generating tokens. + Returns: + Whether the request is performing asynchronous saving operations. + If true, this indicates that the kv cache manager should wait to deallocate the blocks until the saving has completed (determined by `get_finished` on the workers). + """ + +# An internal dataclass to handle async saving/loading requests. @dataclass -class Finished: +class AsyncRequests: saving: dict[int, LlmRequest] loading: dict[int, LlmRequest] - def add_from(self, other: 'Finished'): + def add_from(self, other: 'AsyncRequests'): + """ + Remove requests from the other `AsyncRequests` object, and add them to this one. + """ self.saving.update(other.saving) self.loading.update(other.loading) other.saving = dict() other.loading = dict() - def extract_by_id(self, saving_ids: list[int], loading_ids: list[int]): + def extract_by_id(self, saving_ids: list[int], + loading_ids: list[int]) -> 'AsyncRequests': + """ + Extract the requests with the given IDs from this `AsyncRequests` object. - new_finished = Finished(dict(), dict()) + Args: + saving_ids: The IDs of the requests to extract. + loading_ids: The IDs of the requests to extract. + """ + new_async_requests = AsyncRequests(dict(), dict()) for req_id in saving_ids: - new_finished.saving[req_id] = self.saving[req_id] + new_async_requests.saving[req_id] = self.saving[req_id] del self.saving[req_id] for req_id in loading_ids: - new_finished.loading[req_id] = self.loading[req_id] + new_async_requests.loading[req_id] = self.loading[req_id] del self.loading[req_id] - return new_finished + return new_async_requests def saving_ids(self) -> set[int]: + """ + Get the IDs of the requests that are being saved asynchronously. + """ return set(self.saving.keys()) def loading_ids(self) -> set[int]: + """ + Get the IDs of the requests that are being loaded asynchronously. + """ return set(self.loading.keys()) - @staticmethod - def intersection(*all_finished: 'Finished') -> 'Finished': - if len(all_finished) == 0: - return Finished(dict(), dict()) - - saving_ids = set.intersection( - *[finished.saving_ids() for finished in all_finished]) - loading_ids = set.intersection( - *[finished.loading_ids() for finished in all_finished]) - return Finished( - dict([(req_id, all_finished[0].saving[req_id]) - for req_id in saving_ids]), - dict([(req_id, all_finished[0].loading[req_id]) - for req_id in loading_ids])) - - def __sub__(self, other: 'Finished') -> 'Finished': - return Finished(self.saving - other.saving, - self.loading - other.loading) - class KvCacheConnectorManager(KvCacheConnectorManagerCpp): @@ -149,100 +242,134 @@ def __init__(self, worker: KvCacheConnectorWorker, self.scheduler = scheduler # Requests that haven't yet been passed into get_finished. - self.new_finished = Finished(dict(), dict()) + self.new_async_requests = AsyncRequests(dict(), dict()) # Requests that have been passed into get_finished, but haven't yet been returned. - self.pending_finished = Finished(dict(), dict()) + self.pending_async_requests = AsyncRequests(dict(), dict()) # Requests that have been returned from get_finished locally, but haven't yet been returned by all workers. - self.local_finished = Finished(dict(), dict()) + self.local_finished_async_requests = AsyncRequests(dict(), dict()) self._scheduler_output = None - def get_num_new_matched_tokens(self, request: LlmRequest, - num_computed_tokens: int) -> int: + def _run_on_leader(self, f: Callable[[], Any]) -> Any: + """ + Run a function on the leader rank, and broadcast the result to all other ranks. + """ if self.scheduler is not None: assert mpi_rank() == 0, "The scheduler may only exist on rank 0!" - res = self.scheduler.get_num_new_matched_tokens( - request, num_computed_tokens) + res = f() else: res = None + return mpi_broadcast(res, root=0) - (num_tokens, load_kv_async) = mpi_broadcast(res, root=0) + def get_num_new_matched_tokens(self, request: LlmRequest, + num_computed_tokens: int) -> int: + num_tokens, load_kv_async = self._run_on_leader( + lambda: self.scheduler.get_num_new_matched_tokens( + request, num_computed_tokens)) if num_tokens == 0 and load_kv_async: raise RuntimeError( "load_kv_async must be False when num_tokens is 0!") + # TODO(jthomson04): This part is a bit ugly. + # When the connector indicates that a request will be loaded asynchronously, we need to suspend it's execution. + # This is problematic, since at this point when this function is called, the request has already been scheduled! + # Because of this, we need to remove it from our list of scheduled requests (see `take_scheduled_requests_pending_load`). if load_kv_async: - self.new_finished.loading[request.request_id] = request + self.new_async_requests.loading[request.request_id] = request request.is_kv_cache_connector_async_onboard = True return num_tokens + def take_scheduled_requests_pending_load( + self, scheduled_requests: ScheduledRequests) -> ScheduledRequests: + """ + Remove context requests from our list of scheduled requests that are being loaded asynchronously. + This is done to prevent the runtime from attempting to load the KV cache for these requests. + + Args: + scheduled_requests: The scheduled requests. + + Returns: + The scheduled requests with the context requests that are being loaded asynchronously removed. + """ + allowed_context_requests = [] + + for req in scheduled_requests.context_requests: + # If this request is being loaded asynchronously, in addition to removing it from the list of scheduled requests, + # we also need to update it's state. + if req.request_id in self.new_async_requests.loading.keys(): + req.state = LlmRequestState.DISAGG_GENERATION_TRANS_IN_PROGRESS + else: + allowed_context_requests.append(req) + + # Update the list of scheduled requests. + scheduled_requests.context_requests = allowed_context_requests + + return scheduled_requests + def build_connector_meta(self) -> object: - if self.scheduler is not None: - assert mpi_rank() == 0, "The scheduler may only exist on rank 0!" - if self._scheduler_output is None: - raise RuntimeError("Scheduler output not set!") - metadata = self.scheduler.build_connector_meta( - self._scheduler_output) - else: - metadata = None + metadata = self._run_on_leader( + lambda: self.scheduler.build_connector_meta(self._scheduler_output)) self._scheduler_output = None - metadata = mpi_broadcast(metadata, root=0) - self.worker.bind_connector_meta(metadata) def request_finished(self, req: LlmRequest) -> bool: - if self.scheduler is not None: - assert mpi_rank() == 0, "The scheduler may only exist on rank 0!" - saving_async = self.scheduler.request_finished(req) - else: - saving_async = None + """ + Called when a request is finished generating tokens. - saving_async = mpi_broadcast(saving_async, root=0) + Args: + req: The request that finished generating tokens. + Returns: + Whether the request is performing asynchronous saving operations. If true, we do not immediately call free_resources on the request. + """ + + saving_async = self._run_on_leader( + lambda: self.scheduler.request_finished(req)) + + # This is similar to take_scheduled_requests_pending_load. + # We need to update the request's state to indicate that it's still being used, but isn't schedulable. if saving_async: - self.new_finished.saving[req.request_id] = req + self.new_async_requests.saving[req.request_id] = req req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS return saving_async - def take_scheduled_requests_pending_load( - self, scheduled_requests: ScheduledRequests) -> ScheduledRequests: - allowed_context_requests = [] - - for req in scheduled_requests.context_requests: - if req.request_id in self.new_finished.loading.keys(): - req.state = LlmRequestState.DISAGG_GENERATION_TRANS_IN_PROGRESS - else: - allowed_context_requests.append(req) - - scheduled_requests.context_requests = allowed_context_requests + def get_finished(self) -> list[LlmRequest]: + """ + Process requests that have finished loading and saving. - return scheduled_requests + Returns: + The requests that have newly finished saving. + """ + started_loading_req_ids = list(self.new_async_requests.loading_ids()) + finished_gen_req_ids = list(self.new_async_requests.saving_ids()) - def get_finished(self) -> list[LlmRequest]: - started_loading_req_ids = list(self.new_finished.loading_ids()) - finished_gen_req_ids = list(self.new_finished.saving_ids()) + # Add the requests to our list of outstanding (still in progress) requests. + self.pending_async_requests.add_from(self.new_async_requests) - self.pending_finished.add_from(self.new_finished) + # Pass these newly finished requests into get_finished, and get the list of requests that have finished saving and loading. (finished_saving, finished_loading) = self.worker.get_finished(finished_gen_req_ids, started_loading_req_ids) - new_local_finished = self.pending_finished.extract_by_id( + # Remove the requests from our pending list that have finished locally. + new_local_finished_async_requests = self.pending_async_requests.extract_by_id( finished_saving, finished_loading) - # Get all pending finished requests for this worker. - self.local_finished.add_from(new_local_finished) + # Add these requests to our list of locally finished requests. + self.local_finished_async_requests.add_from( + new_local_finished_async_requests) - # Broadcast this to all other workers. - finished_saving = list(self.local_finished.saving_ids()) - finished_loading = list(self.local_finished.loading_ids()) + # Broadcast this whole list to all other workers. + finished_saving = list(self.local_finished_async_requests.saving_ids()) + finished_loading = list( + self.local_finished_async_requests.loading_ids()) all_results = mpi_allgather((finished_saving, finished_loading)) @@ -252,13 +379,16 @@ def get_finished(self) -> list[LlmRequest]: intersect_finished_loading = set.intersection( *[set(res[1]) for res in all_results]) - all_finished = self.local_finished.extract_by_id( + # Remove these requests from our list of locally finished requests. + all_finished = self.local_finished_async_requests.extract_by_id( intersect_finished_saving, intersect_finished_loading) # For requests that have finished loading, move them back to the context state. for req in all_finished.loading.values(): req.state = LlmRequestState.CONTEXT_INIT + # Return the requests that have finished saving. + # The execution loop will call _terminate_request on these requests. return list(all_finished.saving.values()) def set_scheduler_output(self, scheduler_output: SchedulerOutput): diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 422c20c075f..4f10cbe3eef 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -296,6 +296,8 @@ def _maybe_init_kv_connector_manager(self): self.kv_connector_manager.worker.register_kv_caches(kv_cache_data) + # For each of our layers, we need to register the pre/post hooks. + # These are used for methods like `wait_for_layer_load` and `save_kv_layer`. for _name, module in self.model_engine.model.named_modules(): if isinstance(module, DecoderLayer): module.register_forward_pre_hook( @@ -303,6 +305,9 @@ def _maybe_init_kv_connector_manager(self): module.register_forward_hook( self.kv_connector_manager.layer_post_hook) + # We also need a hook that runs once the model is complete. + # In theory, we could do this at the end of _forward_step, but this is more convenient, + # and may give slightly better perf. self.model_engine.model.register_forward_hook( self.kv_connector_manager.model_post_hook) @@ -927,6 +932,8 @@ def _prepare_and_schedule_batch(self): ) self.kv_cache_transceiver.check_context_transfer_status(1) elif self.kv_connector_manager is None: + # The kv cache connector also puts requests to sleep similar to the transceiver. + # Thus, this assertion is only applicable when both the cache transceiver and connector are disabled. assert scheduled_batch.batch_size > 0, ( "fail to schedule any pending request, " "probably run out of resource.") @@ -974,7 +981,6 @@ def _executor_loop(self): # Return the first token to the client self._handle_first_token_response(scheduled_batch) - self.resource_manager.prepare_resources(scheduled_batch) if self.kv_cache_transceiver and self.guided_decoder: @@ -1012,8 +1018,7 @@ def _executor_loop(self): if self.kv_cache_transceiver and self.ctx_in_transmission_requests: self._terminate_ctx_finished_requests() - - if self.kv_connector_manager: + elif self.kv_connector_manager: reqs_to_terminate = self.kv_connector_manager.get_finished() for req in reqs_to_terminate: self._terminate_request(req) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index dd8d83145d1..c97c59c3a06 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -375,6 +375,7 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): context_batch = scheduled_batch.context_requests generation_batch = scheduled_batch.generation_requests + # Build the scheduler output for the connector. scheduler_output = SchedulerOutput() # allocate KV Cache @@ -391,6 +392,13 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): == self.mapping.cp_size - 1 else 0), req_beam_width, req, self.kv_connector_manager) else: + # TODO(jthomson04): This is begging for a mega refactor, and can likely be significantly simplified. + # In add sequence, the connector API's get_num_new_matched_tokens is called. + # The result of this call may be that blocks will be loaded asynchronously. + # If so, we set the is_kv_cache_connector_async_onboard flag, and set the request state to be DISAGG_GENERATION_TRANS_IN_PROGRESS. + # When the async load is complete, we set the request state back to CONTEXT_INIT. + # When that happens, the request will go through this same code path, but with is_kv_cache_connector_async_onboard set to True. + # Because of this, we need to filter this case out to avoid adding the same sequence twice. if req.is_first_context_chunk and not req.is_kv_cache_connector_async_onboard: self.impl.add_sequence(req.py_request_id, req.prompt_len, req_beam_width, req, @@ -400,6 +408,7 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): for _ in range(get_draft_token_length(req)): self.impl.add_token(req.py_request_id) + # If this is not an async load, we can add the new tokens and blocks right away. if not req.is_kv_cache_connector_async_onboard: scheduler_output.add_request( req.request_id, req.get_tokens(0), @@ -408,6 +417,10 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): self.max_attention_window_vec[0]), req.context_current_position) else: + # When using the connector, this code path will be hit after the async load is complete. + # Alternatively, with no connector, this is hit after the first chunk of prefill. + + # If this is the first actual prefill, we can add all of our new tokens and blocks. if req.is_first_context_chunk or req.is_kv_cache_connector_async_onboard: req.is_kv_cache_connector_async_onboard = False scheduler_output.add_request( @@ -417,6 +430,7 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): self.max_attention_window_vec[0]), req.context_current_position) else: + # Otherwise, we just provide the new context position. No new blocks are allocated. scheduler_output.add_request( req.request_id, [], [], req.context_current_position) From 50bcec39cebc2e985ea2c3cd59463ab9349ab298 Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Tue, 29 Jul 2025 22:00:28 -0700 Subject: [PATCH 17/50] A whole bunch of unit tests Signed-off-by: jthomson04 --- .../batch_manager/kvCacheConnector.h | 2 - .../pybind/batch_manager/kvCacheConnector.cpp | 2 +- .../pybind/batch_manager/kvCacheManager.cpp | 6 +- tensorrt_llm/_torch/pyexecutor/connector.py | 4 +- tensorrt_llm/_torch/pyexecutor/py_executor.py | 7 + tests/unittest/_torch/test_connector.py | 153 ++++++++++++++++++ .../bindings/test_connector_bindings.py | 51 ------ 7 files changed, 165 insertions(+), 60 deletions(-) create mode 100644 tests/unittest/_torch/test_connector.py delete mode 100644 tests/unittest/bindings/test_connector_bindings.py diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h index 8724eb0a109..b69bfa8615c 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h @@ -26,8 +26,6 @@ using SizeType32 = tensorrt_llm::runtime::SizeType32; using RequestIdType = tensorrt_llm::batch_manager::LlmRequest::RequestIdType; -using namespace tensorrt_llm::batch_manager; - namespace tensorrt_llm::batch_manager::kv_connector { diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp index 5e3648fdff2..c92056db370 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp @@ -31,7 +31,7 @@ class PyKvCacheConnectorManager : public KvCacheConnectorManager, py::trampoline public: using KvCacheConnectorManager::KvCacheConnectorManager; - SizeType32 getNumNewMatchedTokens(LlmRequest const& request, SizeType32 numComputedTokens) override + SizeType32 getNumNewMatchedTokens(tb::LlmRequest const& request, SizeType32 numComputedTokens) override { PYBIND11_OVERRIDE_PURE_NAME(SizeType32, KvCacheConnectorManager, "get_num_new_matched_tokens", getNumNewMatchedTokens, request, numComputedTokens); diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index 94aef0197b4..dfaed0f6c5d 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -98,7 +98,7 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager void addSequence(tb::LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, tensorrt_llm::common::OptionalRef llmRequest = std::nullopt, - tensorrt_llm::common::OptionalRef kvCacheConnectorManager + tensorrt_llm::common::OptionalRef kvCacheConnectorManager = std::nullopt) override { PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, addSequence, requestId, inputLength, beamWidth, @@ -238,10 +238,10 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, flushIterationEvents); } - kv_connector::KvCacheConnectorPoolsData getKvCacheConnectorPoolsData() const override + [[nodiscard]] tb::kv_connector::KvCacheConnectorPoolsData getKvCacheConnectorPoolsData() const override { PYBIND11_OVERLOAD_PURE( - kv_connector::KvCacheConnectorPoolsData, tbk::BaseKVCacheManager, getKvCacheConnectorPoolsData); + tb::kv_connector::KvCacheConnectorPoolsData, tbk::BaseKVCacheManager, getKvCacheConnectorPoolsData); } }; diff --git a/tensorrt_llm/_torch/pyexecutor/connector.py b/tensorrt_llm/_torch/pyexecutor/connector.py index 2607e20eaaf..8ae8eb1f06b 100644 --- a/tensorrt_llm/_torch/pyexecutor/connector.py +++ b/tensorrt_llm/_torch/pyexecutor/connector.py @@ -275,7 +275,7 @@ def get_num_new_matched_tokens(self, request: LlmRequest, # TODO(jthomson04): This part is a bit ugly. # When the connector indicates that a request will be loaded asynchronously, we need to suspend it's execution. - # This is problematic, since at this point when this function is called, the request has already been scheduled! + # This is problematic, since at the point when this function is called, the request has already been scheduled! # Because of this, we need to remove it from our list of scheduled requests (see `take_scheduled_requests_pending_load`). if load_kv_async: self.new_async_requests.loading[request.request_id] = request @@ -308,8 +308,6 @@ def take_scheduled_requests_pending_load( # Update the list of scheduled requests. scheduled_requests.context_requests = allowed_context_requests - return scheduled_requests - def build_connector_meta(self) -> object: metadata = self._run_on_leader( lambda: self.scheduler.build_connector_meta(self._scheduler_output)) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 4f10cbe3eef..0abeb9d9718 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -986,6 +986,13 @@ def _executor_loop(self): if self.kv_cache_transceiver and self.guided_decoder: self.guided_decoder.init_disagg_gen_requests( scheduled_batch) + elif self.kv_connector_manager: + self.kv_connector_manager.take_scheduled_requests_pending_load(scheduled_batch) + + + if scheduled_batch.batch_size > 0 or ( + self.enable_attention_dp and self.dist.tp_size > 1): + if self.drafter is not None and self.use_spec_decode: if self.guided_decoder is not None: self.guided_decoder.rollback_rejected_tokens( diff --git a/tests/unittest/_torch/test_connector.py b/tests/unittest/_torch/test_connector.py new file mode 100644 index 00000000000..039ce7e00a5 --- /dev/null +++ b/tests/unittest/_torch/test_connector.py @@ -0,0 +1,153 @@ +import pickle +import sys +from unittest.mock import MagicMock + +import cloudpickle +import mpi4py +import pytest + +from tensorrt_llm import mpi_rank +from tensorrt_llm._torch.pyexecutor.connector import KvCacheConnectorManager +from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests + +cloudpickle.register_pickle_by_value(sys.modules[__name__]) +mpi4py.MPI.pickle.__init__( + cloudpickle.dumps, + cloudpickle.loads, + pickle.HIGHEST_PROTOCOL, +) + + +def run_across_mpi(executor, fun, num_ranks): + return list(executor.starmap(fun, [() for i in range(num_ranks)])) + + +@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True) +def test_connector_manager_get_finished_allgather(mpi_pool_executor): + + def test(): + worker = MagicMock() + + if mpi_rank() == 0: + scheduler = MagicMock() + + scheduler.request_finished.return_value = True + else: + scheduler = None + + manager = KvCacheConnectorManager(worker, scheduler=scheduler) + + req = MagicMock() + + req.request_id = 42 + + manager.request_finished(req) + + # To start, make both workers return nothing. + worker.get_finished.return_value = ([], []) + + assert manager.get_finished() == [] + + assert worker.get_finished.call_count == 1 + assert worker.get_finished.call_args[0] == ([42], []) + + worker.get_finished.reset_mock() + + # Now, only return the request id on one worker. + if mpi_rank() == 0: + worker.get_finished.return_value = ([42], []) + else: + worker.get_finished.return_value = ([], []) + + # It should still return nothing, since rank 1 is still saving. + assert manager.get_finished() == [] + + assert worker.get_finished.call_count == 1 + assert worker.get_finished.call_args[0] == ([], []) + + # Now, also return it on worker 1. + if mpi_rank() == 0: + worker.get_finished.return_value = ([], []) + else: + worker.get_finished.return_value = ([42], []) + + assert manager.get_finished() == [req] + + run_across_mpi(mpi_pool_executor, test, 2) + + +@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True) +def test_connector_manager_num_matched_tokens(mpi_pool_executor): + + def test(): + worker = MagicMock() + + if mpi_rank() == 0: + scheduler = MagicMock() + scheduler.get_num_new_matched_tokens.return_value = (16, True) + else: + scheduler = None + + manager = KvCacheConnectorManager(worker, scheduler=scheduler) + + req = MagicMock() + + req.request_id = 42 + + assert manager.get_num_new_matched_tokens(req, 32) == 16 + assert req.is_kv_cache_connector_async_onboard + + if mpi_rank() == 0: + assert scheduler.get_num_new_matched_tokens.call_count == 1 + assert scheduler.get_num_new_matched_tokens.call_args[0] == (req, + 32) + + run_across_mpi(mpi_pool_executor, test, 2) + + +@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True) +def test_connector_manager_take_scheduled_requests(mpi_pool_executor): + + def test(): + worker = MagicMock() + + if mpi_rank() == 0: + scheduler = MagicMock() + else: + scheduler = None + + manager = KvCacheConnectorManager(worker, scheduler=scheduler) + + scheduled_requests = ScheduledRequests() + + req0 = MagicMock() + req0.request_id = 0 + + req1 = MagicMock() + req1.request_id = 1 + + if mpi_rank() == 0: + scheduler.get_num_new_matched_tokens.return_value = (16, True) + + assert manager.get_num_new_matched_tokens(req0, 0) == 16 + if mpi_rank() == 0: + assert scheduler.get_num_new_matched_tokens.call_count == 1 + assert scheduler.get_num_new_matched_tokens.call_args[0] == (req0, + 0) + + scheduler.get_num_new_matched_tokens.reset_mock() + scheduler.get_num_new_matched_tokens.return_value = (32, False) + + assert manager.get_num_new_matched_tokens(req1, 0) == 32 + if mpi_rank() == 0: + assert scheduler.get_num_new_matched_tokens.call_count == 1 + assert scheduler.get_num_new_matched_tokens.call_args[0] == (req1, + 0) + + scheduled_requests.context_requests = [req0, req1] + + manager.take_scheduled_requests_pending_load(scheduled_requests) + + assert scheduled_requests.context_requests == [req1] + + run_across_mpi(mpi_pool_executor, test, 2) diff --git a/tests/unittest/bindings/test_connector_bindings.py b/tests/unittest/bindings/test_connector_bindings.py deleted file mode 100644 index 531f74c10f4..00000000000 --- a/tests/unittest/bindings/test_connector_bindings.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import List - -from tensorrt_llm.bindings.internal.batch_manager import ( - KvCacheConnectorPoolsData, KvCacheConnectorScheduler, - KvCacheConnectorWorker, LlmRequest) - - -class BasicConnectorWorker(KvCacheConnectorWorker): - - def register_kv_caches(self, kv_cache_data: KvCacheConnectorPoolsData): - pass - - def start_load_kv(self): - pass - - def wait_for_save(self): - pass - - def wait_for_layer_load(self, layer_idx: int): - pass - - def save_kv_layer(self, layer_idx: int): - pass - - def get_finished( - self, finished_req_ids: List[int]) -> tuple[List[int], List[int]]: - return [42], [7] - - -class BasicConnectorScheduler(KvCacheConnectorScheduler): - - def get_num_new_matched_tokens( - self, request: LlmRequest, - num_computed_tokens: int) -> tuple[int, bool]: - return 16, True - - def update_state_after_alloc(self): - pass - - -def test_basic_init(): - connector_scheduler = BasicConnectorScheduler() - - connector_scheduler.update_state_after_alloc() - - connector_worker = BasicConnectorWorker() - - assert connector_worker.get_finished([]) == ([42], [7]) - - connector_worker.save_kv_layer(0) - connector_worker.wait_for_save() From e305010ed9933518c807cbf4b7432894121d3662 Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Wed, 30 Jul 2025 09:56:48 -0700 Subject: [PATCH 18/50] precommit Signed-off-by: jthomson04 --- tensorrt_llm/_torch/pyexecutor/_util.py | 30 ++++++------- tensorrt_llm/_torch/pyexecutor/py_executor.py | 44 ++++++++++--------- 2 files changed, 38 insertions(+), 36 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index b06897f42be..e690814211a 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -422,21 +422,21 @@ def teardown_managers(self, resources: Dict) -> None: def create_py_executor_instance( - *, - dist, - resources, - mapping, - pytorch_backend_config, - executor_config, - ctx_chunk_config, - model_engine, - start_worker, - sampler, - drafter, - guided_decoder: Optional[GuidedDecoder] = None, - lora_config: Optional[LoraConfig] = None, - garbage_collection_gen0_threshold: Optional[int] = None, - kv_connector_manager: Optional[KvCacheConnectorManager] = None + *, + dist, + resources, + mapping, + pytorch_backend_config, + executor_config, + ctx_chunk_config, + model_engine, + start_worker, + sampler, + drafter, + guided_decoder: Optional[GuidedDecoder] = None, + lora_config: Optional[LoraConfig] = None, + garbage_collection_gen0_threshold: Optional[int] = None, + kv_connector_manager: Optional[KvCacheConnectorManager] = None ) -> PyExecutor: kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 0abeb9d9718..e499f534be1 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -136,24 +136,25 @@ class BatchStatePP(BatchState): class PyExecutor: - def __init__(self, - resource_manager, - scheduler: RequestScheduler, - model_engine: ModelEngine, - sampler: Sampler, - dist: Distributed, - max_num_sequences: int, - drafter: Optional[Drafter] = None, - disable_overlap_scheduler: bool = False, - max_input_len: int = 2048, - max_batch_size: int = 8, - max_beam_width: int = 1, - max_draft_len: int = 0, - kv_cache_transceiver: Optional[KvCacheTransceiver] = None, - guided_decoder: Optional[GuidedDecoder] = None, - garbage_collection_gen0_threshold: Optional[int] = None, - start_worker: bool = True, - kv_connector_manager: Optional[KvCacheConnectorManager] = None): + def __init__( + self, + resource_manager, + scheduler: RequestScheduler, + model_engine: ModelEngine, + sampler: Sampler, + dist: Distributed, + max_num_sequences: int, + drafter: Optional[Drafter] = None, + disable_overlap_scheduler: bool = False, + max_input_len: int = 2048, + max_batch_size: int = 8, + max_beam_width: int = 1, + max_draft_len: int = 0, + kv_cache_transceiver: Optional[KvCacheTransceiver] = None, + guided_decoder: Optional[GuidedDecoder] = None, + garbage_collection_gen0_threshold: Optional[int] = None, + start_worker: bool = True, + kv_connector_manager: Optional[KvCacheConnectorManager] = None): super(PyExecutor, self).__init__() self.device_id = torch.cuda.current_device() self.global_rank = global_mpi_rank() @@ -986,9 +987,10 @@ def _executor_loop(self): if self.kv_cache_transceiver and self.guided_decoder: self.guided_decoder.init_disagg_gen_requests( scheduled_batch) - elif self.kv_connector_manager: - self.kv_connector_manager.take_scheduled_requests_pending_load(scheduled_batch) - + + if self.kv_connector_manager: + self.kv_connector_manager.take_scheduled_requests_pending_load( + scheduled_batch) if scheduled_batch.batch_size > 0 or ( self.enable_attention_dp and self.dist.tp_size > 1): From fe45192a9aadec1a3b530cbab676f04f468d6c19 Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Wed, 30 Jul 2025 10:35:51 -0700 Subject: [PATCH 19/50] Fix wait_for_save Signed-off-by: jthomson04 --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index e499f534be1..20cd10f7ccb 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -306,12 +306,6 @@ def _maybe_init_kv_connector_manager(self): module.register_forward_hook( self.kv_connector_manager.layer_post_hook) - # We also need a hook that runs once the model is complete. - # In theory, we could do this at the end of _forward_step, but this is more convenient, - # and may give slightly better perf. - self.model_engine.model.register_forward_hook( - self.kv_connector_manager.model_post_hook) - def _event_loop_wrapper(self): try: with customized_gc_thresholds( @@ -1474,6 +1468,10 @@ def forward(scheduled_requests, resource_manager, new_tensors_device, outputs = forward(scheduled_requests, self.resource_manager, new_tensors_device, gather_context_logits, cache_indirection_buffer) + + if self.kv_connector_manager is not None: + self.kv_connector_manager.worker.wait_for_save() + return outputs except Exception as e: traceback.print_exc() From 7ca84a20df086d0e1a7a7795eeff9f84047785c8 Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Wed, 30 Jul 2025 17:03:26 -0700 Subject: [PATCH 20/50] start on integration tests Signed-off-by: jthomson04 --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 3 + .../_torch/pyexecutor/resource_manager.py | 10 +- tensorrt_llm/_torch/pyexecutor/scheduler.py | 1 + .../defs/llmapi/test_llm_api_connector.py | 158 ++++++++++++++++++ 4 files changed, 168 insertions(+), 4 deletions(-) create mode 100644 tests/integration/defs/llmapi/test_llm_api_connector.py diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 20cd10f7ccb..cd5c67d9d55 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -976,6 +976,7 @@ def _executor_loop(self): # Return the first token to the client self._handle_first_token_response(scheduled_batch) + scheduled_batch.is_warmup = self.is_warmup self.resource_manager.prepare_resources(scheduled_batch) if self.kv_cache_transceiver and self.guided_decoder: @@ -985,6 +986,8 @@ def _executor_loop(self): if self.kv_connector_manager: self.kv_connector_manager.take_scheduled_requests_pending_load( scheduled_batch) + self.kv_connector_manager.build_connector_meta() + self.kv_connector_manager.worker.start_load_kv() if scheduled_batch.batch_size > 0 or ( self.enable_attention_dp and self.dist.tp_size > 1): diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index c97c59c3a06..7b44072e15e 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -390,7 +390,8 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): req.py_request_id, seq_len + (len(req.query_id) if self.mapping.cp_rank == self.mapping.cp_size - 1 else 0), - req_beam_width, req, self.kv_connector_manager) + req_beam_width, req, self.kv_connector_manager + if not scheduled_batch.is_warmup else None) else: # TODO(jthomson04): This is begging for a mega refactor, and can likely be significantly simplified. # In add sequence, the connector API's get_num_new_matched_tokens is called. @@ -400,9 +401,10 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): # When that happens, the request will go through this same code path, but with is_kv_cache_connector_async_onboard set to True. # Because of this, we need to filter this case out to avoid adding the same sequence twice. if req.is_first_context_chunk and not req.is_kv_cache_connector_async_onboard: - self.impl.add_sequence(req.py_request_id, req.prompt_len, - req_beam_width, req, - self.kv_connector_manager) + self.impl.add_sequence( + req.py_request_id, req.prompt_len, req_beam_width, req, + self.kv_connector_manager + if not scheduled_batch.is_warmup else None) for _ in range(self.num_extra_kv_tokens): self.impl.add_token(req.py_request_id) for _ in range(get_draft_token_length(req)): diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index ef86fc1f49e..85c847ad650 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -21,6 +21,7 @@ def __init__(self): self.context_requests: RequestList = [] self.generation_requests: RequestList = [] self.paused_requests: RequestList = [] + self.is_warmup: bool = False @property def is_generation_only(self) -> bool: diff --git a/tests/integration/defs/llmapi/test_llm_api_connector.py b/tests/integration/defs/llmapi/test_llm_api_connector.py new file mode 100644 index 00000000000..3527c10e8c8 --- /dev/null +++ b/tests/integration/defs/llmapi/test_llm_api_connector.py @@ -0,0 +1,158 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +from collections import defaultdict +from unittest.mock import DEFAULT, MagicMock, patch + +import pytest + +from tensorrt_llm import LLM, SamplingParams +from tensorrt_llm.models.modeling_utils import KvCacheConnectorConfig + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +KvConnectorScheduler = MagicMock() +KvConnectorWorker = MagicMock() + + +def init_connector_classes(): + KvConnectorScheduler.reset_mock() + KvConnectorWorker.reset_mock() + + scheduler = MagicMock() + worker = MagicMock() + + KvConnectorScheduler.return_value = scheduler + KvConnectorWorker.return_value = worker + + return scheduler, worker + + +# Makes sure everything is called in the right order. +class CallTimeMonitor: + + def __init__(self): + self.call_times = defaultdict(list) + self.counter = 0 + + def monitor_fn(self, mock_fn, name): + + def wrapper(*args, **kwargs): + self.call_times[name].append(self.counter) + self.counter += 1 + + return DEFAULT + + mock_fn.side_effect = wrapper + + def __getitem__(self, name): + return self.call_times[name] + + +@pytest.fixture +def connector(): + with patch("tensorrt_llm._torch.pyexecutor.py_executor_creator.importlib" + ) as importlib_mock: + mock_scheduler = MagicMock() + mock_worker = MagicMock() + + importlib_mock.import_module.return_value.KvConnectorScheduler.return_value = mock_scheduler + importlib_mock.import_module.return_value.KvConnectorWorker.return_value = mock_worker + + connector_config = KvCacheConnectorConfig( + connector_module="", + connector_scheduler_class="KvConnectorScheduler", + connector_worker_class="KvConnectorWorker", + ) + + call_time_monitor = CallTimeMonitor() + + call_time_monitor.monitor_fn(mock_scheduler.build_connector_meta, + "build_connector_meta") + call_time_monitor.monitor_fn(mock_scheduler.get_num_new_matched_tokens, + "get_num_new_matched_tokens") + call_time_monitor.monitor_fn(mock_scheduler.request_finished, + "request_finished") + + call_time_monitor.monitor_fn(mock_worker.start_load_kv, "start_load_kv") + call_time_monitor.monitor_fn(mock_worker.wait_for_layer_load, + "wait_for_layer_load") + call_time_monitor.monitor_fn(mock_worker.save_kv_layer, "save_kv_layer") + call_time_monitor.monitor_fn(mock_worker.wait_for_save, "wait_for_save") + call_time_monitor.monitor_fn(mock_worker.get_finished, "get_finished") + + yield connector_config, mock_scheduler, mock_worker, call_time_monitor + + +@pytest.mark.threadleak(enabled=False) +def test_llm_api_connector_simple(connector): + connector_config, scheduler, worker, call_time_monitor = connector + + os.environ["TLLM_WORKER_USE_SINGLE_PROCESS"] = "1" + + NUM_TOKENS = 8 + + model = LLM(model="Qwen/Qwen2-0.5B", + backend="pytorch", + disable_overlap_scheduler=True, + connector_config=connector_config, + cuda_graph_config=None) + + assert worker.register_kv_caches.call_count == 1 + + scheduler.get_num_new_matched_tokens.return_value = 0, False + + worker.get_finished.return_value = [], [] + + sampling_params = SamplingParams(max_tokens=NUM_TOKENS, ignore_eos=True) + + model.generate(["Hello, world"], sampling_params) + + assert scheduler.build_connector_meta.call_count == NUM_TOKENS + + for i, call in enumerate(scheduler.build_connector_meta.call_args_list): + scheduler_output = call[0][0] + assert len(scheduler_output.requests) == 1 + if i != 0: + assert len(scheduler_output.requests[0].new_tokens) == 1 + + assert worker.start_load_kv.call_count == NUM_TOKENS + + # We should have always built our metadata before loading kv. + for load_kv_call_time, build_metadata_call_time in zip( + call_time_monitor["start_load_kv"], + call_time_monitor["build_connector_meta"]): + assert build_metadata_call_time < load_kv_call_time + + assert scheduler.get_num_new_matched_tokens.call_count == 1 + + num_layers = max(call.args[0] + for call in worker.wait_for_layer_load.call_args_list) + 1 + + assert worker.wait_for_layer_load.call_count == num_layers * NUM_TOKENS + assert worker.save_kv_layer.call_count == num_layers * NUM_TOKENS + + for i, call in enumerate(worker.wait_for_layer_load.call_args_list): + assert call.args[0] == i % num_layers + + for i, call in enumerate(worker.save_kv_layer.call_args_list): + assert call.args[0] == i % num_layers + + assert worker.wait_for_save.call_count == NUM_TOKENS + + assert scheduler.request_finished.call_count == 1 + assert worker.get_finished.call_count == NUM_TOKENS From b85f74900bf0ab99885c142124ee9707a8abdb18 Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Wed, 30 Jul 2025 22:32:36 -0700 Subject: [PATCH 21/50] Integration tests for async save and load Signed-off-by: jthomson04 --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 2 +- .../defs/llmapi/test_llm_api_connector.py | 129 +++++++++++------- 2 files changed, 83 insertions(+), 48 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index cd5c67d9d55..845d7e7d5ca 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1027,7 +1027,7 @@ def _executor_loop(self): elif self.kv_connector_manager: reqs_to_terminate = self.kv_connector_manager.get_finished() for req in reqs_to_terminate: - self._terminate_request(req) + self.resource_manager.free_resources(req) if self.enable_iter_perf_stats: iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[ diff --git a/tests/integration/defs/llmapi/test_llm_api_connector.py b/tests/integration/defs/llmapi/test_llm_api_connector.py index 3527c10e8c8..640820c8145 100644 --- a/tests/integration/defs/llmapi/test_llm_api_connector.py +++ b/tests/integration/defs/llmapi/test_llm_api_connector.py @@ -15,8 +15,7 @@ import os import sys -from collections import defaultdict -from unittest.mock import DEFAULT, MagicMock, patch +from unittest.mock import MagicMock, patch import pytest @@ -42,27 +41,6 @@ def init_connector_classes(): return scheduler, worker -# Makes sure everything is called in the right order. -class CallTimeMonitor: - - def __init__(self): - self.call_times = defaultdict(list) - self.counter = 0 - - def monitor_fn(self, mock_fn, name): - - def wrapper(*args, **kwargs): - self.call_times[name].append(self.counter) - self.counter += 1 - - return DEFAULT - - mock_fn.side_effect = wrapper - - def __getitem__(self, name): - return self.call_times[name] - - @pytest.fixture def connector(): with patch("tensorrt_llm._torch.pyexecutor.py_executor_creator.importlib" @@ -79,30 +57,17 @@ def connector(): connector_worker_class="KvConnectorWorker", ) - call_time_monitor = CallTimeMonitor() - - call_time_monitor.monitor_fn(mock_scheduler.build_connector_meta, - "build_connector_meta") - call_time_monitor.monitor_fn(mock_scheduler.get_num_new_matched_tokens, - "get_num_new_matched_tokens") - call_time_monitor.monitor_fn(mock_scheduler.request_finished, - "request_finished") + yield connector_config, mock_scheduler, mock_worker - call_time_monitor.monitor_fn(mock_worker.start_load_kv, "start_load_kv") - call_time_monitor.monitor_fn(mock_worker.wait_for_layer_load, - "wait_for_layer_load") - call_time_monitor.monitor_fn(mock_worker.save_kv_layer, "save_kv_layer") - call_time_monitor.monitor_fn(mock_worker.wait_for_save, "wait_for_save") - call_time_monitor.monitor_fn(mock_worker.get_finished, "get_finished") - yield connector_config, mock_scheduler, mock_worker, call_time_monitor +# Needed because MagicMocks don't work across processes. +# TODO(jthomson04): This limits us to testing only TP1 for now. +os.environ["TLLM_WORKER_USE_SINGLE_PROCESS"] = "1" @pytest.mark.threadleak(enabled=False) def test_llm_api_connector_simple(connector): - connector_config, scheduler, worker, call_time_monitor = connector - - os.environ["TLLM_WORKER_USE_SINGLE_PROCESS"] = "1" + connector_config, scheduler, worker = connector NUM_TOKENS = 8 @@ -124,25 +89,25 @@ def test_llm_api_connector_simple(connector): assert scheduler.build_connector_meta.call_count == NUM_TOKENS + # We should have a single `SchedulerOutput` per forward pass. for i, call in enumerate(scheduler.build_connector_meta.call_args_list): scheduler_output = call[0][0] assert len(scheduler_output.requests) == 1 + + # If this is not prefill, we should always be adding a single token. if i != 0: assert len(scheduler_output.requests[0].new_tokens) == 1 + # We call `start_load_kv` once at the beginning of each forward pass. assert worker.start_load_kv.call_count == NUM_TOKENS - # We should have always built our metadata before loading kv. - for load_kv_call_time, build_metadata_call_time in zip( - call_time_monitor["start_load_kv"], - call_time_monitor["build_connector_meta"]): - assert build_metadata_call_time < load_kv_call_time - + # Only called once when the request is received. assert scheduler.get_num_new_matched_tokens.call_count == 1 num_layers = max(call.args[0] for call in worker.wait_for_layer_load.call_args_list) + 1 + # Called num_layers * num_forward_passes times. assert worker.wait_for_layer_load.call_count == num_layers * NUM_TOKENS assert worker.save_kv_layer.call_count == num_layers * NUM_TOKENS @@ -156,3 +121,73 @@ def test_llm_api_connector_simple(connector): assert scheduler.request_finished.call_count == 1 assert worker.get_finished.call_count == NUM_TOKENS + + +@pytest.mark.threadleak(enabled=False) +def test_llm_api_connector_async_onboard(connector): + connector_config, scheduler, worker = connector + + NUM_TOKENS = 8 + + model = LLM(model="Qwen/Qwen2-0.5B", + backend="pytorch", + disable_overlap_scheduler=True, + connector_config=connector_config, + cuda_graph_config=None) + + assert worker.register_kv_caches.call_count == 1 + + scheduler.get_num_new_matched_tokens.return_value = 16, True + + worker.get_finished.side_effect = lambda finished_gen, load_async: ( + finished_gen, load_async) + + model.generate([ + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." + ], SamplingParams(max_tokens=NUM_TOKENS, ignore_eos=True)) + + # Once for the initial poll, then once for each token. + assert worker.get_finished.call_count == NUM_TOKENS + 1 + + # In the first iteration, there should be a single request id provided. + assert len(worker.get_finished.call_args_list[0].args[1]) == 1 + + +@pytest.mark.threadleak(enabled=False) +def test_llm_api_connector_async_save(connector): + connector_config, scheduler, worker = connector + + NUM_TOKENS = 8 + + model = LLM(model="Qwen/Qwen2-0.5B", + backend="pytorch", + disable_overlap_scheduler=True, + connector_config=connector_config, + cuda_graph_config=None) + + assert worker.register_kv_caches.call_count == 1 + + scheduler.get_num_new_matched_tokens.return_value = 0, False + + scheduler.request_finished.return_value = True + + worker.get_finished.side_effect = lambda finished_gen, load_async: ( + finished_gen, load_async) + + sampling_params = SamplingParams(max_tokens=NUM_TOKENS, ignore_eos=True) + + model.generate(["Hello, world"], sampling_params) + + assert scheduler.request_finished.call_count == 1 + + # On the last call to get_finished, we should be providing the async saving request. + assert worker.get_finished.call_count == NUM_TOKENS + + for i in range(NUM_TOKENS): + args = worker.get_finished.call_args_list[i].args + if i != NUM_TOKENS - 1: + assert args == ([], []) + else: + assert len(args[0]) == 1 + assert args[0][0] == scheduler.request_finished.call_args.args[ + 0].request_id From e16a38d16aae606ce47088bbb0f8d744ccc3bdd3 Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Thu, 31 Jul 2025 12:06:36 -0700 Subject: [PATCH 22/50] Simplify add token stuff Signed-off-by: jthomson04 --- .../batch_manager/kvCacheConnector.h | 2 ++ .../batch_manager/kvCacheManager.h | 8 +++---- .../batch_manager/kvCacheManager.cpp | 15 ++++--------- .../pybind/batch_manager/kvCacheManager.cpp | 7 +++--- tensorrt_llm/_torch/pyexecutor/connector.py | 14 +++++++++++- tensorrt_llm/_torch/pyexecutor/py_executor.py | 2 +- .../_torch/pyexecutor/resource_manager.py | 22 ++++++++++--------- .../defs/llmapi/test_llm_api_connector.py | 10 ++++++--- 8 files changed, 45 insertions(+), 35 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h index b69bfa8615c..1fdccfe53a1 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h @@ -26,6 +26,8 @@ using SizeType32 = tensorrt_llm::runtime::SizeType32; using RequestIdType = tensorrt_llm::batch_manager::LlmRequest::RequestIdType; +/// See tensorrt_llm/_torch/pyexecutor/connector.py for details on the Connector API. + namespace tensorrt_llm::batch_manager::kv_connector { diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index 651f09c65a1..7ccbc9b5302 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -1237,9 +1237,7 @@ class BaseKVCacheManager = 0; /// @brief Increase size for request at seqSlotIdx. Allocate new KV cache block(s) if needed. - /// @param returnNewBlockId If true, return the id of the newly allocated block (if any). Only supported when VSWA - /// and beam search are disabled. - virtual std::optional addToken(LlmRequest::RequestIdType requestId, bool returnNewBlockId = false) = 0; + virtual std::optional addToken(LlmRequest::RequestIdType requestId) = 0; /// @brief Add new request to the KV cache manager. /// @param inputLength Input length for which KV cache need to be allocated. @@ -1542,7 +1540,7 @@ class KVCacheManager : public BaseKVCacheManager /// @brief Increase size for request with requestId. Allocate new KV cache block(s) if needed. /// @param returnNewBlockId If true, return the id of the newly allocated block (if any). Only supported when VSWA /// and beam search are disabled. - std::optional addToken(LlmRequest::RequestIdType requestId, bool returnNewBlockId = false) override; + std::optional addToken(LlmRequest::RequestIdType requestId) override; /// @brief Add new request to the KV cache manager. /// @param inputLength Input length for which KV cache need to be allocated. @@ -1686,7 +1684,7 @@ class KVCacheManager : public BaseKVCacheManager void cacheBlockOffsets(GenerationRequest& seq, SizeType32 windowSize); void cacheNewBlockOffsets(GenerationRequest& seq, SizeType32 windowSize); void updateNewBlockPointer(GenerationRequest& seq, SizeType32 windowSize, SizeType32 blockIdx); - std::optional updateToken(GenerationRequest& sequence, bool addToken, bool returnNewBlockId); + std::optional updateToken(GenerationRequest& sequence, bool addToken); private: // Maximum number of sequences diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 744a9e1106a..f0e4177645c 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -1980,12 +1980,8 @@ void KVCacheManager::updateNewBlockPointer(GenerationRequest& sequence, SizeType } } -std::optional KVCacheManager::updateToken(GenerationRequest& sequence, bool addToken, bool returnNewBlockId) +std::optional KVCacheManager::updateToken(GenerationRequest& sequence, bool addToken) { - TLLM_CHECK_WITH_INFO( - !returnNewBlockId || (mBlockManager.getWindowSizesMetadata().size() == 1 && sequence.getBeamWidth() == 1), - "KV Cache Connector is not supported with beam search"); - auto currNumTokens = sequence.getNumTokens(); if (addToken) @@ -2025,9 +2021,6 @@ std::optional KVCacheManager::updateToken(GenerationRequest& sequenc { mBlockManager.allocateBlock(sequence, windowSize); cacheNewBlockOffsets(sequence, windowSize); - - return returnNewBlockId ? std::make_optional(sequence.getCacheBlockIds(windowSize).at(0).back()) - : std::nullopt; } else { @@ -2049,10 +2042,10 @@ std::optional KVCacheManager::updateToken(GenerationRequest& sequenc return std::nullopt; } -std::optional KVCacheManager::addToken(RequestIdType requestId, bool returnNewBlockId) +std::optional KVCacheManager::addToken(RequestIdType requestId) { auto& sequence = getSequence(requestId); - return updateToken(sequence, true, returnNewBlockId); + return updateToken(sequence, true); } std::optional KVCacheManager::findNewContextBlock( @@ -2494,7 +2487,7 @@ void KVCacheManager::removeToken(RequestIdType requestId) { return; } - updateToken(sequence, false, false); + updateToken(sequence, false); } void KVCacheManager::rewindKVCache(RequestIdType requestId, SizeType32 rewindLengths) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index dfaed0f6c5d..471c2a4820c 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -90,10 +90,9 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager PYBIND11_OVERLOAD_PURE(tbk::KvCacheStats, tbk::BaseKVCacheManager, getKvCacheStats); } - std::optional addToken(tb::LlmRequest::RequestIdType requestId, bool returnNewBlockId = false) override + std::optional addToken(tb::LlmRequest::RequestIdType requestId) override { - PYBIND11_OVERLOAD_PURE( - std::optional, tbk::BaseKVCacheManager, addToken, requestId, returnNewBlockId); + PYBIND11_OVERLOAD_PURE(std::optional, tbk::BaseKVCacheManager, addToken, requestId); } void addSequence(tb::LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, @@ -350,7 +349,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) [](tbk::BaseKVCacheManager& self) { return self.getOffsetTableDimensions().maxBlocksPerSeq; }) .def("get_needed_blocks_one_step", &BaseKVCacheManager::getNeededBlocksOneStep) .def("get_remaining_blocks_to_completion", &BaseKVCacheManager::getRemainingBlocksToCompletion) - .def("add_token", &BaseKVCacheManager::addToken, py::arg("request_id"), py::arg("return_new_block_id") = false) + .def("add_token", &BaseKVCacheManager::addToken) .def("add_sequence", &BaseKVCacheManager::addSequence, py::arg("request_id"), py::arg("input_length"), py::arg("beam_width"), py::arg("llm_request") = std::nullopt, py::arg("kv_cache_connector_manager") = std::nullopt) diff --git a/tensorrt_llm/_torch/pyexecutor/connector.py b/tensorrt_llm/_torch/pyexecutor/connector.py index 8ae8eb1f06b..b25ec343ff6 100644 --- a/tensorrt_llm/_torch/pyexecutor/connector.py +++ b/tensorrt_llm/_torch/pyexecutor/connector.py @@ -28,6 +28,8 @@ The Connector API is split into two parts: 1. The scheduler, which is responsible for orchestration, and building metadata for the workers. 2. The worker, which performs and monitors transfers indicated by the scheduler's metadata. + +To implement a custom KV connector, you need to implement both the scheduler and worker-side interfaces. """ @@ -230,6 +232,16 @@ def loading_ids(self) -> set[int]: class KvCacheConnectorManager(KvCacheConnectorManagerCpp): + """ + The KvCacheConnectorManager is used to manager connector-related state. + + It has the following responsibilities: + 1. Managing the state of async requests (both offload and onboard) + 2. Handling MPI communication. We only run the leader on one rank, but need the results of the leader API on all ranks. + + Note: This class is solely an implementation detail, and is not part of the connector interface itself. + When implementing a connector API, you do not need to implement this class. + """ def __init__(self, worker: KvCacheConnectorWorker, scheduler: Optional[KvCacheConnectorScheduler]): @@ -308,7 +320,7 @@ def take_scheduled_requests_pending_load( # Update the list of scheduled requests. scheduled_requests.context_requests = allowed_context_requests - def build_connector_meta(self) -> object: + def handle_metadata(self) -> object: metadata = self._run_on_leader( lambda: self.scheduler.build_connector_meta(self._scheduler_output)) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 845d7e7d5ca..0d6354d0b08 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -986,7 +986,7 @@ def _executor_loop(self): if self.kv_connector_manager: self.kv_connector_manager.take_scheduled_requests_pending_load( scheduled_batch) - self.kv_connector_manager.build_connector_meta() + self.kv_connector_manager.handle_metadata() self.kv_connector_manager.worker.start_load_kv() if scheduled_batch.batch_size > 0 or ( diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 7b44072e15e..d94a6b80dec 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -414,9 +414,7 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): if not req.is_kv_cache_connector_async_onboard: scheduler_output.add_request( req.request_id, req.get_tokens(0), - self.impl.get_cache_block_ids( - req.request_id, - self.max_attention_window_vec[0]), + self.get_cache_indices(req), req.context_current_position) else: # When using the connector, this code path will be hit after the async load is complete. @@ -427,9 +425,7 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): req.is_kv_cache_connector_async_onboard = False scheduler_output.add_request( req.request_id, req.get_tokens(0), - self.impl.get_cache_block_ids( - req.request_id, - self.max_attention_window_vec[0]), + self.get_cache_indices(req), req.context_current_position) else: # Otherwise, we just provide the new context position. No new blocks are allocated. @@ -438,16 +434,22 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): req.context_current_position) for req in generation_batch: - new_block_id = self.impl.add_token(req.py_request_id, True) + + old_block_ids = self.get_cache_indices(req) + + self.impl.add_token(req.py_request_id) for _ in range(get_draft_token_length(req)): self.impl.add_token(req.py_request_id) + new_block_ids = self.get_cache_indices(req) + + delta_block_ids = new_block_ids[len(old_block_ids):] + tokens = req.get_tokens(0) - scheduler_output.add_request( - req.request_id, tokens[-1:], - [new_block_id] if new_block_id is not None else [], len(tokens)) + scheduler_output.add_request(req.request_id, tokens[-1:], + delta_block_ids, len(tokens)) if self.kv_connector_manager is not None: self.kv_connector_manager.set_scheduler_output(scheduler_output) diff --git a/tests/integration/defs/llmapi/test_llm_api_connector.py b/tests/integration/defs/llmapi/test_llm_api_connector.py index 640820c8145..3c9bf9a77c0 100644 --- a/tests/integration/defs/llmapi/test_llm_api_connector.py +++ b/tests/integration/defs/llmapi/test_llm_api_connector.py @@ -20,6 +20,7 @@ import pytest from tensorrt_llm import LLM, SamplingParams +from tensorrt_llm.llmapi.llm_args import KvCacheConfig from tensorrt_llm.models.modeling_utils import KvCacheConnectorConfig sys.path.append(os.path.dirname(os.path.abspath(__file__))) @@ -75,7 +76,8 @@ def test_llm_api_connector_simple(connector): backend="pytorch", disable_overlap_scheduler=True, connector_config=connector_config, - cuda_graph_config=None) + cuda_graph_config=None, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1)) assert worker.register_kv_caches.call_count == 1 @@ -133,7 +135,8 @@ def test_llm_api_connector_async_onboard(connector): backend="pytorch", disable_overlap_scheduler=True, connector_config=connector_config, - cuda_graph_config=None) + cuda_graph_config=None, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1)) assert worker.register_kv_caches.call_count == 1 @@ -163,7 +166,8 @@ def test_llm_api_connector_async_save(connector): backend="pytorch", disable_overlap_scheduler=True, connector_config=connector_config, - cuda_graph_config=None) + cuda_graph_config=None, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1)) assert worker.register_kv_caches.call_count == 1 From 7081fe7446c026bffdac6a4234cab47853db06a2 Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Thu, 31 Jul 2025 14:19:29 -0700 Subject: [PATCH 23/50] Tests for scheduler metadata Signed-off-by: jthomson04 --- tensorrt_llm/_torch/pyexecutor/connector.py | 7 +- .../_torch/pyexecutor/py_executor_creator.py | 4 +- .../_torch/pyexecutor/resource_manager.py | 7 +- .../defs/llmapi/test_llm_api_connector.py | 109 +++++++++++++----- 4 files changed, 90 insertions(+), 37 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/connector.py b/tensorrt_llm/_torch/pyexecutor/connector.py index b25ec343ff6..9213d514302 100644 --- a/tensorrt_llm/_torch/pyexecutor/connector.py +++ b/tensorrt_llm/_torch/pyexecutor/connector.py @@ -4,6 +4,7 @@ from tensorrt_llm._utils import mpi_allgather, mpi_broadcast, mpi_rank from tensorrt_llm.bindings import LlmRequestState +from tensorrt_llm.bindings.executor import ExecutorConfig from tensorrt_llm.bindings.internal.batch_manager import \ KvCacheConnectorManager as KvCacheConnectorManagerCpp from tensorrt_llm.bindings.internal.batch_manager import ( @@ -61,7 +62,8 @@ def add_request(self, request_id: int, new_tokens: list[int], class KvCacheConnectorWorker(ABC): - def __init__(self): + def __init__(self, config: ExecutorConfig): + self._config = config super().__init__() def bind_connector_meta(self, metadata: object): @@ -137,7 +139,8 @@ def get_finished( class KvCacheConnectorScheduler(ABC): - def __init__(self): + def __init__(self, executor_config: ExecutorConfig): + self._config = executor_config super().__init__() @abstractmethod diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index c562ad57961..cb4a90f4a47 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -375,12 +375,12 @@ def create_py_executor( scheduler_cls = getattr( module, kv_connector_config.connector_scheduler_class) - connector_worker = worker_cls() + connector_worker = worker_cls(executor_config) # Only initialize the scheduler on rank 0. rank = tensorrt_llm.mpi_rank() if rank == 0: - connector_scheduler = scheduler_cls() + connector_scheduler = scheduler_cls(executor_config) else: connector_scheduler = None diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index d94a6b80dec..575024a7e41 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -434,6 +434,8 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): req.context_current_position) for req in generation_batch: + tokens = req.get_tokens(0) + computed_token_position = len(req.get_tokens(0)) - 1 old_block_ids = self.get_cache_indices(req) @@ -446,10 +448,9 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): delta_block_ids = new_block_ids[len(old_block_ids):] - tokens = req.get_tokens(0) - scheduler_output.add_request(req.request_id, tokens[-1:], - delta_block_ids, len(tokens)) + delta_block_ids, + computed_token_position) if self.kv_connector_manager is not None: self.kv_connector_manager.set_scheduler_output(scheduler_output) diff --git a/tests/integration/defs/llmapi/test_llm_api_connector.py b/tests/integration/defs/llmapi/test_llm_api_connector.py index 3c9bf9a77c0..ed94fff4f91 100644 --- a/tests/integration/defs/llmapi/test_llm_api_connector.py +++ b/tests/integration/defs/llmapi/test_llm_api_connector.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import os import sys from unittest.mock import MagicMock, patch @@ -42,8 +43,8 @@ def init_connector_classes(): return scheduler, worker -@pytest.fixture -def connector(): +@pytest.fixture(scope="function") +def model_with_connector(): with patch("tensorrt_llm._torch.pyexecutor.py_executor_creator.importlib" ) as importlib_mock: mock_scheduler = MagicMock() @@ -58,7 +59,10 @@ def connector(): connector_worker_class="KvConnectorWorker", ) - yield connector_config, mock_scheduler, mock_worker + def model_fn(*args, **kwargs): + return LLM(*args, **kwargs, connector_config=connector_config) + + yield model_fn, mock_scheduler, mock_worker # Needed because MagicMocks don't work across processes. @@ -67,17 +71,17 @@ def connector(): @pytest.mark.threadleak(enabled=False) -def test_llm_api_connector_simple(connector): - connector_config, scheduler, worker = connector - +def test_llm_api_connector_simple(model_with_connector): NUM_TOKENS = 8 - model = LLM(model="Qwen/Qwen2-0.5B", - backend="pytorch", - disable_overlap_scheduler=True, - connector_config=connector_config, - cuda_graph_config=None, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1)) + model_fn, scheduler, worker = model_with_connector + + model = model_fn( + model="Qwen/Qwen2-0.5B", + backend="pytorch", + disable_overlap_scheduler=True, + cuda_graph_config=None, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1)) assert worker.register_kv_caches.call_count == 1 @@ -126,17 +130,17 @@ def test_llm_api_connector_simple(connector): @pytest.mark.threadleak(enabled=False) -def test_llm_api_connector_async_onboard(connector): - connector_config, scheduler, worker = connector - +def test_llm_api_connector_async_onboard(model_with_connector): NUM_TOKENS = 8 - model = LLM(model="Qwen/Qwen2-0.5B", - backend="pytorch", - disable_overlap_scheduler=True, - connector_config=connector_config, - cuda_graph_config=None, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1)) + model_fn, scheduler, worker = model_with_connector + + model = model_fn( + model="Qwen/Qwen2-0.5B", + backend="pytorch", + disable_overlap_scheduler=True, + cuda_graph_config=None, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1)) assert worker.register_kv_caches.call_count == 1 @@ -157,17 +161,17 @@ def test_llm_api_connector_async_onboard(connector): @pytest.mark.threadleak(enabled=False) -def test_llm_api_connector_async_save(connector): - connector_config, scheduler, worker = connector - +def test_llm_api_connector_async_save(model_with_connector): NUM_TOKENS = 8 - model = LLM(model="Qwen/Qwen2-0.5B", - backend="pytorch", - disable_overlap_scheduler=True, - connector_config=connector_config, - cuda_graph_config=None, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1)) + model_fn, scheduler, worker = model_with_connector + + model = model_fn( + model="Qwen/Qwen2-0.5B", + backend="pytorch", + disable_overlap_scheduler=True, + cuda_graph_config=None, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1)) assert worker.register_kv_caches.call_count == 1 @@ -195,3 +199,48 @@ def test_llm_api_connector_async_save(connector): assert len(args[0]) == 1 assert args[0][0] == scheduler.request_finished.call_args.args[ 0].request_id + + +@pytest.mark.threadleak(enabled=False) +def test_llm_api_scheduler_output(model_with_connector): + NUM_INPUT_TOKENS = 48 + NUM_TOKENS = 32 + BLOCK_SIZE = 32 + + model_fn, scheduler, worker = model_with_connector + + model = model_fn( + model="Qwen/Qwen2-0.5B", + backend="pytorch", + disable_overlap_scheduler=True, + cuda_graph_config=None, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1)) + + assert worker.register_kv_caches.call_count == 1 + + scheduler.get_num_new_matched_tokens.return_value = 0, False + + worker.get_finished.return_value = [], [] + + sampling_params = SamplingParams(max_tokens=32, ignore_eos=True) + + model.generate([0] * NUM_INPUT_TOKENS, sampling_params) + + assert scheduler.build_connector_meta.call_count == NUM_TOKENS + + for i, call in enumerate(scheduler.build_connector_meta.call_args_list): + sched_output = call.args[0] + + assert len(sched_output.requests) == 1 + request = sched_output.requests[0] + if i == 0: + assert len(request.new_tokens) == NUM_INPUT_TOKENS + assert len(request.new_block_ids) == math.ceil(NUM_INPUT_TOKENS / + BLOCK_SIZE) + else: + assert len(request.new_tokens) == 1 + + if request.computed_position % BLOCK_SIZE == 0: + assert len(request.new_block_ids) == 1 + else: + assert request.new_block_ids == [] From 7d7dabeaf44fca4a14b7f8354cef8f81e68a9e80 Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Thu, 31 Jul 2025 15:06:27 -0700 Subject: [PATCH 24/50] Chunked prefill tests Signed-off-by: jthomson04 --- .../defs/llmapi/test_llm_api_connector.py | 67 +++++++++++++++++-- 1 file changed, 63 insertions(+), 4 deletions(-) diff --git a/tests/integration/defs/llmapi/test_llm_api_connector.py b/tests/integration/defs/llmapi/test_llm_api_connector.py index ed94fff4f91..d4e9a89b1cc 100644 --- a/tests/integration/defs/llmapi/test_llm_api_connector.py +++ b/tests/integration/defs/llmapi/test_llm_api_connector.py @@ -71,7 +71,7 @@ def model_fn(*args, **kwargs): @pytest.mark.threadleak(enabled=False) -def test_llm_api_connector_simple(model_with_connector): +def test_connector_simple(model_with_connector): NUM_TOKENS = 8 model_fn, scheduler, worker = model_with_connector @@ -130,7 +130,7 @@ def test_llm_api_connector_simple(model_with_connector): @pytest.mark.threadleak(enabled=False) -def test_llm_api_connector_async_onboard(model_with_connector): +def test_connector_async_onboard(model_with_connector): NUM_TOKENS = 8 model_fn, scheduler, worker = model_with_connector @@ -161,7 +161,7 @@ def test_llm_api_connector_async_onboard(model_with_connector): @pytest.mark.threadleak(enabled=False) -def test_llm_api_connector_async_save(model_with_connector): +def test_connector_async_save(model_with_connector): NUM_TOKENS = 8 model_fn, scheduler, worker = model_with_connector @@ -202,7 +202,7 @@ def test_llm_api_connector_async_save(model_with_connector): @pytest.mark.threadleak(enabled=False) -def test_llm_api_scheduler_output(model_with_connector): +def test_connector_scheduler_output(model_with_connector): NUM_INPUT_TOKENS = 48 NUM_TOKENS = 32 BLOCK_SIZE = 32 @@ -237,6 +237,7 @@ def test_llm_api_scheduler_output(model_with_connector): assert len(request.new_tokens) == NUM_INPUT_TOKENS assert len(request.new_block_ids) == math.ceil(NUM_INPUT_TOKENS / BLOCK_SIZE) + assert request.computed_position == 0 else: assert len(request.new_tokens) == 1 @@ -244,3 +245,61 @@ def test_llm_api_scheduler_output(model_with_connector): assert len(request.new_block_ids) == 1 else: assert request.new_block_ids == [] + + scheduler.build_connector_meta.reset_mock() + + scheduler.get_num_new_matched_tokens.return_value = 8, False + + model.generate([0] * NUM_INPUT_TOKENS, sampling_params) + + assert scheduler.build_connector_meta.call_args_list[0].args[0].requests[ + 0].computed_position == 8 + + +@pytest.mark.threadleak(enabled=False) +def test_connector_scheduler_output_chunked_context(model_with_connector): + model_fn, scheduler, worker = model_with_connector + + CHUNK_SIZE = 128 + BLOCK_SIZE = 32 + + model = model_fn( + model="Qwen/Qwen2-0.5B", + backend="pytorch", + disable_overlap_scheduler=True, + cuda_graph_config=None, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1), + enable_chunked_prefill=True, + max_num_tokens=CHUNK_SIZE) + + assert worker.register_kv_caches.call_count == 1 + + scheduler.get_num_new_matched_tokens.return_value = 0, False + + worker.get_finished.return_value = [], [] + + sampling_params = SamplingParams(max_tokens=32, ignore_eos=True) + + model.generate([0] * (CHUNK_SIZE * 2), sampling_params) + + for i, call in enumerate(scheduler.build_connector_meta.call_args_list): + sched_output = call.args[0] + + assert len(sched_output.requests) == 1 + + req = sched_output.requests[0] + + if i == 0: + # The first prefill chunk. + # All of the prefill tokens and all the blocks should be provided upfront. + assert req.computed_position == 0 + assert len(req.new_tokens) == CHUNK_SIZE * 2 + assert len(req.new_block_ids) == math.ceil(CHUNK_SIZE * 2 / + BLOCK_SIZE) + elif i == 1: + # The second prefill chunk. + assert req.computed_position == CHUNK_SIZE + assert len(req.new_tokens) == 0 + assert len(req.new_block_ids) == 0 + else: + assert len(req.new_tokens) == 1 From 65f58a46f0237a0ace0cf602ec283904e968dcb6 Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Fri, 1 Aug 2025 12:22:18 -0700 Subject: [PATCH 25/50] simplify register_kv_caches handling Signed-off-by: jthomson04 --- .../batch_manager/kvCacheConnector.h | 50 ------------------- .../batch_manager/kvCacheManager.h | 8 --- .../batch_manager/kvCacheManager.cpp | 36 ------------- .../pybind/batch_manager/kvCacheConnector.cpp | 15 ------ .../pybind/batch_manager/kvCacheManager.cpp | 9 +--- cpp/tensorrt_llm/runtime/torch.h | 10 ++-- tensorrt_llm/_torch/pyexecutor/connector.py | 7 +-- tensorrt_llm/_torch/pyexecutor/py_executor.py | 17 +++++-- .../_torch/pyexecutor/resource_manager.py | 5 -- 9 files changed, 22 insertions(+), 135 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h index 1fdccfe53a1..10ccaf0831b 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h @@ -31,56 +31,6 @@ using RequestIdType = tensorrt_llm::batch_manager::LlmRequest::RequestIdType; namespace tensorrt_llm::batch_manager::kv_connector { -class KvCacheConnectorPoolData -{ -public: - KvCacheConnectorPoolData(runtime::ITensor::SharedPtr poolTensor, SizeType32 numBlocks) - : mPoolTensor(std::move(poolTensor)) - , mNumBlocks(numBlocks) - { - } - - runtime::ITensor::SharedPtr const& getPoolTensor() const - { - return mPoolTensor; - } - - SizeType32 getNumBlocks() const - { - return mNumBlocks; - } - -private: - runtime::ITensor::SharedPtr mPoolTensor; - SizeType32 mNumBlocks; -}; - -/// @brief Data used to provide the KV cache tensors to the connector worker for all the pools. -class KvCacheConnectorPoolsData -{ -public: - explicit KvCacheConnectorPoolsData( - std::vector& poolsData, std::vector& layerToPoolMapping) - : mPoolsData(poolsData) - , mLayerToPoolMapping(layerToPoolMapping) - { - } - - std::vector& getPoolsData() - { - return mPoolsData; - } - - std::vector& getLayerToPoolMapping() - { - return mLayerToPoolMapping; - } - -private: - std::vector mPoolsData; - std::vector mLayerToPoolMapping; -}; - /// @brief The KV connector manager. This is passed into the C++ KV Cache Manager when adding sequences. class KvCacheConnectorManager { diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index 7ccbc9b5302..45fef623842 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -748,8 +748,6 @@ class WindowBlockManager return 0; } - [[nodiscard]] kv_connector::KvCacheConnectorPoolData getKvCacheConnectorPoolData() const; - private: //! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq. void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx); @@ -1139,8 +1137,6 @@ class BlockManager return mWindowBlockManagers.at(windowSize).getPool(relativePoolIndex); } - [[nodiscard]] std::vector getKvCacheConnectorPoolsData() const; - private: [[nodiscard]] WindowBlockManager const& windowManagerByLayer(SizeType32 layerIdx) const { @@ -1374,8 +1370,6 @@ class BaseKVCacheManager [[nodiscard]] virtual SizeType32 getMaxCapacityBatchSize(SizeType32 inputLength, SizeType32 outputLength) const = 0; [[nodiscard]] virtual CacheType getCacheType() const = 0; - - [[nodiscard]] virtual kv_connector::KvCacheConnectorPoolsData getKvCacheConnectorPoolsData() const = 0; }; class KVCacheManager : public BaseKVCacheManager @@ -1678,8 +1672,6 @@ class KVCacheManager : public BaseKVCacheManager [[nodiscard]] static SizeType32 calculateMaxAttentionWindow(SizeType32 inputLength, SizeType32 outputLength, SizeType32 sinkTokenLength, SizeType32 blockCapacity, SizeType32 beamWidth, SizeType32 tokensPerBlock); - [[nodiscard]] kv_connector::KvCacheConnectorPoolsData getKvCacheConnectorPoolsData() const override; - private: void cacheBlockOffsets(GenerationRequest& seq, SizeType32 windowSize); void cacheNewBlockOffsets(GenerationRequest& seq, SizeType32 windowSize); diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index f0e4177645c..8fc1a52f31c 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -1548,24 +1548,6 @@ void BlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef BlockManager::getKvCacheConnectorPoolsData() const -{ - TLLM_CHECK_WITH_INFO( - mWindowBlockManagers.size() == 1, "KV Cache connector is not supported with multiple window sizes"); - std::vector poolsData; - poolsData.reserve(1); - for (auto const& [_, manager] : mWindowBlockManagers) - { - poolsData.emplace_back(manager.getKvCacheConnectorPoolData()); - } - return poolsData; -} - -[[nodiscard]] kv_connector::KvCacheConnectorPoolData WindowBlockManager::getKvCacheConnectorPoolData() const -{ - return kv_connector::KvCacheConnectorPoolData(mPools.at(0).primaryPtr, mNumPrimaryBlocks); -} - void WindowBlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef llmRequest) { auto constexpr beamIdx = 0; @@ -2622,22 +2604,4 @@ SizeType32 KVCacheManager::calculateMaxBlockRequirements(SizeType32 inputLength, return std::min(outputLength + leftoverBlockCapacity * tokensPerBlock, inputLength + outputLength); } -[[nodiscard]] kv_connector::KvCacheConnectorPoolsData KVCacheManager::getKvCacheConnectorPoolsData() const -{ - auto poolsData = mBlockManager.getKvCacheConnectorPoolsData(); - - auto layerToPoolView = BufferRange(*mLayerToPoolMapping); - - auto numLayers = mBlockManager.getNumLayers(); - - auto layerToPool = std::vector(numLayers); - - for (size_t layer = 0; layer < static_cast(numLayers); layer++) - { - layerToPool[layer] = layerToPoolView[layer * 2]; - } - - return kv_connector::KvCacheConnectorPoolsData(poolsData, layerToPool); -} - } // namespace tensorrt_llm::batch_manager::kv_cache_manager diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp index c92056db370..a9bd519688b 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp @@ -42,21 +42,6 @@ class PyKvCacheConnectorManager : public KvCacheConnectorManager, py::trampoline void tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManagerConnectorBindings::initBindings(py::module_& m) { - py::class_(m, "KvCacheConnectorPoolData") - .def_property_readonly("tensor", - [](tb::kv_connector::KvCacheConnectorPoolData& self) - { - auto const& poolTensor = self.getPoolTensor(); - - return tensorrt_llm::runtime::Torch::tensor(poolTensor); - }) - .def_property_readonly("num_blocks", &tb::kv_connector::KvCacheConnectorPoolData::getNumBlocks); - - py::class_(m, "KvCacheConnectorPoolsData") - .def_property_readonly("pools", &tb::kv_connector::KvCacheConnectorPoolsData::getPoolsData) - .def_property_readonly( - "layer_to_pool_mapping", &tb::kv_connector::KvCacheConnectorPoolsData::getLayerToPoolMapping); - py::class_( m, "KvCacheConnectorManager") .def(py::init<>()) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index 471c2a4820c..0ae931be5fe 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -236,12 +236,6 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager { PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, flushIterationEvents); } - - [[nodiscard]] tb::kv_connector::KvCacheConnectorPoolsData getKvCacheConnectorPoolsData() const override - { - PYBIND11_OVERLOAD_PURE( - tb::kv_connector::KvCacheConnectorPoolsData, tbk::BaseKVCacheManager, getKvCacheConnectorPoolsData); - } }; // TODO: Deduplicate executor bindings KvCacheStats @@ -433,8 +427,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) .def("get_cache_block_ids", &BaseKVCacheManager::getCacheBlockIds) .def("get_batch_cache_block_ids", &BaseKVCacheManager::getBatchCacheBlockIds) .def("get_newly_allocated_block_ids", &BaseKVCacheManager::getNewlyAllocatedBlockIds) - .def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents) - .def("get_kv_cache_connector_pools_data", &BaseKVCacheManager::getKvCacheConnectorPoolsData); + .def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents); py::enum_(m, "CacheType") .value("SELF", tbk::CacheType::kSELF) diff --git a/cpp/tensorrt_llm/runtime/torch.h b/cpp/tensorrt_llm/runtime/torch.h index c6863890665..4f8162a1682 100644 --- a/cpp/tensorrt_llm/runtime/torch.h +++ b/cpp/tensorrt_llm/runtime/torch.h @@ -43,15 +43,11 @@ class Torch .deleter( [ptr = std::move(tensor)](void* data) mutable { - try + if (data != ptr->data()) { - TLLM_CHECK(data == ptr->data()); - ptr.reset(); - } - catch (std::exception const& e) - { - TLLM_LOG_EXCEPTION(e); + TLLM_LOG_WARNING("Torch tensor refers to deallocated memory."); } + ptr.reset(); }) .make_tensor(); } diff --git a/tensorrt_llm/_torch/pyexecutor/connector.py b/tensorrt_llm/_torch/pyexecutor/connector.py index 9213d514302..442f9be7815 100644 --- a/tensorrt_llm/_torch/pyexecutor/connector.py +++ b/tensorrt_llm/_torch/pyexecutor/connector.py @@ -2,13 +2,14 @@ from dataclasses import dataclass, field from typing import Any, Callable, Optional +import torch + from tensorrt_llm._utils import mpi_allgather, mpi_broadcast, mpi_rank from tensorrt_llm.bindings import LlmRequestState from tensorrt_llm.bindings.executor import ExecutorConfig from tensorrt_llm.bindings.internal.batch_manager import \ KvCacheConnectorManager as KvCacheConnectorManagerCpp -from tensorrt_llm.bindings.internal.batch_manager import ( - KvCacheConnectorPoolsData, LlmRequest) +from tensorrt_llm.bindings.internal.batch_manager import LlmRequest from .scheduler import ScheduledRequests """ @@ -76,7 +77,7 @@ def _clear_connector_meta(self): self._metadata = None @abstractmethod - def register_kv_caches(self, kv_cache_data: KvCacheConnectorPoolsData): + def register_kv_caches(self, kv_cache_data: dict[int, torch.Tensor]): """ Register the KV cache tensors to the worker. This can be used for something like NIXL registration. diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 0d6354d0b08..100970d4ee4 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -292,10 +292,21 @@ def _maybe_init_kv_connector_manager(self): "KV Cache Connector is not supported with overlap scheduler." ) - kv_cache_data = self.kv_cache_manager.get_kv_cache_connector_pools_data( - ) + # TODO: This does NOT support pipeline parallel. + layer_kv_tensors = { + layer_idx: self.kv_cache_manager.get_buffers(layer_idx) + for layer_idx in self.kv_cache_manager.pp_layers + } + + kv_shape = layer_kv_tensors[list(layer_kv_tensors.keys())[0]].shape + + if not all(t.shape == kv_shape for t in layer_kv_tensors.values()): + return ValueError( + "KV Cache Connector is not supported with Variable sliding window attention!" + ) - self.kv_connector_manager.worker.register_kv_caches(kv_cache_data) + self.kv_connector_manager.worker.register_kv_caches( + layer_kv_tensors) # For each of our layers, we need to register the pre/post hooks. # These are used for methods like `wait_for_layer_load` and `save_kv_layer`. diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 575024a7e41..fb760488922 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -10,8 +10,6 @@ import tensorrt_llm import tensorrt_llm.bindings from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE -from tensorrt_llm.bindings.internal.batch_manager import \ - KvCacheConnectorPoolsData from tensorrt_llm.lora_manager import LoraConfig, LoraManager, LoraModelConfig from tensorrt_llm.sampling_params import SamplingParams @@ -954,9 +952,6 @@ def _set_temp_attention_window_inputs( else: return None - def get_kv_cache_connector_pools_data(self) -> KvCacheConnectorPoolsData: - return self.impl.get_kv_cache_connector_pools_data() - class MambaCacheManager(BaseResourceManager): From 4140d52b771531ef18ece5effa971fa30b78bfce Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Fri, 1 Aug 2025 13:27:55 -0700 Subject: [PATCH 26/50] remove changes to add token and update token Signed-off-by: jthomson04 --- cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h | 8 +++----- cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp | 8 +++----- cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp | 4 ++-- 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index 45fef623842..b5982978123 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -1233,7 +1233,7 @@ class BaseKVCacheManager = 0; /// @brief Increase size for request at seqSlotIdx. Allocate new KV cache block(s) if needed. - virtual std::optional addToken(LlmRequest::RequestIdType requestId) = 0; + virtual void addToken(LlmRequest::RequestIdType requestId) = 0; /// @brief Add new request to the KV cache manager. /// @param inputLength Input length for which KV cache need to be allocated. @@ -1532,9 +1532,7 @@ class KVCacheManager : public BaseKVCacheManager LlmRequest const& req, SizeType32 windowSize) const override; /// @brief Increase size for request with requestId. Allocate new KV cache block(s) if needed. - /// @param returnNewBlockId If true, return the id of the newly allocated block (if any). Only supported when VSWA - /// and beam search are disabled. - std::optional addToken(LlmRequest::RequestIdType requestId) override; + void addToken(LlmRequest::RequestIdType requestId) override; /// @brief Add new request to the KV cache manager. /// @param inputLength Input length for which KV cache need to be allocated. @@ -1676,7 +1674,7 @@ class KVCacheManager : public BaseKVCacheManager void cacheBlockOffsets(GenerationRequest& seq, SizeType32 windowSize); void cacheNewBlockOffsets(GenerationRequest& seq, SizeType32 windowSize); void updateNewBlockPointer(GenerationRequest& seq, SizeType32 windowSize, SizeType32 blockIdx); - std::optional updateToken(GenerationRequest& sequence, bool addToken); + void updateToken(GenerationRequest& sequence, bool addToken); private: // Maximum number of sequences diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 8fc1a52f31c..93f569115a9 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -1962,7 +1962,7 @@ void KVCacheManager::updateNewBlockPointer(GenerationRequest& sequence, SizeType } } -std::optional KVCacheManager::updateToken(GenerationRequest& sequence, bool addToken) +void KVCacheManager::updateToken(GenerationRequest& sequence, bool addToken) { auto currNumTokens = sequence.getNumTokens(); @@ -2020,14 +2020,12 @@ std::optional KVCacheManager::updateToken(GenerationRequest& sequenc } } } - - return std::nullopt; } -std::optional KVCacheManager::addToken(RequestIdType requestId) +void KVCacheManager::addToken(RequestIdType requestId) { auto& sequence = getSequence(requestId); - return updateToken(sequence, true); + updateToken(sequence, true); } std::optional KVCacheManager::findNewContextBlock( diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index 0ae931be5fe..5ef7ca34485 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -90,9 +90,9 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager PYBIND11_OVERLOAD_PURE(tbk::KvCacheStats, tbk::BaseKVCacheManager, getKvCacheStats); } - std::optional addToken(tb::LlmRequest::RequestIdType requestId) override + void addToken(tb::LlmRequest::RequestIdType requestId) override { - PYBIND11_OVERLOAD_PURE(std::optional, tbk::BaseKVCacheManager, addToken, requestId); + PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, addToken, requestId); } void addSequence(tb::LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, From 812fcf428720ad33ad96efda5df0d93eae522214 Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Fri, 1 Aug 2025 15:41:50 -0700 Subject: [PATCH 27/50] add support for the overlap scheduler + little refactoring Signed-off-by: jthomson04 --- tensorrt_llm/_torch/pyexecutor/_util.py | 25 ++++--- tensorrt_llm/_torch/pyexecutor/py_executor.py | 40 ++++++----- .../_torch/pyexecutor/py_executor_creator.py | 4 +- .../_torch/pyexecutor/resource_manager.py | 10 ++- tensorrt_llm/_torch/pyexecutor/scheduler.py | 1 - .../defs/llmapi/test_llm_api_connector.py | 69 ++++++++++++------- 6 files changed, 89 insertions(+), 60 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index e690814211a..9cf17d7f33c 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -276,7 +276,9 @@ def estimate_max_tokens(self, py_executor: PyExecutor) -> None: executor_config.kv_cache_config.max_tokens = kv_cache_max_tokens def _create_kv_cache_manager( - self, model_engine: PyTorchModelEngine) -> KVCacheManager: + self, + model_engine: PyTorchModelEngine, + for_estimation: bool = False) -> KVCacheManager: executor_config = self._executor_config mapping = self._mapping assert model_engine.model.model_config.is_generation, "Only construct KV cache for generation models." @@ -317,7 +319,8 @@ def _create_kv_cache_manager( dtype=kv_cache_dtype, spec_config=spec_config, max_beam_width=executor_config.max_beam_width, - kv_connector_manager=self._kv_connector_manager, + kv_connector_manager=self._kv_connector_manager + if not for_estimation else None, ) elif is_nemotron_hybrid(config): if executor_config.max_beam_width > 1: @@ -325,7 +328,7 @@ def _create_kv_cache_manager( "MambaHybridCacheManager + beam search is not supported yet." ) - if self._kv_connector_manager is not None: + if not for_estimation and self._kv_connector_manager is not None: raise ValueError( "Connector manager is not supported for MambaHybridCacheManager." ) @@ -386,7 +389,8 @@ def _create_kv_cache_manager( max_num_tokens=executor_config.max_num_tokens, model_config=binding_model_config, max_beam_width=executor_config.max_beam_width, - kv_connector_manager=self._kv_connector_manager, + kv_connector_manager=self._kv_connector_manager + if not for_estimation else None, ) # KVCacheManager (Non-draft) modifies the max_seq_len field, update it to executor_config if model_engine.kv_cache_manager_key == ResourceManagerType.KV_CACHE_MANAGER: @@ -394,17 +398,20 @@ def _create_kv_cache_manager( return kv_cache_manager - def build_managers(self, resources: Dict) -> None: + def build_managers(self, + resources: Dict, + for_estimation: bool = False) -> None: """Construct KV caches for model and draft model (if applicable).""" - kv_cache_manager = self._create_kv_cache_manager(self._model_engine) + kv_cache_manager = self._create_kv_cache_manager( + self._model_engine, for_estimation) - if self._kv_connector_manager is not None and self._draft_model_engine is not None: + if not for_estimation and self._kv_connector_manager is not None and self._draft_model_engine is not None: raise ValueError( "Connector manager is not supported for draft model.") draft_kv_cache_manager = self._create_kv_cache_manager( - self._draft_model_engine - ) if self._draft_model_engine is not None else None + self._draft_model_engine, + for_estimation) if self._draft_model_engine is not None else None resources[ResourceManagerType.KV_CACHE_MANAGER] = kv_cache_manager resources[ diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 100970d4ee4..834c91ce79f 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -287,11 +287,6 @@ def _maybe_init_kv_connector_manager(self): "KV Cache Connector is not supported with pipeline parallelism." ) - if not self.disable_overlap_scheduler: - raise NotImplementedError( - "KV Cache Connector is not supported with overlap scheduler." - ) - # TODO: This does NOT support pipeline parallel. layer_kv_tensors = { layer_idx: self.kv_cache_manager.get_buffers(layer_idx) @@ -957,6 +952,19 @@ def _execute_guided_decoder(self, scheduled_batch: ScheduledRequests, self.guided_decoder.build(scheduled_batch) self.guided_decoder.execute(scheduled_batch, logits) + def _execute_kv_connector(self, scheduled_batch): + if self.kv_connector_manager: + self.kv_connector_manager.take_scheduled_requests_pending_load( + scheduled_batch) + self.kv_connector_manager.handle_metadata() + self.kv_connector_manager.worker.start_load_kv() + + def _terminate_async_save_requests(self): + if self.kv_connector_manager: + reqs_to_terminate = self.kv_connector_manager.get_finished() + for req in reqs_to_terminate: + self.resource_manager.free_resources(req) + def _executor_loop(self): torch.cuda.set_device(self.device_id) # ensure the context is created, otherwise, some MPI calls will fail. @@ -987,18 +995,13 @@ def _executor_loop(self): # Return the first token to the client self._handle_first_token_response(scheduled_batch) - scheduled_batch.is_warmup = self.is_warmup self.resource_manager.prepare_resources(scheduled_batch) if self.kv_cache_transceiver and self.guided_decoder: self.guided_decoder.init_disagg_gen_requests( scheduled_batch) - - if self.kv_connector_manager: - self.kv_connector_manager.take_scheduled_requests_pending_load( - scheduled_batch) - self.kv_connector_manager.handle_metadata() - self.kv_connector_manager.worker.start_load_kv() + + self._execute_kv_connector(scheduled_batch) if scheduled_batch.batch_size > 0 or ( self.enable_attention_dp and self.dist.tp_size > 1): @@ -1035,10 +1038,8 @@ def _executor_loop(self): if self.kv_cache_transceiver and self.ctx_in_transmission_requests: self._terminate_ctx_finished_requests() - elif self.kv_connector_manager: - reqs_to_terminate = self.kv_connector_manager.get_finished() - for req in reqs_to_terminate: - self.resource_manager.free_resources(req) + + self._terminate_async_save_requests() if self.enable_iter_perf_stats: iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[ @@ -1106,9 +1107,12 @@ def _executor_loop_overlap(self): # For generation requests which have completed KV cache transfer self._prepare_disagg_gen_transmission_complete( scheduled_batch) - self.resource_manager.prepare_resources(scheduled_batch) + self._execute_kv_connector(scheduled_batch) + + if scheduled_batch.batch_size > 0: + # The generation requests that are do not have batch_idx, # needs to be in front of the batch due to the assumptions # made in model_engine.py::_forward_step. This is only important @@ -1164,6 +1168,8 @@ def _executor_loop_overlap(self): if self.kv_cache_transceiver and self.ctx_in_transmission_requests: self._terminate_ctx_finished_requests() + self._terminate_async_save_requests() + def _process_previous_batch(self): if self.kv_cache_transceiver and self.previous_batch.ctx_transmission_reqs: for req in self.previous_batch.ctx_transmission_reqs: diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index cb4a90f4a47..81f0396743b 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -409,7 +409,7 @@ def create_py_executor( with mem_monitor.observe_creation_stage( _ExecutorCreationStage.INIT_KV_CACHE if estimating_kv_cache else _ExecutorCreationStage.KV_CACHE): - kv_cache_creator.build_managers(resources) + kv_cache_creator.build_managers(resources, estimating_kv_cache) # Resource managers for speculative decoding # For user-specified drafters, use extra_resource_managers in PyTorchBackend config @@ -463,7 +463,7 @@ def create_py_executor( # create_kv_cache_manager above, which caps executor_config.max_seq_len. Restoring # the original value before creating the final KV cache. executor_config.max_seq_len = max_seq_len - kv_cache_creator.build_managers(resources) + kv_cache_creator.build_managers(resources, False) for eng in [model_engine, draft_model_engine]: if eng is None: diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index fb760488922..03c40a3f723 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -388,8 +388,7 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): req.py_request_id, seq_len + (len(req.query_id) if self.mapping.cp_rank == self.mapping.cp_size - 1 else 0), - req_beam_width, req, self.kv_connector_manager - if not scheduled_batch.is_warmup else None) + req_beam_width, req, self.kv_connector_manager) else: # TODO(jthomson04): This is begging for a mega refactor, and can likely be significantly simplified. # In add sequence, the connector API's get_num_new_matched_tokens is called. @@ -399,10 +398,9 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): # When that happens, the request will go through this same code path, but with is_kv_cache_connector_async_onboard set to True. # Because of this, we need to filter this case out to avoid adding the same sequence twice. if req.is_first_context_chunk and not req.is_kv_cache_connector_async_onboard: - self.impl.add_sequence( - req.py_request_id, req.prompt_len, req_beam_width, req, - self.kv_connector_manager - if not scheduled_batch.is_warmup else None) + self.impl.add_sequence(req.py_request_id, req.prompt_len, + req_beam_width, req, + self.kv_connector_manager) for _ in range(self.num_extra_kv_tokens): self.impl.add_token(req.py_request_id) for _ in range(get_draft_token_length(req)): diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index 85c847ad650..ef86fc1f49e 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -21,7 +21,6 @@ def __init__(self): self.context_requests: RequestList = [] self.generation_requests: RequestList = [] self.paused_requests: RequestList = [] - self.is_warmup: bool = False @property def is_generation_only(self) -> bool: diff --git a/tests/integration/defs/llmapi/test_llm_api_connector.py b/tests/integration/defs/llmapi/test_llm_api_connector.py index d4e9a89b1cc..3d678585df3 100644 --- a/tests/integration/defs/llmapi/test_llm_api_connector.py +++ b/tests/integration/defs/llmapi/test_llm_api_connector.py @@ -71,7 +71,8 @@ def model_fn(*args, **kwargs): @pytest.mark.threadleak(enabled=False) -def test_connector_simple(model_with_connector): +@pytest.mark.parametrize("use_overlap_scheduler", [True, False]) +def test_connector_simple(model_with_connector, use_overlap_scheduler): NUM_TOKENS = 8 model_fn, scheduler, worker = model_with_connector @@ -79,7 +80,7 @@ def test_connector_simple(model_with_connector): model = model_fn( model="Qwen/Qwen2-0.5B", backend="pytorch", - disable_overlap_scheduler=True, + disable_overlap_scheduler=not use_overlap_scheduler, cuda_graph_config=None, kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1)) @@ -93,7 +94,9 @@ def test_connector_simple(model_with_connector): model.generate(["Hello, world"], sampling_params) - assert scheduler.build_connector_meta.call_count == NUM_TOKENS + # With the overlap scheduler, we generate one extra token. + assert scheduler.build_connector_meta.call_count == NUM_TOKENS + int( + use_overlap_scheduler) # We should have a single `SchedulerOutput` per forward pass. for i, call in enumerate(scheduler.build_connector_meta.call_args_list): @@ -105,7 +108,8 @@ def test_connector_simple(model_with_connector): assert len(scheduler_output.requests[0].new_tokens) == 1 # We call `start_load_kv` once at the beginning of each forward pass. - assert worker.start_load_kv.call_count == NUM_TOKENS + assert worker.start_load_kv.call_count == NUM_TOKENS + int( + use_overlap_scheduler) # Only called once when the request is received. assert scheduler.get_num_new_matched_tokens.call_count == 1 @@ -114,8 +118,10 @@ def test_connector_simple(model_with_connector): for call in worker.wait_for_layer_load.call_args_list) + 1 # Called num_layers * num_forward_passes times. - assert worker.wait_for_layer_load.call_count == num_layers * NUM_TOKENS - assert worker.save_kv_layer.call_count == num_layers * NUM_TOKENS + assert worker.wait_for_layer_load.call_count == num_layers * ( + NUM_TOKENS + int(use_overlap_scheduler)) + assert worker.save_kv_layer.call_count == num_layers * ( + NUM_TOKENS + int(use_overlap_scheduler)) for i, call in enumerate(worker.wait_for_layer_load.call_args_list): assert call.args[0] == i % num_layers @@ -123,14 +129,17 @@ def test_connector_simple(model_with_connector): for i, call in enumerate(worker.save_kv_layer.call_args_list): assert call.args[0] == i % num_layers - assert worker.wait_for_save.call_count == NUM_TOKENS + assert worker.wait_for_save.call_count == NUM_TOKENS + int( + use_overlap_scheduler) assert scheduler.request_finished.call_count == 1 - assert worker.get_finished.call_count == NUM_TOKENS + assert worker.get_finished.call_count == NUM_TOKENS + int( + use_overlap_scheduler) @pytest.mark.threadleak(enabled=False) -def test_connector_async_onboard(model_with_connector): +@pytest.mark.parametrize("use_overlap_scheduler", [True, False]) +def test_connector_async_onboard(model_with_connector, use_overlap_scheduler): NUM_TOKENS = 8 model_fn, scheduler, worker = model_with_connector @@ -138,7 +147,7 @@ def test_connector_async_onboard(model_with_connector): model = model_fn( model="Qwen/Qwen2-0.5B", backend="pytorch", - disable_overlap_scheduler=True, + disable_overlap_scheduler=not use_overlap_scheduler, cuda_graph_config=None, kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1)) @@ -153,15 +162,17 @@ def test_connector_async_onboard(model_with_connector): "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." ], SamplingParams(max_tokens=NUM_TOKENS, ignore_eos=True)) - # Once for the initial poll, then once for each token. - assert worker.get_finished.call_count == NUM_TOKENS + 1 + # Once for the initial poll, then once for each token. One extra token when using the overlap scheduler. + assert worker.get_finished.call_count == NUM_TOKENS + 1 + int( + use_overlap_scheduler) # In the first iteration, there should be a single request id provided. assert len(worker.get_finished.call_args_list[0].args[1]) == 1 @pytest.mark.threadleak(enabled=False) -def test_connector_async_save(model_with_connector): +@pytest.mark.parametrize("use_overlap_scheduler", [True, False]) +def test_connector_async_save(model_with_connector, use_overlap_scheduler): NUM_TOKENS = 8 model_fn, scheduler, worker = model_with_connector @@ -169,7 +180,7 @@ def test_connector_async_save(model_with_connector): model = model_fn( model="Qwen/Qwen2-0.5B", backend="pytorch", - disable_overlap_scheduler=True, + disable_overlap_scheduler=not use_overlap_scheduler, cuda_graph_config=None, kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1)) @@ -188,12 +199,13 @@ def test_connector_async_save(model_with_connector): assert scheduler.request_finished.call_count == 1 - # On the last call to get_finished, we should be providing the async saving request. - assert worker.get_finished.call_count == NUM_TOKENS + # On the last call to get_finished, we should be providing the async saving request. One extra token when using the overlap scheduler. + assert worker.get_finished.call_count == NUM_TOKENS + int( + use_overlap_scheduler) - for i in range(NUM_TOKENS): - args = worker.get_finished.call_args_list[i].args - if i != NUM_TOKENS - 1: + for i, call in enumerate(worker.get_finished.call_args_list): + args = call.args + if i != len(worker.get_finished.call_args_list) - 1: assert args == ([], []) else: assert len(args[0]) == 1 @@ -202,7 +214,9 @@ def test_connector_async_save(model_with_connector): @pytest.mark.threadleak(enabled=False) -def test_connector_scheduler_output(model_with_connector): +@pytest.mark.parametrize("use_overlap_scheduler", [True, False]) +def test_connector_scheduler_output(model_with_connector, + use_overlap_scheduler): NUM_INPUT_TOKENS = 48 NUM_TOKENS = 32 BLOCK_SIZE = 32 @@ -212,7 +226,7 @@ def test_connector_scheduler_output(model_with_connector): model = model_fn( model="Qwen/Qwen2-0.5B", backend="pytorch", - disable_overlap_scheduler=True, + disable_overlap_scheduler=not use_overlap_scheduler, cuda_graph_config=None, kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1)) @@ -226,7 +240,9 @@ def test_connector_scheduler_output(model_with_connector): model.generate([0] * NUM_INPUT_TOKENS, sampling_params) - assert scheduler.build_connector_meta.call_count == NUM_TOKENS + # Additional token when using the overlap scheduler. + assert scheduler.build_connector_meta.call_count == NUM_TOKENS + int( + use_overlap_scheduler) for i, call in enumerate(scheduler.build_connector_meta.call_args_list): sched_output = call.args[0] @@ -241,7 +257,8 @@ def test_connector_scheduler_output(model_with_connector): else: assert len(request.new_tokens) == 1 - if request.computed_position % BLOCK_SIZE == 0: + if (request.computed_position + + int(use_overlap_scheduler)) % BLOCK_SIZE == 0: assert len(request.new_block_ids) == 1 else: assert request.new_block_ids == [] @@ -257,7 +274,9 @@ def test_connector_scheduler_output(model_with_connector): @pytest.mark.threadleak(enabled=False) -def test_connector_scheduler_output_chunked_context(model_with_connector): +@pytest.mark.parametrize("use_overlap_scheduler", [True, False]) +def test_connector_scheduler_output_chunked_context(model_with_connector, + use_overlap_scheduler): model_fn, scheduler, worker = model_with_connector CHUNK_SIZE = 128 @@ -266,7 +285,7 @@ def test_connector_scheduler_output_chunked_context(model_with_connector): model = model_fn( model="Qwen/Qwen2-0.5B", backend="pytorch", - disable_overlap_scheduler=True, + disable_overlap_scheduler=not use_overlap_scheduler, cuda_graph_config=None, kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1), enable_chunked_prefill=True, From 7b3795f8d2a621f9f7a9b815e9f3fd7d359bdc1e Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Fri, 1 Aug 2025 17:33:13 -0700 Subject: [PATCH 28/50] little cleanup Signed-off-by: jthomson04 --- .../pybind/batch_manager/kvCacheConnector.cpp | 3 --- tensorrt_llm/_torch/pyexecutor/connector.py | 3 --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 8 ++++---- .../defs/llmapi/test_llm_api_connector.py | 19 ------------------- 4 files changed, 4 insertions(+), 29 deletions(-) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp index a9bd519688b..9c1507a3454 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp @@ -16,9 +16,6 @@ */ #include "tensorrt_llm/pybind/batch_manager/kvCacheConnector.h" -#include "tensorrt_llm/runtime/torch.h" - -#include namespace { diff --git a/tensorrt_llm/_torch/pyexecutor/connector.py b/tensorrt_llm/_torch/pyexecutor/connector.py index 442f9be7815..dd7c9dac9f5 100644 --- a/tensorrt_llm/_torch/pyexecutor/connector.py +++ b/tensorrt_llm/_torch/pyexecutor/connector.py @@ -413,6 +413,3 @@ def layer_pre_hook(self, module, *args): def layer_post_hook(self, module, *args): self.worker.save_kv_layer(module.layer_idx) - - def model_post_hook(self, module, *args): - self.worker.wait_for_save() diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 834c91ce79f..463ef75c732 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -296,7 +296,7 @@ def _maybe_init_kv_connector_manager(self): kv_shape = layer_kv_tensors[list(layer_kv_tensors.keys())[0]].shape if not all(t.shape == kv_shape for t in layer_kv_tensors.values()): - return ValueError( + raise ValueError( "KV Cache Connector is not supported with Variable sliding window attention!" ) @@ -952,7 +952,7 @@ def _execute_guided_decoder(self, scheduled_batch: ScheduledRequests, self.guided_decoder.build(scheduled_batch) self.guided_decoder.execute(scheduled_batch, logits) - def _execute_kv_connector(self, scheduled_batch): + def _handle_kv_connector(self, scheduled_batch): if self.kv_connector_manager: self.kv_connector_manager.take_scheduled_requests_pending_load( scheduled_batch) @@ -1001,7 +1001,7 @@ def _executor_loop(self): self.guided_decoder.init_disagg_gen_requests( scheduled_batch) - self._execute_kv_connector(scheduled_batch) + self._handle_kv_connector(scheduled_batch) if scheduled_batch.batch_size > 0 or ( self.enable_attention_dp and self.dist.tp_size > 1): @@ -1109,7 +1109,7 @@ def _executor_loop_overlap(self): scheduled_batch) self.resource_manager.prepare_resources(scheduled_batch) - self._execute_kv_connector(scheduled_batch) + self._handle_kv_connector(scheduled_batch) if scheduled_batch.batch_size > 0: diff --git a/tests/integration/defs/llmapi/test_llm_api_connector.py b/tests/integration/defs/llmapi/test_llm_api_connector.py index 3d678585df3..43b54c0bea6 100644 --- a/tests/integration/defs/llmapi/test_llm_api_connector.py +++ b/tests/integration/defs/llmapi/test_llm_api_connector.py @@ -15,7 +15,6 @@ import math import os -import sys from unittest.mock import MagicMock, patch import pytest @@ -24,24 +23,6 @@ from tensorrt_llm.llmapi.llm_args import KvCacheConfig from tensorrt_llm.models.modeling_utils import KvCacheConnectorConfig -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -KvConnectorScheduler = MagicMock() -KvConnectorWorker = MagicMock() - - -def init_connector_classes(): - KvConnectorScheduler.reset_mock() - KvConnectorWorker.reset_mock() - - scheduler = MagicMock() - worker = MagicMock() - - KvConnectorScheduler.return_value = scheduler - KvConnectorWorker.return_value = worker - - return scheduler, worker - @pytest.fixture(scope="function") def model_with_connector(): From 48e08ed3d53763573b5a8c7121a266f03486199c Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Tue, 5 Aug 2025 14:28:15 -0700 Subject: [PATCH 29/50] Little refactor, provide kv cache as a single contiguous tensor Signed-off-by: jthomson04 --- tensorrt_llm/_torch/pyexecutor/connector.py | 80 ++++++++++++++----- tensorrt_llm/_torch/pyexecutor/py_executor.py | 21 ++--- .../_torch/pyexecutor/resource_manager.py | 32 +++----- tests/unittest/_torch/test_connector.py | 17 ++++ 4 files changed, 96 insertions(+), 54 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/connector.py b/tensorrt_llm/_torch/pyexecutor/connector.py index dd7c9dac9f5..e7d6caa8093 100644 --- a/tensorrt_llm/_torch/pyexecutor/connector.py +++ b/tensorrt_llm/_torch/pyexecutor/connector.py @@ -1,17 +1,17 @@ -from abc import ABC, abstractmethod -from dataclasses import dataclass, field -from typing import Any, Callable, Optional - -import torch - -from tensorrt_llm._utils import mpi_allgather, mpi_broadcast, mpi_rank -from tensorrt_llm.bindings import LlmRequestState -from tensorrt_llm.bindings.executor import ExecutorConfig -from tensorrt_llm.bindings.internal.batch_manager import \ - KvCacheConnectorManager as KvCacheConnectorManagerCpp -from tensorrt_llm.bindings.internal.batch_manager import LlmRequest - -from .scheduler import ScheduledRequests +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ This file contains the primary interface for the KV Cache Connector. @@ -34,6 +34,21 @@ To implement a custom KV connector, you need to implement both the scheduler and worker-side interfaces. """ +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Callable, Optional + +import torch + +from tensorrt_llm._utils import mpi_allgather, mpi_broadcast, mpi_rank +from tensorrt_llm.bindings import LlmRequestState +from tensorrt_llm.bindings.executor import ExecutorConfig +from tensorrt_llm.bindings.internal.batch_manager import \ + KvCacheConnectorManager as KvCacheConnectorManagerCpp +from tensorrt_llm.bindings.internal.batch_manager import LlmRequest + +from .scheduler import ScheduledRequests + # Used to store data for a single inflight request. @dataclass @@ -60,6 +75,26 @@ def add_request(self, request_id: int, new_tokens: list[int], RequestData(request_id, new_tokens, new_block_ids, computed_position)) + def record_first_prefill_chunk(self, req: LlmRequest, block_ids: list[int]): + if not req.is_kv_cache_connector_async_onboard: + self.requests.append( + RequestData(req.request_id, req.get_tokens(0), block_ids, + req.context_current_position)) + + def record_nth_prefill_chunk(self, req: LlmRequest): + self.requests.append( + RequestData(req.request_id, [], [], req.context_current_position)) + + def record_generation_req(self, req: LlmRequest, + delta_block_ids: list[int]): + + tokens = req.get_tokens(0) + computed_position = len(tokens) - 1 + + self.requests.append( + RequestData(req.request_id, tokens[-1:], delta_block_ids, + computed_position)) + class KvCacheConnectorWorker(ABC): @@ -77,13 +112,13 @@ def _clear_connector_meta(self): self._metadata = None @abstractmethod - def register_kv_caches(self, kv_cache_data: dict[int, torch.Tensor]): + def register_kv_caches(self, kv_cache_tensor: torch.Tensor): """ Register the KV cache tensors to the worker. This can be used for something like NIXL registration. Args: - kv_cache_data: The data for all the KV cache pools. + kv_cache_tensor: The contiguous KV cache tensor. """ @abstractmethod @@ -222,12 +257,14 @@ def extract_by_id(self, saving_ids: list[int], return new_async_requests + @property def saving_ids(self) -> set[int]: """ Get the IDs of the requests that are being saved asynchronously. """ return set(self.saving.keys()) + @property def loading_ids(self) -> set[int]: """ Get the IDs of the requests that are being loaded asynchronously. @@ -300,7 +337,7 @@ def get_num_new_matched_tokens(self, request: LlmRequest, return num_tokens def take_scheduled_requests_pending_load( - self, scheduled_requests: ScheduledRequests) -> ScheduledRequests: + self, scheduled_requests: ScheduledRequests): """ Remove context requests from our list of scheduled requests that are being loaded asynchronously. This is done to prevent the runtime from attempting to load the KV cache for these requests. @@ -361,8 +398,8 @@ def get_finished(self) -> list[LlmRequest]: Returns: The requests that have newly finished saving. """ - started_loading_req_ids = list(self.new_async_requests.loading_ids()) - finished_gen_req_ids = list(self.new_async_requests.saving_ids()) + started_loading_req_ids = list(self.new_async_requests.loading_ids) + finished_gen_req_ids = list(self.new_async_requests.saving_ids) # Add the requests to our list of outstanding (still in progress) requests. self.pending_async_requests.add_from(self.new_async_requests) @@ -381,9 +418,8 @@ def get_finished(self) -> list[LlmRequest]: new_local_finished_async_requests) # Broadcast this whole list to all other workers. - finished_saving = list(self.local_finished_async_requests.saving_ids()) - finished_loading = list( - self.local_finished_async_requests.loading_ids()) + finished_saving = list(self.local_finished_async_requests.saving_ids) + finished_loading = list(self.local_finished_async_requests.loading_ids) all_results = mpi_allgather((finished_saving, finished_loading)) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 463ef75c732..029f274d3c0 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -287,21 +287,24 @@ def _maybe_init_kv_connector_manager(self): "KV Cache Connector is not supported with pipeline parallelism." ) - # TODO: This does NOT support pipeline parallel. - layer_kv_tensors = { - layer_idx: self.kv_cache_manager.get_buffers(layer_idx) + all_layers = [ + self.kv_cache_manager.get_buffers(layer_idx) for layer_idx in self.kv_cache_manager.pp_layers - } + ] + + if not all(t.shape == all_layers[0].shape for t in all_layers): + raise ValueError( + "KV Cache Connector is not supported with sliding window attention." + ) - kv_shape = layer_kv_tensors[list(layer_kv_tensors.keys())[0]].shape + full_kv_tensor = torch.cat(all_layers, dim=1) - if not all(t.shape == kv_shape for t in layer_kv_tensors.values()): + if not full_kv_tensor.is_contiguous(): raise ValueError( - "KV Cache Connector is not supported with Variable sliding window attention!" + "KV Cache Connector is not supported with non-contiguous KV cache." ) - self.kv_connector_manager.worker.register_kv_caches( - layer_kv_tensors) + self.kv_connector_manager.worker.register_kv_caches(full_kv_tensor) # For each of our layers, we need to register the pre/post hooks. # These are used for methods like `wait_for_layer_load` and `save_kv_layer`. diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 03c40a3f723..6a271d4371a 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -390,13 +390,13 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): == self.mapping.cp_size - 1 else 0), req_beam_width, req, self.kv_connector_manager) else: - # TODO(jthomson04): This is begging for a mega refactor, and can likely be significantly simplified. - # In add sequence, the connector API's get_num_new_matched_tokens is called. + # In add_sequence, the connector API's get_num_new_matched_tokens is called. # The result of this call may be that blocks will be loaded asynchronously. # If so, we set the is_kv_cache_connector_async_onboard flag, and set the request state to be DISAGG_GENERATION_TRANS_IN_PROGRESS. # When the async load is complete, we set the request state back to CONTEXT_INIT. # When that happens, the request will go through this same code path, but with is_kv_cache_connector_async_onboard set to True. # Because of this, we need to filter this case out to avoid adding the same sequence twice. + # NOTE(jthomson04): Surely there's a better way to do this. if req.is_first_context_chunk and not req.is_kv_cache_connector_async_onboard: self.impl.add_sequence(req.py_request_id, req.prompt_len, req_beam_width, req, @@ -406,33 +406,21 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): for _ in range(get_draft_token_length(req)): self.impl.add_token(req.py_request_id) - # If this is not an async load, we can add the new tokens and blocks right away. - if not req.is_kv_cache_connector_async_onboard: - scheduler_output.add_request( - req.request_id, req.get_tokens(0), - self.get_cache_indices(req), - req.context_current_position) + scheduler_output.record_first_prefill_chunk( + req, self.get_cache_indices(req)) else: # When using the connector, this code path will be hit after the async load is complete. # Alternatively, with no connector, this is hit after the first chunk of prefill. - # If this is the first actual prefill, we can add all of our new tokens and blocks. + # If this is the first prefill chunk, we can add all of our new tokens and blocks. if req.is_first_context_chunk or req.is_kv_cache_connector_async_onboard: req.is_kv_cache_connector_async_onboard = False - scheduler_output.add_request( - req.request_id, req.get_tokens(0), - self.get_cache_indices(req), - req.context_current_position) + scheduler_output.record_first_prefill_chunk( + req, self.get_cache_indices(req)) else: - # Otherwise, we just provide the new context position. No new blocks are allocated. - scheduler_output.add_request( - req.request_id, [], [], - req.context_current_position) + scheduler_output.record_nth_prefill_chunk(req) for req in generation_batch: - tokens = req.get_tokens(0) - computed_token_position = len(req.get_tokens(0)) - 1 - old_block_ids = self.get_cache_indices(req) self.impl.add_token(req.py_request_id) @@ -444,9 +432,7 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): delta_block_ids = new_block_ids[len(old_block_ids):] - scheduler_output.add_request(req.request_id, tokens[-1:], - delta_block_ids, - computed_token_position) + scheduler_output.record_generation_req(req, delta_block_ids) if self.kv_connector_manager is not None: self.kv_connector_manager.set_scheduler_output(scheduler_output) diff --git a/tests/unittest/_torch/test_connector.py b/tests/unittest/_torch/test_connector.py index 039ce7e00a5..79a09b46fad 100644 --- a/tests/unittest/_torch/test_connector.py +++ b/tests/unittest/_torch/test_connector.py @@ -1,3 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import pickle import sys from unittest.mock import MagicMock @@ -23,6 +38,8 @@ def run_across_mpi(executor, fun, num_ranks): @pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True) +# TODO(jthomson04): I don't have the slightest idea why this test is leaking threads. +@pytest.mark.threadleak(enabled=False) def test_connector_manager_get_finished_allgather(mpi_pool_executor): def test(): From 1c3fe6f70f14785397d892fd20d4c9515a80099b Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Wed, 6 Aug 2025 11:36:04 -0700 Subject: [PATCH 30/50] Gate cuda graph support Signed-off-by: jthomson04 --- tensorrt_llm/_torch/pyexecutor/_util.py | 4 ++-- tensorrt_llm/_torch/pyexecutor/connector.py | 20 +++++++++---------- tensorrt_llm/_torch/pyexecutor/py_executor.py | 4 ++-- .../_torch/pyexecutor/py_executor_creator.py | 3 +++ .../_torch/pyexecutor/resource_manager.py | 2 +- 5 files changed, 17 insertions(+), 16 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 9cf17d7f33c..24c9fb9bde5 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -329,7 +329,7 @@ def _create_kv_cache_manager( ) if not for_estimation and self._kv_connector_manager is not None: - raise ValueError( + raise NotImplementedError( "Connector manager is not supported for MambaHybridCacheManager." ) @@ -406,7 +406,7 @@ def build_managers(self, self._model_engine, for_estimation) if not for_estimation and self._kv_connector_manager is not None and self._draft_model_engine is not None: - raise ValueError( + raise NotImplementedError( "Connector manager is not supported for draft model.") draft_kv_cache_manager = self._create_kv_cache_manager( diff --git a/tensorrt_llm/_torch/pyexecutor/connector.py b/tensorrt_llm/_torch/pyexecutor/connector.py index e7d6caa8093..40208967787 100644 --- a/tensorrt_llm/_torch/pyexecutor/connector.py +++ b/tensorrt_llm/_torch/pyexecutor/connector.py @@ -69,12 +69,6 @@ class RequestData: class SchedulerOutput: requests: list[RequestData] = field(default_factory=list) - def add_request(self, request_id: int, new_tokens: list[int], - new_block_ids: list[int], computed_position: int): - self.requests.append( - RequestData(request_id, new_tokens, new_block_ids, - computed_position)) - def record_first_prefill_chunk(self, req: LlmRequest, block_ids: list[int]): if not req.is_kv_cache_connector_async_onboard: self.requests.append( @@ -129,22 +123,25 @@ def start_load_kv(self): """ @abstractmethod - def wait_for_layer_load(self, layer_idx: int): + def wait_for_layer_load(self, layer_idx: int, stream: torch.cuda.Stream): """ Wait for a layer to finish being loaded before proceeding with the forward pass on the layer. + Note: This function is called immediately before the layer's work is enqueued into the stream. Args: layer_idx: The index of the layer to wait for. + stream: The stream the forward pass is being executed on. """ @abstractmethod - def save_kv_layer(self, layer_idx: int): + def save_kv_layer(self, layer_idx: int, stream: torch.cuda.Stream): """ Begin saving the KV cache for a layer. - This is called after the forward pass on the layer has completed. + Note: This function is called immediately after the layer's work is enqueued into the stream. Args: layer_idx: The index of the layer to save. + stream: The stream the forward pass is being executed on. """ @abstractmethod @@ -445,7 +442,8 @@ def set_scheduler_output(self, scheduler_output: SchedulerOutput): self._scheduler_output = scheduler_output def layer_pre_hook(self, module, *args): - self.worker.wait_for_layer_load(module.layer_idx) + self.worker.wait_for_layer_load(module.layer_idx, + torch.cuda.current_stream()) def layer_post_hook(self, module, *args): - self.worker.save_kv_layer(module.layer_idx) + self.worker.save_kv_layer(module.layer_idx, torch.cuda.current_stream()) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 029f274d3c0..6e59b331664 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -293,14 +293,14 @@ def _maybe_init_kv_connector_manager(self): ] if not all(t.shape == all_layers[0].shape for t in all_layers): - raise ValueError( + raise NotImplementedError( "KV Cache Connector is not supported with sliding window attention." ) full_kv_tensor = torch.cat(all_layers, dim=1) if not full_kv_tensor.is_contiguous(): - raise ValueError( + raise NotImplementedError( "KV Cache Connector is not supported with non-contiguous KV cache." ) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 81f0396743b..ad38afd96c6 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -367,6 +367,9 @@ def create_py_executor( logger.info( f"Initializing kv connector with config: {kv_connector_config}") + if pytorch_backend_config.use_cuda_graph: + raise NotImplementedError( + "CUDA graphs are not supported with KV connector hooks.") try: module = importlib.import_module( kv_connector_config.connector_module) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 6a271d4371a..4d745cd7567 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -388,7 +388,7 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): req.py_request_id, seq_len + (len(req.query_id) if self.mapping.cp_rank == self.mapping.cp_size - 1 else 0), - req_beam_width, req, self.kv_connector_manager) + req_beam_width, req, None) else: # In add_sequence, the connector API's get_num_new_matched_tokens is called. # The result of this call may be that blocks will be loaded asynchronously. From 914b34b86f2574faa4af17547977e80ef139b836 Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Wed, 6 Aug 2025 13:03:00 -0700 Subject: [PATCH 31/50] Include cache block ids in request_finished Signed-off-by: jthomson04 --- tensorrt_llm/_torch/pyexecutor/connector.py | 8 +++++--- tensorrt_llm/_torch/pyexecutor/py_executor.py | 9 +++++++-- .../defs/llmapi/test_llm_api_connector.py | 13 ++++++++++++- 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/connector.py b/tensorrt_llm/_torch/pyexecutor/connector.py index 40208967787..c515246eedd 100644 --- a/tensorrt_llm/_torch/pyexecutor/connector.py +++ b/tensorrt_llm/_torch/pyexecutor/connector.py @@ -205,7 +205,8 @@ def get_num_new_matched_tokens( """ @abstractmethod - def request_finished(self, request: LlmRequest) -> bool: + def request_finished(self, request: LlmRequest, + cache_block_ids: list[int]) -> bool: """ Called when a request is finished generating tokens. @@ -366,7 +367,8 @@ def handle_metadata(self) -> object: self.worker.bind_connector_meta(metadata) - def request_finished(self, req: LlmRequest) -> bool: + def request_finished(self, req: LlmRequest, + cache_block_ids: list[int]) -> bool: """ Called when a request is finished generating tokens. @@ -378,7 +380,7 @@ def request_finished(self, req: LlmRequest) -> bool: """ saving_async = self._run_on_leader( - lambda: self.scheduler.request_finished(req)) + lambda: self.scheduler.request_finished(req, cache_block_ids)) # This is similar to take_scheduled_requests_pending_load. # We need to update the request's state to indicate that it's still being used, but isn't schedulable. diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 6e59b331664..ac7d7aca98c 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1603,9 +1603,14 @@ def _handle_errors(self, self._enqueue_responses(error_responses.items()) def _terminate_request(self, request: LlmRequest): - if self.kv_connector_manager is None or not self.kv_connector_manager.request_finished( - request): + if self.kv_connector_manager is None: self.resource_manager.free_resources(request) + else: + cache_block_ids = self.kv_cache_manager.get_cache_indices(request) + + if not self.kv_connector_manager.request_finished( + request, cache_block_ids): + self.resource_manager.free_resources(request) @nvtx_range("_handle_canceled_requests") def _handle_canceled_requests(self): diff --git a/tests/integration/defs/llmapi/test_llm_api_connector.py b/tests/integration/defs/llmapi/test_llm_api_connector.py index 43b54c0bea6..d1dc31fd480 100644 --- a/tests/integration/defs/llmapi/test_llm_api_connector.py +++ b/tests/integration/defs/llmapi/test_llm_api_connector.py @@ -114,6 +114,9 @@ def test_connector_simple(model_with_connector, use_overlap_scheduler): use_overlap_scheduler) assert scheduler.request_finished.call_count == 1 + + assert len(scheduler.request_finished.call_args.args[1]) == 1 + assert worker.get_finished.call_count == NUM_TOKENS + int( use_overlap_scheduler) @@ -180,6 +183,8 @@ def test_connector_async_save(model_with_connector, use_overlap_scheduler): assert scheduler.request_finished.call_count == 1 + assert len(scheduler.request_finished.call_args.args[1]) == 1 + # On the last call to get_finished, we should be providing the async saving request. One extra token when using the overlap scheduler. assert worker.get_finished.call_count == NUM_TOKENS + int( use_overlap_scheduler) @@ -248,6 +253,9 @@ def test_connector_scheduler_output(model_with_connector, scheduler.get_num_new_matched_tokens.return_value = 8, False + assert len(scheduler.request_finished.call_args.args[1]) == math.ceil( + (NUM_INPUT_TOKENS + NUM_TOKENS) / BLOCK_SIZE) + model.generate([0] * NUM_INPUT_TOKENS, sampling_params) assert scheduler.build_connector_meta.call_args_list[0].args[0].requests[ @@ -278,7 +286,7 @@ def test_connector_scheduler_output_chunked_context(model_with_connector, worker.get_finished.return_value = [], [] - sampling_params = SamplingParams(max_tokens=32, ignore_eos=True) + sampling_params = SamplingParams(max_tokens=BLOCK_SIZE, ignore_eos=True) model.generate([0] * (CHUNK_SIZE * 2), sampling_params) @@ -303,3 +311,6 @@ def test_connector_scheduler_output_chunked_context(model_with_connector, assert len(req.new_block_ids) == 0 else: assert len(req.new_tokens) == 1 + + assert len(scheduler.request_finished.call_args.args[1]) == math.ceil( + (CHUNK_SIZE * 2 + BLOCK_SIZE) / BLOCK_SIZE) From 5a5ea470141e8db58840c318e5039e014dd64833 Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Wed, 6 Aug 2025 18:00:16 -0700 Subject: [PATCH 32/50] Little bugfixes and implement a basic example Signed-off-by: jthomson04 --- .../batch_manager/kvCacheManager.h | 2 + .../batch_manager/kvCacheManager.cpp | 7 + .../pybind/batch_manager/kvCacheManager.cpp | 11 +- examples/llm-api/connector.py | 232 ++++++++++++++++++ tensorrt_llm/_torch/pyexecutor/connector.py | 13 +- tensorrt_llm/_torch/pyexecutor/py_executor.py | 26 +- .../_torch/pyexecutor/resource_manager.py | 3 + .../defs/llmapi/test_llm_api_connector.py | 36 ++- 8 files changed, 291 insertions(+), 39 deletions(-) create mode 100644 examples/llm-api/connector.py diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index b5982978123..46b8f7b52cf 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -1303,6 +1303,7 @@ class BaseKVCacheManager LlmRequest::RequestIdType requestId, SizeType32 windowSize) const = 0; + [[nodiscard]] virtual runtime::ITensor::SharedPtr getUniquePrimaryPool() const = 0; [[nodiscard]] virtual runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 layer_idx) const = 0; [[nodiscard]] virtual SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const = 0; @@ -1640,6 +1641,7 @@ class KVCacheManager : public BaseKVCacheManager std::vector getNewlyAllocatedBlockIds( LlmRequest::RequestIdType requestId, SizeType32 windowSize) const override; + runtime::ITensor::SharedPtr getUniquePrimaryPool() const override; runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 layer_idx) const override; SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const override diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 93f569115a9..340eec25fa1 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -2522,6 +2522,13 @@ std::vector KVCacheManager::getNewlyAllocatedBlockIds( return mBlockManager.getNewlyAllocatedBlockIds(getSequence(requestId), windowSize); } +runtime::ITensor::SharedPtr KVCacheManager::getUniquePrimaryPool() const +{ + TLLM_CHECK_WITH_INFO(mBlockManager.getWindowSizesMetadata().size() == 1, + "getUniquePrimaryPool is only supported for a single window size"); + return mBlockManager.getPrimaryPool(0); +} + runtime::ITensor::SharedPtr KVCacheManager::getPrimaryPool(SizeType32 layer_idx) const { return mBlockManager.getPrimaryPool(mBlockManager.getLayerPoolIdx(layer_idx)); diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index 5ef7ca34485..3861f3f5c18 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -216,10 +216,16 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager std::deque, tbk::BaseKVCacheManager, getLatestEvents, timeout); } - tensorrt_llm::runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 layer_idx) const override + tensorrt_llm::runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 poolIdx) const override { PYBIND11_OVERLOAD_PURE( - tensorrt_llm::runtime::ITensor::SharedPtr, tbk::BaseKVCacheManager, getPrimaryPool, layer_idx); + tensorrt_llm::runtime::ITensor::SharedPtr, tbk::BaseKVCacheManager, getPrimaryPool, poolIdx); + } + + tensorrt_llm::runtime::ITensor::SharedPtr getUniquePrimaryPool() const override + { + PYBIND11_OVERLOAD_PURE( + tensorrt_llm::runtime::ITensor::SharedPtr, tbk::BaseKVCacheManager, getUniquePrimaryPool); } SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const override @@ -380,6 +386,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) auto pool_layer_idx = self.getPoolLayerIdx(layer_idx); return pool.index({torch::indexing::Slice(), pool_layer_idx}); }) + .def("get_unique_primary_pool", [](tbk::BaseKVCacheManager& self) { return self.getUniquePrimaryPool(); }) .def("get_block_offsets_of_batch", [](tbk::BaseKVCacheManager& self, at::Tensor output, SizeType32 firstBatchSlotIdx, SizeType32 batchSize, SizeType32 beamWidth) diff --git a/examples/llm-api/connector.py b/examples/llm-api/connector.py new file mode 100644 index 00000000000..be6aa3e168c --- /dev/null +++ b/examples/llm-api/connector.py @@ -0,0 +1,232 @@ +import os +import sys +from dataclasses import dataclass, field +from pathlib import Path +from tempfile import TemporaryDirectory + +import torch + +from tensorrt_llm import LLM, SamplingParams, logger +from tensorrt_llm._torch.pyexecutor.connector import (KvCacheConnectorScheduler, + KvCacheConnectorWorker, + SchedulerOutput) +from tensorrt_llm.bindings.executor import ExecutorConfig +from tensorrt_llm.bindings.internal.batch_manager import LlmRequest +from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig + +# This is a simple example of the use of the KV cache connector. +# It persists KV cache contents into a folder, and can load them back on subsequent runs. +# See tensorrt_llm/_torch/pyexecutor/connector.py for details about the KV cache connector interface. +# NOTE: This example connector implementation is NOT suitable for production use. + + +@dataclass +class PersistentKvCacheConnectorMetadata: + load: list[tuple[str, int]] = field(default_factory=list) + save: list[tuple[str, int]] = field(default_factory=list) + + +class PersistentKvCacheConnectorWorker(KvCacheConnectorWorker): + + def __init__(self, executor_config: ExecutorConfig): + super().__init__(executor_config) + + self.kv_cache_tensor = None + + def register_kv_caches(self, kv_cache_tensor: torch.Tensor): + assert self.kv_cache_tensor is None, "KV cache tensor already registered" + self.kv_cache_tensor = kv_cache_tensor + + def start_load_kv(self, stream: torch.cuda.Stream): + # Do all loads synchronously, and blockwise. + for path, block_id in self._metadata.load: + cpu_tensor = torch.load(path, map_location="cpu") + + # Copy into the device block. + self.kv_cache_tensor[block_id].copy_(cpu_tensor, non_blocking=False) + + def wait_for_layer_load(self, layer_idx: int, stream: torch.cuda.Stream): + pass + + def save_kv_layer(self, layer_idx: int, stream: torch.cuda.Stream): + pass + + def wait_for_save(self, stream: torch.cuda.Stream): + + # Make sure the forward pass is complete before beginning our save. + stream.synchronize() + + for path, block_id in self._metadata.save: + cpu_tensor = self.kv_cache_tensor[block_id].cpu() + + # Don't write anything if this specific block already exists. + if Path(path).exists(): + continue + + # Do a blocking save to the file. This way, we only return once all saves are complete. + torch.save(cpu_tensor, path) + + def get_finished( + self, finished_gen_req_ids: list[int], + started_loading_req_ids: list[int]) -> tuple[list[int], list[int]]: + + return [], [] + + +class PersistentKvCacheConnectorLeader(KvCacheConnectorScheduler): + + def __init__(self, executor_config: ExecutorConfig): + super().__init__(executor_config) + + self.block_size = self._config.tokens_per_block + self.pending_loads = {} + + self.cache_folder = os.environ.get("CONNECTOR_CACHE_FOLDER", + "./connector_cache") + + os.makedirs(self.cache_folder, exist_ok=True) + + def build_connector_meta(self, scheduler_output: SchedulerOutput): + # NOTE: This is a simplified implementation, and does not work with chunked prefill. + + metadata = PersistentKvCacheConnectorMetadata() + + for req in scheduler_output.new_requests: + # If we don't have any pending loads for this request, we can skip it. + if req.request_id not in self.pending_loads: + continue + + num_computed_blocks = req.computed_position // self.block_size + block_ids = req.new_block_ids + + pending_load = self.pending_loads[req.request_id] + + # TODO: The `computed_position` field in the scheduler output counts both the device cache hits and onboarded device blocks. + # This is inconsistent with vLLM. + for file_path, block_pos in zip( + pending_load, + range(num_computed_blocks - len(pending_load), + len(block_ids))): + metadata.load.append((file_path, block_ids[block_pos])) + + # Break up the remainder of the token sequence into chunks. + chunks = self._chunk_tokens(req.new_tokens) + + # For each chunk that isn't already on device, and isn't in our connector cache, we need to save it. + for block_pos in range(num_computed_blocks + len(pending_load), + len(block_ids)): + if len(chunks[block_pos]) == self.block_size: + hashed_tokens = self._hash_tokens(chunks[block_pos]) + + file_path = self._file_path(hashed_tokens) + + metadata.save.append((file_path, block_ids[block_pos])) + + self.pending_loads = {} + + return metadata + + def _hash_tokens(self, tokens: list[int]) -> int: + return abs(hash(tuple(tokens))) + + def _file_path(self, hash_value: int) -> Path: + return Path(self.cache_folder) / f"{hash_value}.pt" + + def _chunk_tokens(self, tokens: list[int]) -> list[list[int]]: + return [ + tokens[i:i + self.block_size] + for i in range(0, len(tokens), self.block_size) + ] + + def get_num_new_matched_tokens( + self, request: LlmRequest, + num_computed_tokens: int) -> tuple[int, bool]: + self.pending_loads[request.request_id] = [] + + # Don't bother with sequences with partial matches. + if (num_computed_tokens % self.block_size) != 0: + return 0, False + + computed_blocks = num_computed_tokens // self.block_size + + # Get all the tokens that don't have a cache hit on device. + remaining_tokens = request.get_tokens(0)[computed_blocks * + self.block_size:] + + remaining_chunks = self._chunk_tokens(remaining_tokens) + + # For each chunk, check if it exists in our cache. + for chunk in remaining_chunks: + # Only do full blocks. + if len(chunk) == self.block_size: + hashed_tokens = self._hash_tokens(chunk) + + file_path = self._file_path(hashed_tokens) + + # If we get a cache hit, we want to load it into device. + # Otherwise, we can stop looking. + if file_path.exists(): + self.pending_loads[request.request_id].append(file_path) + else: + break + + logger.info( + f"KV CONNECTOR: Matched {len(self.pending_loads[request.request_id])} blocks for request {request.request_id}" + ) + + return len( + self.pending_loads[request.request_id]) * self.block_size, False + + def request_finished(self, request: LlmRequest, + cache_block_ids: list[int]) -> bool: + # We don't do any asynchronous saving, so always return False + return False + + +if __name__ == "__main__": + + sys.path.append(os.path.join( + os.path.dirname(__file__), + "..", + )) + + this_module = __file__[__file__.rfind("/") + 1:__file__.rfind(".py")] + + connector_config = KvCacheConnectorConfig( + connector_module=this_module, + connector_scheduler_class="PersistentKvCacheConnectorLeader", + connector_worker_class="PersistentKvCacheConnectorWorker", + ) + + connector_cache_dir = TemporaryDirectory() + os.environ["CONNECTOR_CACHE_FOLDER"] = connector_cache_dir.name + + model = LLM(model="Qwen/Qwen2-0.5B", + backend="pytorch", + cuda_graph_config=None, + connector_config=connector_config) + + test_text = ( + "Nvidia Corporation is an American technology company headquartered in Santa Clara, California." + "Founded in 1993 by Jensen Huang, Chris Malachowsky, and Curtis Priem, it develops graphics processing units (GPUs), " + "system on a chips (SoCs), and application programming interfaces (APIs) for data science, high-performance computing, " + "and mobile and automotive applications. Tell me about the company.") + + sampling_params = SamplingParams(max_tokens=32) + + output = model.generate([test_text], sampling_params) + + print("First output: ", output[0].outputs[0].text) + print("Loading new LLM instance...") + + del model + + model = LLM(model="Qwen/Qwen2-0.5B", + backend="pytorch", + cuda_graph_config=None, + connector_config=connector_config) + + output = model.generate([test_text], sampling_params) + print("Second output (using connector cache): ", output[0].outputs[0].text) + + connector_cache_dir.cleanup() diff --git a/tensorrt_llm/_torch/pyexecutor/connector.py b/tensorrt_llm/_torch/pyexecutor/connector.py index c515246eedd..d966894c664 100644 --- a/tensorrt_llm/_torch/pyexecutor/connector.py +++ b/tensorrt_llm/_torch/pyexecutor/connector.py @@ -67,16 +67,17 @@ class RequestData: # This is used when calling `build_connector_meta` on the scheduler. @dataclass class SchedulerOutput: - requests: list[RequestData] = field(default_factory=list) + new_requests: list[RequestData] = field(default_factory=list) + cached_requests: list[RequestData] = field(default_factory=list) def record_first_prefill_chunk(self, req: LlmRequest, block_ids: list[int]): if not req.is_kv_cache_connector_async_onboard: - self.requests.append( + self.new_requests.append( RequestData(req.request_id, req.get_tokens(0), block_ids, req.context_current_position)) def record_nth_prefill_chunk(self, req: LlmRequest): - self.requests.append( + self.cached_requests.append( RequestData(req.request_id, [], [], req.context_current_position)) def record_generation_req(self, req: LlmRequest, @@ -85,7 +86,7 @@ def record_generation_req(self, req: LlmRequest, tokens = req.get_tokens(0) computed_position = len(tokens) - 1 - self.requests.append( + self.cached_requests.append( RequestData(req.request_id, tokens[-1:], delta_block_ids, computed_position)) @@ -116,7 +117,7 @@ def register_kv_caches(self, kv_cache_tensor: torch.Tensor): """ @abstractmethod - def start_load_kv(self): + def start_load_kv(self, stream: torch.cuda.Stream): """ Begin loading the KV cache in preparation for the next forward pass. Specific blocks to transfer are indicated by the scheduler's metadata. @@ -145,7 +146,7 @@ def save_kv_layer(self, layer_idx: int, stream: torch.cuda.Stream): """ @abstractmethod - def wait_for_save(self): + def wait_for_save(self, stream: torch.cuda.Stream): """ Block until all synchronous saving operations are complete. Called at the end of the forward pass. """ diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index ac7d7aca98c..981bd5ae067 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -287,24 +287,8 @@ def _maybe_init_kv_connector_manager(self): "KV Cache Connector is not supported with pipeline parallelism." ) - all_layers = [ - self.kv_cache_manager.get_buffers(layer_idx) - for layer_idx in self.kv_cache_manager.pp_layers - ] - - if not all(t.shape == all_layers[0].shape for t in all_layers): - raise NotImplementedError( - "KV Cache Connector is not supported with sliding window attention." - ) - - full_kv_tensor = torch.cat(all_layers, dim=1) - - if not full_kv_tensor.is_contiguous(): - raise NotImplementedError( - "KV Cache Connector is not supported with non-contiguous KV cache." - ) - - self.kv_connector_manager.worker.register_kv_caches(full_kv_tensor) + kv_tensor = self.kv_cache_manager.get_unique_primary_pool() + self.kv_connector_manager.worker.register_kv_caches(kv_tensor) # For each of our layers, we need to register the pre/post hooks. # These are used for methods like `wait_for_layer_load` and `save_kv_layer`. @@ -960,7 +944,8 @@ def _handle_kv_connector(self, scheduled_batch): self.kv_connector_manager.take_scheduled_requests_pending_load( scheduled_batch) self.kv_connector_manager.handle_metadata() - self.kv_connector_manager.worker.start_load_kv() + self.kv_connector_manager.worker.start_load_kv( + torch.cuda.current_stream()) def _terminate_async_save_requests(self): if self.kv_connector_manager: @@ -1493,7 +1478,8 @@ def forward(scheduled_requests, resource_manager, new_tensors_device, cache_indirection_buffer) if self.kv_connector_manager is not None: - self.kv_connector_manager.worker.wait_for_save() + self.kv_connector_manager.worker.wait_for_save( + torch.cuda.current_stream()) return outputs except Exception as e: diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 4d745cd7567..b80011f7bf0 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -632,6 +632,9 @@ def get_buffers(self, layer_idx: int) -> Optional[torch.Tensor]: self.head_dim, ) + def get_unique_primary_pool(self) -> torch.Tensor: + return self.impl.get_unique_primary_pool() + def get_block_ids_per_seq(self, request_ids: List[int]) -> torch.Tensor: block_ids_per_seq = self.get_batch_cache_indices(request_ids) block_ids_per_seq_tensors = [ diff --git a/tests/integration/defs/llmapi/test_llm_api_connector.py b/tests/integration/defs/llmapi/test_llm_api_connector.py index d1dc31fd480..5b6df019086 100644 --- a/tests/integration/defs/llmapi/test_llm_api_connector.py +++ b/tests/integration/defs/llmapi/test_llm_api_connector.py @@ -82,11 +82,14 @@ def test_connector_simple(model_with_connector, use_overlap_scheduler): # We should have a single `SchedulerOutput` per forward pass. for i, call in enumerate(scheduler.build_connector_meta.call_args_list): scheduler_output = call[0][0] - assert len(scheduler_output.requests) == 1 + if i == 0: + assert len(scheduler_output.new_requests) == 1 + assert len(scheduler_output.cached_requests) == 0 + else: + assert len(scheduler_output.new_requests) == 0 + assert len(scheduler_output.cached_requests) == 1 - # If this is not prefill, we should always be adding a single token. - if i != 0: - assert len(scheduler_output.requests[0].new_tokens) == 1 + assert len(scheduler_output.cached_requests[0].new_tokens) == 1 # We call `start_load_kv` once at the beginning of each forward pass. assert worker.start_load_kv.call_count == NUM_TOKENS + int( @@ -233,14 +236,20 @@ def test_connector_scheduler_output(model_with_connector, for i, call in enumerate(scheduler.build_connector_meta.call_args_list): sched_output = call.args[0] - assert len(sched_output.requests) == 1 - request = sched_output.requests[0] if i == 0: + assert len(sched_output.new_requests) == 1 + assert len(sched_output.cached_requests) == 0 + request = sched_output.new_requests[0] + assert len(request.new_tokens) == NUM_INPUT_TOKENS assert len(request.new_block_ids) == math.ceil(NUM_INPUT_TOKENS / BLOCK_SIZE) assert request.computed_position == 0 else: + assert len(sched_output.cached_requests) == 1 + assert len(sched_output.new_requests) == 0 + request = sched_output.cached_requests[0] + assert len(request.new_tokens) == 1 if (request.computed_position + @@ -258,8 +267,8 @@ def test_connector_scheduler_output(model_with_connector, model.generate([0] * NUM_INPUT_TOKENS, sampling_params) - assert scheduler.build_connector_meta.call_args_list[0].args[0].requests[ - 0].computed_position == 8 + assert scheduler.build_connector_meta.call_args_list[0].args[ + 0].new_requests[0].computed_position == 8 @pytest.mark.threadleak(enabled=False) @@ -293,9 +302,14 @@ def test_connector_scheduler_output_chunked_context(model_with_connector, for i, call in enumerate(scheduler.build_connector_meta.call_args_list): sched_output = call.args[0] - assert len(sched_output.requests) == 1 - - req = sched_output.requests[0] + if i == 0: + assert len(sched_output.new_requests) == 1 + assert len(sched_output.cached_requests) == 0 + req = sched_output.new_requests[0] + else: + assert len(sched_output.cached_requests) == 1 + assert len(sched_output.new_requests) == 0 + req = sched_output.cached_requests[0] if i == 0: # The first prefill chunk. From 2056e707c451629a0458191e1a474f2bf462133d Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Thu, 7 Aug 2025 13:16:16 -0700 Subject: [PATCH 33/50] Address reviewer comments Signed-off-by: jthomson04 --- .../tensorrt_llm/batch_manager/llmRequest.h | 1 + .../batch_manager/kvCacheManager.cpp | 2 -- tensorrt_llm/_torch/pyexecutor/_util.py | 18 +++++------ tensorrt_llm/_torch/pyexecutor/connector.py | 32 +++++++++++-------- tensorrt_llm/_torch/pyexecutor/py_executor.py | 18 ++++------- .../_torch/pyexecutor/resource_manager.py | 20 +++++++----- tests/unittest/_torch/test_connector.py | 2 +- 7 files changed, 47 insertions(+), 46 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index f1c9d52763e..595e92b5948 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -2027,6 +2027,7 @@ class GenericLlmRequest bool mIsDummyRequest{false}; + /// Whether any blocks for this request are being asynchronously onboarded via the kv cache connector. bool mIsKvCacheConnectorAsyncOnboard{false}; private: diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 340eec25fa1..a8455376890 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -1244,8 +1244,6 @@ void WindowBlockManager::addSequence(GenerationRequest& sequence, SizeType32 inp if (kvCacheConnectorManager) { numConnectorMatchedTokens = kvCacheConnectorManager->getNumNewMatchedTokens(llmRequest, prepopulatedPromptLen); - TLLM_CHECK_WITH_INFO(prepopulatedPromptLen + numConnectorMatchedTokens < llmRequest.getPromptLen(), - "There must be at least one uncomputed token in the prompt!"); } llmRequest.setPrepopulatedPromptLen(prepopulatedPromptLen + numConnectorMatchedTokens, getTokensPerBlock()); diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 24c9fb9bde5..f566cb810bc 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -278,7 +278,7 @@ def estimate_max_tokens(self, py_executor: PyExecutor) -> None: def _create_kv_cache_manager( self, model_engine: PyTorchModelEngine, - for_estimation: bool = False) -> KVCacheManager: + estimating_kv_cache: bool = False) -> KVCacheManager: executor_config = self._executor_config mapping = self._mapping assert model_engine.model.model_config.is_generation, "Only construct KV cache for generation models." @@ -320,7 +320,7 @@ def _create_kv_cache_manager( spec_config=spec_config, max_beam_width=executor_config.max_beam_width, kv_connector_manager=self._kv_connector_manager - if not for_estimation else None, + if not estimating_kv_cache else None, ) elif is_nemotron_hybrid(config): if executor_config.max_beam_width > 1: @@ -328,7 +328,7 @@ def _create_kv_cache_manager( "MambaHybridCacheManager + beam search is not supported yet." ) - if not for_estimation and self._kv_connector_manager is not None: + if not estimating_kv_cache and self._kv_connector_manager is not None: raise NotImplementedError( "Connector manager is not supported for MambaHybridCacheManager." ) @@ -390,7 +390,7 @@ def _create_kv_cache_manager( model_config=binding_model_config, max_beam_width=executor_config.max_beam_width, kv_connector_manager=self._kv_connector_manager - if not for_estimation else None, + if not estimating_kv_cache else None, ) # KVCacheManager (Non-draft) modifies the max_seq_len field, update it to executor_config if model_engine.kv_cache_manager_key == ResourceManagerType.KV_CACHE_MANAGER: @@ -400,18 +400,18 @@ def _create_kv_cache_manager( def build_managers(self, resources: Dict, - for_estimation: bool = False) -> None: + estimating_kv_cache: bool = False) -> None: """Construct KV caches for model and draft model (if applicable).""" kv_cache_manager = self._create_kv_cache_manager( - self._model_engine, for_estimation) + self._model_engine, estimating_kv_cache) - if not for_estimation and self._kv_connector_manager is not None and self._draft_model_engine is not None: + if not estimating_kv_cache and self._kv_connector_manager is not None and self._draft_model_engine is not None: raise NotImplementedError( "Connector manager is not supported for draft model.") draft_kv_cache_manager = self._create_kv_cache_manager( - self._draft_model_engine, - for_estimation) if self._draft_model_engine is not None else None + self._draft_model_engine, estimating_kv_cache + ) if self._draft_model_engine is not None else None resources[ResourceManagerType.KV_CACHE_MANAGER] = kv_cache_manager resources[ diff --git a/tensorrt_llm/_torch/pyexecutor/connector.py b/tensorrt_llm/_torch/pyexecutor/connector.py index d966894c664..a0b85664e8a 100644 --- a/tensorrt_llm/_torch/pyexecutor/connector.py +++ b/tensorrt_llm/_torch/pyexecutor/connector.py @@ -36,7 +36,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Any, Callable, Optional +from typing import Any, Callable, List, Optional import torch @@ -56,9 +56,9 @@ class RequestData: # The request ID. request_id: int # The new tokens that were generated in the prior forward pass. - new_tokens: list[int] + new_tokens: List[int] # The new block IDs allocated in the prior forward pass. - new_block_ids: list[int] + new_block_ids: List[int] # The position of the latest token with computed (valid) kv cache values. computed_position: int @@ -67,10 +67,13 @@ class RequestData: # This is used when calling `build_connector_meta` on the scheduler. @dataclass class SchedulerOutput: - new_requests: list[RequestData] = field(default_factory=list) - cached_requests: list[RequestData] = field(default_factory=list) + # Requests being scheduled for the first time. Requests will show up in `new_request` exactly once. + new_requests: List[RequestData] = field(default_factory=list) - def record_first_prefill_chunk(self, req: LlmRequest, block_ids: list[int]): + # Requests being scheduled, that have already shown up in `new_requests`. + cached_requests: List[RequestData] = field(default_factory=list) + + def record_first_prefill_chunk(self, req: LlmRequest, block_ids: List[int]): if not req.is_kv_cache_connector_async_onboard: self.new_requests.append( RequestData(req.request_id, req.get_tokens(0), block_ids, @@ -81,7 +84,7 @@ def record_nth_prefill_chunk(self, req: LlmRequest): RequestData(req.request_id, [], [], req.context_current_position)) def record_generation_req(self, req: LlmRequest, - delta_block_ids: list[int]): + delta_block_ids: List[int]): tokens = req.get_tokens(0) computed_position = len(tokens) - 1 @@ -153,8 +156,8 @@ def wait_for_save(self, stream: torch.cuda.Stream): @abstractmethod def get_finished( - self, finished_gen_req_ids: list[int], - started_loading_req_ids: list[int]) -> tuple[list[int], list[int]]: + self, finished_gen_req_ids: List[int], + started_loading_req_ids: List[int]) -> tuple[List[int], List[int]]: """ Get the requests that have finished loading and saving. @@ -189,6 +192,7 @@ def build_connector_meta(self, scheduler_output: SchedulerOutput): The metadata for the workers. """ + @abstractmethod def get_num_new_matched_tokens( self, request: LlmRequest, num_computed_tokens: int) -> tuple[int, bool]: @@ -207,7 +211,7 @@ def get_num_new_matched_tokens( @abstractmethod def request_finished(self, request: LlmRequest, - cache_block_ids: list[int]) -> bool: + cache_block_ids: List[int]) -> bool: """ Called when a request is finished generating tokens. @@ -236,8 +240,8 @@ def add_from(self, other: 'AsyncRequests'): other.saving = dict() other.loading = dict() - def extract_by_id(self, saving_ids: list[int], - loading_ids: list[int]) -> 'AsyncRequests': + def extract_by_id(self, saving_ids: List[int], + loading_ids: List[int]) -> 'AsyncRequests': """ Extract the requests with the given IDs from this `AsyncRequests` object. @@ -369,7 +373,7 @@ def handle_metadata(self) -> object: self.worker.bind_connector_meta(metadata) def request_finished(self, req: LlmRequest, - cache_block_ids: list[int]) -> bool: + cache_block_ids: List[int]) -> bool: """ Called when a request is finished generating tokens. @@ -391,7 +395,7 @@ def request_finished(self, req: LlmRequest, return saving_async - def get_finished(self) -> list[LlmRequest]: + def get_finished(self) -> List[LlmRequest]: """ Process requests that have finished loading and saving. diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 981bd5ae067..ae02c366845 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -919,12 +919,6 @@ def _prepare_and_schedule_batch(self): "num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache" ) self.kv_cache_transceiver.check_context_transfer_status(1) - elif self.kv_connector_manager is None: - # The kv cache connector also puts requests to sleep similar to the transceiver. - # Thus, this assertion is only applicable when both the cache transceiver and connector are disabled. - assert scheduled_batch.batch_size > 0, ( - "fail to schedule any pending request, " - "probably run out of resource.") self.num_scheduled_requests = scheduled_batch.batch_size logger.debug( @@ -939,7 +933,7 @@ def _execute_guided_decoder(self, scheduled_batch: ScheduledRequests, self.guided_decoder.build(scheduled_batch) self.guided_decoder.execute(scheduled_batch, logits) - def _handle_kv_connector(self, scheduled_batch): + def _kv_connector_start_batch(self, scheduled_batch): if self.kv_connector_manager: self.kv_connector_manager.take_scheduled_requests_pending_load( scheduled_batch) @@ -947,7 +941,7 @@ def _handle_kv_connector(self, scheduled_batch): self.kv_connector_manager.worker.start_load_kv( torch.cuda.current_stream()) - def _terminate_async_save_requests(self): + def _kv_connector_terminate_requests(self): if self.kv_connector_manager: reqs_to_terminate = self.kv_connector_manager.get_finished() for req in reqs_to_terminate: @@ -989,7 +983,7 @@ def _executor_loop(self): self.guided_decoder.init_disagg_gen_requests( scheduled_batch) - self._handle_kv_connector(scheduled_batch) + self._kv_connector_start_batch(scheduled_batch) if scheduled_batch.batch_size > 0 or ( self.enable_attention_dp and self.dist.tp_size > 1): @@ -1027,7 +1021,7 @@ def _executor_loop(self): if self.kv_cache_transceiver and self.ctx_in_transmission_requests: self._terminate_ctx_finished_requests() - self._terminate_async_save_requests() + self._kv_connector_terminate_requests() if self.enable_iter_perf_stats: iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[ @@ -1097,7 +1091,7 @@ def _executor_loop_overlap(self): scheduled_batch) self.resource_manager.prepare_resources(scheduled_batch) - self._handle_kv_connector(scheduled_batch) + self._kv_connector_start_batch(scheduled_batch) if scheduled_batch.batch_size > 0: @@ -1156,7 +1150,7 @@ def _executor_loop_overlap(self): if self.kv_cache_transceiver and self.ctx_in_transmission_requests: self._terminate_ctx_finished_requests() - self._terminate_async_save_requests() + self._kv_connector_terminate_requests() def _process_previous_batch(self): if self.kv_cache_transceiver and self.previous_batch.ctx_transmission_reqs: diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index b80011f7bf0..314cfb9a939 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -374,7 +374,8 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): generation_batch = scheduled_batch.generation_requests # Build the scheduler output for the connector. - scheduler_output = SchedulerOutput() + scheduler_output = SchedulerOutput( + ) if self.kv_connector_manager is not None else None # allocate KV Cache for req in context_batch: @@ -406,9 +407,10 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): for _ in range(get_draft_token_length(req)): self.impl.add_token(req.py_request_id) - scheduler_output.record_first_prefill_chunk( - req, self.get_cache_indices(req)) - else: + if self.kv_connector_manager is not None: + scheduler_output.record_first_prefill_chunk( + req, self.get_cache_indices(req)) + elif self.kv_connector_manager is not None: # When using the connector, this code path will be hit after the async load is complete. # Alternatively, with no connector, this is hit after the first chunk of prefill. @@ -421,18 +423,20 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): scheduler_output.record_nth_prefill_chunk(req) for req in generation_batch: - old_block_ids = self.get_cache_indices(req) + if self.kv_connector_manager is not None: + old_block_ids = self.get_cache_indices(req) self.impl.add_token(req.py_request_id) for _ in range(get_draft_token_length(req)): self.impl.add_token(req.py_request_id) - new_block_ids = self.get_cache_indices(req) + if self.kv_connector_manager is not None: + new_block_ids = self.get_cache_indices(req) - delta_block_ids = new_block_ids[len(old_block_ids):] + delta_block_ids = new_block_ids[len(old_block_ids):] - scheduler_output.record_generation_req(req, delta_block_ids) + scheduler_output.record_generation_req(req, delta_block_ids) if self.kv_connector_manager is not None: self.kv_connector_manager.set_scheduler_output(scheduler_output) diff --git a/tests/unittest/_torch/test_connector.py b/tests/unittest/_torch/test_connector.py index 79a09b46fad..377d3ad3a63 100644 --- a/tests/unittest/_torch/test_connector.py +++ b/tests/unittest/_torch/test_connector.py @@ -58,7 +58,7 @@ def test(): req.request_id = 42 - manager.request_finished(req) + manager.request_finished(req, []) # To start, make both workers return nothing. worker.get_finished.return_value = ([], []) From 96b71c4fe2acf3e74c89c9d3c152207e9eb872e6 Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Thu, 7 Aug 2025 15:42:09 -0700 Subject: [PATCH 34/50] more improvements + refactoring + docstrings Signed-off-by: jthomson04 --- tensorrt_llm/_torch/pyexecutor/_util.py | 2 +- .../{connector.py => kv_cache_connector.py} | 0 tensorrt_llm/_torch/pyexecutor/py_executor.py | 11 +++++++---- .../_torch/pyexecutor/py_executor_creator.py | 4 ++-- .../_torch/pyexecutor/resource_manager.py | 2 +- tensorrt_llm/executor/executor.py | 2 +- tensorrt_llm/executor/proxy.py | 2 +- tensorrt_llm/executor/worker.py | 3 +-- tensorrt_llm/llmapi/llm_args.py | 18 ++++++++++++++++-- tensorrt_llm/models/modeling_utils.py | 7 ------- .../defs/llmapi/test_llm_api_connector.py | 3 +-- tests/unittest/_torch/test_connector.py | 3 ++- 12 files changed, 33 insertions(+), 24 deletions(-) rename tensorrt_llm/_torch/pyexecutor/{connector.py => kv_cache_connector.py} (100%) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index f566cb810bc..256117df8c8 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -22,8 +22,8 @@ from ..speculative import get_num_extra_kv_tokens, get_spec_decoder from .config import PyTorchConfig from .config_utils import is_mla, is_nemotron_hybrid -from .connector import KvCacheConnectorManager from .guided_decoder import GuidedDecoder +from .kv_cache_connector import KvCacheConnectorManager from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver from .llm_request import ExecutorResponse from .model_engine import PyTorchModelEngine diff --git a/tensorrt_llm/_torch/pyexecutor/connector.py b/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py similarity index 100% rename from tensorrt_llm/_torch/pyexecutor/connector.py rename to tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index ae02c366845..03e01da9953 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -36,9 +36,9 @@ from ..models.modeling_utils import DecoderModelForCausalLM from ..modules.decoder_layer import DecoderLayer from ..speculative.drafter import Drafter -from .connector import KvCacheConnectorManager from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem from .guided_decoder import GuidedDecoder +from .kv_cache_connector import KvCacheConnectorManager from .kv_cache_transceiver import KvCacheTransceiver from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState, LlmResponse) @@ -947,6 +947,11 @@ def _kv_connector_terminate_requests(self): for req in reqs_to_terminate: self.resource_manager.free_resources(req) + def _kv_connector_wait_for_save(self): + if self.kv_connector_manager is not None: + self.kv_connector_manager.worker.wait_for_save( + torch.cuda.current_stream()) + def _executor_loop(self): torch.cuda.set_device(self.device_id) # ensure the context is created, otherwise, some MPI calls will fail. @@ -1471,9 +1476,7 @@ def forward(scheduled_requests, resource_manager, new_tensors_device, new_tensors_device, gather_context_logits, cache_indirection_buffer) - if self.kv_connector_manager is not None: - self.kv_connector_manager.worker.wait_for_save( - torch.cuda.current_stream()) + self._kv_connector_wait_for_save() return outputs except Exception as e: diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index ad38afd96c6..546d096af9f 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -13,10 +13,10 @@ from tensorrt_llm._utils import get_sm_version from tensorrt_llm.bindings.executor import ContextChunkingPolicy, ExecutorConfig from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig +from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig from tensorrt_llm.logger import logger from tensorrt_llm.lora_manager import LoraConfig from tensorrt_llm.mapping import Mapping -from tensorrt_llm.models.modeling_utils import KvCacheConnectorConfig from tensorrt_llm.quantization import QuantAlgo from ..attention_backend.interface import AttentionRuntimeFeatures @@ -27,8 +27,8 @@ create_py_executor_instance, instantiate_sampler, is_mla) from .config import PyTorchConfig from .config_utils import is_mla -from .connector import KvCacheConnectorManager from .guided_decoder import GuidedDecoder +from .kv_cache_connector import KvCacheConnectorManager from .model_engine import PyTorchModelEngine from .py_executor import PyExecutor diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 314cfb9a939..f38b9654d2d 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -16,7 +16,7 @@ from ..._utils import binding_dtype_size, binding_to_str_dtype, nvtx_range from ...logger import logger from ...mapping import Mapping -from .connector import KvCacheConnectorManager, SchedulerOutput +from .kv_cache_connector import KvCacheConnectorManager, SchedulerOutput from .llm_request import (LlmRequest, LlmRequestState, SamplingConfig, get_draft_token_length) from .scheduler import ScheduledRequests diff --git a/tensorrt_llm/executor/executor.py b/tensorrt_llm/executor/executor.py index 5592132119c..aff813a58bc 100644 --- a/tensorrt_llm/executor/executor.py +++ b/tensorrt_llm/executor/executor.py @@ -21,13 +21,13 @@ from ..bindings import executor as tllm from ..builder import Engine from ..disaggregated_params import DisaggregatedParams +from ..llmapi.llm_args import KvCacheConnectorConfig from ..llmapi.llm_utils import KvCacheRetentionConfig from ..llmapi.mpi_session import (MpiSession, external_mpi_comm_available, need_spawn_mpi_workers) from ..llmapi.utils import (AsyncQueue, enable_llm_debug, enable_worker_single_process_for_tp1, print_colored, print_colored_debug) -from ..models.modeling_utils import KvCacheConnectorConfig from ..sampling_params import (BatchedLogitsProcessor, LogprobParams, SamplingParams) from ..scheduling_params import SchedulingParams diff --git a/tensorrt_llm/executor/proxy.py b/tensorrt_llm/executor/proxy.py index f4bccd966ae..6d1e1afd87f 100644 --- a/tensorrt_llm/executor/proxy.py +++ b/tensorrt_llm/executor/proxy.py @@ -10,9 +10,9 @@ import zmq.asyncio from tensorrt_llm.logger import logger -from tensorrt_llm.models.modeling_utils import KvCacheConnectorConfig from .._utils import customized_gc_thresholds, mpi_rank, nvtx_range_debug +from ..llmapi.llm_args import KvCacheConnectorConfig from ..llmapi.mpi_session import (MpiCommSession, MpiPoolSession, MpiSession, RemoteMpiCommSessionClient) from ..llmapi.tracer import enable_llm_tracer, get_tracer, global_tracer diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 0fa91d8cc72..cd7056b0ee7 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -13,13 +13,12 @@ import torch from tensorrt_llm.logger import logger -from tensorrt_llm.models.modeling_utils import KvCacheConnectorConfig from .._utils import (KVCacheEventSerializer, global_mpi_rank, global_mpi_size, mpi_comm, mpi_rank, nvtx_range_debug) from ..bindings import executor as tllm from ..builder import ConfigEncoder, Engine, EngineConfig -from ..llmapi.llm_args import PybindMirror +from ..llmapi.llm_args import KvCacheConnectorConfig, PybindMirror from ..llmapi.mpi_session import set_mpi_session_cpp from ..llmapi.tracer import VizTracer, global_tracer, set_global_tracer from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue, diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 95a0ca66df7..4201ffba4c3 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -52,8 +52,7 @@ from ..logger import logger from ..mapping import Mapping from ..models.automodel import AutoConfig -from ..models.modeling_utils import (KvCacheConnectorConfig, PretrainedConfig, - QuantAlgo, QuantConfig, +from ..models.modeling_utils import (PretrainedConfig, QuantAlgo, QuantConfig, SpeculativeDecodingMode) from ..sampling_params import BatchedLogitsProcessor from .build_cache import BuildCacheConfig @@ -396,6 +395,21 @@ def spec_dec_mode(self): self.decoding_type.upper()) +class KvCacheConnectorConfig(StrictBaseModel): + """ + Configuration for the KV Cache Connector. + """ + connector_module: str = Field( + ..., + description= + "The import path to the connector module. It will be imported with `importlib.import_module`." + ) + connector_scheduler_class: str = Field( + ..., description="The class name of the scheduler within the module.") + connector_worker_class: str = Field( + ..., description="The class name of the worker within the module.") + + class MedusaDecodingConfig(DecodingBaseConfig): medusa_choices: Optional[List[List[int]]] = None num_medusa_heads: Optional[int] = None diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py index 6fdeb0163b2..b2fdc393a02 100644 --- a/tensorrt_llm/models/modeling_utils.py +++ b/tensorrt_llm/models/modeling_utils.py @@ -124,13 +124,6 @@ def from_arguments(args: argparse.Namespace): assert False, "Unknown speculative_decoding_mode " + args.speculative_decoding_mode -@dataclasses.dataclass -class KvCacheConnectorConfig: - connector_module: str - connector_scheduler_class: str - connector_worker_class: str - - @dataclasses.dataclass class QuantConfig: """ diff --git a/tests/integration/defs/llmapi/test_llm_api_connector.py b/tests/integration/defs/llmapi/test_llm_api_connector.py index 5b6df019086..a66a2734aad 100644 --- a/tests/integration/defs/llmapi/test_llm_api_connector.py +++ b/tests/integration/defs/llmapi/test_llm_api_connector.py @@ -20,8 +20,7 @@ import pytest from tensorrt_llm import LLM, SamplingParams -from tensorrt_llm.llmapi.llm_args import KvCacheConfig -from tensorrt_llm.models.modeling_utils import KvCacheConnectorConfig +from tensorrt_llm.llmapi.llm_args import KvCacheConfig, KvCacheConnectorConfig @pytest.fixture(scope="function") diff --git a/tests/unittest/_torch/test_connector.py b/tests/unittest/_torch/test_connector.py index 377d3ad3a63..2c450b24265 100644 --- a/tests/unittest/_torch/test_connector.py +++ b/tests/unittest/_torch/test_connector.py @@ -22,7 +22,8 @@ import pytest from tensorrt_llm import mpi_rank -from tensorrt_llm._torch.pyexecutor.connector import KvCacheConnectorManager +from tensorrt_llm._torch.pyexecutor.kv_cache_connector import \ + KvCacheConnectorManager from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests cloudpickle.register_pickle_by_value(sys.modules[__name__]) From d0ad8a6a75ffebe2e96263b54fa92efc62c8bad7 Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Sun, 10 Aug 2025 17:15:44 -0700 Subject: [PATCH 35/50] Nanobind support (finally) Signed-off-by: jthomson04 --- .../tensorrt_llm/batch_manager/llmRequest.h | 13 --- cpp/tensorrt_llm/nanobind/CMakeLists.txt | 1 + .../batch_manager/kvCacheConnector.cpp | 48 ++++++++ .../nanobind/batch_manager/kvCacheConnector.h | 22 ++++ .../nanobind/batch_manager/kvCacheManager.cpp | 11 +- cpp/tensorrt_llm/nanobind/bindings.cpp | 3 + .../pybind/batch_manager/bindings.cpp | 4 +- .../_torch/pyexecutor/kv_cache_connector.py | 104 ++++++++++++++---- .../_torch/pyexecutor/resource_manager.py | 40 ++----- .../defs/llmapi/test_llm_api_connector.py | 62 +++++------ 10 files changed, 203 insertions(+), 105 deletions(-) create mode 100644 cpp/tensorrt_llm/nanobind/batch_manager/kvCacheConnector.cpp create mode 100644 cpp/tensorrt_llm/nanobind/batch_manager/kvCacheConnector.h diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index 595e92b5948..3320c6b0929 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -1843,16 +1843,6 @@ class GenericLlmRequest return mIsDummyRequest; } - void setIsKvCacheConnectorAsyncOnboard(bool isKvCacheConnectorAsyncOnboard) - { - mIsKvCacheConnectorAsyncOnboard = isKvCacheConnectorAsyncOnboard; - } - - [[nodiscard]] bool isKvCacheConnectorAsyncOnboard() const - { - return mIsKvCacheConnectorAsyncOnboard; - } - RequestIdType mRequestId; SizeType32 mPromptLen; SizeType32 mMaxNewTokens; @@ -2027,9 +2017,6 @@ class GenericLlmRequest bool mIsDummyRequest{false}; - /// Whether any blocks for this request are being asynchronously onboarded via the kv cache connector. - bool mIsKvCacheConnectorAsyncOnboard{false}; - private: void initialize(VecTokens const& inputTokens, bool outputLogProbs) { diff --git a/cpp/tensorrt_llm/nanobind/CMakeLists.txt b/cpp/tensorrt_llm/nanobind/CMakeLists.txt index aa5b3cf45da..6bce76021ef 100755 --- a/cpp/tensorrt_llm/nanobind/CMakeLists.txt +++ b/cpp/tensorrt_llm/nanobind/CMakeLists.txt @@ -7,6 +7,7 @@ set(SRCS batch_manager/algorithms.cpp batch_manager/bindings.cpp batch_manager/cacheTransceiver.cpp + batch_manager/kvCacheConnector.cpp batch_manager/kvCacheManager.cpp batch_manager/llmRequest.cpp executor/bindings.cpp diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheConnector.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheConnector.cpp new file mode 100644 index 00000000000..b843b5802c9 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheConnector.cpp @@ -0,0 +1,48 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/nanobind/batch_manager/kvCacheConnector.h" + +#include +#include + +namespace +{ +using KvCacheConnectorManager = tensorrt_llm::batch_manager::kv_connector::KvCacheConnectorManager; + +namespace tb = tensorrt_llm::batch_manager; + +class PyKvCacheConnectorManager : KvCacheConnectorManager +{ +public: + NB_TRAMPOLINE(KvCacheConnectorManager, 1); + + SizeType32 getNumNewMatchedTokens(tb::LlmRequest const& request, SizeType32 numComputedTokens) override + { + NB_OVERRIDE_PURE_NAME("get_num_new_matched_tokens", getNumNewMatchedTokens, request, numComputedTokens); + } +}; + +} // namespace + +void tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManagerConnectorBindings::initBindings(nb::module_& m) +{ + nb::class_(m, "KvCacheConnectorManager") + .def(nb::init<>()) + .def("get_num_new_matched_tokens", &tb::kv_connector::KvCacheConnectorManager::getNumNewMatchedTokens, + nb::arg("request"), nb::arg("num_computed_tokens")); +} diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheConnector.h b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheConnector.h new file mode 100644 index 00000000000..63c183d0b2b --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheConnector.h @@ -0,0 +1,22 @@ +#pragma once + +#include "tensorrt_llm/batch_manager/kvCacheConnector.h" +#include + +namespace nb = nanobind; + +namespace tensorrt_llm::batch_manager::kv_cache_manager +{ +class KVCacheManagerConnectorBindings +{ +public: + static void initBindings(nb::module_& m); +}; +} // namespace tensorrt_llm::batch_manager::kv_cache_manager + +namespace tensorrt_llm::pybind::batch_manager::kv_connector +{ + +using namespace tensorrt_llm::batch_manager::kv_connector; + +} // namespace tensorrt_llm::pybind::batch_manager::kv_connector diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index 412698215aa..f60ddccb0f8 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -110,9 +110,11 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager } void addSequence(tb::LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, - tensorrt_llm::common::OptionalRef llmRequest = std::nullopt) override + tensorrt_llm::common::OptionalRef llmRequest = std::nullopt, + tensorrt_llm::common::OptionalRef kvCacheConnectorManager + = std::nullopt) override { - NB_OVERRIDE_PURE(addSequence, requestId, inputLength, beamWidth, llmRequest); + NB_OVERRIDE_PURE(addSequence, requestId, inputLength, beamWidth, llmRequest, kvCacheConnectorManager); } void removeSequence(tb::LlmRequest::RequestIdType requestId, @@ -346,7 +348,9 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) .def("get_needed_blocks_one_step", &BaseKVCacheManager::getNeededBlocksOneStep) .def("get_remaining_blocks_to_completion", &BaseKVCacheManager::getRemainingBlocksToCompletion) .def("add_token", &BaseKVCacheManager::addToken) - .def("add_sequence", &BaseKVCacheManager::addSequence) + .def("add_sequence", &BaseKVCacheManager::addSequence, nb::arg("request_id"), nb::arg("input_length"), + nb::arg("beam_width"), nb::arg("llm_request") = std::nullopt, + nb::arg("kv_cache_connector_manager") = std::nullopt) .def("remove_sequence", &BaseKVCacheManager::removeSequence) .def("scheduling_remove_sequence", &BaseKVCacheManager::schedulingRemoveSequence) .def("get_block_pool_pointers", @@ -380,6 +384,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) auto pool_layer_idx = self.getPoolLayerIdx(layer_idx); return pool.index({torch::indexing::Slice(), pool_layer_idx}); }) + .def("get_unique_primary_pool", [](tbk::BaseKVCacheManager& self) { return self.getUniquePrimaryPool(); }) .def("get_block_offsets_of_batch", [](tbk::BaseKVCacheManager& self, at::Tensor output, SizeType32 firstBatchSlotIdx, SizeType32 batchSize, SizeType32 beamWidth) diff --git a/cpp/tensorrt_llm/nanobind/bindings.cpp b/cpp/tensorrt_llm/nanobind/bindings.cpp index 460e330fa58..5ed586c345f 100644 --- a/cpp/tensorrt_llm/nanobind/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/bindings.cpp @@ -34,6 +34,7 @@ #include "tensorrt_llm/nanobind/batch_manager/algorithms.h" #include "tensorrt_llm/nanobind/batch_manager/bindings.h" #include "tensorrt_llm/nanobind/batch_manager/cacheTransceiver.h" +#include "tensorrt_llm/nanobind/batch_manager/kvCacheConnector.h" #include "tensorrt_llm/nanobind/batch_manager/kvCacheManager.h" #include "tensorrt_llm/nanobind/batch_manager/llmRequest.h" #include "tensorrt_llm/nanobind/executor/bindings.h" @@ -477,6 +478,8 @@ NB_MODULE(TRTLLM_NB_MODULE, m) tensorrt_llm::nanobind::runtime::initBindings(mInternalRuntime); tensorrt_llm::nanobind::testing::initBindings(mInternalTesting); tpb::initBindings(mInternalBatchManager); + + tb::kv_cache_manager::KVCacheManagerConnectorBindings::initBindings(mInternalBatchManager); tb::kv_cache_manager::KVCacheManagerBindings::initBindings(mInternalBatchManager); tb::BasePeftCacheManagerBindings::initBindings(mInternalBatchManager); tb::CacheTransceiverBindings::initBindings(mInternalBatchManager); diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp index 6c30a350594..17ce11ac414 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp @@ -255,9 +255,7 @@ void initBindings(pybind11::module_& m) } }) .def_property("is_dummy_request", &GenLlmReq::isDummyRequest, &GenLlmReq::setIsDummyRequest) - .def_property_readonly("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics) - .def_property("is_kv_cache_connector_async_onboard", &GenLlmReq::isKvCacheConnectorAsyncOnboard, - &GenLlmReq::setIsKvCacheConnectorAsyncOnboard); + .def_property_readonly("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics); py::classh(m, "LlmRequest", pybind11::dynamic_attr()) .def(py::init<>( diff --git a/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py b/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py index a0b85664e8a..62041707377 100644 --- a/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py +++ b/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py @@ -35,8 +35,9 @@ """ from abc import ABC, abstractmethod +from collections import defaultdict from dataclasses import dataclass, field -from typing import Any, Callable, List, Optional +from typing import TYPE_CHECKING, Any, Callable, List, Optional import torch @@ -49,6 +50,9 @@ from .scheduler import ScheduledRequests +if TYPE_CHECKING: + from .resource_manager import KVCacheManager + # Used to store data for a single inflight request. @dataclass @@ -73,26 +77,6 @@ class SchedulerOutput: # Requests being scheduled, that have already shown up in `new_requests`. cached_requests: List[RequestData] = field(default_factory=list) - def record_first_prefill_chunk(self, req: LlmRequest, block_ids: List[int]): - if not req.is_kv_cache_connector_async_onboard: - self.new_requests.append( - RequestData(req.request_id, req.get_tokens(0), block_ids, - req.context_current_position)) - - def record_nth_prefill_chunk(self, req: LlmRequest): - self.cached_requests.append( - RequestData(req.request_id, [], [], req.context_current_position)) - - def record_generation_req(self, req: LlmRequest, - delta_block_ids: List[int]): - - tokens = req.get_tokens(0) - computed_position = len(tokens) - 1 - - self.cached_requests.append( - RequestData(req.request_id, tokens[-1:], delta_block_ids, - computed_position)) - class KvCacheConnectorWorker(ABC): @@ -275,6 +259,64 @@ def loading_ids(self) -> set[int]: return set(self.loading.keys()) +class KvCacheConnectorSchedulerOutputRequest: + + def __init__(self): + self.block_ids = [] + self.tokens = [] + + def update_and_build_data(self, req: LlmRequest, + kv_cache_manager: "KVCacheManager"): + block_ids = kv_cache_manager.get_cache_indices(req) + tokens = req.get_tokens(0) + + new_block_ids = block_ids[len(self.block_ids):] + new_tokens = tokens[len(self.tokens):] + + self.block_ids.extend(new_block_ids) + self.tokens.extend(new_tokens) + + computed_position = len( + tokens + ) - 1 if req.state != LlmRequestState.CONTEXT_INIT else req.context_current_position + + return RequestData(req.request_id, new_tokens, new_block_ids, + computed_position) + + +class KvCacheConnectorSchedulerOutputManager: + + def __init__(self): + self.requests = defaultdict(KvCacheConnectorSchedulerOutputRequest) + + def build_scheduler_output(self, scheduled_batch: ScheduledRequests, + new_async_requests: AsyncRequests, + kv_cache_manager: "KVCacheManager"): + scheduler_output = SchedulerOutput() + + for req in scheduled_batch.context_requests: + if req.request_id in new_async_requests.loading_ids: + continue + + is_new = req.request_id not in self.requests + + request_data = self.requests[req.request_id].update_and_build_data( + req, kv_cache_manager) + + if is_new: + scheduler_output.new_requests.append(request_data) + else: + scheduler_output.cached_requests.append(request_data) + + for req in scheduled_batch.generation_requests: + request_data = self.requests[req.request_id].update_and_build_data( + req, kv_cache_manager) + + scheduler_output.cached_requests.append(request_data) + + return scheduler_output + + class KvCacheConnectorManager(KvCacheConnectorManagerCpp): """ The KvCacheConnectorManager is used to manager connector-related state. @@ -306,7 +348,11 @@ def __init__(self, worker: KvCacheConnectorWorker, # Requests that have been returned from get_finished locally, but haven't yet been returned by all workers. self.local_finished_async_requests = AsyncRequests(dict(), dict()) + # Requests that have finished loading asynchronously. + self.finished_async_loading_requests = dict() + self._scheduler_output = None + self.scheduler_output_manager = KvCacheConnectorSchedulerOutputManager() def _run_on_leader(self, f: Callable[[], Any]) -> Any: """ @@ -339,6 +385,15 @@ def get_num_new_matched_tokens(self, request: LlmRequest, return num_tokens + def should_add_sequence(self, request: LlmRequest) -> bool: + req_id = request.request_id + return req_id not in self.finished_async_loading_requests + + def build_scheduler_output(self, scheduled_batch: ScheduledRequests, + kv_cache_manager: "KVCacheManager"): + self._scheduler_output = self.scheduler_output_manager.build_scheduler_output( + scheduled_batch, self.new_async_requests, kv_cache_manager) + def take_scheduled_requests_pending_load( self, scheduled_requests: ScheduledRequests): """ @@ -358,6 +413,7 @@ def take_scheduled_requests_pending_load( # we also need to update it's state. if req.request_id in self.new_async_requests.loading.keys(): req.state = LlmRequestState.DISAGG_GENERATION_TRANS_IN_PROGRESS + self.new_async_requests.loading[req.request_id] = req else: allowed_context_requests.append(req) @@ -384,6 +440,9 @@ def request_finished(self, req: LlmRequest, Whether the request is performing asynchronous saving operations. If true, we do not immediately call free_resources on the request. """ + if req.request_id in self.finished_async_loading_requests: + del self.finished_async_loading_requests[req.request_id] + saving_async = self._run_on_leader( lambda: self.scheduler.request_finished(req, cache_block_ids)) @@ -438,8 +497,9 @@ def get_finished(self) -> List[LlmRequest]: intersect_finished_saving, intersect_finished_loading) # For requests that have finished loading, move them back to the context state. - for req in all_finished.loading.values(): + for id, req in all_finished.loading.items(): req.state = LlmRequestState.CONTEXT_INIT + self.finished_async_loading_requests[id] = req # Return the requests that have finished saving. # The execution loop will call _terminate_request on these requests. diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index f38b9654d2d..f0c32fc9523 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -16,7 +16,7 @@ from ..._utils import binding_dtype_size, binding_to_str_dtype, nvtx_range from ...logger import logger from ...mapping import Mapping -from .kv_cache_connector import KvCacheConnectorManager, SchedulerOutput +from .kv_cache_connector import KvCacheConnectorManager from .llm_request import (LlmRequest, LlmRequestState, SamplingConfig, get_draft_token_length) from .scheduler import ScheduledRequests @@ -373,10 +373,6 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): context_batch = scheduled_batch.context_requests generation_batch = scheduled_batch.generation_requests - # Build the scheduler output for the connector. - scheduler_output = SchedulerOutput( - ) if self.kv_connector_manager is not None else None - # allocate KV Cache for req in context_batch: req_beam_width = req.sampling_config.beam_width @@ -398,7 +394,8 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): # When that happens, the request will go through this same code path, but with is_kv_cache_connector_async_onboard set to True. # Because of this, we need to filter this case out to avoid adding the same sequence twice. # NOTE(jthomson04): Surely there's a better way to do this. - if req.is_first_context_chunk and not req.is_kv_cache_connector_async_onboard: + if req.is_first_context_chunk and self._kv_connector_should_add_sequence( + req): self.impl.add_sequence(req.py_request_id, req.prompt_len, req_beam_width, req, self.kv_connector_manager) @@ -407,39 +404,20 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): for _ in range(get_draft_token_length(req)): self.impl.add_token(req.py_request_id) - if self.kv_connector_manager is not None: - scheduler_output.record_first_prefill_chunk( - req, self.get_cache_indices(req)) - elif self.kv_connector_manager is not None: - # When using the connector, this code path will be hit after the async load is complete. - # Alternatively, with no connector, this is hit after the first chunk of prefill. - - # If this is the first prefill chunk, we can add all of our new tokens and blocks. - if req.is_first_context_chunk or req.is_kv_cache_connector_async_onboard: - req.is_kv_cache_connector_async_onboard = False - scheduler_output.record_first_prefill_chunk( - req, self.get_cache_indices(req)) - else: - scheduler_output.record_nth_prefill_chunk(req) - for req in generation_batch: - if self.kv_connector_manager is not None: - old_block_ids = self.get_cache_indices(req) self.impl.add_token(req.py_request_id) for _ in range(get_draft_token_length(req)): self.impl.add_token(req.py_request_id) - if self.kv_connector_manager is not None: - new_block_ids = self.get_cache_indices(req) - - delta_block_ids = new_block_ids[len(old_block_ids):] - - scheduler_output.record_generation_req(req, delta_block_ids) - if self.kv_connector_manager is not None: - self.kv_connector_manager.set_scheduler_output(scheduler_output) + self.kv_connector_manager.build_scheduler_output( + scheduled_batch, self) + + def _kv_connector_should_add_sequence(self, request: LlmRequest) -> bool: + return self.kv_connector_manager is None or self.kv_connector_manager.should_add_sequence( + request) def add_dummy_requests( self, diff --git a/tests/integration/defs/llmapi/test_llm_api_connector.py b/tests/integration/defs/llmapi/test_llm_api_connector.py index a66a2734aad..071830eaaeb 100644 --- a/tests/integration/defs/llmapi/test_llm_api_connector.py +++ b/tests/integration/defs/llmapi/test_llm_api_connector.py @@ -40,7 +40,16 @@ def model_with_connector(): ) def model_fn(*args, **kwargs): - return LLM(*args, **kwargs, connector_config=connector_config) + return LLM( + *args, + **kwargs, + model="Qwen/Qwen2-0.5B", + backend="pytorch", + connector_config=connector_config, + cuda_graph_config=None, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1), + use_torch_sampler=True, + ) yield model_fn, mock_scheduler, mock_worker @@ -57,12 +66,7 @@ def test_connector_simple(model_with_connector, use_overlap_scheduler): model_fn, scheduler, worker = model_with_connector - model = model_fn( - model="Qwen/Qwen2-0.5B", - backend="pytorch", - disable_overlap_scheduler=not use_overlap_scheduler, - cuda_graph_config=None, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1)) + model = model_fn(disable_overlap_scheduler=not use_overlap_scheduler, ) assert worker.register_kv_caches.call_count == 1 @@ -84,6 +88,11 @@ def test_connector_simple(model_with_connector, use_overlap_scheduler): if i == 0: assert len(scheduler_output.new_requests) == 1 assert len(scheduler_output.cached_requests) == 0 + elif i == 1 and use_overlap_scheduler: + assert len(scheduler_output.new_requests) == 0 + assert len(scheduler_output.cached_requests) == 1 + + assert len(scheduler_output.cached_requests[0].new_tokens) == 0 else: assert len(scheduler_output.new_requests) == 0 assert len(scheduler_output.cached_requests) == 1 @@ -130,12 +139,7 @@ def test_connector_async_onboard(model_with_connector, use_overlap_scheduler): model_fn, scheduler, worker = model_with_connector - model = model_fn( - model="Qwen/Qwen2-0.5B", - backend="pytorch", - disable_overlap_scheduler=not use_overlap_scheduler, - cuda_graph_config=None, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1)) + model = model_fn(disable_overlap_scheduler=not use_overlap_scheduler, ) assert worker.register_kv_caches.call_count == 1 @@ -163,12 +167,7 @@ def test_connector_async_save(model_with_connector, use_overlap_scheduler): model_fn, scheduler, worker = model_with_connector - model = model_fn( - model="Qwen/Qwen2-0.5B", - backend="pytorch", - disable_overlap_scheduler=not use_overlap_scheduler, - cuda_graph_config=None, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1)) + model = model_fn(disable_overlap_scheduler=not use_overlap_scheduler, ) assert worker.register_kv_caches.call_count == 1 @@ -211,12 +210,7 @@ def test_connector_scheduler_output(model_with_connector, model_fn, scheduler, worker = model_with_connector - model = model_fn( - model="Qwen/Qwen2-0.5B", - backend="pytorch", - disable_overlap_scheduler=not use_overlap_scheduler, - cuda_graph_config=None, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1)) + model = model_fn(disable_overlap_scheduler=not use_overlap_scheduler, ) assert worker.register_kv_caches.call_count == 1 @@ -244,6 +238,11 @@ def test_connector_scheduler_output(model_with_connector, assert len(request.new_block_ids) == math.ceil(NUM_INPUT_TOKENS / BLOCK_SIZE) assert request.computed_position == 0 + elif i == 1 and use_overlap_scheduler: + assert len(sched_output.new_requests) == 0 + assert len(sched_output.cached_requests) == 1 + + assert len(sched_output.cached_requests[0].new_tokens) == 0 else: assert len(sched_output.cached_requests) == 1 assert len(sched_output.new_requests) == 0 @@ -279,14 +278,9 @@ def test_connector_scheduler_output_chunked_context(model_with_connector, CHUNK_SIZE = 128 BLOCK_SIZE = 32 - model = model_fn( - model="Qwen/Qwen2-0.5B", - backend="pytorch", - disable_overlap_scheduler=not use_overlap_scheduler, - cuda_graph_config=None, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1), - enable_chunked_prefill=True, - max_num_tokens=CHUNK_SIZE) + model = model_fn(disable_overlap_scheduler=not use_overlap_scheduler, + enable_chunked_prefill=True, + max_num_tokens=CHUNK_SIZE) assert worker.register_kv_caches.call_count == 1 @@ -322,6 +316,8 @@ def test_connector_scheduler_output_chunked_context(model_with_connector, assert req.computed_position == CHUNK_SIZE assert len(req.new_tokens) == 0 assert len(req.new_block_ids) == 0 + elif i == 2 and use_overlap_scheduler: + assert len(req.new_tokens) == 0 else: assert len(req.new_tokens) == 1 From b7d2ee67b860b0bdd68dfa90f3ba3cb768116f3e Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Sun, 10 Aug 2025 19:50:21 -0700 Subject: [PATCH 36/50] coderabbit + refactor Signed-off-by: jthomson04 --- ...connector.py => llm_kv_cache_connector.py} | 20 ++++++++++------ .../_torch/pyexecutor/kv_cache_connector.py | 7 +++--- tensorrt_llm/_torch/pyexecutor/py_executor.py | 4 ++++ .../_torch/pyexecutor/resource_manager.py | 2 +- .../defs/llmapi/test_llm_api_connector.py | 23 +++++++++++-------- 5 files changed, 36 insertions(+), 20 deletions(-) rename examples/llm-api/{connector.py => llm_kv_cache_connector.py} (94%) diff --git a/examples/llm-api/connector.py b/examples/llm-api/llm_kv_cache_connector.py similarity index 94% rename from examples/llm-api/connector.py rename to examples/llm-api/llm_kv_cache_connector.py index be6aa3e168c..b74517754a6 100644 --- a/examples/llm-api/connector.py +++ b/examples/llm-api/llm_kv_cache_connector.py @@ -7,9 +7,8 @@ import torch from tensorrt_llm import LLM, SamplingParams, logger -from tensorrt_llm._torch.pyexecutor.connector import (KvCacheConnectorScheduler, - KvCacheConnectorWorker, - SchedulerOutput) +from tensorrt_llm._torch.pyexecutor.kv_cache_connector import ( + KvCacheConnectorScheduler, KvCacheConnectorWorker, SchedulerOutput) from tensorrt_llm.bindings.executor import ExecutorConfig from tensorrt_llm.bindings.internal.batch_manager import LlmRequest from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig @@ -204,7 +203,8 @@ def request_finished(self, request: LlmRequest, model = LLM(model="Qwen/Qwen2-0.5B", backend="pytorch", cuda_graph_config=None, - connector_config=connector_config) + connector_config=connector_config, + use_torch_sampler=True) test_text = ( "Nvidia Corporation is an American technology company headquartered in Santa Clara, California." @@ -215,8 +215,9 @@ def request_finished(self, request: LlmRequest, sampling_params = SamplingParams(max_tokens=32) output = model.generate([test_text], sampling_params) + text0 = output[0].outputs[0].text - print("First output: ", output[0].outputs[0].text) + print("First output: ", text0) print("Loading new LLM instance...") del model @@ -224,9 +225,14 @@ def request_finished(self, request: LlmRequest, model = LLM(model="Qwen/Qwen2-0.5B", backend="pytorch", cuda_graph_config=None, - connector_config=connector_config) + connector_config=connector_config, + use_torch_sampler=True) output = model.generate([test_text], sampling_params) - print("Second output (using connector cache): ", output[0].outputs[0].text) + text1 = output[0].outputs[0].text + + print("Second output (using connector cache): ", text1) + + assert text0 == text1 connector_cache_dir.cleanup() diff --git a/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py b/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py index 62041707377..33ac552ac7c 100644 --- a/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py +++ b/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py @@ -37,7 +37,7 @@ from abc import ABC, abstractmethod from collections import defaultdict from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Callable, List, Optional +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple import torch @@ -82,6 +82,7 @@ class KvCacheConnectorWorker(ABC): def __init__(self, config: ExecutorConfig): self._config = config + self._metadata = None super().__init__() def bind_connector_meta(self, metadata: object): @@ -141,7 +142,7 @@ def wait_for_save(self, stream: torch.cuda.Stream): @abstractmethod def get_finished( self, finished_gen_req_ids: List[int], - started_loading_req_ids: List[int]) -> tuple[List[int], List[int]]: + started_loading_req_ids: List[int]) -> Tuple[List[int], List[int]]: """ Get the requests that have finished loading and saving. @@ -179,7 +180,7 @@ def build_connector_meta(self, scheduler_output: SchedulerOutput): @abstractmethod def get_num_new_matched_tokens( self, request: LlmRequest, - num_computed_tokens: int) -> tuple[int, bool]: + num_computed_tokens: int) -> Tuple[int, bool]: """ Get the number of tokens that can be loaded from remote KV cache. This does not include the tokens already matched on device (indicated by `num_computed_tokens`). diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index a5a33167f81..9833ef123cb 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -288,6 +288,10 @@ def _maybe_init_kv_connector_manager(self): "KV Cache Connector is not supported with pipeline parallelism." ) + if self.kv_cache_manager is None: + raise ValueError( + "KV Cache Connector requires a KV Cache Manager.") + kv_tensor = self.kv_cache_manager.get_unique_primary_pool() self.kv_connector_manager.worker.register_kv_caches(kv_tensor) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 2b7a60ef316..dd8f7e5a39e 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -417,7 +417,7 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): req_beam_width, req, None) else: if req.is_first_context_chunk and self._kv_connector_should_add_sequence( - req): + req): self.impl.add_sequence(req.py_request_id, req.prompt_len, req_beam_width, req, self.kv_connector_manager) diff --git a/tests/integration/defs/llmapi/test_llm_api_connector.py b/tests/integration/defs/llmapi/test_llm_api_connector.py index 071830eaaeb..13ed383fd40 100644 --- a/tests/integration/defs/llmapi/test_llm_api_connector.py +++ b/tests/integration/defs/llmapi/test_llm_api_connector.py @@ -14,7 +14,6 @@ # limitations under the License. import math -import os from unittest.mock import MagicMock, patch import pytest @@ -54,14 +53,17 @@ def model_fn(*args, **kwargs): yield model_fn, mock_scheduler, mock_worker -# Needed because MagicMocks don't work across processes. -# TODO(jthomson04): This limits us to testing only TP1 for now. -os.environ["TLLM_WORKER_USE_SINGLE_PROCESS"] = "1" +@pytest.fixture(scope="function") +def enforce_single_worker(monkeypatch): + monkeypatch.setenv("TLLM_WORKER_USE_SINGLE_PROCESS", "1") + + yield @pytest.mark.threadleak(enabled=False) @pytest.mark.parametrize("use_overlap_scheduler", [True, False]) -def test_connector_simple(model_with_connector, use_overlap_scheduler): +def test_connector_simple(enforce_single_worker, model_with_connector, + use_overlap_scheduler): NUM_TOKENS = 8 model_fn, scheduler, worker = model_with_connector @@ -134,7 +136,8 @@ def test_connector_simple(model_with_connector, use_overlap_scheduler): @pytest.mark.threadleak(enabled=False) @pytest.mark.parametrize("use_overlap_scheduler", [True, False]) -def test_connector_async_onboard(model_with_connector, use_overlap_scheduler): +def test_connector_async_onboard(enforce_single_worker, model_with_connector, + use_overlap_scheduler): NUM_TOKENS = 8 model_fn, scheduler, worker = model_with_connector @@ -162,7 +165,8 @@ def test_connector_async_onboard(model_with_connector, use_overlap_scheduler): @pytest.mark.threadleak(enabled=False) @pytest.mark.parametrize("use_overlap_scheduler", [True, False]) -def test_connector_async_save(model_with_connector, use_overlap_scheduler): +def test_connector_async_save(enforce_single_worker, model_with_connector, + use_overlap_scheduler): NUM_TOKENS = 8 model_fn, scheduler, worker = model_with_connector @@ -202,7 +206,7 @@ def test_connector_async_save(model_with_connector, use_overlap_scheduler): @pytest.mark.threadleak(enabled=False) @pytest.mark.parametrize("use_overlap_scheduler", [True, False]) -def test_connector_scheduler_output(model_with_connector, +def test_connector_scheduler_output(enforce_single_worker, model_with_connector, use_overlap_scheduler): NUM_INPUT_TOKENS = 48 NUM_TOKENS = 32 @@ -271,7 +275,8 @@ def test_connector_scheduler_output(model_with_connector, @pytest.mark.threadleak(enabled=False) @pytest.mark.parametrize("use_overlap_scheduler", [True, False]) -def test_connector_scheduler_output_chunked_context(model_with_connector, +def test_connector_scheduler_output_chunked_context(enforce_single_worker, + model_with_connector, use_overlap_scheduler): model_fn, scheduler, worker = model_with_connector From 03fa47013ebaa9b91ca0260c6303c62d549669ba Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Mon, 11 Aug 2025 10:35:54 -0700 Subject: [PATCH 37/50] CI Integration, only support guarantee no evict, various coderabbit suggestions Signed-off-by: jthomson04 --- examples/llm-api/llm_kv_cache_connector.py | 42 +++++++++++-------- .../_torch/pyexecutor/kv_cache_connector.py | 11 ++--- .../_torch/pyexecutor/py_executor_creator.py | 10 ++++- .../defs/llmapi/test_llm_api_connector.py | 3 +- .../defs/llmapi/test_llm_examples.py | 9 ++++ .../integration/test_lists/test-db/l0_a10.yml | 2 + 6 files changed, 53 insertions(+), 24 deletions(-) diff --git a/examples/llm-api/llm_kv_cache_connector.py b/examples/llm-api/llm_kv_cache_connector.py index b74517754a6..c6bb5bfb881 100644 --- a/examples/llm-api/llm_kv_cache_connector.py +++ b/examples/llm-api/llm_kv_cache_connector.py @@ -4,6 +4,7 @@ from pathlib import Path from tempfile import TemporaryDirectory +import click import torch from tensorrt_llm import LLM, SamplingParams, logger @@ -18,6 +19,8 @@ # See tensorrt_llm/_torch/pyexecutor/connector.py for details about the KV cache connector interface. # NOTE: This example connector implementation is NOT suitable for production use. +CONNECTOR_CACHE_FOLDER_KEY = "CONNECTOR_CACHE_FOLDER" + @dataclass class PersistentKvCacheConnectorMetadata: @@ -80,7 +83,7 @@ def __init__(self, executor_config: ExecutorConfig): self.block_size = self._config.tokens_per_block self.pending_loads = {} - self.cache_folder = os.environ.get("CONNECTOR_CACHE_FOLDER", + self.cache_folder = os.environ.get(CONNECTOR_CACHE_FOLDER_KEY, "./connector_cache") os.makedirs(self.cache_folder, exist_ok=True) @@ -182,8 +185,9 @@ def request_finished(self, request: LlmRequest, return False -if __name__ == "__main__": - +@click.command() +@click.argument("model", type=str) +def main(model: str): sys.path.append(os.path.join( os.path.dirname(__file__), "..", @@ -198,13 +202,13 @@ def request_finished(self, request: LlmRequest, ) connector_cache_dir = TemporaryDirectory() - os.environ["CONNECTOR_CACHE_FOLDER"] = connector_cache_dir.name + os.environ[CONNECTOR_CACHE_FOLDER_KEY] = connector_cache_dir.name - model = LLM(model="Qwen/Qwen2-0.5B", - backend="pytorch", - cuda_graph_config=None, - connector_config=connector_config, - use_torch_sampler=True) + llm = LLM(model=model, + backend="pytorch", + cuda_graph_config=None, + connector_config=connector_config, + use_torch_sampler=True) test_text = ( "Nvidia Corporation is an American technology company headquartered in Santa Clara, California." @@ -214,21 +218,21 @@ def request_finished(self, request: LlmRequest, sampling_params = SamplingParams(max_tokens=32) - output = model.generate([test_text], sampling_params) + output = llm.generate([test_text], sampling_params) text0 = output[0].outputs[0].text print("First output: ", text0) print("Loading new LLM instance...") - del model + del llm - model = LLM(model="Qwen/Qwen2-0.5B", - backend="pytorch", - cuda_graph_config=None, - connector_config=connector_config, - use_torch_sampler=True) + llm = LLM(model=model, + backend="pytorch", + cuda_graph_config=None, + connector_config=connector_config, + use_torch_sampler=True) - output = model.generate([test_text], sampling_params) + output = llm.generate([test_text], sampling_params) text1 = output[0].outputs[0].text print("Second output (using connector cache): ", text1) @@ -236,3 +240,7 @@ def request_finished(self, request: LlmRequest, assert text0 == text1 connector_cache_dir.cleanup() + + +if __name__ == "__main__": + main() diff --git a/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py b/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py index 33ac552ac7c..8511cfcad3f 100644 --- a/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py +++ b/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py @@ -37,7 +37,8 @@ from abc import ABC, abstractmethod from collections import defaultdict from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, + Tuple) import torch @@ -212,8 +213,8 @@ def request_finished(self, request: LlmRequest, # An internal dataclass to handle async saving/loading requests. @dataclass class AsyncRequests: - saving: dict[int, LlmRequest] - loading: dict[int, LlmRequest] + saving: Dict[int, LlmRequest] + loading: Dict[int, LlmRequest] def add_from(self, other: 'AsyncRequests'): """ @@ -246,14 +247,14 @@ def extract_by_id(self, saving_ids: List[int], return new_async_requests @property - def saving_ids(self) -> set[int]: + def saving_ids(self) -> Set[int]: """ Get the IDs of the requests that are being saved asynchronously. """ return set(self.saving.keys()) @property - def loading_ids(self) -> set[int]: + def loading_ids(self) -> Set[int]: """ Get the IDs of the requests that are being loaded asynchronously. """ diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 546d096af9f..3e1c07ed4a9 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -11,7 +11,9 @@ import tensorrt_llm from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType from tensorrt_llm._utils import get_sm_version -from tensorrt_llm.bindings.executor import ContextChunkingPolicy, ExecutorConfig +from tensorrt_llm.bindings.executor import (CapacitySchedulerPolicy, + ContextChunkingPolicy, + ExecutorConfig) from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig from tensorrt_llm.logger import logger @@ -370,6 +372,12 @@ def create_py_executor( if pytorch_backend_config.use_cuda_graph: raise NotImplementedError( "CUDA graphs are not supported with KV connector hooks.") + + if executor_config.scheduler_config.capacity_scheduler_policy != CapacitySchedulerPolicy.GUARANTEED_NO_EVICT: + raise NotImplementedError( + "KV connector is only supported with guaranteed no evict scheduler policy." + ) + try: module = importlib.import_module( kv_connector_config.connector_module) diff --git a/tests/integration/defs/llmapi/test_llm_api_connector.py b/tests/integration/defs/llmapi/test_llm_api_connector.py index 13ed383fd40..9ac0426c9fc 100644 --- a/tests/integration/defs/llmapi/test_llm_api_connector.py +++ b/tests/integration/defs/llmapi/test_llm_api_connector.py @@ -20,6 +20,7 @@ from tensorrt_llm import LLM, SamplingParams from tensorrt_llm.llmapi.llm_args import KvCacheConfig, KvCacheConnectorConfig +from tests.integration.defs.conftest import llm_models_root @pytest.fixture(scope="function") @@ -42,7 +43,7 @@ def model_fn(*args, **kwargs): return LLM( *args, **kwargs, - model="Qwen/Qwen2-0.5B", + model=f"{llm_models_root()}/Qwen2-0.5B", backend="pytorch", connector_config=connector_config, cuda_graph_config=None, diff --git a/tests/integration/defs/llmapi/test_llm_examples.py b/tests/integration/defs/llmapi/test_llm_examples.py index 993372eb540..6e535645512 100644 --- a/tests/integration/defs/llmapi/test_llm_examples.py +++ b/tests/integration/defs/llmapi/test_llm_examples.py @@ -163,3 +163,12 @@ def test_llmapi_sampling(llm_root, engine_dir, llm_venv): @pytest.mark.skip(reason="https://nvbugs/5365825") def test_llmapi_runtime(llm_root, engine_dir, llm_venv): _run_llmapi_example(llm_root, engine_dir, llm_venv, "llm_runtime.py") + + +@pytest.mark.parametrize("model", ["Qwen2-0.5B"]) +def test_llmapi_kv_cache_connector(llm_root, llm_venv, model): + script_path = Path( + llm_root) / "examples" / "llm-api" / "llm_kv_cache_connector.py" + model_path = f"{llm_models_root()}/{model}" + + venv_check_call(llm_venv, [str(script_path), model_path]) diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index ce285faa799..a1dbd5d84ed 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -77,12 +77,14 @@ l0_a10: - unittest/trt/model/test_mistral.py - unittest/trt/model/test_llama.py - test_e2e.py::test_gpt3_175b_1layers_build_only # 6 mins + - llmapi/test_llm_api_connector.py - llmapi/test_llm_e2e.py::test_llmapi_load_engine_from_build_command[llama-llama-models/llama-7b-hf] # 5min - llmapi/test_llm_e2e.py::test_llmapi_build_command_parameters_align[llama-llama-models-v2/TinyLlama-1.1B-Chat-v1.0] - llmapi/test_llm_e2e.py::test_llmapi_load_engine_from_build_command_with_lora[llama-llama-models-v2/llama-v2-7b-hf] - llmapi/test_llm_examples.py::test_llmapi_chat_example - llmapi/test_llm_e2e.py::test_llmapi_exit - llmapi/test_llm_examples.py::test_llmapi_server_example + - llmapi/test_llm_examples.py::test_llmapi_kv_cache_connector - test_e2e.py::test_trtllm_serve_example - test_e2e.py::test_openai_misc_example[trt] - test_e2e.py::test_openai_completions_example[trt] From 921dd94ad5308df253d5a0575d01593757747faa Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Mon, 11 Aug 2025 12:08:55 -0700 Subject: [PATCH 38/50] update state after alloc Signed-off-by: jthomson04 --- examples/llm-api/llm_kv_cache_connector.py | 4 ++++ .../_torch/pyexecutor/kv_cache_connector.py | 15 +++++++++++++++ .../_torch/pyexecutor/resource_manager.py | 5 +++++ .../defs/llmapi/test_llm_api_connector.py | 19 ++++++++++++++++++- 4 files changed, 42 insertions(+), 1 deletion(-) diff --git a/examples/llm-api/llm_kv_cache_connector.py b/examples/llm-api/llm_kv_cache_connector.py index c6bb5bfb881..24114166611 100644 --- a/examples/llm-api/llm_kv_cache_connector.py +++ b/examples/llm-api/llm_kv_cache_connector.py @@ -184,6 +184,10 @@ def request_finished(self, request: LlmRequest, # We don't do any asynchronous saving, so always return False return False + def update_state_after_alloc(self, request: LlmRequest, + block_ids: list[int]): + pass + @click.command() @click.argument("model", type=str) diff --git a/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py b/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py index 8511cfcad3f..6e31daf8023 100644 --- a/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py +++ b/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py @@ -209,6 +209,17 @@ def request_finished(self, request: LlmRequest, If true, this indicates that the kv cache manager should wait to deallocate the blocks until the saving has completed (determined by `get_finished` on the workers). """ + @abstractmethod + def update_state_after_alloc(self, request: LlmRequest, + block_ids: List[int]): + """ + Called after get_num_new_matched_tokens is called to provide the block ids to the scheduler. + + Args: + request: The request that was allocated resources. + block_ids: The KV cacheblock IDs that were allocated. + """ + # An internal dataclass to handle async saving/loading requests. @dataclass @@ -507,6 +518,10 @@ def get_finished(self) -> List[LlmRequest]: # The execution loop will call _terminate_request on these requests. return list(all_finished.saving.values()) + def update_state_after_alloc(self, req: LlmRequest, block_ids: List[int]): + if self.scheduler is not None: + self.scheduler.update_state_after_alloc(req, block_ids) + def set_scheduler_output(self, scheduler_output: SchedulerOutput): self._scheduler_output = scheduler_output diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index dd8f7e5a39e..d61bf97d6dc 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -426,6 +426,11 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): for _ in range(get_draft_token_length(req)): self.impl.add_token(req.py_request_id) + if self.kv_connector_manager is not None: + block_ids = self.get_cache_indices(req) + self.kv_connector_manager.update_state_after_alloc( + req, block_ids) + for req in generation_batch: self.impl.add_token(req.py_request_id) for _ in range(get_draft_token_length(req)): diff --git a/tests/integration/defs/llmapi/test_llm_api_connector.py b/tests/integration/defs/llmapi/test_llm_api_connector.py index 9ac0426c9fc..083408e7225 100644 --- a/tests/integration/defs/llmapi/test_llm_api_connector.py +++ b/tests/integration/defs/llmapi/test_llm_api_connector.py @@ -20,7 +20,8 @@ from tensorrt_llm import LLM, SamplingParams from tensorrt_llm.llmapi.llm_args import KvCacheConfig, KvCacheConnectorConfig -from tests.integration.defs.conftest import llm_models_root + +from ..conftest import llm_models_root @pytest.fixture(scope="function") @@ -81,6 +82,11 @@ def test_connector_simple(enforce_single_worker, model_with_connector, model.generate(["Hello, world"], sampling_params) + assert scheduler.update_state_after_alloc.call_count == 1 + + # Allocate 1 block. + assert len(scheduler.update_state_after_alloc.call_args.args[1]) == 1 + # With the overlap scheduler, we generate one extra token. assert scheduler.build_connector_meta.call_count == NUM_TOKENS + int( use_overlap_scheduler) @@ -227,6 +233,11 @@ def test_connector_scheduler_output(enforce_single_worker, model_with_connector, model.generate([0] * NUM_INPUT_TOKENS, sampling_params) + assert scheduler.update_state_after_alloc.call_count == 1 + assert len( + scheduler.update_state_after_alloc.call_args.args[1]) == math.ceil( + NUM_INPUT_TOKENS / BLOCK_SIZE) + # Additional token when using the overlap scheduler. assert scheduler.build_connector_meta.call_count == NUM_TOKENS + int( use_overlap_scheduler) @@ -298,6 +309,12 @@ def test_connector_scheduler_output_chunked_context(enforce_single_worker, model.generate([0] * (CHUNK_SIZE * 2), sampling_params) + assert scheduler.update_state_after_alloc.call_count == 1 + + assert len( + scheduler.update_state_after_alloc.call_args.args[1]) == math.ceil( + CHUNK_SIZE * 2 / BLOCK_SIZE) + for i, call in enumerate(scheduler.build_connector_meta.call_args_list): sched_output = call.args[0] From 35eeb029b62d78fe8ac2d7945c1314c053bdab1b Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Mon, 11 Aug 2025 16:46:19 -0700 Subject: [PATCH 39/50] Fix scheduler output Signed-off-by: jthomson04 --- examples/llm-api/llm_kv_cache_connector.py | 6 +----- .../_torch/pyexecutor/kv_cache_connector.py | 18 +++++++++++++++++- .../defs/llmapi/test_llm_api_connector.py | 3 ++- tests/unittest/_torch/test_connector.py | 1 - 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/examples/llm-api/llm_kv_cache_connector.py b/examples/llm-api/llm_kv_cache_connector.py index 24114166611..20c5ff77053 100644 --- a/examples/llm-api/llm_kv_cache_connector.py +++ b/examples/llm-api/llm_kv_cache_connector.py @@ -103,12 +103,8 @@ def build_connector_meta(self, scheduler_output: SchedulerOutput): pending_load = self.pending_loads[req.request_id] - # TODO: The `computed_position` field in the scheduler output counts both the device cache hits and onboarded device blocks. - # This is inconsistent with vLLM. for file_path, block_pos in zip( - pending_load, - range(num_computed_blocks - len(pending_load), - len(block_ids))): + pending_load, range(num_computed_blocks, len(block_ids))): metadata.load.append((file_path, block_ids[block_pos])) # Break up the remainder of the token sequence into chunks. diff --git a/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py b/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py index 6e31daf8023..d885c993454 100644 --- a/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py +++ b/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py @@ -301,6 +301,7 @@ class KvCacheConnectorSchedulerOutputManager: def __init__(self): self.requests = defaultdict(KvCacheConnectorSchedulerOutputRequest) + self.external_loads = dict() def build_scheduler_output(self, scheduled_batch: ScheduledRequests, new_async_requests: AsyncRequests, @@ -316,6 +317,11 @@ def build_scheduler_output(self, scheduled_batch: ScheduledRequests, request_data = self.requests[req.request_id].update_and_build_data( req, kv_cache_manager) + # Don't include the connector matched tokens in the initial scheduler output. + if req.request_id in self.external_loads: + request_data.computed_position -= self.external_loads[ + req.request_id] + if is_new: scheduler_output.new_requests.append(request_data) else: @@ -327,8 +333,14 @@ def build_scheduler_output(self, scheduled_batch: ScheduledRequests, scheduler_output.cached_requests.append(request_data) + self.external_loads = dict() + return scheduler_output + def record_new_matched_tokens(self, request: LlmRequest, + num_new_matched_tokens: int): + self.external_loads[request.request_id] = num_new_matched_tokens + class KvCacheConnectorManager(KvCacheConnectorManagerCpp): """ @@ -394,7 +406,9 @@ def get_num_new_matched_tokens(self, request: LlmRequest, # Because of this, we need to remove it from our list of scheduled requests (see `take_scheduled_requests_pending_load`). if load_kv_async: self.new_async_requests.loading[request.request_id] = request - request.is_kv_cache_connector_async_onboard = True + + self.scheduler_output_manager.record_new_matched_tokens( + request, num_tokens) return num_tokens @@ -426,6 +440,8 @@ def take_scheduled_requests_pending_load( # we also need to update it's state. if req.request_id in self.new_async_requests.loading.keys(): req.state = LlmRequestState.DISAGG_GENERATION_TRANS_IN_PROGRESS + + # Replace the request with the canonical request. self.new_async_requests.loading[req.request_id] = req else: allowed_context_requests.append(req) diff --git a/tests/integration/defs/llmapi/test_llm_api_connector.py b/tests/integration/defs/llmapi/test_llm_api_connector.py index 083408e7225..5ed039b238c 100644 --- a/tests/integration/defs/llmapi/test_llm_api_connector.py +++ b/tests/integration/defs/llmapi/test_llm_api_connector.py @@ -281,8 +281,9 @@ def test_connector_scheduler_output(enforce_single_worker, model_with_connector, model.generate([0] * NUM_INPUT_TOKENS, sampling_params) + # The initial computed position should be 0, since we haven't yet onboarded any blocks. assert scheduler.build_connector_meta.call_args_list[0].args[ - 0].new_requests[0].computed_position == 8 + 0].new_requests[0].computed_position == 0 @pytest.mark.threadleak(enabled=False) diff --git a/tests/unittest/_torch/test_connector.py b/tests/unittest/_torch/test_connector.py index 2c450b24265..1f8cf33ee4d 100644 --- a/tests/unittest/_torch/test_connector.py +++ b/tests/unittest/_torch/test_connector.py @@ -113,7 +113,6 @@ def test(): req.request_id = 42 assert manager.get_num_new_matched_tokens(req, 32) == 16 - assert req.is_kv_cache_connector_async_onboard if mpi_rank() == 0: assert scheduler.get_num_new_matched_tokens.call_count == 1 From b142ddec2aa3b0ccd3dc5acff09626296bc5a671 Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Tue, 12 Aug 2025 12:23:14 -0700 Subject: [PATCH 40/50] fix license headers Signed-off-by: jthomson04 --- .../nanobind/batch_manager/kvCacheConnector.h | 17 +++++++++++++++++ .../pybind/batch_manager/kvCacheConnector.h | 17 +++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheConnector.h b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheConnector.h index 63c183d0b2b..44d9b97f12f 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheConnector.h +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheConnector.h @@ -1,3 +1,20 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #pragma once #include "tensorrt_llm/batch_manager/kvCacheConnector.h" diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.h b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.h index 4b1568a3abe..665fa2fc315 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.h +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.h @@ -1,3 +1,20 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #pragma once #include "tensorrt_llm/batch_manager/kvCacheConnector.h" From 0b210f03b9b3cc932b5c49f29c347580a3d6ae12 Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Tue, 12 Aug 2025 18:25:34 -0700 Subject: [PATCH 41/50] fix tests and test list Signed-off-by: jthomson04 --- .../_torch/pyexecutor/kv_cache_connector.py | 2 +- .../defs/llmapi/test_llm_api_connector.py | 2 +- tests/integration/test_lists/test-db/l0_a10.yml | 13 +++++++++++-- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py b/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py index d885c993454..5e8bf6dfaa3 100644 --- a/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py +++ b/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py @@ -291,7 +291,7 @@ def update_and_build_data(self, req: LlmRequest, computed_position = len( tokens - ) - 1 if req.state != LlmRequestState.CONTEXT_INIT else req.context_current_position + ) - 1 if req.state != LlmRequestState.CONTEXT_INIT and req.state != LlmRequestState.DISAGG_GENERATION_TRANS_IN_PROGRESS else req.context_current_position return RequestData(req.request_id, new_tokens, new_block_ids, computed_position) diff --git a/tests/integration/defs/llmapi/test_llm_api_connector.py b/tests/integration/defs/llmapi/test_llm_api_connector.py index 5ed039b238c..e07947b46dc 100644 --- a/tests/integration/defs/llmapi/test_llm_api_connector.py +++ b/tests/integration/defs/llmapi/test_llm_api_connector.py @@ -279,7 +279,7 @@ def test_connector_scheduler_output(enforce_single_worker, model_with_connector, assert len(scheduler.request_finished.call_args.args[1]) == math.ceil( (NUM_INPUT_TOKENS + NUM_TOKENS) / BLOCK_SIZE) - model.generate([0] * NUM_INPUT_TOKENS, sampling_params) + model.generate([1] * NUM_INPUT_TOKENS, sampling_params) # The initial computed position should be 0, since we haven't yet onboarded any blocks. assert scheduler.build_connector_meta.call_args_list[0].args[ diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index a1dbd5d84ed..1667fef9794 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -77,14 +77,23 @@ l0_a10: - unittest/trt/model/test_mistral.py - unittest/trt/model/test_llama.py - test_e2e.py::test_gpt3_175b_1layers_build_only # 6 mins - - llmapi/test_llm_api_connector.py + - llmapi/test_llm_api_connector.py::test_connector_simple[True] + - llmapi/test_llm_api_connector.py::test_connector_simple[False] + - llmapi/test_llm_api_connector.py::test_connector_async_onboard[True] + - llmapi/test_llm_api_connector.py::test_connector_async_onboard[False] + - llmapi/test_llm_api_connector.py::test_connector_async_save[True] + - llmapi/test_llm_api_connector.py::test_connector_async_save[False] + - llmapi/test_llm_api_connector.py::test_connector_scheduler_output[True] + - llmapi/test_llm_api_connector.py::test_connector_scheduler_output[False] + - llmapi/test_llm_api_connector.py::test_connector_scheduler_output_chunked_context[True] + - llmapi/test_llm_api_connector.py::test_connector_scheduler_output_chunked_context[False] - llmapi/test_llm_e2e.py::test_llmapi_load_engine_from_build_command[llama-llama-models/llama-7b-hf] # 5min - llmapi/test_llm_e2e.py::test_llmapi_build_command_parameters_align[llama-llama-models-v2/TinyLlama-1.1B-Chat-v1.0] - llmapi/test_llm_e2e.py::test_llmapi_load_engine_from_build_command_with_lora[llama-llama-models-v2/llama-v2-7b-hf] - llmapi/test_llm_examples.py::test_llmapi_chat_example - llmapi/test_llm_e2e.py::test_llmapi_exit - llmapi/test_llm_examples.py::test_llmapi_server_example - - llmapi/test_llm_examples.py::test_llmapi_kv_cache_connector + - llmapi/test_llm_examples.py::test_llmapi_kv_cache_connector[Qwen2-0.5B] - test_e2e.py::test_trtllm_serve_example - test_e2e.py::test_openai_misc_example[trt] - test_e2e.py::test_openai_completions_example[trt] From 209052afe9f4b52b1434cd8788aaa83fb58ae10c Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Wed, 13 Aug 2025 11:34:28 -0700 Subject: [PATCH 42/50] Dont pass connector manager through add_sequence Signed-off-by: jthomson04 --- .../batch_manager/kvCacheManager.h | 30 +++++---- .../batch_manager/kvCacheManager.cpp | 67 ++++++++++--------- .../nanobind/batch_manager/kvCacheManager.cpp | 26 ++++--- .../pybind/batch_manager/kvCacheManager.cpp | 18 +++-- .../_torch/pyexecutor/resource_manager.py | 5 +- 5 files changed, 75 insertions(+), 71 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index 9e9272aa83f..c75778f0c6d 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -537,7 +537,8 @@ class WindowBlockManager SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr stream, bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, - std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse); + std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse, + std::shared_ptr kvCacheConnectorManager); ~WindowBlockManager(); @@ -548,8 +549,8 @@ class WindowBlockManager void startScheduling(); //! \brief Assign blocks for new sequence. Try to reuse blocks. - void addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, - LlmRequest& llmRequest, OptionalRef kvCacheConnectorManager); + void addSequence( + GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest); //! \brief Assign blocks for new sequence. Does not try to reuse blocks. void addSequence(GenerationRequest& sequence, SizeType32 numBlocks, SizeType32 unsharedBlockIdx); @@ -834,6 +835,8 @@ class WindowBlockManager bool mEnablePartialReuse; // Whether partially matched blocks that are already in use should be copied and reused. bool mCopyOnPartialReuse; + // The kv cache connector manager + std::shared_ptr mKvCacheConnectorManager; }; class BlockManager @@ -851,7 +854,8 @@ class BlockManager SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType = CacheType::kSELF, std::optional secondaryOffloadMinPriority = std::nullopt, std::shared_ptr eventManager = nullptr, bool enablePartialReuse = true, - bool copyOnPartialReuse = true); + bool copyOnPartialReuse = true, + std::shared_ptr kvCacheConnectorManager = nullptr); BlockManager(BlockManager const&) = delete; BlockManager& operator=(BlockManager const&) = delete; @@ -868,8 +872,7 @@ class BlockManager void allocatePools(bool useUvm); void addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, - LlmRequest& llmRequest, OptionalRef kvCacheConnectorManager, - SizeType32 windowSize); + LlmRequest& llmRequest, SizeType32 windowSize); void addSequence( GenerationRequest& sequence, SizeType32 numBlocks, SizeType32 unsharedBlockIdx, SizeType32 windowSize); @@ -1213,8 +1216,7 @@ class BaseKVCacheManager /// @details If llmRequest is supplied and KV cache reuse is enabled, try to recover KV cache blocks for /// inputLength - 1 tokens and populate prepopulatedPromptLen. virtual void addSequence(LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, - OptionalRef llmRequest = std::nullopt, - OptionalRef kvCacheConnectorManager = std::nullopt) + OptionalRef llmRequest = std::nullopt) = 0; virtual void removeSequence( @@ -1361,7 +1363,8 @@ class KVCacheManager : public BaseKVCacheManager bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF, std::optional secondaryOffloadMinPriority = std::nullopt, std::shared_ptr eventManager = nullptr, bool enablePartialReuse = true, - bool copyOnpartialReuse = true); + bool copyOnpartialReuse = true, + std::shared_ptr kvCacheConnectorManager = nullptr); KVCacheManager(std::vector const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock, BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, @@ -1371,7 +1374,8 @@ class KVCacheManager : public BaseKVCacheManager bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF, std::optional secondaryOffloadMinPriority = std::nullopt, std::shared_ptr eventManager = nullptr, bool enablePartialReuse = true, - bool copyOnpartialReuse = true); + bool copyOnpartialReuse = true, + std::shared_ptr kvCacheConnectorManager = nullptr); KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock, BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, @@ -1381,7 +1385,8 @@ class KVCacheManager : public BaseKVCacheManager bool enableBlockReuse = true, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF, std::optional secondaryOffloadMinPriority = std::nullopt, std::shared_ptr eventManager = nullptr, bool enablePartialReuse = true, - bool copyOnpartialReuse = true); + bool copyOnpartialReuse = true, + std::shared_ptr kvCacheConnectorManager = nullptr); KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock, BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, @@ -1513,8 +1518,7 @@ class KVCacheManager : public BaseKVCacheManager /// @details If llmRequest is supplied and KV cache reuse is enabled, try to recover KV cache blocks for /// inputLength - 1 tokens and populate prepopulatedPromptLen. void addSequence(LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, - OptionalRef llmRequest = std::nullopt, - OptionalRef kvCacheConnectorManager = std::nullopt) override; + OptionalRef llmRequest = std::nullopt) override; void removeSequence( LlmRequest::RequestIdType requestId, OptionalRef llmRequest = std::nullopt) override; diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 4211de0f64b..8696ed8eeb8 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -504,7 +504,8 @@ BlockManager::BlockManager(std::vector const& numKvHeadsPerLayer, Si std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, - std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse) + std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse, + std::shared_ptr kvCacheConnectorManager) : mNumLayers{static_cast(numKvHeadsPerLayer.size())} , mTokensPerBlock{tokensPerBlock} , mEventManager{std::move(eventManager)} @@ -513,6 +514,10 @@ BlockManager::BlockManager(std::vector const& numKvHeadsPerLayer, Si { auto const uniqueWindowSizeToLayers = BaseKVCacheManager::groupLayersByWindowSize(maxAttentionWindowVec, mNumLayers); + + TLLM_CHECK_WITH_INFO(kvCacheConnectorManager == nullptr || uniqueWindowSizeToLayers.size() == 1, + "KV Cache Connector is not supported with multiple window sizes"); + auto const numUniqueWindowSizes = static_cast(uniqueWindowSizeToLayers.size()); mIsVariableWindow = numUniqueWindowSizes > 1; @@ -530,7 +535,7 @@ BlockManager::BlockManager(std::vector const& numKvHeadsPerLayer, Si mWindowBlockManagers.try_emplace(windowSize, dtype, windowSize, layersWithWindowSize, numKvHeadsPerLayer, sizePerHead, tokensPerBlock, allottedPrimaryBlocks, allottedSecondaryBlocks, maxNumSequences, stream, onboardBlocks, cacheType, secondaryOffloadMinPriority, mEventManager, enablePartialReuse, - copyOnPartialReuse); + copyOnPartialReuse, kvCacheConnectorManager); } auto const numAllPools = getNumPools(); @@ -572,7 +577,8 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr stream, bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, - std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse) + std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse, + std::shared_ptr kvCacheConnectorManager) : mDataType{dtype} , mWindowSize{windowSize} , mNumPrimaryBlocks{blocksInPrimaryPool} @@ -596,6 +602,7 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind , mTotalInputTokens{0.0} , mEnablePartialReuse{enablePartialReuse} , mCopyOnPartialReuse{copyOnPartialReuse} + , mKvCacheConnectorManager{std::move(kvCacheConnectorManager)} { std::map numLayersPerPool; @@ -1147,15 +1154,13 @@ void WindowBlockManager::refreshBlocks() } void BlockManager::addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, - LlmRequest& llmRequest, OptionalRef kvCacheConnectorManager, - SizeType32 windowSize) + LlmRequest& llmRequest, SizeType32 windowSize) { - mWindowBlockManagers.at(windowSize) - .addSequence(sequence, inputLength, numContextBlocks, llmRequest, kvCacheConnectorManager); + mWindowBlockManagers.at(windowSize).addSequence(sequence, inputLength, numContextBlocks, llmRequest); } -void WindowBlockManager::addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, - LlmRequest& llmRequest, OptionalRef kvCacheConnectorManager) +void WindowBlockManager::addSequence( + GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest) { auto const requestId = sequence.getRequestId(); auto const [seqIt, emplaceDone] = mAllocatedBlocksPerSeq.emplace(requestId, std::vector{}); @@ -1190,9 +1195,9 @@ void WindowBlockManager::addSequence(GenerationRequest& sequence, SizeType32 inp SizeType32 numConnectorMatchedTokens = 0; // If we're using a KV cache connector, check if any additional blocks can be loaded. - if (kvCacheConnectorManager) + if (mKvCacheConnectorManager && !llmRequest.isDummyRequest()) { - numConnectorMatchedTokens = kvCacheConnectorManager->getNumNewMatchedTokens(llmRequest, prepopulatedPromptLen); + numConnectorMatchedTokens = mKvCacheConnectorManager->getNumNewMatchedTokens(llmRequest, prepopulatedPromptLen); } llmRequest.setPrepopulatedPromptLen(prepopulatedPromptLen + numConnectorMatchedTokens, getTokensPerBlock()); @@ -1208,6 +1213,13 @@ void BlockManager::addSequence( void WindowBlockManager::addSequence(GenerationRequest& sequence, SizeType32 numBlocks, SizeType32 unsharedBlockIdx) { + if (mKvCacheConnectorManager) + { + TLLM_LOG_WARNING( + "KV Cache Connector specified when block reuse is disabled. The KV Cache Connector will be " + "ignored."); + } + auto const requestId = sequence.getRequestId(); auto const [seqIt, emplaceDone] = mAllocatedBlocksPerSeq.emplace(requestId, std::vector{}); TLLM_CHECK(emplaceDone); @@ -1620,12 +1632,13 @@ KVCacheManager::KVCacheManager(std::vector const& numKvHeadsPerLayer SizeType32 sinkTokenLength, int64_t stream, std::optional maxSequenceLength, bool enableBlockReuse, bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, - std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse) + std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse, + std::shared_ptr kvCacheConnectorManager) : KVCacheManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength, std::make_shared(reinterpret_cast(stream)), maxSequenceLength, enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority, eventManager, enablePartialReuse, - copyOnPartialReuse) + copyOnPartialReuse, kvCacheConnectorManager) { } @@ -1636,7 +1649,8 @@ KVCacheManager::KVCacheManager(std::vector const& numKvHeadsPerLayer SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional maxSequenceLength, bool enableBlockReuse, bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, - std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse) + std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse, + std::shared_ptr kvCacheConnectorManager) : mMaxBeamWidth(maxBeamWidth) , mDataType(dtype) , mMaxAttentionWindow(*std::max_element(maxAttentionWindowVec.begin(), maxAttentionWindowVec.end())) @@ -1646,7 +1660,7 @@ KVCacheManager::KVCacheManager(std::vector const& numKvHeadsPerLayer , mBlockManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, std::move(stream), maxSequenceLength, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, mSinkBubbleLength, onboardBlocks, cacheType, secondaryOffloadMinPriority, std::move(eventManager), - enablePartialReuse, copyOnPartialReuse) + enablePartialReuse, copyOnPartialReuse, std::move(kvCacheConnectorManager)) // disable block reuse for sink bubble since chopVectorIntoBlocks does not match KV cache blocks in this case , mEnableBlockReuse{mSinkBubbleLength > 0 ? false : enableBlockReuse} { @@ -1668,11 +1682,12 @@ KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, Size SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional maxSequenceLength, bool enableBlockReuse, bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, - std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse) + std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse, + std::shared_ptr kvCacheConnectorManager) : KVCacheManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength, std::move(stream), maxSequenceLength, enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority, - std::move(eventManager), enablePartialReuse, copyOnPartialReuse) + std::move(eventManager), enablePartialReuse, copyOnPartialReuse, std::move(kvCacheConnectorManager)) { } @@ -1973,17 +1988,12 @@ std::optional KVCacheManager::findNewContextBlock( return newContextBlockOpt; } -void KVCacheManager::addSequence(RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, - OptionalRef llmRequest, OptionalRef kvCacheConnectorManager) +void KVCacheManager::addSequence( + RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, OptionalRef llmRequest) { // Need to add the bubble after the sink tokens to use even block size inputLength += mSinkBubbleLength; - if (kvCacheConnectorManager) - { - TLLM_CHECK_WITH_INFO(beamWidth == 1, "KV Cache Connector is not supported with beam search"); - } - auto kvCacheRetentionConfig = llmRequest ? llmRequest->getKvCacheRetentionConfig().value_or(executor::KvCacheRetentionConfig()) : executor::KvCacheRetentionConfig(); @@ -2027,8 +2037,7 @@ void KVCacheManager::addSequence(RequestIdType requestId, SizeType32 inputLength auto const numContextBlocks = tc::ceilDiv(effectiveInputLength, getTokensPerBlock()); if (!sequence.isCyclic() && mEnableBlockReuse) { - mBlockManager.addSequence( - sequence, effectiveInputLength, numContextBlocks, *llmRequest, kvCacheConnectorManager, windowSize); + mBlockManager.addSequence(sequence, effectiveInputLength, numContextBlocks, *llmRequest, windowSize); } else { @@ -2040,12 +2049,6 @@ void KVCacheManager::addSequence(RequestIdType requestId, SizeType32 inputLength "will " "have no effect.", llmRequest->mRequestId); - if (kvCacheConnectorManager.has_value()) - { - TLLM_LOG_WARNING( - "KV Cache Connector specified when block reuse is disabled. The KV Cache Connector will be " - "ignored."); - } } mBlockManager.addSequence(sequence, numContextBlocks, unsharedBlockIdx, windowSize); } diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index f60ddccb0f8..09159592646 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -39,6 +39,7 @@ #include namespace tb = tensorrt_llm::batch_manager; +namespace tbc = tensorrt_llm::batch_manager::kv_connector; namespace tbk = tensorrt_llm::batch_manager::kv_cache_manager; namespace tr = tensorrt_llm::runtime; namespace nb = nanobind; @@ -110,11 +111,9 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager } void addSequence(tb::LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, - tensorrt_llm::common::OptionalRef llmRequest = std::nullopt, - tensorrt_llm::common::OptionalRef kvCacheConnectorManager - = std::nullopt) override + tensorrt_llm::common::OptionalRef llmRequest = std::nullopt) override { - NB_OVERRIDE_PURE(addSequence, requestId, inputLength, beamWidth, llmRequest, kvCacheConnectorManager); + NB_OVERRIDE_PURE(addSequence, requestId, inputLength, beamWidth, llmRequest); } void removeSequence(tb::LlmRequest::RequestIdType requestId, @@ -348,9 +347,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) .def("get_needed_blocks_one_step", &BaseKVCacheManager::getNeededBlocksOneStep) .def("get_remaining_blocks_to_completion", &BaseKVCacheManager::getRemainingBlocksToCompletion) .def("add_token", &BaseKVCacheManager::addToken) - .def("add_sequence", &BaseKVCacheManager::addSequence, nb::arg("request_id"), nb::arg("input_length"), - nb::arg("beam_width"), nb::arg("llm_request") = std::nullopt, - nb::arg("kv_cache_connector_manager") = std::nullopt) + .def("add_sequence", &BaseKVCacheManager::addSequence) .def("remove_sequence", &BaseKVCacheManager::removeSequence) .def("scheduling_remove_sequence", &BaseKVCacheManager::schedulingRemoveSequence) .def("get_block_pool_pointers", @@ -450,12 +447,13 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) .value("SELFKONLY", tbk::CacheType::kSELFKONLY); nb::class_(m, "KVCacheManager") - .def(nb::init const&, SizeType32, SizeType32, - std::map> const&, SizeType32, SizeType32, - std::vector const&, std::optional const&, - nvinfer1::DataType, SizeType32, int64_t, std::optional, bool, bool, - tbk::CacheType, std::optional, - std::shared_ptr, bool, bool>(), + .def( + nb::init const&, SizeType32, SizeType32, + std::map> const&, SizeType32, SizeType32, + std::vector const&, std::optional const&, + nvinfer1::DataType, SizeType32, int64_t, std::optional, bool, bool, tbk::CacheType, + std::optional, std::shared_ptr, + bool, bool, std::shared_ptr>(), nb::arg("num_kv_heads_per_layer"), nb::arg("size_per_head"), nb::arg("tokens_per_block"), nb::arg("blocks_per_window"), nb::arg("max_num_sequences"), nb::arg("max_beam_width"), nb::arg("max_attention_window_vec"), nb::arg("temp_attention_window_inputs").none(), nb::arg("dtype"), @@ -463,7 +461,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) nb::arg("enable_block_reuse") = false, nb::arg("onboard_blocks") = true, nb::arg("cache_type") = tbk::CacheType::kSELF, nb::arg("secondary_offload_min_priority") = std::nullopt, nb::arg("event_manager") = nullptr, nb::arg("enable_partial_reuse") = true, - nb::arg("copy_on_partial_reuse") = true); + nb::arg("copy_on_partial_reuse") = true, nb::arg("kv_connector_manager") = nullptr); } void tb::BasePeftCacheManagerBindings::initBindings(nb::module_& m) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index 3861f3f5c18..b49bdb7bbcc 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -30,6 +30,7 @@ #include namespace tb = tensorrt_llm::batch_manager; +namespace tbc = tensorrt_llm::batch_manager::kv_connector; namespace tbk = tensorrt_llm::batch_manager::kv_cache_manager; namespace tr = tensorrt_llm::runtime; namespace py = pybind11; @@ -96,12 +97,10 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager } void addSequence(tb::LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, - tensorrt_llm::common::OptionalRef llmRequest = std::nullopt, - tensorrt_llm::common::OptionalRef kvCacheConnectorManager - = std::nullopt) override + tensorrt_llm::common::OptionalRef llmRequest = std::nullopt) override { - PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, addSequence, requestId, inputLength, beamWidth, - llmRequest, kvCacheConnectorManager); + PYBIND11_OVERLOAD_PURE( + void, tbk::BaseKVCacheManager, addSequence, requestId, inputLength, beamWidth, llmRequest); } void removeSequence(tb::LlmRequest::RequestIdType requestId, @@ -350,9 +349,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) .def("get_needed_blocks_one_step", &BaseKVCacheManager::getNeededBlocksOneStep) .def("get_remaining_blocks_to_completion", &BaseKVCacheManager::getRemainingBlocksToCompletion) .def("add_token", &BaseKVCacheManager::addToken) - .def("add_sequence", &BaseKVCacheManager::addSequence, py::arg("request_id"), py::arg("input_length"), - py::arg("beam_width"), py::arg("llm_request") = std::nullopt, - py::arg("kv_cache_connector_manager") = std::nullopt) + .def("add_sequence", &BaseKVCacheManager::addSequence) .def("remove_sequence", &BaseKVCacheManager::removeSequence) .def("scheduling_remove_sequence", &BaseKVCacheManager::schedulingRemoveSequence) .def("get_block_pool_pointers", @@ -447,7 +444,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) std::vector const&, std::optional const&, nvinfer1::DataType, SizeType32, bool, int64_t, bool, bool, tbk::CacheType, std::optional, std::shared_ptr, - bool, bool>(), + bool, bool, std::shared_ptr>(), py::arg("num_kv_heads_per_layer"), py::arg("size_per_head"), py::arg("tokens_per_block"), py::arg("blocks_per_window"), py::arg("max_num_sequences"), py::arg("max_beam_width"), py::arg("max_attention_window_vec"), py::arg("temp_attention_window_inputs"), py::arg("dtype"), @@ -455,7 +452,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) py::arg("enable_block_reuse") = false, py::arg("onboard_blocks") = true, py::arg_v("cache_type", tbk::CacheType::kSELF, "bindings.internal.batch_manager.CacheType.SELF"), py::arg("secondary_offload_min_priority") = std::nullopt, py::arg("event_manager") = nullptr, - py::arg("enable_partial_reuse") = true, py::arg("copy_on_partial_reuse") = true); + py::arg("enable_partial_reuse") = true, py::arg("copy_on_partial_reuse") = true, + py::arg("kv_connector_manager") = nullptr); } void tb::BasePeftCacheManagerBindings::initBindings(py::module_& m) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 61ee3a305a3..92ef01c26f1 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -333,6 +333,7 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int], 'cache_type': kv_cache_type, 'enable_partial_reuse': kv_cache_config.enable_partial_reuse, 'copy_on_partial_reuse': kv_cache_config.copy_on_partial_reuse, + 'kv_connector_manager': self.kv_connector_manager, } if self.event_buffer_max_size > 0: if mapping.enable_attention_dp: @@ -415,13 +416,13 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): req.py_request_id, seq_len + (len(req.query_id) if self.mapping.cp_rank == self.mapping.cp_size - 1 else 0), - req_beam_width, req, None) + req_beam_width, req) else: if req.is_first_context_chunk and self._kv_connector_should_add_sequence( req): self.impl.add_sequence(req.py_request_id, req.prompt_len, req_beam_width, - req, self.kv_connector_manager) + req) for _ in range(self.num_extra_kv_tokens): self.impl.add_token(req.py_request_id) for _ in range(get_draft_token_length(req)): From df6350de64548a1c5a8b1c7e143909d7a4784724 Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Fri, 15 Aug 2025 17:29:26 -0700 Subject: [PATCH 43/50] Init scheduler and worker concurrently Signed-off-by: jthomson04 --- .../_torch/pyexecutor/py_executor_creator.py | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 219c9dcddf1..bf38e50830f 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -1,6 +1,7 @@ import copy import enum import importlib +from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager from dataclasses import dataclass from itertools import chain @@ -386,14 +387,21 @@ def create_py_executor( scheduler_cls = getattr( module, kv_connector_config.connector_scheduler_class) - connector_worker = worker_cls(executor_config) - - # Only initialize the scheduler on rank 0. - rank = tensorrt_llm.mpi_rank() - if rank == 0: - connector_scheduler = scheduler_cls(executor_config) - else: - connector_scheduler = None + # Some connector API implementations may need to establish out-of-band communication between the scheduler and workers. + # In this case, the worker may be dependent on the scheduler, or vice-versa. + # To deal with cases like this, we instantiate them both concurrently. + with ThreadPoolExecutor(max_workers=2) as executor: + connector_worker_task = executor.submit(worker_cls, + executor_config) + + if scheduler_cls is not None: + connector_scheduler_task = executor.submit( + scheduler_cls, executor_config) + connector_scheduler = connector_scheduler_task.result() + else: + connector_scheduler = None + + connector_worker = connector_worker_task.result() kv_connector_manager = KvCacheConnectorManager( connector_worker, connector_scheduler) From d5f7f1de1a967d0ba22d78435480c8ef89a5230a Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Tue, 19 Aug 2025 12:46:29 -0700 Subject: [PATCH 44/50] maybe fix CI Signed-off-by: jthomson04 --- examples/llm-api/llm_kv_cache_connector.py | 10 ++++++---- tensorrt_llm/llmapi/llm.py | 3 ++- .../integration/defs/llmapi/test_llm_api_connector.py | 1 - 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/llm-api/llm_kv_cache_connector.py b/examples/llm-api/llm_kv_cache_connector.py index 20c5ff77053..1a0763853ad 100644 --- a/examples/llm-api/llm_kv_cache_connector.py +++ b/examples/llm-api/llm_kv_cache_connector.py @@ -1,3 +1,7 @@ +### :title KV Cache Connector +### :order 6 +### :section Customization + import os import sys from dataclasses import dataclass, field @@ -207,8 +211,7 @@ def main(model: str): llm = LLM(model=model, backend="pytorch", cuda_graph_config=None, - connector_config=connector_config, - use_torch_sampler=True) + connector_config=connector_config) test_text = ( "Nvidia Corporation is an American technology company headquartered in Santa Clara, California." @@ -229,8 +232,7 @@ def main(model: str): llm = LLM(model=model, backend="pytorch", cuda_graph_config=None, - connector_config=connector_config, - use_torch_sampler=True) + connector_config=connector_config) output = llm.generate([test_text], sampling_params) text1 = output[0].outputs[0].text diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 96a683e765d..cccf82ebfe1 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -1054,7 +1054,8 @@ def _build_model(self): lora_config=self.args.lora_config, garbage_collection_gen0_threshold=self.args. garbage_collection_gen0_threshold, - kv_connector_config=self.args.connector_config, + # Autodeploy does not support connector_config + kv_connector_config=getattr(self.args, "connector_config", None), ) def _validate_args_for_torch_backend(self, kwargs: dict) -> None: diff --git a/tests/integration/defs/llmapi/test_llm_api_connector.py b/tests/integration/defs/llmapi/test_llm_api_connector.py index e07947b46dc..7b0c5383d26 100644 --- a/tests/integration/defs/llmapi/test_llm_api_connector.py +++ b/tests/integration/defs/llmapi/test_llm_api_connector.py @@ -49,7 +49,6 @@ def model_fn(*args, **kwargs): connector_config=connector_config, cuda_graph_config=None, kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1), - use_torch_sampler=True, ) yield model_fn, mock_scheduler, mock_worker From ebfe401548e756f29d0ce160394a6f1c3aa87655 Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Tue, 19 Aug 2025 18:23:48 -0700 Subject: [PATCH 45/50] Add fix for llm stability Signed-off-by: jthomson04 --- examples/llm-api/llm_kv_cache_connector.py | 6 +++--- tensorrt_llm/llmapi/llm.py | 4 ++-- tensorrt_llm/llmapi/llm_args.py | 3 ++- tests/integration/defs/llmapi/test_llm_api_connector.py | 4 ++-- tests/unittest/api_stability/references/llm.yaml | 4 ++++ 5 files changed, 13 insertions(+), 8 deletions(-) diff --git a/examples/llm-api/llm_kv_cache_connector.py b/examples/llm-api/llm_kv_cache_connector.py index 1a0763853ad..bd8bf7fcc7e 100644 --- a/examples/llm-api/llm_kv_cache_connector.py +++ b/examples/llm-api/llm_kv_cache_connector.py @@ -199,7 +199,7 @@ def main(model: str): this_module = __file__[__file__.rfind("/") + 1:__file__.rfind(".py")] - connector_config = KvCacheConnectorConfig( + kv_connector_config = KvCacheConnectorConfig( connector_module=this_module, connector_scheduler_class="PersistentKvCacheConnectorLeader", connector_worker_class="PersistentKvCacheConnectorWorker", @@ -211,7 +211,7 @@ def main(model: str): llm = LLM(model=model, backend="pytorch", cuda_graph_config=None, - connector_config=connector_config) + kv_connector_config=kv_connector_config) test_text = ( "Nvidia Corporation is an American technology company headquartered in Santa Clara, California." @@ -232,7 +232,7 @@ def main(model: str): llm = LLM(model=model, backend="pytorch", cuda_graph_config=None, - connector_config=connector_config) + kv_connector_config=kv_connector_config) output = llm.generate([test_text], sampling_params) text1 = output[0].outputs[0].text diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index cccf82ebfe1..048a3a93f58 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -1054,8 +1054,8 @@ def _build_model(self): lora_config=self.args.lora_config, garbage_collection_gen0_threshold=self.args. garbage_collection_gen0_threshold, - # Autodeploy does not support connector_config - kv_connector_config=getattr(self.args, "connector_config", None), + # Autodeploy does not support kv_connector_config + kv_connector_config=getattr(self.args, "kv_connector_config", None), ) def _validate_args_for_torch_backend(self, kwargs: dict) -> None: diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index b9d2bd498c2..3806bae294b 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -2181,9 +2181,10 @@ class TorchLlmArgs(BaseLlmArgs): status="prototype", ) - connector_config: Optional[KvCacheConnectorConfig] = Field( + kv_connector_config: Optional[KvCacheConnectorConfig] = Field( default=None, description="The config for KV cache connector.", + status="prototype", ) # PrivateVars diff --git a/tests/integration/defs/llmapi/test_llm_api_connector.py b/tests/integration/defs/llmapi/test_llm_api_connector.py index 7b0c5383d26..e32c1615441 100644 --- a/tests/integration/defs/llmapi/test_llm_api_connector.py +++ b/tests/integration/defs/llmapi/test_llm_api_connector.py @@ -34,7 +34,7 @@ def model_with_connector(): importlib_mock.import_module.return_value.KvConnectorScheduler.return_value = mock_scheduler importlib_mock.import_module.return_value.KvConnectorWorker.return_value = mock_worker - connector_config = KvCacheConnectorConfig( + kv_connector_config = KvCacheConnectorConfig( connector_module="", connector_scheduler_class="KvConnectorScheduler", connector_worker_class="KvConnectorWorker", @@ -46,7 +46,7 @@ def model_fn(*args, **kwargs): **kwargs, model=f"{llm_models_root()}/Qwen2-0.5B", backend="pytorch", - connector_config=connector_config, + kv_connector_config=kv_connector_config, cuda_graph_config=None, kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1), ) diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index d9dcd0f83d2..d82284efafa 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -159,6 +159,10 @@ methods: annotation: Optional[tensorrt_llm.llmapi.llm_args.DecodingConfig] default: null status: deprecated + kv_connector_config: + annotation: Optional[tensorrt_llm.llmapi.llm_args.KvCacheConnectorConfig] + default: null + status: prototype return_annotation: None generate: parameters: From a383d03bf83bca04f6e2d7d7165eebaf692f9eb4 Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Thu, 21 Aug 2025 14:26:18 -0700 Subject: [PATCH 46/50] Dont call request_finished unless request has already been scheduled Signed-off-by: jthomson04 --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index ab8732120ad..5955733ef6e 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1664,11 +1664,16 @@ def _terminate_request(self, request: LlmRequest): if self.kv_connector_manager is None: self.resource_manager.free_resources(request) else: - cache_block_ids = self.kv_cache_manager.get_cache_indices(request) - - if not self.kv_connector_manager.request_finished( - request, cache_block_ids): - self.resource_manager.free_resources(request) + # Only call request_finished on the connector if the request has already been added to the kv cache manager. + try: + cache_block_ids = self.kv_cache_manager.get_cache_indices( + request) + except IndexError: + pass + else: + if not self.kv_connector_manager.request_finished( + request, cache_block_ids): + self.resource_manager.free_resources(request) @nvtx_range("_handle_canceled_requests") def _handle_canceled_requests(self): From 2f99b38080f591c724eeea577b0592cf785ed047 Mon Sep 17 00:00:00 2001 From: richardhuo-nv Date: Tue, 26 Aug 2025 09:29:07 -0700 Subject: [PATCH 47/50] fix commit Signed-off-by: richardhuo-nv --- tensorrt_llm/executor/proxy.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorrt_llm/executor/proxy.py b/tensorrt_llm/executor/proxy.py index e63a33588f1..ec561bb2918 100644 --- a/tensorrt_llm/executor/proxy.py +++ b/tensorrt_llm/executor/proxy.py @@ -95,6 +95,7 @@ def __init__( worker_kwargs = dict(**worker_kwargs, worker_queues=self._setup_queues(), postproc_worker_config=postproc_worker_config, + is_llm_executor=False, kv_connector_config=kv_connector_config) if "log_level" not in worker_kwargs: From d90d8c676dc0519d9fdd4aea3b4e185bb0794241 Mon Sep 17 00:00:00 2001 From: richardhuo-nv Date: Tue, 26 Aug 2025 09:40:46 -0700 Subject: [PATCH 48/50] fix pre-commit hook Signed-off-by: richardhuo-nv --- tensorrt_llm/executor/executor.py | 3 +-- tensorrt_llm/executor/worker.py | 3 +-- tensorrt_llm/llmapi/llm_args.py | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/tensorrt_llm/executor/executor.py b/tensorrt_llm/executor/executor.py index 09e1382e324..8be96a8ab9c 100644 --- a/tensorrt_llm/executor/executor.py +++ b/tensorrt_llm/executor/executor.py @@ -21,8 +21,7 @@ from ..bindings import executor as tllm from ..builder import Engine from ..disaggregated_params import DisaggregatedParams -from ..llmapi.llm_args import KvCacheConnectorConfig -from ..llmapi.llm_args import TorchLlmArgs +from ..llmapi.llm_args import KvCacheConnectorConfig, TorchLlmArgs from ..llmapi.llm_utils import KvCacheRetentionConfig from ..llmapi.mpi_session import (MpiSession, external_mpi_comm_available, need_spawn_mpi_workers) diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 1dd2577c5f9..01b9f052b9f 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -18,8 +18,7 @@ mpi_comm, mpi_rank, nvtx_range_debug) from ..bindings import executor as tllm from ..builder import ConfigEncoder, Engine, EngineConfig -from ..llmapi.llm_args import KvCacheConnectorConfig, PybindMirror -from ..llmapi.llm_args import PybindMirror, TorchLlmArgs +from ..llmapi.llm_args import KvCacheConnectorConfig, PybindMirror, TorchLlmArgs from ..llmapi.mpi_session import set_mpi_session_cpp from ..llmapi.tokenizer import TokenizerBase from ..llmapi.tracer import VizTracer, global_tracer, set_global_tracer diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 6420d2649c3..62db888a447 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -2321,7 +2321,7 @@ class TorchLlmArgs(BaseLlmArgs): default=None, description="The config for KV cache connector.", ) - + mm_encoder_only: bool = Field( default=False, description= From 0126e6fbbfd0c0059ca264b1a73c62b8bbce9104 Mon Sep 17 00:00:00 2001 From: richardhuo-nv Date: Tue, 26 Aug 2025 16:10:01 -0700 Subject: [PATCH 49/50] fix merge Signed-off-by: richardhuo-nv --- tensorrt_llm/executor/worker.py | 26 +++----------------------- 1 file changed, 3 insertions(+), 23 deletions(-) diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 01b9f052b9f..105005ce95a 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -159,29 +159,9 @@ def _create_engine(executor_config): executor_config=executor_config, managed_weights=engine.managed_weights) - if not hasattr(executor_config, "backend"): - return tllm.Executor(engine, tllm.ModelType.DECODER_ONLY, - executor_config) - args = { - "executor_config": executor_config, - "checkpoint_dir": executor_config.hf_model_dir, - } - if executor_config.backend == "pytorch": - from tensorrt_llm._torch.pyexecutor.py_executor_creator import \ - create_py_executor - create_executor = create_py_executor - args["lora_config"] = lora_config - args[ - "garbage_collection_gen0_threshold"] = llm_args.garbage_collection_gen0_threshold - args["kv_connector_config"] = kv_connector_config - elif executor_config.backend == "_autodeploy": - from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \ - create_autodeploy_executor - create_executor = create_autodeploy_executor - else: - raise ValueError( - f"Unsupported backend config: {executor_config.backend}") - return create_executor(**args) + assert not hasattr(executor_config, "backend") + return tllm.Executor(engine, tllm.ModelType.DECODER_ONLY, + executor_config) self.engine = _create_py_executor( executor_config) if llm_args is not None else _create_engine( From aa1b9a031ec79b53b21bd877b4fbd4e431738b35 Mon Sep 17 00:00:00 2001 From: richardhuo-nv Date: Wed, 27 Aug 2025 23:11:44 -0700 Subject: [PATCH 50/50] fix the scheduler could start on rank > 0 Signed-off-by: richardhuo-nv --- tensorrt_llm/_torch/pyexecutor/py_executor_creator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 56fbdb6cf25..e824ee02d8d 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -404,6 +404,7 @@ def create_py_executor( scheduler_cls = getattr( module, kv_connector_config.connector_scheduler_class) + rank = tensorrt_llm.mpi_rank() # Some connector API implementations may need to establish out-of-band communication between the scheduler and workers. # In this case, the worker may be dependent on the scheduler, or vice-versa. # To deal with cases like this, we instantiate them both concurrently. @@ -411,7 +412,7 @@ def create_py_executor( connector_worker_task = executor.submit(worker_cls, executor_config) - if scheduler_cls is not None: + if scheduler_cls is not None and rank == 0: connector_scheduler_task = executor.submit( scheduler_cls, executor_config) connector_scheduler = connector_scheduler_task.result()