@@ -504,7 +504,8 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
504504 std::optional<TempAttentionWindowInputs> const & tempAttentionWindowInputs, nvinfer1::DataType dtype,
505505 SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType,
506506 std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
507- std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
507+ std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
508+ std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
508509 : mNumLayers {static_cast <SizeType32>(numKvHeadsPerLayer.size ())}
509510 , mTokensPerBlock {tokensPerBlock}
510511 , mEventManager {std::move (eventManager)}
@@ -513,6 +514,10 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
513514{
514515 auto const uniqueWindowSizeToLayers
515516 = BaseKVCacheManager::groupLayersByWindowSize (maxAttentionWindowVec, mNumLayers );
517+
518+ TLLM_CHECK_WITH_INFO (kvCacheConnectorManager == nullptr || uniqueWindowSizeToLayers.size () == 1 ,
519+ " KV Cache Connector is not supported with multiple window sizes" );
520+
516521 auto const numUniqueWindowSizes = static_cast <SizeType32>(uniqueWindowSizeToLayers.size ());
517522
518523 mIsVariableWindow = numUniqueWindowSizes > 1 ;
@@ -530,7 +535,7 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
530535 mWindowBlockManagers .try_emplace (windowSize, dtype, windowSize, layersWithWindowSize, numKvHeadsPerLayer,
531536 sizePerHead, tokensPerBlock, allottedPrimaryBlocks, allottedSecondaryBlocks, maxNumSequences, stream,
532537 onboardBlocks, cacheType, secondaryOffloadMinPriority, mEventManager , enablePartialReuse,
533- copyOnPartialReuse);
538+ copyOnPartialReuse, kvCacheConnectorManager );
534539 }
535540
536541 auto const numAllPools = getNumPools ();
@@ -572,7 +577,8 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
572577 SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool,
573578 SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream, bool onboardBlocks, CacheType cacheType,
574579 std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
575- std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
580+ std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
581+ std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
576582 : mDataType {dtype}
577583 , mWindowSize {windowSize}
578584 , mNumPrimaryBlocks {blocksInPrimaryPool}
@@ -596,6 +602,7 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
596602 , mTotalInputTokens {0.0 }
597603 , mEnablePartialReuse {enablePartialReuse}
598604 , mCopyOnPartialReuse {copyOnPartialReuse}
605+ , mKvCacheConnectorManager {std::move (kvCacheConnectorManager)}
599606{
600607 std::map<SizeType32, SizeType32> numLayersPerPool;
601608
@@ -1188,9 +1195,18 @@ void WindowBlockManager::addSequence(
11881195 auto const prepopulatedPromptLen = loadOrAllocateBlocks (blockKeys, numContextBlocks, sequence, perBlockRetentions);
11891196 mReusedTokens += static_cast <double >(prepopulatedPromptLen);
11901197 mTotalInputTokens += static_cast <double >(uniqueTokens.size ());
1191- llmRequest.setPrepopulatedPromptLen (prepopulatedPromptLen, getTokensPerBlock ());
1192- TLLM_LOG_DEBUG (" addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d" , llmRequest.mRequestId ,
1193- inputLength, prepopulatedPromptLen);
1198+
1199+ SizeType32 numConnectorMatchedTokens = 0 ;
1200+
1201+ // If we're using a KV cache connector, check if any additional blocks can be loaded.
1202+ if (mKvCacheConnectorManager && !llmRequest.isDummyRequest ())
1203+ {
1204+ numConnectorMatchedTokens = mKvCacheConnectorManager ->getNumNewMatchedTokens (llmRequest, prepopulatedPromptLen);
1205+ }
1206+
1207+ llmRequest.setPrepopulatedPromptLen (prepopulatedPromptLen + numConnectorMatchedTokens, getTokensPerBlock ());
1208+ TLLM_LOG_DEBUG (" addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d, numConnectorMatchedTokens %d" ,
1209+ llmRequest.mRequestId , inputLength, prepopulatedPromptLen, numConnectorMatchedTokens);
11941210}
11951211
11961212// There are two versions of BlockManager::addSequence function.
@@ -1206,6 +1222,13 @@ void BlockManager::addSequence(
12061222void WindowBlockManager::addSequence (
12071223 GenerationRequest& sequence, SizeType32 numContextBlocks, bool isShareLastContextBlock)
12081224{
1225+ if (mKvCacheConnectorManager )
1226+ {
1227+ TLLM_LOG_WARNING (
1228+ " KV Cache Connector specified when block reuse is disabled. The KV Cache Connector will be "
1229+ " ignored." );
1230+ }
1231+
12091232 auto const requestId = sequence.getRequestId ();
12101233 auto const [seqIt, emplaceDone] = mAllocatedBlocksPerSeq .emplace (requestId, std::vector<BlockPtr>{});
12111234 TLLM_CHECK (emplaceDone);
@@ -1618,12 +1641,13 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
16181641 SizeType32 sinkTokenLength, int64_t stream, std::optional<runtime::SizeType32> maxSequenceLength,
16191642 bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
16201643 std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
1621- std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
1644+ std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
1645+ std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
16221646 : KVCacheManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth,
16231647 maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength,
16241648 std::make_shared<runtime::CudaStream>(reinterpret_cast <cudaStream_t>(stream)), maxSequenceLength,
16251649 enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority, eventManager, enablePartialReuse,
1626- copyOnPartialReuse)
1650+ copyOnPartialReuse, kvCacheConnectorManager )
16271651{
16281652}
16291653
@@ -1634,7 +1658,8 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
16341658 SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<runtime::SizeType32> maxSequenceLength,
16351659 bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
16361660 std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
1637- std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
1661+ std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
1662+ std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
16381663 : mMaxBeamWidth (maxBeamWidth)
16391664 , mDataType (dtype)
16401665 , mMaxAttentionWindow (*std::max_element (maxAttentionWindowVec.begin(), maxAttentionWindowVec.end()))
@@ -1644,7 +1669,7 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
16441669 , mBlockManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences,
16451670 std::move (stream), maxSequenceLength, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype,
16461671 mSinkBubbleLength, onboardBlocks, cacheType, secondaryOffloadMinPriority, std::move(eventManager),
1647- enablePartialReuse, copyOnPartialReuse)
1672+ enablePartialReuse, copyOnPartialReuse, std::move(kvCacheConnectorManager) )
16481673 // disable block reuse for sink bubble since chopVectorIntoBlocks does not match KV cache blocks in this case
16491674 , mEnableBlockReuse{mSinkBubbleLength > 0 ? false : enableBlockReuse}
16501675{
@@ -1668,11 +1693,12 @@ KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, Size
16681693 SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<runtime::SizeType32> maxSequenceLength,
16691694 bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
16701695 std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
1671- std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
1696+ std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
1697+ std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
16721698 : KVCacheManager(std::vector<SizeType32>(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow,
16731699 maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength,
16741700 std::move(stream), maxSequenceLength, enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority,
1675- std::move(eventManager), enablePartialReuse, copyOnPartialReuse)
1701+ std::move(eventManager), enablePartialReuse, copyOnPartialReuse, std::move(kvCacheConnectorManager) )
16761702{
16771703}
16781704
@@ -2383,6 +2409,13 @@ std::vector<SizeType32> KVCacheManager::getNewlyAllocatedBlockIds(
23832409 return mBlockManager .getNewlyAllocatedBlockIds (getSequence (requestId), windowSize);
23842410}
23852411
2412+ runtime::ITensor::SharedPtr KVCacheManager::getUniquePrimaryPool () const
2413+ {
2414+ TLLM_CHECK_WITH_INFO (mBlockManager .getWindowSizesMetadata ().size () == 1 ,
2415+ " getUniquePrimaryPool is only supported for a single window size" );
2416+ return mBlockManager .getPrimaryPool (0 );
2417+ }
2418+
23862419runtime::ITensor::SharedPtr KVCacheManager::getPrimaryPool (SizeType32 layer_idx) const
23872420{
23882421 return mBlockManager .getPrimaryPool (mBlockManager .getLayerPoolIdx (layer_idx));
@@ -2462,4 +2495,5 @@ SizeType32 KVCacheManager::calculateMaxBlockRequirements(SizeType32 inputLength,
24622495 auto const leftoverBlockCapacity = blockCapacity - outputBlockRequirements;
24632496 return std::min (outputLength + leftoverBlockCapacity * tokensPerBlock, inputLength + outputLength);
24642497}
2498+
24652499} // namespace tensorrt_llm::batch_manager::kv_cache_manager
0 commit comments