From 776f2df82291212a848059c586cb2ea21505cb24 Mon Sep 17 00:00:00 2001 From: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> Date: Wed, 13 Aug 2025 09:09:31 +0000 Subject: [PATCH 1/3] optimal kvcache transfer Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> --- .../tensorrt_llm/batch_manager/kvCacheUtils.h | 162 ++++++++++------- .../tensorrt_llm/batch_manager/llmRequest.h | 10 +- .../batch_manager/cacheFormatter.cpp | 166 +++++++++++------- .../batch_manager/cacheTransBuffer.cpp | 20 ++- .../batch_manager/cacheTransBuffer.h | 1 + .../batch_manager/dataTransceiver.cpp | 43 +++-- .../batch_manager/dataTransceiver.h | 10 +- .../batch_manager/dataTransceiverImpl.cpp | 5 +- .../batch_manager/mlaCacheFormatter.cpp | 28 +-- .../trtGptModelInflightBatching.cpp | 2 +- cpp/tensorrt_llm/common/envUtils.cpp | 6 + cpp/tensorrt_llm/common/envUtils.h | 2 + .../pybind/batch_manager/cacheTransceiver.cpp | 3 +- .../batch_manager/cacheTransceiverTest.cpp | 83 +++++---- .../batch_manager/cacheTransBufferTest.cpp | 8 +- .../batch_manager/kvCacheUtilsTest.cpp | 6 +- .../accuracy/test_disaggregated_serving.py | 3 - 17 files changed, 353 insertions(+), 205 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h index 2aebf77b96d..419ffb0902e 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h @@ -17,119 +17,159 @@ #pragma once #include "tensorrt_llm/batch_manager/kvCacheManager.h" +#include "tensorrt_llm/runtime/iTensor.h" namespace tensorrt_llm::batch_manager::kv_cache_manager { class BlockIterator; -class BlockRange +class BlockRangeForWindow { public: - // C++20 std::default_sentinel_t equivalent + BlockRangeForWindow(std::vector blockIds, runtime::ITensor::SharedPtr pool) + : mBlockIds(std::move(blockIds)) + , mPool(std::move(pool)) + { + } + struct Sentinel { }; - static BlockRange fromAllBlockIds(BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId, - SizeType32 beam = kFIRST_AND_ONLY_BEAM) - { - assert(kFIRST_AND_ONLY_BEAM == beam); - auto const windowSize = firstWindowSize(cacheManager); - auto const blockIds = cacheManager.getSequence(requestId).getCacheBlockIds(windowSize).at(kFIRST_AND_ONLY_BEAM); - return BlockRange(cacheManager, blockIds, requestId); - } + friend class BlockIterator; + BlockIterator begin() const; - static BlockRange fromNewlyAllocatedBlockIds( - BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId) + [[nodiscard]] Sentinel end() const { - auto const windowSize = firstWindowSize(cacheManager); - auto const blockIds = cacheManager.getNewlyAllocatedBlockIds(requestId, windowSize); - return BlockRange(cacheManager, blockIds, requestId); + return {}; } - BlockRange(runtime::ITensor::SharedPtr pool, std::vector const& blockIds) // Only used in tests - : mManager{nullptr} - , mPool{std::move(pool)} - , mWindowSize{0} - , mRequestId{0} - , mBlockIds{blockIds} + [[nodiscard]] size_t size() const { - TLLM_CHECK(mPool); + return mBlockIds.size(); } - [[nodiscard]] BlockIterator begin() const; +private: + std::vector mBlockIds; + runtime::ITensor::SharedPtr mPool; +}; - [[nodiscard]] Sentinel end() const +class BlockRange +{ +public: + static BlockRange fromAllBlockIds(BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId) { - return {}; + + return BlockRange(cacheManager, requestId); } - [[nodiscard]] size_t size() const + static BlockRange fromNewlyAllocatedBlockIds( + BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId) { - return mBlockIds.size(); + std::unordered_map> blockIdsPerWindow; + + auto windowsMetadata = cacheManager.getBlockManager().getWindowSizesMetadata(); + for (auto const& [windowSize, metadata] : windowsMetadata) + { + blockIdsPerWindow[windowSize] = cacheManager.getNewlyAllocatedBlockIds(requestId, windowSize); + } + return BlockRange(cacheManager, std::move(blockIdsPerWindow), requestId); } - [[nodiscard]] std::vector const& getBlockIds() const + void setBlockIdsForWindow(SizeType32 windowSize, std::vector blockIds) { - return mBlockIds; + TLLM_CHECK_WITH_INFO(mBlockIdsPerWindow.find(windowSize) != mBlockIdsPerWindow.end(), + "Window size %d should exists", windowSize); + mBlockIdsPerWindow[windowSize] = std::move(blockIds); } - void setBlockIds(std::vector blockIds) + void setBlockIdsForAllWindows(std::unordered_map> blockIdsPerWindow) { - mBlockIds = std::move(blockIds); + for (auto const& [windowSize, blockIds] : blockIdsPerWindow) + { + TLLM_CHECK_WITH_INFO( + mPoolsPerWindow.find(windowSize) != mPoolsPerWindow.end(), "Window size %d should exists", windowSize); + } + mBlockIdsPerWindow = std::move(blockIdsPerWindow); } - [[nodiscard]] std::vector getBlockHashes() const + [[nodiscard]] std::unordered_map> getBlockHashesPerWindow() const { TLLM_CHECK(mManager); - std::vector blockHashes; - blockHashes.reserve(mBlockIds.size()); + std::unordered_map> blockHashesPerWindow; auto& blockManager = mManager->getBlockManager(); - for (auto id : mBlockIds) + for (auto const& [windowSize, blockIds] : mBlockIdsPerWindow) { - blockHashes.emplace_back(blockManager.getBlockById(id, mWindowSize)->getHash()); + for (auto const& blockId : blockIds) + { + blockHashesPerWindow[windowSize].emplace_back( + blockManager.getBlockById(blockId, windowSize)->getHash()); + } } - return blockHashes; + return blockHashesPerWindow; } - void updatePoolIdx(SizeType32 poolIdx) + BlockRangeForWindow getBlockRangeForWindow(SizeType32 windowSize) const { - TLLM_CHECK(mManager); - mPool = mManager->getBlockManager().getPrimaryPool(poolIdx); - auto const newWindowSize = mManager->getBlockManager().getPoolWindowSize(poolIdx); - if (newWindowSize != mWindowSize) + TLLM_CHECK_WITH_INFO( + mPoolsPerWindow.find(windowSize) != mPoolsPerWindow.end(), "Window size %d not found", windowSize); + auto pool = mPoolsPerWindow.at(windowSize).front(); + auto blockIds = mBlockIdsPerWindow.at(windowSize); + return BlockRangeForWindow(std::move(blockIds), std::move(pool)); + } + + std::vector getWindowSizes() const + { + std::vector windowSizes; + for (auto const& [windowSize, _] : mPoolsPerWindow) { - mWindowSize = newWindowSize; - mBlockIds = mManager->getSequence(mRequestId).getCacheBlockIds(mWindowSize).at(kFIRST_AND_ONLY_BEAM); + windowSizes.push_back(windowSize); } + return windowSizes; } - friend class BlockIterator; + std::unordered_map> const& getBlockIdsPerWindow() const + { + return mBlockIdsPerWindow; + } private: - BlockRange( - BaseKVCacheManager const& cacheManager, std::vector blockIds, LlmRequest::RequestIdType requestId) + BlockRange(BaseKVCacheManager const& cacheManager, + std::unordered_map> blockIdsPerWindow, LlmRequest::RequestIdType requestId) : mManager(&cacheManager) - , mPool(cacheManager.getBlockManager().getPrimaryPool(kFIRST_POOL_INDEX)) - , mWindowSize(firstWindowSize(cacheManager)) , mRequestId(requestId) - , mBlockIds(std::move(blockIds)) + , mBlockIdsPerWindow(std::move(blockIdsPerWindow)) { + + // cacheManager.getBlockManager.getPrimaryPool(0); + auto poolNum = mManager->getNumPools(); + for (SizeType32 poolIdx = 0; poolIdx < poolNum; ++poolIdx) + { + auto windowSize = cacheManager.getBlockManager().getPoolWindowSize(poolIdx); + mPoolsPerWindow[windowSize].push_back(cacheManager.getBlockManager().getPrimaryPool(poolIdx)); + } } - static SizeType32 firstWindowSize(BaseKVCacheManager const& cacheManager) + BlockRange(BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId) + : mManager(&cacheManager) + , mRequestId(requestId) { - constexpr SizeType32 FIRST_POOL_IDX = 0; - return cacheManager.getBlockManager().getPoolWindowSize(FIRST_POOL_IDX); + auto poolNum = mManager->getNumPools(); + for (SizeType32 poolIdx = 0; poolIdx < poolNum; ++poolIdx) + { + auto windowSize = cacheManager.getBlockManager().getPoolWindowSize(poolIdx); + mPoolsPerWindow[windowSize].push_back(cacheManager.getBlockManager().getPrimaryPool(poolIdx)); + mBlockIdsPerWindow[windowSize] + = cacheManager.getSequence(mRequestId).getCacheBlockIds(windowSize).at(kFIRST_AND_ONLY_BEAM); + } } private: BaseKVCacheManager const* mManager; - runtime::ITensor::SharedPtr mPool; - SizeType32 mWindowSize; - const LlmRequest::RequestIdType mRequestId; - std::vector mBlockIds; + LlmRequest::RequestIdType const mRequestId; + std::unordered_map> mBlockIdsPerWindow; + std::unordered_map> mPoolsPerWindow; static constexpr SizeType32 kFIRST_AND_ONLY_BEAM = 0; static constexpr SizeType32 kFIRST_POOL_INDEX = 0; @@ -144,7 +184,7 @@ class BlockIterator using reference = value_type&; using SizeType32 = tensorrt_llm::runtime::SizeType32; - BlockIterator(BlockRange const* range, size_t idx) + BlockIterator(BlockRangeForWindow const* range, size_t idx) : mRange{range} , mIdx{idx} { @@ -187,7 +227,7 @@ class BlockIterator return mIdx == other.mIdx && mRange == other.mRange; } - [[nodiscard]] bool operator==(BlockRange::Sentinel other) const + [[nodiscard]] bool operator==(BlockRangeForWindow::Sentinel other) const { return mIdx == mRange->mBlockIds.size(); } @@ -207,12 +247,12 @@ class BlockIterator } } - BlockRange const* mRange; + BlockRangeForWindow const* mRange; runtime::ITensor::SharedPtr mCurrent; size_t mIdx; }; -inline BlockIterator BlockRange::begin() const +inline BlockIterator BlockRangeForWindow::begin() const { return {this, 0}; } diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index e4d13c9e17b..e973f16f67e 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -1831,14 +1831,14 @@ class GenericLlmRequest } } - void setRequestedBlockHashes(std::vector hashes) + void setRequestedBlockHashes(std::unordered_map>&& hashesPerWindow) { - mRequestedBlockHashes = std::move(hashes); + mRequestedBlockHashesPerWindow = std::move(hashesPerWindow); } - [[nodiscard]] std::vector const& getRequestedBlockHashes() const + [[nodiscard]] std::unordered_map> const& getRequestedBlockHashesPerWindow() const { - return mRequestedBlockHashes; + return mRequestedBlockHashesPerWindow; } void setIsDummyRequest(bool isDummyRequest) @@ -2033,7 +2033,7 @@ class GenericLlmRequest TensorMap mAdditionalGenerationOutputTensors; // Context request only. The hashes of the blocks that are requested by the corresponding generation request. - std::vector mRequestedBlockHashes; + std::unordered_map> mRequestedBlockHashesPerWindow; bool mIsDummyRequest{false}; diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp index 503c2e6c5d0..3320c013dd0 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @@ -19,6 +19,7 @@ #include "mlaCacheFormatter.h" #include "tensorrt_llm/batch_manager/contextProgress.h" +#include "tensorrt_llm/batch_manager/kvCacheEventManager.h" #include "tensorrt_llm/batch_manager/kvCacheUtils.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/cudaUtils.h" @@ -41,35 +42,72 @@ namespace tensorrt_llm::batch_manager::kv_cache_manager BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest) { - size_t requestBlockNum = llmRequest.getRequestedBlockHashes().size(); - constexpr SizeType32 beam{0}; - auto blockRange = BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId, beam); - auto poolNum = cacheManager->getBlockManager().getNumPools(); - if (poolNum > 1 || common::getEnvDisableSelectiveCacheTransfer()) + bool needSendAllForWindow = common::getEnvKVCacheTransferAllBlocksForWindow(); + + auto blockRange = BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId); + // auto inputLen = llmRequest.getPromptLen(); + + auto const& windowsMetadata = cacheManager->getBlockManager().getWindowSizesMetadata(); + + if (common::getEnvDisableSelectiveCacheTransfer() && (windowsMetadata.size() == 1 || needSendAllForWindow)) { - // disable selective cache transfer for poolNum > 1 return blockRange; } - if (requestBlockNum < blockRange.size() && requestBlockNum > 0) + auto const& blockIdsPerWindow = blockRange.getBlockIdsPerWindow(); + + bool needReuse = !common::getEnvDisableSelectiveCacheTransfer(); + auto const& requestedBlockHashesPerWindow = llmRequest.getRequestedBlockHashesPerWindow(); + for (auto const& [windowSize, metadata] : windowsMetadata) { - // handle block reuse, the prefix blocks are reused - // TODO(zhengd): pass the hashes directly instead of from llmRequest; use hash instead of block num - auto const& ids = blockRange.getBlockIds(); - blockRange.setBlockIds({ids.end() - requestBlockNum, ids.end()}); + SizeType32 reuseStartBlockIdx + = (needReuse && requestedBlockHashesPerWindow.at(windowSize).size() > 0 + && requestedBlockHashesPerWindow.at(windowSize).size() < blockIdsPerWindow.at(windowSize).size()) + ? (blockIdsPerWindow.at(windowSize).size() - requestedBlockHashesPerWindow.at(windowSize).size()) + : 0; + auto windowStartBlockIdx = needSendAllForWindow + ? 0 + : static_cast(blockIdsPerWindow.at(windowSize).size()) + - (windowSize / cacheManager->getBlockManager().getTokensPerBlock() + 1); + // TODO: promptLen to get the startBlockIdx + SizeType32 startBlockIdx = std::max(0, std::max(reuseStartBlockIdx, windowStartBlockIdx)); + TLLM_LOG_DEBUG( + "getBlockRangeForSending windowSize: %d, startBlockIdx: %d reuseStartBlockIdx: %d windowStartBlockIdx: %d", + windowSize, startBlockIdx, reuseStartBlockIdx, windowStartBlockIdx); + blockRange.setBlockIdsForWindow(windowSize, + std::vector( + blockIdsPerWindow.at(windowSize).begin() + startBlockIdx, blockIdsPerWindow.at(windowSize).end())); } + return blockRange; } BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest) { + auto const& windowsMetadata = cacheManager->getBlockManager().getWindowSizesMetadata(); + if (windowsMetadata.size() == 1 || common::getEnvKVCacheTransferAllBlocksForWindow()) + { + if (common::getEnvDisableSelectiveCacheTransfer()) + { + return BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId); + } + TLLM_LOG_DEBUG("getBlockRangeForReceiving fromNewlyAllocatedBlockIds from newly allocated block ids"); + return BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId); + } + bool cacheReuse = !common::getEnvDisableSelectiveCacheTransfer(); + auto blockRange = cacheReuse ? BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId) + : BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId); - auto poolNum = cacheManager->getBlockManager().getNumPools(); - if (poolNum > 1 || common::getEnvDisableSelectiveCacheTransfer()) + for (auto const& [windowSize, metadata] : windowsMetadata) { - constexpr SizeType32 beam{0}; - return BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId, beam); + auto const& blockIdsPerWindow = blockRange.getBlockIdsPerWindow(); + auto windowStartBlockIdx = static_cast(blockIdsPerWindow.at(windowSize).size()) + - (windowSize / cacheManager->getBlockManager().getTokensPerBlock() + 1); + SizeType32 startBlockIdx = std::max(0, windowStartBlockIdx); + blockRange.setBlockIdsForWindow(windowSize, + std::vector( + blockIdsPerWindow.at(windowSize).begin() + startBlockIdx, blockIdsPerWindow.at(windowSize).end())); } - return BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId); + return blockRange; } bool CacheFormatter::needSendCache( @@ -184,20 +222,24 @@ void CacheFormatter::format(TransferSession& session) { progress->wait(layerIdx); } - blockRange.updatePoolIdx(poolIdx); - for (auto it = blockRange.begin(); it != blockRange.end(); ++it) + auto const& windowSizes = blockRange.getWindowSizes(); + for (auto const& windowSize : windowSizes) { - // Block dim: [1, numLayersInPool, ...], offset = {0, layerIndexInPool} - auto layer = runtime::ITensor::slice(it, offset, 1); - if (offset.d[1] == 0) + auto blockRangeForWindow = blockRange.getBlockRangeForWindow(windowSize); + for (auto it = blockRangeForWindow.begin(); it != blockRangeForWindow.end(); ++it) { - TLLM_LOG_DEBUG("Block %p of pool %d shape = %s", it->data(), poolIdx, - runtime::ITensor::toString(it->getShape()).c_str()); - } - for (size_t i = 0; i < connections.size(); i++) - { - TLLM_LOG_DEBUG("Send layer %d(%d-%d)", layerIdx, poolIdx, layerIdxInPool); - session.send(i, layer->data(), layer->getSizeInBytes()); + // Block dim: [1, numLayersInPool, ...], offset = {0, layerIndexInPool} + auto layer = runtime::ITensor::slice(it, offset, 1); + if (offset.d[1] == 0) + { + TLLM_LOG_DEBUG("Block %p of pool %d shape = %s", it->data(), poolIdx, + runtime::ITensor::toString(it->getShape()).c_str()); + } + for (size_t i = 0; i < connections.size(); i++) + { + TLLM_LOG_DEBUG("Send layer %d(%d-%d)", layerIdx, poolIdx, layerIdxInPool); + session.send(i, layer->data(), layer->getSizeInBytes()); + } } } } @@ -207,29 +249,29 @@ void CacheFormatter::format(TransferSession& session) int blockNum = 0; size_t allCacheBlockSize = 0; + auto const& windowSizes = blockRange.getWindowSizes(); + TLLM_LOG_DEBUG( + mpi::MpiComm::world().getRank(), " blockRange.getWindowSizes(); windowSizes size: %d", windowSizes.size()); + TLLM_CHECK_WITH_INFO( + static_cast(windowSizes.size()) == numPools, "window sizes should be the same as numPools"); std::map> inputKvCacheBlocks; - for (auto poolIdx = 0; poolIdx < numPools; poolIdx++) + + for (auto const& windowSize : windowSizes) { - blockRange.updatePoolIdx(poolIdx); - SizeType32 window = mCacheManager->getBlockManager().getPoolWindowSize(poolIdx); - TLLM_CHECK_WITH_INFO(inputKvCacheBlocks.find(window) == inputKvCacheBlocks.end(), - "window size already exists, which is not supported"); - inputKvCacheBlocks.emplace(window, std::vector()); - auto maxBlockThisWindow = window / selfConfig.getModelConfig().mTokensPerBlock; - SizeType32 blockNumThisWindow = 0; - for (auto it = blockRange.begin(); it != blockRange.end(); ++it) + auto blockRangeForWindow = blockRange.getBlockRangeForWindow(windowSize); + TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), " format windowSize: %d blockRangeForWindow size: %d", + windowSize, blockRangeForWindow.size()); + inputKvCacheBlocks.emplace(windowSize, std::vector()); + for (auto it = blockRangeForWindow.begin(); it != blockRangeForWindow.end(); ++it) { - blockNum++; - inputKvCacheBlocks.at(window).push_back(it); + inputKvCacheBlocks.at(windowSize).push_back(it); allCacheBlockSize += it->getSize(); - blockNumThisWindow++; - if (blockNumThisWindow >= maxBlockThisWindow) - { - break; - } + blockNum++; } } + TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), "inputKvCacheBlocks size: %ld,blockNum: %d , windowSizes: %ld", + inputKvCacheBlocks.size(), blockNum, windowSizes.size()); if (inputKvCacheBlocks.size() > 1) { @@ -438,27 +480,25 @@ void CacheFormatter::unformat(TransferSession& session) // TODO(oargov): are we sure the other side has the same number of pools? this might not hold for pp_size>1... size_t blockNum = 0; size_t cacheBlockSizeSum = 0; - for (auto poolIdx = 0; poolIdx < numPools; poolIdx++) + + auto windowSizes = blockRange.getWindowSizes(); + TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), " unformat windowSizes size: %d", windowSizes.size()); + for (auto const& windowSize : windowSizes) { - blockRange.updatePoolIdx(poolIdx); - SizeType32 window = mCacheManager->getBlockManager().getPoolWindowSize(poolIdx); - TLLM_CHECK_WITH_INFO(outputBuffersPerWindow.find(window) == outputBuffersPerWindow.end(), - "window size already exists, which is not supported"); - outputBuffersPerWindow.emplace(window, std::vector()); - auto maxBlockThisWindow = window / selfConfig.getModelConfig().mTokensPerBlock; - SizeType32 blockNumThisWindow = 0; - for (auto it = blockRange.begin(); it != blockRange.end(); ++it) + auto blockRangeForWindow = blockRange.getBlockRangeForWindow(windowSize); + TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), " unformat windowSize: %d blockRangeForWindow size: %d", + windowSize, blockRangeForWindow.size()); + outputBuffersPerWindow.emplace(windowSize, std::vector()); + + for (auto it = blockRangeForWindow.begin(); it != blockRangeForWindow.end(); ++it) { - blockNum++; - blockNumThisWindow++; - outputBuffersPerWindow.at(window).push_back(it); + outputBuffersPerWindow.at(windowSize).push_back(it); cacheBlockSizeSum += it->getSize(); - if (blockNumThisWindow >= maxBlockThisWindow) - { - break; - } + blockNum++; } } + TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), "outputBuffersPerWindow size: %ld,blockNum: %d , windowSizes: %ld", + outputBuffersPerWindow.size(), blockNum, windowSizes.size()); TLLM_CHECK(!outputBuffersPerWindow.empty()); if (outputBuffersPerWindow.size() > 1) { @@ -502,8 +542,10 @@ void CacheFormatter::unformat(TransferSession& session) auto const poolIdx = 0; auto const layerIdxInPool = layerIdx; int idx = 0; - blockRange.updatePoolIdx(poolIdx); - for (auto it = blockRange.begin(); it != blockRange.end(); ++it) + // blockRange.updatePoolIdx(poolIdx); + auto const window = mCacheManager->getBlockManager().getPoolLayerIdx(layerIdx); + auto blockRangeForWindow = blockRange.getBlockRangeForWindow(window); + for (auto it = blockRangeForWindow.begin(); it != blockRangeForWindow.end(); ++it) { if (layerIdxInPool == 0) { diff --git a/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp b/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp index 1a3aed54f41..693a606fff1 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp @@ -210,7 +210,13 @@ CacheTransBufferManager::CacheTransBufferManager( { auto poolIdx = mCacheManager->getBlockManager().getLayerPoolIdx(layerId); auto windowSize = static_cast(mCacheManager->getBlockManager().getPoolWindowSize(poolIdx)); - auto validTokenNum = (windowSize < maxNumTokens.value() ? windowSize : maxNumTokens.value()); + auto validTokenNum + = (windowSize < maxNumTokens.value() ? (windowSize + tokensPerBlock) : maxNumTokens.value()); + if (common::getEnvKVCacheTransferAllBlocksForWindow()) + { + validTokenNum = maxNumTokens.value(); + } + bufferSizeFromMaxNumToken += validTokenNum * kvCacheByteSizePerTokenPerLayer; } } @@ -238,7 +244,7 @@ CacheTransBufferManager::CacheTransBufferManager( } size_t CacheTransBufferManager::preAllocBufferSize( - std::map const& cacheSizeBytesPerTokenPerWindow, + std::map const& cacheSizeBytesPerTokenPerWindow, SizeType32 tokensPerBlock, std::optional const& cacheTransceiverConfig) { if (!cacheTransceiverConfig.has_value()) @@ -256,9 +262,13 @@ size_t CacheTransBufferManager::preAllocBufferSize( TransferBufferSize = 0; for (auto const& [windowSize, cacheSizeBytesPerToken] : cacheSizeBytesPerTokenPerWindow) { - auto validTokenNum - = (static_cast(windowSize) < maxNumTokens.value() ? static_cast(windowSize) - : maxNumTokens.value()); + auto validTokenNum = (static_cast(windowSize) < maxNumTokens.value() + ? static_cast(windowSize) + tokensPerBlock + : maxNumTokens.value()); + if (common::getEnvKVCacheTransferAllBlocksForWindow()) + { + validTokenNum = maxNumTokens.value(); + } TransferBufferSize += validTokenNum * cacheSizeBytesPerToken; } } diff --git a/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h b/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h index e7b050388fe..f1324a1cb01 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h +++ b/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h @@ -61,6 +61,7 @@ class CacheTransBufferManager KVCacheManager::BaseKVCacheManager* cacheManager, std::optional maxNumTokens = std::nullopt); static size_t preAllocBufferSize(std::map const& cacheSizeBytesPerTokenPerWindow, + SizeType32 tokensPerBlock, std::optional const& cacheTransceiverConfig = std::nullopt); std::optional assignBufferIndexForSend(); diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp index 522ec80f84a..189d099783d 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp @@ -23,6 +23,7 @@ #include "tensorrt_llm/common/envUtils.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/utils.h" +#include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" #include #include @@ -41,17 +42,19 @@ RequestInfo::RequestInfo(LlmRequest::RequestIdType requestId, executor::DataTran { } -RequestInfo::RequestInfo( - LlmRequest::RequestIdType requestId, std::vector blockHashes, executor::DataTransceiverState transState) +RequestInfo::RequestInfo(LlmRequest::RequestIdType requestId, + std::unordered_map>&& blockHashesPerWindow, + executor::DataTransceiverState transState) : mRequestId{requestId} - , mBlockHashes{std::move(blockHashes)} + , mBlockHashesPerWindow{std::move(blockHashesPerWindow)} , mTransState{std::move(transState)} { } bool RequestInfo::operator==(RequestInfo const& rhs) const { - return mRequestId == rhs.mRequestId && mBlockHashes == rhs.mBlockHashes && mTransState == rhs.mTransState; + return mRequestId == rhs.mRequestId && mBlockHashesPerWindow == rhs.mBlockHashesPerWindow + && mTransState == rhs.mTransState; } LlmRequest::RequestIdType RequestInfo::getRequestId() const noexcept @@ -68,7 +71,13 @@ void RequestInfo::serialize(RequestInfo const& requestInfo, std::ostream& os) { namespace su = executor::serialize_utils; su::serialize(requestInfo.mRequestId, os); - su::serialize(requestInfo.mBlockHashes, os); + + su::serialize(requestInfo.mBlockHashesPerWindow.size(), os); + for (auto const& [windowSize, blockHashes] : requestInfo.mBlockHashesPerWindow) + { + su::serialize(windowSize, os); + su::serialize(blockHashes, os); + } su::serialize(requestInfo.mTransState, os); } @@ -76,9 +85,16 @@ RequestInfo RequestInfo::deserialize(std::istream& is) { namespace su = executor::serialize_utils; auto requestId = su::deserialize(is); - auto blockHashes = su::deserialize(is); + std::unordered_map> blockHashesPerWindow; + auto size = su::deserialize(is); + for (size_t i = 0; i < size; i++) + { + auto windowSize = su::deserialize(is); + std::vector blockHashes = su::deserialize(is); + blockHashesPerWindow.emplace(windowSize, std::move(blockHashes)); + } auto transState = su::deserialize(is); - return RequestInfo{requestId, std::move(blockHashes), std::move(transState)}; + return RequestInfo{requestId, std::move(blockHashesPerWindow), std::move(transState)}; } std::size_t RequestInfo::serializedSize(RequestInfo const& requestInfo) @@ -86,7 +102,12 @@ std::size_t RequestInfo::serializedSize(RequestInfo const& requestInfo) namespace su = executor::serialize_utils; std::size_t totalSize = 0; totalSize += su::serializedSize(requestInfo.mRequestId); - totalSize += su::serializedSize(requestInfo.mBlockHashes); + totalSize += su::serializedSize(requestInfo.mBlockHashesPerWindow.size()); + for (auto const& [windowSize, blockHashes] : requestInfo.mBlockHashesPerWindow) + { + totalSize += su::serializedSize(windowSize); + totalSize += su::serializedSize(blockHashes); + } totalSize += su::serializedSize(requestInfo.mTransState); return totalSize; } @@ -214,12 +235,12 @@ class DataResponder::Impl { break; } - std::vector blockHashes; + std::unordered_map> blockHashesPerWindow; if (!isSending() && !mReadyResponses.empty()) { auto const& requestInfo = mSender->recvRequestInfo(); auto reqId = requestInfo.getRequestId(); - blockHashes = requestInfo.getBlockHashes(); + blockHashesPerWindow = requestInfo.getBlockHashesPerWindow(); mCurrentRequest = reqId; if (mRemainSendCount.find(reqId) == mRemainSendCount.end()) @@ -239,7 +260,7 @@ class DataResponder::Impl // TODO(zhengd): pass the hashes directly instead of update llmRequest auto llmRequest = it->second.mRequest; - llmRequest->setRequestedBlockHashes(std::move(blockHashes)); + llmRequest->setRequestedBlockHashes(std::move(blockHashesPerWindow)); if (common::getEnvParallelCacheSend()) { diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.h b/cpp/tensorrt_llm/batch_manager/dataTransceiver.h index ef66cd1382d..51ae17a56fb 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.h +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.h @@ -21,6 +21,7 @@ #include #include +#include "tensorrt_llm/batch_manager/cacheTransceiver.h" #include "tensorrt_llm/batch_manager/llmRequest.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/envUtils.h" @@ -49,7 +50,8 @@ class RequestInfo /// @param transState The state of the data transceiver. RequestInfo(LlmRequest::RequestIdType requestId, executor::DataTransceiverState transState); - RequestInfo(LlmRequest::RequestIdType requestId, std::vector blockHashes, + RequestInfo(LlmRequest::RequestIdType requestId, + std::unordered_map>&& blockHashesPerWindow, executor::DataTransceiverState transState); RequestInfo() = default; @@ -61,9 +63,9 @@ class RequestInfo /// @return The request ID. [[nodiscard]] LlmRequest::RequestIdType getRequestId() const noexcept; - [[nodiscard]] std::vector const& getBlockHashes() const noexcept + [[nodiscard]] std::unordered_map> const& getBlockHashesPerWindow() const noexcept { - return mBlockHashes; + return mBlockHashesPerWindow; } /// @brief Return the state of the data transceiver. @@ -88,7 +90,7 @@ class RequestInfo // The ID used in the context phase of the current request. LlmRequest::RequestIdType mRequestId; - std::vector mBlockHashes; + std::unordered_map> mBlockHashesPerWindow; // The state of the data transceiver. executor::DataTransceiverState mTransState; diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp index 1a5c7fab4dd..2aae7ceab2c 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp @@ -178,14 +178,13 @@ TransferSession DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest) RequestInfo requestInfo(requestId, mSelfState); - auto disableSelectiveCacheTransfer = common::getEnvDisableSelectiveCacheTransfer() - || (mFormatter->getCacheManager()->getBlockManager().getNumPools() > 1); + auto disableSelectiveCacheTransfer = common::getEnvDisableSelectiveCacheTransfer(); if (!disableSelectiveCacheTransfer) { auto* cacheManager = mFormatter->getCacheManager(); auto blockRange = kv_cache_manager::BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId); - requestInfo = RequestInfo(requestId, blockRange.getBlockHashes(), mSelfState); + requestInfo = RequestInfo(requestId, blockRange.getBlockHashesPerWindow(), mSelfState); } auto* agentConnectionManager = dynamic_cast(mManager); diff --git a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp index 22756f25527..ac2856fc045 100644 --- a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp @@ -105,23 +105,26 @@ void MLACacheFormatter::format(TransferSession& session) // diff end - auto const numPools = mCacheManager->getBlockManager().getNumPools(); - auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest); - auto lastTokenTime = llmRequest.getPerfMetrics().timingMetrics.lastTokenTime; bool recordDelay = lastTokenTime != std::chrono::steady_clock::time_point(); int blockNum = 0; std::vector inputKvCacheBlocks; - for (auto poolIdx = 0; poolIdx < numPools; poolIdx++) + auto const numPools = mCacheManager->getBlockManager().getNumPools(); + auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest); + auto const& windowSizes = blockRange.getWindowSizes(); + TLLM_CHECK_WITH_INFO( + static_cast(windowSizes.size()) == numPools, "window sizes should be the same as numPools"); + for (auto const& windowSize : windowSizes) { - blockRange.updatePoolIdx(poolIdx); - for (auto it = blockRange.begin(); it != blockRange.end(); ++it) + auto blockRangeForWindow = blockRange.getBlockRangeForWindow(windowSize); + for (auto it = blockRangeForWindow.begin(); it != blockRangeForWindow.end(); ++it) { - blockNum++; inputKvCacheBlocks.push_back(it); + blockNum++; } } + TLLM_CHECK(blockNum > 0); int deviceId = mCacheManager->getBlockManager().getStreamDevice(); @@ -307,15 +310,18 @@ void MLACacheFormatter::unformat(TransferSession& session) std::vector recvBufferTmps; std::vector outputBuffers; auto const numPools = mCacheManager->getBlockManager().getNumPools(); + auto const& windowSizes = blockRange.getWindowSizes(); + TLLM_CHECK_WITH_INFO( + static_cast(windowSizes.size()) == numPools, "window sizes should be the same as numPools"); // TODO(oargov): are we sure the other side has the same number of pools? this might not hold for pp_size>1... size_t blockNum = 0; - for (auto poolIdx = 0; poolIdx < numPools; poolIdx++) + for (auto const& windowSize : windowSizes) { - blockRange.updatePoolIdx(poolIdx); - for (auto it = blockRange.begin(); it != blockRange.end(); ++it) + auto blockRangeForWindow = blockRange.getBlockRangeForWindow(windowSize); + for (auto it = blockRangeForWindow.begin(); it != blockRangeForWindow.end(); ++it) { - blockNum++; outputBuffers.push_back(it); + blockNum++; } } diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp index 08cb4d407c1..9091dd33426 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp @@ -292,7 +292,7 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptrgetBufferManager(), kvCacheConfig); diff --git a/cpp/tensorrt_llm/common/envUtils.cpp b/cpp/tensorrt_llm/common/envUtils.cpp index 59c9d2fffe4..9580407f340 100644 --- a/cpp/tensorrt_llm/common/envUtils.cpp +++ b/cpp/tensorrt_llm/common/envUtils.cpp @@ -446,6 +446,12 @@ size_t getEnvMemSizeForKVCacheTransferBuffer() return memSizeForKVCacheTransferBuffer; } +bool getEnvKVCacheTransferAllBlocksForWindow() +{ + static bool const allBlocksForWindow = getBoolEnv("TRTLLM_KVCACHE_TRANSFER_ALL_BLOCKS_FOR_WINDOW"); + return allBlocksForWindow; +} + uint16_t getEnvNixlPort() { static uint16_t const nixlPort = getUInt64Env("TRTLLM_NIXL_PORT").value_or(0); diff --git a/cpp/tensorrt_llm/common/envUtils.h b/cpp/tensorrt_llm/common/envUtils.h index f5c0d854ba4..614b72021ab 100644 --- a/cpp/tensorrt_llm/common/envUtils.h +++ b/cpp/tensorrt_llm/common/envUtils.h @@ -116,4 +116,6 @@ bool getEnvDisaggBenchmarkGenOnly(); // Whether to disable the chunked-attention in the generation phase. bool getEnvDisableChunkedAttentionInGenPhase(); +bool getEnvKVCacheTransferAllBlocksForWindow(); + } // namespace tensorrt_llm::common diff --git a/cpp/tensorrt_llm/pybind/batch_manager/cacheTransceiver.cpp b/cpp/tensorrt_llm/pybind/batch_manager/cacheTransceiver.cpp index d92336e6bdf..cd103937b47 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/cacheTransceiver.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/cacheTransceiver.cpp @@ -97,5 +97,6 @@ void tb::CacheTransceiverBindings::initBindings(py::module_& m) .def(py::init>(), py::arg("cache_manager"), py::arg("max_num_tokens") = std::nullopt) .def_static("pre_alloc_buffer_size", &tb::kv_cache_manager::CacheTransBufferManager::preAllocBufferSize, - py::arg("cache_size_bytes_per_token_per_window"), py::arg("cache_transceiver_config") = py::none()); + py::arg("cache_size_bytes_per_token_per_window"), py::arg("tokens_per_block"), + py::arg("cache_transceiver_config") = py::none()); } diff --git a/cpp/tests/batch_manager/cacheTransceiverTest.cpp b/cpp/tests/batch_manager/cacheTransceiverTest.cpp index 99c40f810f6..20514ecb823 100644 --- a/cpp/tests/batch_manager/cacheTransceiverTest.cpp +++ b/cpp/tests/batch_manager/cacheTransceiverTest.cpp @@ -426,10 +426,15 @@ class SymmetricalCacheTest : public ::testing::Test // NOLINT(cppcoreguidelines- if (isSender) { auto blockRange = BlockRange::fromAllBlockIds(*mManager, llmRequest->mRequestId); - for (auto& block : blockRange) + auto const& windowSizes = blockRange.getWindowSizes(); + for (auto const& windowSize : windowSizes) { - // fill cache with tokens (= request length), for reuse test - TLLM_CUDA_CHECK(cudaMemset(block.data(), llmRequest->getPromptLen(), block.getSizeInBytes())); + auto blockRangeForWindow = blockRange.getBlockRangeForWindow(windowSize); + for (auto it = blockRangeForWindow.begin(); it != blockRangeForWindow.end(); ++it) + { + // fill cache with tokens (= request length), for reuse test + TLLM_CUDA_CHECK(cudaMemset(it->data(), llmRequest->getPromptLen(), it->getSizeInBytes())); + } } mFutures.emplace_back(mResponder->respondAndSendAsync(*llmRequest)); } @@ -439,12 +444,17 @@ class SymmetricalCacheTest : public ::testing::Test // NOLINT(cppcoreguidelines- future.get(); TLLM_CUDA_CHECK(cudaDeviceSynchronize()); auto blockRange = BlockRange::fromAllBlockIds(*mManager, llmRequest->mRequestId); - for (auto& block : blockRange) + auto const& windowSizes = blockRange.getWindowSizes(); + for (auto const& windowSize : windowSizes) { - std::vector bytes(block.getSizeInBytes()); - TLLM_CUDA_CHECK(cudaMemcpy(bytes.data(), block.data(), block.getSizeInBytes(), cudaMemcpyDeviceToHost)); - EXPECT_TRUE(std::all_of(bytes.begin(), bytes.end(), - [&llmRequest](uint8_t i) { return i == llmRequest->getPromptLen() & 0xff; })); + auto blockRangeForWindow = blockRange.getBlockRangeForWindow(windowSize); + for (auto it = blockRangeForWindow.begin(); it != blockRangeForWindow.end(); ++it) + { + std::vector bytes(it->getSizeInBytes()); + TLLM_CUDA_CHECK(cudaMemcpy(bytes.data(), it->data(), it->getSizeInBytes(), cudaMemcpyDeviceToHost)); + EXPECT_TRUE(std::all_of(bytes.begin(), bytes.end(), + [&llmRequest](uint8_t i) { return i == llmRequest->getPromptLen() & 0xff; })); + } } } } @@ -636,7 +646,7 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam(); auto maxNumTokens = tokensPerBlock * maxBlocksPerSeq; @@ -875,20 +885,21 @@ class AsymmetricalCacheTest : public ::testing::TestWithParamaddSequence(llmRequest->mRequestId, llmRequest->getNumTokens(beamIdx), beamWidth, llmRequest); auto blockRange = BlockRange::fromAllBlockIds(*mManager, llmRequest->mRequestId); - int blockIdx = 0; int const numPools = mManager->getBlockManager().getNumPools(); TLLM_LOG_DEBUG(" addRequestAndTransportCacheForContext mManager numPools: %d", numPools); - for (int poolIdx = 0; poolIdx < numPools; poolIdx++) + auto const& windowSizes = blockRange.getWindowSizes(); + int blockIdx = 0; + for (auto const& windowSize : windowSizes) { - blockRange.updatePoolIdx(poolIdx); - TLLM_LOG_DEBUG("update poolIdx: %d", poolIdx); - for (auto& block : blockRange) + auto blockRangeForWindow = blockRange.getBlockRangeForWindow(windowSize); + TLLM_LOG_DEBUG("update windowSize: %d", windowSize); + for (auto it = blockRangeForWindow.begin(); it != blockRangeForWindow.end(); ++it) { - fillBlockData(block, blockIdx, llmRequest->getPromptLen(), poolIdx); + fillBlockData(*it, blockIdx, llmRequest->getPromptLen(), windowSize); blockIdx++; } - TLLM_LOG_DEBUG("blockPoolIdx: %d finish fill block data", poolIdx); + TLLM_LOG_DEBUG("windowSize: %d finish fill block data", windowSize); } TLLM_LOG_DEBUG( @@ -920,22 +931,29 @@ class AsymmetricalCacheTest : public ::testing::TestWithParammRequestId); - auto const numPools = mManager->getBlockManager().getNumPools(); - for (int poolIdx = 0; poolIdx < numPools; poolIdx++) + auto const& windowSizes = blockRange.getWindowSizes(); + for (auto const& windowSize : windowSizes) { - blockRange.updatePoolIdx(poolIdx); - for (auto& block : blockRange) + auto blockRangeForWindow = blockRange.getBlockRangeForWindow(windowSize); + int maxBlockInWindow = windowSize / mCacheState->getModelConfig().mTokensPerBlock; + int startBlockId = std::max(0, static_cast(blockRangeForWindow.size()) - (maxBlockInWindow + 1)); + int blockIdInWindow = 0; + for (auto it = blockRangeForWindow.begin(); it != blockRangeForWindow.end(); ++it) { - verifyBlockData(block, blockIdx, llmRequest->getPromptLen(), poolIdx); + if (blockIdInWindow >= startBlockId) + { + verifyBlockData(*it, blockIdx, llmRequest->getPromptLen(), windowSize); + } blockIdx++; + blockIdInWindow++; } } } - void fillBlockData(tensorrt_llm::runtime::ITensor& blockData, int blockId, size_t initial, int blockPoolIdx = 0) + void fillBlockData(tensorrt_llm::runtime::ITensor& blockData, int blockId, size_t initial, int windowSize = 0) { auto const& blockManager = mManager->getBlockManager(); - auto const onlyWindowSize = blockManager.getPoolWindowSize(blockPoolIdx); + auto const onlyWindowSize = windowSize == 0 ? blockManager.getPoolWindowSize(0) : windowSize; auto const& bufferManager = blockManager.getBufferManager(onlyWindowSize); auto hostTensor = tensorrt_llm::runtime::BufferManager::cpu(blockData.getShape(), blockData.getDataType()); int layerSizePerRank = blockData.getDimension<1>(); @@ -972,7 +990,7 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam(hostTensor->data(keyIndex)); *dataPtr = generateValue; }, - generateExpectedValue(initial, blockPoolIdx, tokenId + startTokenId, layerId + startLayerId, + generateExpectedValue(initial, windowSize, tokenId + startTokenId, layerId + startLayerId, headId + startHeadId, hiddenId, true, blockData.getDataType())); if (kvFactor == 2) { @@ -983,7 +1001,7 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam(hostTensor->data(valueIndex)); *dataPtr = generateValue; }, - generateExpectedValue(initial, blockPoolIdx, tokenId + startTokenId, + generateExpectedValue(initial, windowSize, tokenId + startTokenId, layerId + startLayerId, headId + startHeadId, hiddenId, false, blockData.getDataType())); } @@ -995,10 +1013,11 @@ class AsymmetricalCacheTest : public ::testing::TestWithParamgetBlockManager(); - auto const onlyWindowSize = blockManager.getPoolWindowSize(blockPoolIdx); + + auto const onlyWindowSize = windowSize == 0 ? blockManager.getPoolWindowSize(0) : windowSize; auto const& bufferManager = blockManager.getBufferManager(onlyWindowSize); auto hostTensor = tensorrt_llm::runtime::BufferManager::cpu(blockData.getShape(), blockData.getDataType()); @@ -1039,7 +1058,7 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam(hostTensor->data(keyIndex)); EXPECT_EQ(*dataPtr, generateValue); }, - generateExpectedValue(initial, blockPoolIdx, tokenId + startTokenId, layerId + startLayerId, + generateExpectedValue(initial, windowSize, tokenId + startTokenId, layerId + startLayerId, headId + startHeadId, hiddenId, true, blockData.getDataType())); if (kvFactor == 2) { @@ -1050,7 +1069,7 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam(hostTensor->data(valueIndex)); EXPECT_EQ(*dataPtr, generateValue); }, - generateExpectedValue(initial, blockPoolIdx, tokenId + startTokenId, + generateExpectedValue(initial, windowSize, tokenId + startTokenId, layerId + startLayerId, headId + startHeadId, hiddenId, false, blockData.getDataType())); } @@ -1060,7 +1079,7 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam generateExpectedValue(size_t initial, int blockPoolIdx, int tokenId, + std::variant generateExpectedValue(size_t initial, int windowSize, int tokenId, int layerId, int headId, int hiddenId, bool key, nvinfer1::DataType dataType) { @@ -1068,7 +1087,7 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam{}(initial); std::hash hasher{}; seed ^= hashValue + 0x9e3779b9 + (seed << 6) + (seed >> 2); - seed ^= hasher(blockPoolIdx) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + seed ^= hasher(windowSize) + 0x9e3779b9 + (seed << 6) + (seed >> 2); seed ^= hasher(tokenId) + 0x9e3779b9 + (seed << 6) + (seed >> 2); seed ^= hasher(layerId) + 0x9e3779b9 + (seed << 6) + (seed >> 2); seed ^= hasher(headId) + 0x9e3779b9 + (seed << 6) + (seed >> 2); @@ -1245,7 +1264,7 @@ TEST_P(AsymmetricalCacheTestWithDP, TestCase) setUpCacheTransceiver(); std::vector> requests; int requestId = 0; - for (auto len : {30, 10, 60, 30, 60, 10}) + for (auto len : {60, 30, 60, 10}) { requests.emplace_back(makeLlmRequestWithDP(len, requestId, requestId % contextTp)); requestId++; diff --git a/cpp/tests/unit_tests/batch_manager/cacheTransBufferTest.cpp b/cpp/tests/unit_tests/batch_manager/cacheTransBufferTest.cpp index 27e1590e6a2..6c7d28b46ce 100644 --- a/cpp/tests/unit_tests/batch_manager/cacheTransBufferTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/cacheTransBufferTest.cpp @@ -116,8 +116,8 @@ TEST_F(CacheTransBufferTest, TestPreAllocBufferSize) {maxBlocksPerSeq * tokensPerBlock, cacheSizeBytesPerToken}}; tensorrt_llm::executor::CacheTransceiverConfig cacheTransceiverConfig{ tensorrt_llm::executor::CacheTransceiverConfig::BackendType::UCX, maxNumTokens}; - size_t bufferSizeBytes - = CacheTransBufferManager::preAllocBufferSize(cacheSizeBytesPerTokenPerWindow, cacheTransceiverConfig); + size_t bufferSizeBytes = CacheTransBufferManager::preAllocBufferSize( + cacheSizeBytesPerTokenPerWindow, tokensPerBlock, cacheTransceiverConfig); auto bufferId = mTransBufferManager->assignBufferIndexForSend(); EXPECT_TRUE(bufferId.has_value()); EXPECT_EQ(bufferId.value(), 0); @@ -160,8 +160,8 @@ TEST_F(CacheTransBufferTest, TestPreAllocBufferSize2) tensorrt_llm::executor::CacheTransceiverConfig::BackendType::UCX, maxNumTokens}; std::map cacheSizeBytesPerTokenPerWindow{ {maxBlocksPerSeq * tokensPerBlock, cacheSizeBytesPerToken}}; - size_t bufferSizeBytes - = CacheTransBufferManager::preAllocBufferSize(cacheSizeBytesPerTokenPerWindow, cacheTransceiverConfig); + size_t bufferSizeBytes = CacheTransBufferManager::preAllocBufferSize( + cacheSizeBytesPerTokenPerWindow, tokensPerBlock, cacheTransceiverConfig); auto bufferId = mTransBufferManager->assignBufferIndexForSend(); EXPECT_TRUE(bufferId.has_value()); EXPECT_EQ(bufferId.value(), 0); diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheUtilsTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheUtilsTest.cpp index c4de5b6a8c6..99962fcaf4f 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheUtilsTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheUtilsTest.cpp @@ -52,7 +52,7 @@ TEST_F(BlockIteratorTest, BasicTest) auto blockTensor = tr::ITensor::slice(pool, blockIds.at(idx), 1); std::fill_n(tr::bufferCast(*blockTensor), blockTensor->getSize(), idx); } - auto range = BlockRange(pool, blockIds); + auto range = BlockRangeForWindow(std::move(blockIds), std::move(pool)); auto begin = range.begin(); auto end = range.end(); auto allEqualTo = [](tr::ITensor const& tensor, auto x) -> bool @@ -124,7 +124,9 @@ TEST_F(BlockIteratorTest, CacheManagerTest) auto const pool = blockManager.getPrimaryPool(0); TLLM_CHECK(pool); - auto range = BlockRange(pool, blockIds); + auto blockIdsVec = std::vector(blockIds.begin(), blockIds.end()); + auto poolCopy = pool; + auto range = BlockRangeForWindow(std::move(blockIdsVec), std::move(poolCopy)); size_t cnt{0}; for (auto iter = range.begin(); iter != range.end(); ++iter, ++cnt) { diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index 63509cd6984..a5a34afbc6b 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -673,9 +673,6 @@ class TestGemma3_1BInstruct(LlmapiAccuracyTestHarness): @pytest.mark.parametrize("overlap_scheduler", [False, True]) def test_auto_dtype(self, overlap_scheduler): - pytest.skip( - "Currently we require full kvcache for variable sliding window. " - "This test only transfers the kvcache inside the sliding window.") ctx_server_config = { "disable_overlap_scheduler": True, From 39306d6b3b3c563ed27c45600ac5b052a631f4a5 Mon Sep 17 00:00:00 2001 From: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> Date: Tue, 19 Aug 2025 02:32:14 +0000 Subject: [PATCH 2/3] fix nanobind Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> --- cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp index 8a7f73f3b06..3a5ff6c0270 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp @@ -100,5 +100,6 @@ void tb::CacheTransceiverBindings::initBindings(nb::module_& m) .def(nb::init>(), nb::arg("cache_manager"), nb::arg("max_num_tokens") = std::nullopt) .def_static("pre_alloc_buffer_size", &tb::kv_cache_manager::CacheTransBufferManager::preAllocBufferSize, - nb::arg("cache_size_bytes_per_token_per_window"), nb::arg("cache_transceiver_config") = nb::none()); + nb::arg("cache_size_bytes_per_token_per_window"), nb::arg("tokens_per_block"), + nb::arg("cache_transceiver_config") = nb::none()); } From 78b2de5a4ad5ee8ef10534eb0e420a323da6de6d Mon Sep 17 00:00:00 2001 From: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> Date: Mon, 25 Aug 2025 06:55:12 +0000 Subject: [PATCH 3/3] onewindow fix Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> --- cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp index 3320c013dd0..3b3de83da7c 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @@ -57,12 +57,19 @@ BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest bool needReuse = !common::getEnvDisableSelectiveCacheTransfer(); auto const& requestedBlockHashesPerWindow = llmRequest.getRequestedBlockHashesPerWindow(); + + // if only one window, the context and gen may have different window size, which may be specified by the seq_len; + // so the requested window may be different from the window in metaData. + + bool const onlyOneWindow = requestedBlockHashesPerWindow.size() == 1; + for (auto const& [windowSize, metadata] : windowsMetadata) { + SizeType32 requestedWindow = onlyOneWindow ? requestedBlockHashesPerWindow.begin()->first : windowSize; SizeType32 reuseStartBlockIdx - = (needReuse && requestedBlockHashesPerWindow.at(windowSize).size() > 0 - && requestedBlockHashesPerWindow.at(windowSize).size() < blockIdsPerWindow.at(windowSize).size()) - ? (blockIdsPerWindow.at(windowSize).size() - requestedBlockHashesPerWindow.at(windowSize).size()) + = (needReuse && requestedBlockHashesPerWindow.at(requestedWindow).size() > 0 + && requestedBlockHashesPerWindow.at(requestedWindow).size() < blockIdsPerWindow.at(windowSize).size()) + ? (blockIdsPerWindow.at(windowSize).size() - requestedBlockHashesPerWindow.at(requestedWindow).size()) : 0; auto windowStartBlockIdx = needSendAllForWindow ? 0