@@ -65,6 +65,7 @@ using FreeBlocksQueue = std::list<BlockPtr>;
6565using UniqueToken = tensorrt_llm::runtime::UniqueToken;
6666using VecUniqueTokens = tensorrt_llm::runtime::VecUniqueTokens;
6767using LoraTaskIdType = tensorrt_llm::runtime::LoraTaskIdType;
68+ using BlocksPerWindow = std::map<SizeType32, std::tuple<SizeType32, SizeType32>>;
6869
6970template <typename T>
7071using OptionalRef = tensorrt_llm::common::OptionalRef<T>;
@@ -78,6 +79,8 @@ struct TempAttentionWindowInputs
7879
7980struct WindowSizeMetadata
8081{
82+ SizeType32 allottedPrimaryBlocks; // Number of primary blocks allotted to the windowSize
83+ SizeType32 allottedSecondaryBlocks; // Number of secondary blocks allotted to the windowSize
8184 SizeType32 absolutePoolsOffset; // cumulative number of pools up to manager
8285 SizeType32 numPools; // number of managed pools
8386 SizeType32 maxTokenNum; // Maximum token length (including bubble)
@@ -90,9 +93,10 @@ struct WindowSizeMetadata
9093 std::string toString ()
9194 {
9295 return tensorrt_llm::common::fmtstr (
93- " WindowSizeMetadata{ .absolutePoolsOffset=%d, .numPools=%d, .maxTokenNum=%d, .maxBlocksPerSeq=%d, "
94- " .maxNumBlocks=%d, .temporaryAttentionWindow=%d }" ,
95- absolutePoolsOffset, numPools, maxTokenNum, maxBlocksPerSeq, maxNumBlocks, temporaryAttentionWindow);
96+ " WindowSizeMetadata{ .allottedPrimaryBlocks=%d, .allottedSecondaryBlocks=%d, .absolutePoolsOffset=%d, "
97+ " .numPools=%d, .maxTokenNum=%d, .maxBlocksPerSeq=%d, .maxNumBlocks=%d, .temporaryAttentionWindow=%d }" ,
98+ allottedPrimaryBlocks, allottedSecondaryBlocks, absolutePoolsOffset, numPools, maxTokenNum, maxBlocksPerSeq,
99+ maxNumBlocks, temporaryAttentionWindow);
96100 }
97101};
98102
@@ -838,22 +842,23 @@ class BlockManager
838842 using BaseEvictionPolicy = tensorrt_llm::batch_manager::eviction_policy::BaseEvictionPolicy;
839843
840844 explicit BlockManager (std::vector<SizeType32> const & numKvHeadsPerLayer, SizeType32 sizePerHead,
841- SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool , SizeType32 blocksInSecondaryPool ,
842- SizeType32 maxNumSequences, CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength,
843- SizeType32 maxBeamWidth, std::vector<SizeType32> const & maxAttentionWindowVec,
845+ SizeType32 tokensPerBlock, BlocksPerWindow const & blocksPerWindow , SizeType32 maxNumSequences ,
846+ CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength, SizeType32 maxBeamWidth ,
847+ std::vector<SizeType32> const & maxAttentionWindowVec,
844848 std::optional<TempAttentionWindowInputs> const & tempAttentionWindowInputs, nvinfer1::DataType dtype,
845849 SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType = CacheType::kSELF ,
846850 std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt ,
847851 std::shared_ptr<KVCacheEventManager> eventManager = nullptr , bool enableHashKey = false ,
848852 bool enablePartialReuse = true , bool copyOnPartialReuse = true );
849853
850- // ! \brief Calculate the number of blocks each window size heap receives of blocksIn{Primary/Secondary}Pool
851- // ! \details Example: (total=16384, uniqueWindowSizeToLayers={1024: [1], 4096: [0, 4, 5], 8192: [2, 3]})
852- // ! Would Return: {1024: 565, 4096: 6780, 8192: 9039} [sums to total].
853- // ! See: TEST_F(KVCacheManagerTest, BlockManagerTestBlocksPerWindowSize).
854- // ! \return Map<windowSize, numBlocks>
855- static std::map<SizeType32, SizeType32> blocksPerWindowSize (
856- SizeType32 totalBlocks, std::map<SizeType32, std::vector<SizeType32>> const & uniqueWindowSizeToLayers);
854+ // ! \brief Calculate the proportional share each window size receives of the total memory pool
855+ // ! \details Example: (uniqueWindowSizeToLayers={1024: [1], 4096: [0, 4, 5], 8192: [2, 3]})
856+ // ! Would Return: {1024: 0.0345, 4096: 0.4138, 8192: 0.5517} [sums to 1.0].
857+ // ! See: TEST_F(KVCacheManagerTest, BlockManagerTestWindowSizeToShare).
858+ // ! \return Map<windowSize, share> where share is a float between 0 and 1. Shares sum to 1.0.
859+ static std::map<SizeType32, float > calculateWindowSizeToShare (
860+ std::map<SizeType32, std::vector<SizeType32>> const & uniqueWindowSizeToLayers,
861+ std::map<SizeType32, SizeType32> const & cacheSizePerTokenPerWindowSize);
857862
858863 void allocatePools (bool useUvm);
859864
@@ -1279,21 +1284,54 @@ class BaseKVCacheManager
12791284 [[nodiscard]] static SizeType32 getSinkBubbleLength (SizeType32 sinkTokenLen, SizeType32 tokensPerBlock);
12801285
12811286 // Sum of numLayers * kvFactor * numKvHeads * sizePerHead for each pool
1282- [[nodiscard]] static SizeType32 calculateCacheSizePerToken (tensorrt_llm::runtime::ModelConfig const & modelConfig,
1283- tensorrt_llm::runtime::WorldConfig const & worldConfig, bool isCrossAttention = false , SizeType32 kvFactor = 2 )
1287+ [[nodiscard]] static SizeType32 calculateCacheSizePerTokenForSingleWindowSize (
1288+ tensorrt_llm::runtime::ModelConfig const & modelConfig, std::vector<SizeType32> const & windowSizeLayers,
1289+ bool isCrossAttention, SizeType32 kvFactor)
12841290 {
1291+ auto const nkvh = modelConfig.getNumKvHeadsForGivenLayers (windowSizeLayers, isCrossAttention);
1292+ auto const sumLocalHeads = std::reduce (nkvh.cbegin (), nkvh.cend ());
12851293 // NOTE: We expect the initialization of modelConfig to have already taken the tp size into account and do not
12861294 // address it here
12871295 // consider only local layers for the calculation
1288- return modelConfig.getSumLocalKvHeads (
1289- worldConfig.getPipelineParallelism (), worldConfig.getPipelineParallelRank (), isCrossAttention)
1290- * kvFactor * modelConfig.getSizePerHead ();
1291- }
1292-
1293- [[nodiscard]] static std::tuple<SizeType32, SizeType32> calculateMaxNumBlocks (KvCacheConfig const & config,
1296+ return sumLocalHeads * kvFactor * modelConfig.getSizePerHead ();
1297+ }
1298+
1299+ // / @brief Groups model layers by their attention window size.
1300+ // / @param maxAttentionWindowVec Vector of maximum attention window sizes per layer (may have fewer elements than
1301+ // / numLayers, in which case it cycles)
1302+ // / @param numLayers Total number of layers in the model
1303+ // / @return Map from window size to vector of layer indices that use that window size
1304+ [[nodiscard]] static std::map<SizeType32, std::vector<SizeType32>> groupLayersByWindowSize (
1305+ std::vector<SizeType32> const & maxAttentionWindowVec, SizeType32 numLayers);
1306+
1307+ // / @brief Calculate the free memory available for KV cache allocation.
1308+ // / @param bufferManager Buffer manager for memory operations
1309+ // / @param config KV cache configuration parameters
1310+ // / @return Tuple containing the {.freePrimaryMemBytes, .freeSecondaryMemBytes}
1311+ [[nodiscard]] static std::tuple<uint64_t , uint64_t > calculateFreeMemBytes (
1312+ runtime::BufferManager const & bufferManager, KvCacheConfig const & config);
1313+
1314+ // / @brief Calculate the maximum number of KV cache blocks that can be allocated based on available GPU memory.
1315+ // / @details This function computes how many blocks each WindowBlockManager should receive based on the weighted
1316+ // / share
1317+ // / of memory requirements. The weighting considers both the window size and the number of
1318+ // / layers using each window size, as well as the sum of cache sizes per token for each window.
1319+ // / @param config KV cache configuration parameters
1320+ // / @param isCrossAttention Whether this is for cross-attention KV cache
1321+ // / @param dtype Data type used for KV cache values
1322+ // / @param modelConfig Model configuration containing layer and head information
1323+ // / @param worldConfig World configuration for multi-GPU setups
1324+ // / @param windowSizeToLayers Map from attention window size to vector of layer indices using that window size
1325+ // / @param allottedPrimaryMemBytes Allotted primary memory
1326+ // / @param allottedSecondaryMemBytes Allotted secondary memory
1327+ // / @param extraCostMemory Additional memory cost to account for CacheTransBufferManager::preAllocBufferSize
1328+ // / @param kvFactor Factor for KV cache size calculation (typically 2 for key+value)
1329+ // / @return Map from window size to tuple of (primary blocks, secondary blocks)
1330+ [[nodiscard]] static BlocksPerWindow calculateMaxNumBlocks (KvCacheConfig const & config, bool isCrossAttention,
12941331 nvinfer1::DataType dtype, tensorrt_llm::runtime::ModelConfig const & modelConfig,
1295- tensorrt_llm::runtime::WorldConfig const & worldConfig, runtime::BufferManager const & bufferManager,
1296- SizeType32 kvFactor = 2 , size_t extraCostMemory = 0 );
1332+ tensorrt_llm::runtime::WorldConfig const & worldConfig,
1333+ std::map<SizeType32, std::vector<SizeType32>> const & windowSizeToLayers, uint64_t allottedPrimaryMemBytes,
1334+ uint64_t allottedSecondaryMemBytes, size_t extraCostMemory, SizeType32 kvFactor);
12971335
12981336 // / @brief Calculates the maximum batch size that can fit the kv-cache, given that all sequences in the batch have
12991337 // / the provided input and output length.
@@ -1316,8 +1354,8 @@ class KVCacheManager : public BaseKVCacheManager
13161354 using CacheType = tensorrt_llm::batch_manager::kv_cache_manager::CacheType;
13171355
13181356 KVCacheManager (std::vector<SizeType32> const & numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
1319- SizeType32 blocksInPrimaryPool , SizeType32 blocksInSecondaryPool , SizeType32 maxNumSequences ,
1320- SizeType32 maxBeamWidth, std::vector<SizeType32> const & maxAttentionWindowVec,
1357+ BlocksPerWindow const & blocksPerWindow , SizeType32 maxNumSequences , SizeType32 maxBeamWidth ,
1358+ std::vector<SizeType32> const & maxAttentionWindowVec,
13211359 std::optional<TempAttentionWindowInputs> const & tempAttentionWindowInputs, nvinfer1::DataType dtype,
13221360 SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength,
13231361 bool enableBlockReuse = false , bool onboardBlocks = true , CacheType cacheType = CacheType::kSELF ,
@@ -1326,8 +1364,8 @@ class KVCacheManager : public BaseKVCacheManager
13261364 bool enablePartialReuse = true , bool copyOnpartialReuse = true );
13271365
13281366 KVCacheManager (std::vector<SizeType32> const & numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
1329- SizeType32 blocksInPrimaryPool , SizeType32 blocksInSecondaryPool , SizeType32 maxNumSequences ,
1330- SizeType32 maxBeamWidth, std::vector<SizeType32> const & maxAttentionWindowVec,
1367+ BlocksPerWindow const & blocksPerWindow , SizeType32 maxNumSequences , SizeType32 maxBeamWidth ,
1368+ std::vector<SizeType32> const & maxAttentionWindowVec,
13311369 std::optional<TempAttentionWindowInputs> const & tempAttentionWindowInputs, nvinfer1::DataType dtype,
13321370 SizeType32 sinkTokenLength, int64_t stream, std::optional<SizeType32> maxSequenceLength,
13331371 bool enableBlockReuse = false , bool onboardBlocks = true , CacheType cacheType = CacheType::kSELF ,
@@ -1336,8 +1374,8 @@ class KVCacheManager : public BaseKVCacheManager
13361374 bool copyOnpartialReuse = true );
13371375
13381376 KVCacheManager (SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
1339- SizeType32 blocksInPrimaryPool , SizeType32 blocksInSecondaryPool , SizeType32 maxNumSequences ,
1340- SizeType32 maxBeamWidth, std::vector<SizeType32> const & maxAttentionWindowVec,
1377+ BlocksPerWindow const & blocksPerWindow , SizeType32 maxNumSequences , SizeType32 maxBeamWidth ,
1378+ std::vector<SizeType32> const & maxAttentionWindowVec,
13411379 std::optional<TempAttentionWindowInputs> const & tempAttentionWindowInputs, nvinfer1::DataType dtype,
13421380 SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength,
13431381 bool enableBlockReuse = true , bool onboardBlocks = true , CacheType cacheType = CacheType::kSELF ,
@@ -1346,8 +1384,8 @@ class KVCacheManager : public BaseKVCacheManager
13461384 bool enablePartialReuse = true , bool copyOnpartialReuse = true );
13471385
13481386 KVCacheManager (SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
1349- SizeType32 blocksInPrimaryPool , SizeType32 blocksInSecondaryPool , SizeType32 maxNumSequences ,
1350- SizeType32 maxBeamWidth, std::vector<SizeType32> const & maxAttentionWindowVec,
1387+ BlocksPerWindow const & blocksPerWindow , SizeType32 maxNumSequences , SizeType32 maxBeamWidth ,
1388+ std::vector<SizeType32> const & maxAttentionWindowVec,
13511389 std::optional<TempAttentionWindowInputs> const & tempAttentionWindowInputs, nvinfer1::DataType dtype,
13521390 SizeType32 sinkTokenLength, int64_t stream, std::optional<SizeType32> maxSequenceLength,
13531391 bool enableBlockReuse = false , bool onboardBlocks = true , CacheType cacheType = CacheType::kSELF ,
@@ -1538,22 +1576,22 @@ class KVCacheManager : public BaseKVCacheManager
15381576 // / @param inputLength The number of input tokens in the sequence.
15391577 // / @param outputLength The number of output tokens in the sequence.
15401578 // / @param sinkTokenLength The number of sink tokens configured.
1541- // / @param maxAttentionWindow The maximum attention window allowed by the model.
1579+ // / @param maxAttentionWindow The attention window size allowed by the model.
15421580 // / @param beamWidth The number of beams to consider for the request.
15431581 // / @param tokensPerBlock The number of tokens a single kv-cache block contains.,
15441582 // / @return SizeType32 A number of blocks.
15451583 [[nodiscard]] static SizeType32 calculateMaxBlockRequirements (SizeType32 inputLength, SizeType32 outputLength,
1546- SizeType32 sinkTokenLength, SizeType32 maxAttentionWindow , SizeType32 beamWidth, SizeType32 tokensPerBlock);
1584+ SizeType32 sinkTokenLength, SizeType32 windowSize , SizeType32 beamWidth, SizeType32 tokensPerBlock);
15471585
15481586 // / @brief Calculates the number of kv-cache blocks that a sequence will require, for a single beam.
15491587 // /
15501588 // / @param sequenceLength The total length of the sequence (input and output).
15511589 // / @param sinkTokenLength The number of sink tokens configured.
1552- // / @param maxAttentionWindow The maximum attention window allowed by the model.
1590+ // / @param windowSize The attention window size
15531591 // / @param tokensPerBlock The number of tokens in a single kv-cache block.
15541592 // / @return SizeType32 A number of blocks.
1555- [[nodiscard]] static SizeType32 calculateMaxBlockRequirementsPerBeam (SizeType32 sequenceLength,
1556- SizeType32 sinkTokenLength, SizeType32 maxAttentionWindow , SizeType32 tokensPerBlock);
1593+ [[nodiscard]] static SizeType32 calculateMaxBlockRequirementsPerBeam (
1594+ SizeType32 sequenceLength, SizeType32 sinkTokenLength, SizeType32 windowSize , SizeType32 tokensPerBlock);
15571595
15581596 std::vector<std::vector<SizeType32>> const & getCacheBlockIds (
15591597 LlmRequest::RequestIdType requestId, SizeType32 windowSize) const override ;
0 commit comments