Skip to content

Commit e692779

Browse files
Solve underallocation in VSWA+/VGQA (#4667)
Signed-off-by: Netanel Haber <[email protected]>
1 parent 4319237 commit e692779

File tree

20 files changed

+837
-469
lines changed

20 files changed

+837
-469
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 74 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ using FreeBlocksQueue = std::list<BlockPtr>;
6565
using UniqueToken = tensorrt_llm::runtime::UniqueToken;
6666
using VecUniqueTokens = tensorrt_llm::runtime::VecUniqueTokens;
6767
using LoraTaskIdType = tensorrt_llm::runtime::LoraTaskIdType;
68+
using BlocksPerWindow = std::map<SizeType32, std::tuple<SizeType32, SizeType32>>;
6869

6970
template <typename T>
7071
using OptionalRef = tensorrt_llm::common::OptionalRef<T>;
@@ -78,6 +79,8 @@ struct TempAttentionWindowInputs
7879

7980
struct WindowSizeMetadata
8081
{
82+
SizeType32 allottedPrimaryBlocks; // Number of primary blocks allotted to the windowSize
83+
SizeType32 allottedSecondaryBlocks; // Number of secondary blocks allotted to the windowSize
8184
SizeType32 absolutePoolsOffset; // cumulative number of pools up to manager
8285
SizeType32 numPools; // number of managed pools
8386
SizeType32 maxTokenNum; // Maximum token length (including bubble)
@@ -90,9 +93,10 @@ struct WindowSizeMetadata
9093
std::string toString()
9194
{
9295
return tensorrt_llm::common::fmtstr(
93-
"WindowSizeMetadata{ .absolutePoolsOffset=%d, .numPools=%d, .maxTokenNum=%d, .maxBlocksPerSeq=%d, "
94-
".maxNumBlocks=%d, .temporaryAttentionWindow=%d }",
95-
absolutePoolsOffset, numPools, maxTokenNum, maxBlocksPerSeq, maxNumBlocks, temporaryAttentionWindow);
96+
"WindowSizeMetadata{ .allottedPrimaryBlocks=%d, .allottedSecondaryBlocks=%d, .absolutePoolsOffset=%d, "
97+
".numPools=%d, .maxTokenNum=%d, .maxBlocksPerSeq=%d, .maxNumBlocks=%d, .temporaryAttentionWindow=%d }",
98+
allottedPrimaryBlocks, allottedSecondaryBlocks, absolutePoolsOffset, numPools, maxTokenNum, maxBlocksPerSeq,
99+
maxNumBlocks, temporaryAttentionWindow);
96100
}
97101
};
98102

@@ -838,22 +842,23 @@ class BlockManager
838842
using BaseEvictionPolicy = tensorrt_llm::batch_manager::eviction_policy::BaseEvictionPolicy;
839843

840844
explicit BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, SizeType32 sizePerHead,
841-
SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool,
842-
SizeType32 maxNumSequences, CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength,
843-
SizeType32 maxBeamWidth, std::vector<SizeType32> const& maxAttentionWindowVec,
845+
SizeType32 tokensPerBlock, BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences,
846+
CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength, SizeType32 maxBeamWidth,
847+
std::vector<SizeType32> const& maxAttentionWindowVec,
844848
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
845849
SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType = CacheType::kSELF,
846850
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
847851
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enableHashKey = false,
848852
bool enablePartialReuse = true, bool copyOnPartialReuse = true);
849853

850-
//! \brief Calculate the number of blocks each window size heap receives of blocksIn{Primary/Secondary}Pool
851-
//! \details Example: (total=16384, uniqueWindowSizeToLayers={1024: [1], 4096: [0, 4, 5], 8192: [2, 3]})
852-
//! Would Return: {1024: 565, 4096: 6780, 8192: 9039} [sums to total].
853-
//! See: TEST_F(KVCacheManagerTest, BlockManagerTestBlocksPerWindowSize).
854-
//! \return Map<windowSize, numBlocks>
855-
static std::map<SizeType32, SizeType32> blocksPerWindowSize(
856-
SizeType32 totalBlocks, std::map<SizeType32, std::vector<SizeType32>> const& uniqueWindowSizeToLayers);
854+
//! \brief Calculate the proportional share each window size receives of the total memory pool
855+
//! \details Example: (uniqueWindowSizeToLayers={1024: [1], 4096: [0, 4, 5], 8192: [2, 3]})
856+
//! Would Return: {1024: 0.0345, 4096: 0.4138, 8192: 0.5517} [sums to 1.0].
857+
//! See: TEST_F(KVCacheManagerTest, BlockManagerTestWindowSizeToShare).
858+
//! \return Map<windowSize, share> where share is a float between 0 and 1. Shares sum to 1.0.
859+
static std::map<SizeType32, float> calculateWindowSizeToShare(
860+
std::map<SizeType32, std::vector<SizeType32>> const& uniqueWindowSizeToLayers,
861+
std::map<SizeType32, SizeType32> const& cacheSizePerTokenPerWindowSize);
857862

858863
void allocatePools(bool useUvm);
859864

@@ -1279,21 +1284,54 @@ class BaseKVCacheManager
12791284
[[nodiscard]] static SizeType32 getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock);
12801285

12811286
// Sum of numLayers * kvFactor * numKvHeads * sizePerHead for each pool
1282-
[[nodiscard]] static SizeType32 calculateCacheSizePerToken(tensorrt_llm::runtime::ModelConfig const& modelConfig,
1283-
tensorrt_llm::runtime::WorldConfig const& worldConfig, bool isCrossAttention = false, SizeType32 kvFactor = 2)
1287+
[[nodiscard]] static SizeType32 calculateCacheSizePerTokenForSingleWindowSize(
1288+
tensorrt_llm::runtime::ModelConfig const& modelConfig, std::vector<SizeType32> const& windowSizeLayers,
1289+
bool isCrossAttention, SizeType32 kvFactor)
12841290
{
1291+
auto const nkvh = modelConfig.getNumKvHeadsForGivenLayers(windowSizeLayers, isCrossAttention);
1292+
auto const sumLocalHeads = std::reduce(nkvh.cbegin(), nkvh.cend());
12851293
// NOTE: We expect the initialization of modelConfig to have already taken the tp size into account and do not
12861294
// address it here
12871295
// consider only local layers for the calculation
1288-
return modelConfig.getSumLocalKvHeads(
1289-
worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank(), isCrossAttention)
1290-
* kvFactor * modelConfig.getSizePerHead();
1291-
}
1292-
1293-
[[nodiscard]] static std::tuple<SizeType32, SizeType32> calculateMaxNumBlocks(KvCacheConfig const& config,
1296+
return sumLocalHeads * kvFactor * modelConfig.getSizePerHead();
1297+
}
1298+
1299+
/// @brief Groups model layers by their attention window size.
1300+
/// @param maxAttentionWindowVec Vector of maximum attention window sizes per layer (may have fewer elements than
1301+
/// numLayers, in which case it cycles)
1302+
/// @param numLayers Total number of layers in the model
1303+
/// @return Map from window size to vector of layer indices that use that window size
1304+
[[nodiscard]] static std::map<SizeType32, std::vector<SizeType32>> groupLayersByWindowSize(
1305+
std::vector<SizeType32> const& maxAttentionWindowVec, SizeType32 numLayers);
1306+
1307+
/// @brief Calculate the free memory available for KV cache allocation.
1308+
/// @param bufferManager Buffer manager for memory operations
1309+
/// @param config KV cache configuration parameters
1310+
/// @return Tuple containing the {.freePrimaryMemBytes, .freeSecondaryMemBytes}
1311+
[[nodiscard]] static std::tuple<uint64_t, uint64_t> calculateFreeMemBytes(
1312+
runtime::BufferManager const& bufferManager, KvCacheConfig const& config);
1313+
1314+
/// @brief Calculate the maximum number of KV cache blocks that can be allocated based on available GPU memory.
1315+
/// @details This function computes how many blocks each WindowBlockManager should receive based on the weighted
1316+
/// share
1317+
/// of memory requirements. The weighting considers both the window size and the number of
1318+
/// layers using each window size, as well as the sum of cache sizes per token for each window.
1319+
/// @param config KV cache configuration parameters
1320+
/// @param isCrossAttention Whether this is for cross-attention KV cache
1321+
/// @param dtype Data type used for KV cache values
1322+
/// @param modelConfig Model configuration containing layer and head information
1323+
/// @param worldConfig World configuration for multi-GPU setups
1324+
/// @param windowSizeToLayers Map from attention window size to vector of layer indices using that window size
1325+
/// @param allottedPrimaryMemBytes Allotted primary memory
1326+
/// @param allottedSecondaryMemBytes Allotted secondary memory
1327+
/// @param extraCostMemory Additional memory cost to account for CacheTransBufferManager::preAllocBufferSize
1328+
/// @param kvFactor Factor for KV cache size calculation (typically 2 for key+value)
1329+
/// @return Map from window size to tuple of (primary blocks, secondary blocks)
1330+
[[nodiscard]] static BlocksPerWindow calculateMaxNumBlocks(KvCacheConfig const& config, bool isCrossAttention,
12941331
nvinfer1::DataType dtype, tensorrt_llm::runtime::ModelConfig const& modelConfig,
1295-
tensorrt_llm::runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager,
1296-
SizeType32 kvFactor = 2, size_t extraCostMemory = 0);
1332+
tensorrt_llm::runtime::WorldConfig const& worldConfig,
1333+
std::map<SizeType32, std::vector<SizeType32>> const& windowSizeToLayers, uint64_t allottedPrimaryMemBytes,
1334+
uint64_t allottedSecondaryMemBytes, size_t extraCostMemory, SizeType32 kvFactor);
12971335

12981336
/// @brief Calculates the maximum batch size that can fit the kv-cache, given that all sequences in the batch have
12991337
/// the provided input and output length.
@@ -1316,8 +1354,8 @@ class KVCacheManager : public BaseKVCacheManager
13161354
using CacheType = tensorrt_llm::batch_manager::kv_cache_manager::CacheType;
13171355

13181356
KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
1319-
SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences,
1320-
SizeType32 maxBeamWidth, std::vector<SizeType32> const& maxAttentionWindowVec,
1357+
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
1358+
std::vector<SizeType32> const& maxAttentionWindowVec,
13211359
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
13221360
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength,
13231361
bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
@@ -1326,8 +1364,8 @@ class KVCacheManager : public BaseKVCacheManager
13261364
bool enablePartialReuse = true, bool copyOnpartialReuse = true);
13271365

13281366
KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
1329-
SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences,
1330-
SizeType32 maxBeamWidth, std::vector<SizeType32> const& maxAttentionWindowVec,
1367+
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
1368+
std::vector<SizeType32> const& maxAttentionWindowVec,
13311369
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
13321370
SizeType32 sinkTokenLength, int64_t stream, std::optional<SizeType32> maxSequenceLength,
13331371
bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
@@ -1336,8 +1374,8 @@ class KVCacheManager : public BaseKVCacheManager
13361374
bool copyOnpartialReuse = true);
13371375

13381376
KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
1339-
SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences,
1340-
SizeType32 maxBeamWidth, std::vector<SizeType32> const& maxAttentionWindowVec,
1377+
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
1378+
std::vector<SizeType32> const& maxAttentionWindowVec,
13411379
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
13421380
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength,
13431381
bool enableBlockReuse = true, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
@@ -1346,8 +1384,8 @@ class KVCacheManager : public BaseKVCacheManager
13461384
bool enablePartialReuse = true, bool copyOnpartialReuse = true);
13471385

13481386
KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
1349-
SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences,
1350-
SizeType32 maxBeamWidth, std::vector<SizeType32> const& maxAttentionWindowVec,
1387+
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
1388+
std::vector<SizeType32> const& maxAttentionWindowVec,
13511389
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
13521390
SizeType32 sinkTokenLength, int64_t stream, std::optional<SizeType32> maxSequenceLength,
13531391
bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
@@ -1538,22 +1576,22 @@ class KVCacheManager : public BaseKVCacheManager
15381576
/// @param inputLength The number of input tokens in the sequence.
15391577
/// @param outputLength The number of output tokens in the sequence.
15401578
/// @param sinkTokenLength The number of sink tokens configured.
1541-
/// @param maxAttentionWindow The maximum attention window allowed by the model.
1579+
/// @param maxAttentionWindow The attention window size allowed by the model.
15421580
/// @param beamWidth The number of beams to consider for the request.
15431581
/// @param tokensPerBlock The number of tokens a single kv-cache block contains.,
15441582
/// @return SizeType32 A number of blocks.
15451583
[[nodiscard]] static SizeType32 calculateMaxBlockRequirements(SizeType32 inputLength, SizeType32 outputLength,
1546-
SizeType32 sinkTokenLength, SizeType32 maxAttentionWindow, SizeType32 beamWidth, SizeType32 tokensPerBlock);
1584+
SizeType32 sinkTokenLength, SizeType32 windowSize, SizeType32 beamWidth, SizeType32 tokensPerBlock);
15471585

15481586
/// @brief Calculates the number of kv-cache blocks that a sequence will require, for a single beam.
15491587
///
15501588
/// @param sequenceLength The total length of the sequence (input and output).
15511589
/// @param sinkTokenLength The number of sink tokens configured.
1552-
/// @param maxAttentionWindow The maximum attention window allowed by the model.
1590+
/// @param windowSize The attention window size
15531591
/// @param tokensPerBlock The number of tokens in a single kv-cache block.
15541592
/// @return SizeType32 A number of blocks.
1555-
[[nodiscard]] static SizeType32 calculateMaxBlockRequirementsPerBeam(SizeType32 sequenceLength,
1556-
SizeType32 sinkTokenLength, SizeType32 maxAttentionWindow, SizeType32 tokensPerBlock);
1593+
[[nodiscard]] static SizeType32 calculateMaxBlockRequirementsPerBeam(
1594+
SizeType32 sequenceLength, SizeType32 sinkTokenLength, SizeType32 windowSize, SizeType32 tokensPerBlock);
15571595

15581596
std::vector<std::vector<SizeType32>> const& getCacheBlockIds(
15591597
LlmRequest::RequestIdType requestId, SizeType32 windowSize) const override;

cpp/include/tensorrt_llm/runtime/modelConfig.h

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,18 @@ class ModelConfig
799799
return mNumKvHeadsPerAttentionLayer;
800800
}
801801

802+
[[nodiscard]] std::vector<SizeType32> getNumKvHeadsForGivenLayers(
803+
std::vector<SizeType32> const& layers, bool isCrossAttention) const
804+
{
805+
std::vector<SizeType32> numKvHeads;
806+
numKvHeads.reserve(layers.size());
807+
auto const numKvHeadsAllLayers
808+
= isCrossAttention ? mNumKvHeadsPerCrossAttentionLayer : mNumKvHeadsPerAttentionLayer;
809+
std::transform(layers.begin(), layers.end(), std::back_inserter(numKvHeads),
810+
[&numKvHeadsAllLayers](SizeType32 layer) { return numKvHeadsAllLayers.at(layer); });
811+
return numKvHeads;
812+
}
813+
802814
[[nodiscard]] std::pair<std::vector<SizeType32>::const_iterator, std::vector<SizeType32>::const_iterator>
803815
getNumKvHeadsPerLayerLocalRange(
804816
SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0, bool isCrossAttention = false) const
@@ -834,15 +846,6 @@ class ModelConfig
834846
mNumKvHeadsPerCrossAttentionLayer = headsPerLayer;
835847
}
836848

837-
[[nodiscard]] SizeType32 getSumLocalKvHeads(
838-
SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0, bool isCrossAttention = false) const
839-
{
840-
auto [cbegin, cend]
841-
= getNumKvHeadsPerLayerLocalRange(pipelineParallelism, pipelineParallelismRank, isCrossAttention);
842-
auto const sumLocalHeads = std::reduce(cbegin, cend);
843-
return sumLocalHeads;
844-
}
845-
846849
[[nodiscard]] bool constexpr skipCrossAttnBlocks() const noexcept
847850
{
848851
return mSkipCrossAttnBlocks;

cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -230,22 +230,17 @@ CacheTransBufferManager::CacheTransBufferManager(
230230
allocateBuffer();
231231
}
232232

233-
size_t CacheTransBufferManager::preAllocBufferSize(
234-
std::optional<size_t> maxNumTokens, std::optional<size_t> kvCacheSizePerToken)
233+
size_t CacheTransBufferManager::preAllocBufferSize(std::optional<size_t> maxNumTokens)
235234
{
236235
bool to_allocate = common::getEnvUseMPIKvCache() || common::getEnvUseUCXKvCache() || common::getEnvUseNixlKvCache();
237236
if (!to_allocate)
238237
{
239238
return 0;
240239
}
241-
if (maxNumTokens.has_value())
242-
{
243-
TLLM_CHECK(kvCacheSizePerToken.has_value());
244-
}
245240
size_t TransferBufferSize = common::getEnvMemSizeForKVCacheTransferBuffer();
246241
if (maxNumTokens.has_value())
247242
{
248-
TransferBufferSize = maxNumTokens.value() * kvCacheSizePerToken.value();
243+
TransferBufferSize = maxNumTokens.value();
249244
}
250245
bool useFabricMemory = FabricMemory::supportFbaricMemory()
251246
&& (!(common::getEnvKVCacheTransferUseSyncBuffer() || common::getEnvKVCacheTransferUseAsyncBuffer()));

cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,7 @@ class CacheTransBufferManager
5959
CacheTransBufferManager(
6060
KVCacheManager::BaseKVCacheManager* cacheManager, std::optional<size_t> maxNumTokens = std::nullopt);
6161

62-
static size_t preAllocBufferSize(
63-
std::optional<size_t> maxNumTokens = std::nullopt, std::optional<size_t> kvCacheSizePerToken = std::nullopt);
62+
static size_t preAllocBufferSize(std::optional<size_t> maxNumTokens = std::nullopt);
6463

6564
std::optional<int> assignBufferIndexForSend();
6665
void freeBufferIndexForSend(std::optional<int> bufferId);

0 commit comments

Comments
 (0)