1717#pragma once
1818
1919#include " tensorrt_llm/batch_manager/kvCacheManager.h"
20+ #include " tensorrt_llm/runtime/iTensor.h"
2021
2122namespace tensorrt_llm ::batch_manager::kv_cache_manager
2223{
2324
2425class BlockIterator ;
2526
26- class BlockRange
27+ class BlockRangeForWindow
2728{
2829public:
29- // C++20 std::default_sentinel_t equivalent
30+ BlockRangeForWindow (BaseKVCacheManager const * cacheManager, SizeType32 windowSize, std::vector<SizeType32> blockIds,
31+ runtime::ITensor::SharedPtr pool)
32+ : mCacheManager (cacheManager)
33+ , mWindowSize (windowSize)
34+ , mBlockIds (std::move(blockIds))
35+ , mPool (std::move(pool))
36+ {
37+ }
38+
3039 struct Sentinel
3140 {
3241 };
3342
34- static BlockRange fromAllBlockIds (BaseKVCacheManager const & cacheManager, LlmRequest::RequestIdType requestId,
35- SizeType32 beam = kFIRST_AND_ONLY_BEAM )
43+ friend class BlockIterator ;
44+ BlockIterator begin () const ;
45+
46+ [[nodiscard]] Sentinel end () const
47+ {
48+ return {};
49+ }
50+
51+ [[nodiscard]] size_t size () const
52+ {
53+ return mBlockIds .size ();
54+ }
55+
56+ private:
57+ BaseKVCacheManager const * mCacheManager ;
58+ SizeType32 mWindowSize ;
59+ std::vector<SizeType32> mBlockIds ;
60+ runtime::ITensor::SharedPtr mPool ;
61+ };
62+
63+ class BlockRange
64+ {
65+ public:
66+ static BlockRange fromAllBlockIds (BaseKVCacheManager const & cacheManager, LlmRequest::RequestIdType requestId)
3667 {
37- assert (kFIRST_AND_ONLY_BEAM == beam);
38- auto const windowSize = firstWindowSize (cacheManager);
39- auto const blockIds = cacheManager.getSequence (requestId).getCacheBlockIds (windowSize).at (kFIRST_AND_ONLY_BEAM );
40- return BlockRange (cacheManager, blockIds, requestId);
68+
69+ return BlockRange (cacheManager, requestId);
4170 }
4271
4372 static BlockRange fromReuseTree (
4473 BaseKVCacheManager& cacheManager, BlockKey const & lastBlockKey, int32_t indexFromEnd)
4574 {
46- auto const windowSize = firstWindowSize (cacheManager);
75+
76+ auto poolNum = cacheManager.getNumPools ();
77+ TLLM_CHECK_WITH_INFO (poolNum == 1 , " Reuse tree is not supported for multiple pools or variable window size" );
78+
79+ auto windowSize = cacheManager.getBlockManager ().getWindowSizesMetadata ().begin ()->first ;
4780 // Find the last block in the reuse tree for the provided full sequence of block keys
4881 auto lastBlock = cacheManager.findBlocksInReuseTreeByBlockKey (lastBlockKey, windowSize);
4982 // TODO: handle the case where the last block is not found
@@ -65,78 +98,104 @@ class BlockRange
6598 }
6699 // Reverse to chronological order: oldest to newest
67100 std::reverse (blockIds.begin (), blockIds.end ());
68- return BlockRange (cacheManager, blockIds, 0 );
69- }
70-
71- BlockRange (runtime::ITensor::SharedPtr pool, std::vector<SizeType32> const & blockIds) // Only used in tests
72- : mManager {nullptr }
73- , mPool {std::move (pool)}
74- , mWindowSize {0 }
75- , mRequestId {0 }
76- , mBlockIds {blockIds}
77- {
78- TLLM_CHECK (mPool );
101+ std::unordered_map<SizeType32, std::vector<SizeType32>> blockIdsPerWindow;
102+ blockIdsPerWindow[windowSize] = blockIds;
103+ return BlockRange (cacheManager, blockIdsPerWindow, 0 );
79104 }
80105
81- [[nodiscard]] BlockIterator begin () const ;
82-
83- [[nodiscard]] Sentinel end () const
106+ void setBlockIdsForWindow (SizeType32 windowSize, std::vector<SizeType32> blockIds)
84107 {
85- return {};
108+ TLLM_CHECK_WITH_INFO (mBlockIdsPerWindow .find (windowSize) != mBlockIdsPerWindow .end (),
109+ " Window size %d should exists" , windowSize);
110+ mBlockIdsPerWindow [windowSize] = std::move (blockIds);
86111 }
87112
88- [[nodiscard]] size_t size () const
113+ void setBlockIdsForAllWindows (std::unordered_map<SizeType32, std::vector<SizeType32>> blockIdsPerWindow)
89114 {
90- return mBlockIds .size ();
115+ for (auto const & [windowSize, blockIds] : blockIdsPerWindow)
116+ {
117+ TLLM_CHECK_WITH_INFO (
118+ mPoolsPerWindow .find (windowSize) != mPoolsPerWindow .end (), " Window size %d should exists" , windowSize);
119+ }
120+ mBlockIdsPerWindow = std::move (blockIdsPerWindow);
91121 }
92122
93- [[nodiscard]] std::vector <SizeType32> const & getBlockIds () const
123+ [[nodiscard]] std::unordered_map <SizeType32, std::vector< size_t >> getBlockHashesPerWindow () const
94124 {
95- return mBlockIds ;
125+ TLLM_CHECK (mManager );
126+ std::unordered_map<SizeType32, std::vector<size_t >> blockHashesPerWindow;
127+ auto & blockManager = mManager ->getBlockManager ();
128+ for (auto const & [windowSize, blockIds] : mBlockIdsPerWindow )
129+ {
130+ for (auto const & blockId : blockIds)
131+ {
132+ blockHashesPerWindow[windowSize].emplace_back (
133+ blockManager.getBlockById (blockId, windowSize)->getHash ());
134+ }
135+ }
136+ return blockHashesPerWindow;
96137 }
97138
98- void setBlockIds (std::vector< SizeType32> blockIds)
139+ BlockRangeForWindow getBlockRangeForWindow ( SizeType32 windowSize) const
99140 {
100- mBlockIds = std::move (blockIds);
141+ TLLM_CHECK_WITH_INFO (
142+ mPoolsPerWindow .find (windowSize) != mPoolsPerWindow .end (), " Window size %d not found" , windowSize);
143+ auto pool = mPoolsPerWindow .at (windowSize).front ();
144+ auto blockIds = mBlockIdsPerWindow .at (windowSize);
145+ return BlockRangeForWindow (mManager , windowSize, std::move (blockIds), std::move (pool));
101146 }
102147
103- void updatePoolIdx ( SizeType32 poolIdx)
148+ std::vector< SizeType32> getWindowSizes () const
104149 {
105- TLLM_CHECK (mManager );
106- mPool = mManager ->getBlockManager ().getPrimaryPool (poolIdx);
107- auto const newWindowSize = mManager ->getBlockManager ().getPoolWindowSize (poolIdx);
108- if (newWindowSize != mWindowSize )
150+ std::vector<SizeType32> windowSizes;
151+ for (auto const & [windowSize, _] : mPoolsPerWindow )
109152 {
110- mWindowSize = newWindowSize;
111- mBlockIds = mManager ->getSequence (mRequestId ).getCacheBlockIds (mWindowSize ).at (kFIRST_AND_ONLY_BEAM );
153+ windowSizes.push_back (windowSize);
112154 }
155+ return windowSizes;
113156 }
114157
115- friend class BlockIterator ;
158+ std::unordered_map<SizeType32, std::vector<SizeType32>> const & getBlockIdsPerWindow () const
159+ {
160+ return mBlockIdsPerWindow ;
161+ }
116162
117163private:
118- BlockRange (
119- BaseKVCacheManager const & cacheManager , std::vector<SizeType32> blockIds , LlmRequest::RequestIdType requestId)
164+ BlockRange (BaseKVCacheManager const & cacheManager,
165+ std::unordered_map<SizeType32 , std::vector<SizeType32>> blockIdsPerWindow , LlmRequest::RequestIdType requestId)
120166 : mManager (&cacheManager)
121- , mPool (cacheManager.getBlockManager().getPrimaryPool(kFIRST_POOL_INDEX ))
122- , mWindowSize (firstWindowSize(cacheManager))
123167 , mRequestId (requestId)
124- , mBlockIds (std::move(blockIds ))
168+ , mBlockIdsPerWindow (std::move(blockIdsPerWindow ))
125169 {
170+
171+ // cacheManager.getBlockManager.getPrimaryPool(0);
172+ auto poolNum = mManager ->getNumPools ();
173+ for (SizeType32 poolIdx = 0 ; poolIdx < poolNum; ++poolIdx)
174+ {
175+ auto windowSize = cacheManager.getBlockManager ().getPoolWindowSize (poolIdx);
176+ mPoolsPerWindow [windowSize].push_back (cacheManager.getBlockManager ().getPrimaryPool (poolIdx));
177+ }
126178 }
127179
128- static SizeType32 firstWindowSize (BaseKVCacheManager const & cacheManager)
180+ BlockRange (BaseKVCacheManager const & cacheManager, LlmRequest::RequestIdType requestId)
181+ : mManager (&cacheManager)
182+ , mRequestId (requestId)
129183 {
130- constexpr SizeType32 FIRST_POOL_IDX = 0 ;
131- return cacheManager.getBlockManager ().getPoolWindowSize (FIRST_POOL_IDX);
184+ auto poolNum = mManager ->getNumPools ();
185+ for (SizeType32 poolIdx = 0 ; poolIdx < poolNum; ++poolIdx)
186+ {
187+ auto windowSize = cacheManager.getBlockManager ().getPoolWindowSize (poolIdx);
188+ mPoolsPerWindow [windowSize].push_back (cacheManager.getBlockManager ().getPrimaryPool (poolIdx));
189+ mBlockIdsPerWindow [windowSize]
190+ = cacheManager.getSequence (mRequestId ).getCacheBlockIds (windowSize).at (kFIRST_AND_ONLY_BEAM );
191+ }
132192 }
133193
134194private:
135195 BaseKVCacheManager const * mManager ;
136- runtime::ITensor::SharedPtr mPool ;
137- SizeType32 mWindowSize ;
138- const LlmRequest::RequestIdType mRequestId ;
139- std::vector<SizeType32> mBlockIds ;
196+ LlmRequest::RequestIdType const mRequestId ;
197+ std::unordered_map<SizeType32, std::vector<SizeType32>> mBlockIdsPerWindow ;
198+ std::unordered_map<SizeType32, std::vector<runtime::ITensor::SharedPtr>> mPoolsPerWindow ;
140199
141200 static constexpr SizeType32 kFIRST_AND_ONLY_BEAM = 0 ;
142201 static constexpr SizeType32 kFIRST_POOL_INDEX = 0 ;
@@ -151,7 +210,7 @@ class BlockIterator
151210 using reference = value_type&;
152211 using SizeType32 = tensorrt_llm::runtime::SizeType32;
153212
154- BlockIterator (BlockRange const * range, size_t idx)
213+ BlockIterator (BlockRangeForWindow const * range, size_t idx)
155214 : mRange {range}
156215 , mIdx {idx}
157216 {
@@ -194,7 +253,7 @@ class BlockIterator
194253 return mIdx == other.mIdx && mRange == other.mRange ;
195254 }
196255
197- [[nodiscard]] bool operator ==(BlockRange ::Sentinel other) const
256+ [[nodiscard]] bool operator ==(BlockRangeForWindow ::Sentinel other) const
198257 {
199258 return mIdx == mRange ->mBlockIds .size ();
200259 }
@@ -210,16 +269,27 @@ class BlockIterator
210269 {
211270 if (mIdx < mRange ->mBlockIds .size ())
212271 {
213- mCurrent = runtime::ITensor::slice (mRange ->mPool , mRange ->mBlockIds .at (mIdx ), 1 );
272+ if (mRange ->mCacheManager != nullptr )
273+ {
274+ BlockPtr const & block = mRange ->mCacheManager ->getBlockManager ().getBlockById (
275+ mRange ->mBlockIds .at (mIdx ), mRange ->mWindowSize );
276+ TLLM_CHECK_WITH_INFO (block->isPrimary (), " cache transceiver only supports primary blocks" );
277+ auto const blockOffset = block->getMemoryPoolBlockIndex ();
278+ mCurrent = runtime::ITensor::slice (mRange ->mPool , blockOffset, 1 );
279+ }
280+ else
281+ {
282+ mCurrent = runtime::ITensor::slice (mRange ->mPool , mRange ->mBlockIds .at (mIdx ), 1 );
283+ }
214284 }
215285 }
216286
217- BlockRange const * mRange ;
287+ BlockRangeForWindow const * mRange ;
218288 runtime::ITensor::SharedPtr mCurrent ;
219289 size_t mIdx ;
220290};
221291
222- inline BlockIterator BlockRange ::begin () const
292+ inline BlockIterator BlockRangeForWindow ::begin () const
223293{
224294 return {this , 0 };
225295}
0 commit comments