Skip to content
Merged
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
012f8ad
allocate blocks per window size correctly
netanel-haber May 26, 2025
d5da328
Merge branch 'main' into user/nhaber/fix-variable-window-size-underal…
netanel-haber May 26, 2025
f5265e6
simpler code path for common homogeneous models
netanel-haber May 26, 2025
f94e3c5
shorten: (b|B)locksPerWindowSize -> blocksPerWindow
netanel-haber May 26, 2025
f3c3c63
fix trivial test compile errors
netanel-haber May 26, 2025
5e9c4de
fix non-trivial compile errors
netanel-haber May 26, 2025
3f12bd5
fix resource manager
netanel-haber May 26, 2025
aad71d8
fix extracostmemory
netanel-haber May 27, 2025
0caef2d
minimize diff
netanel-haber May 28, 2025
2e086a0
fix
netanel-haber May 28, 2025
9d867f8
fix
netanel-haber May 28, 2025
52053f1
small fix
netanel-haber May 28, 2025
4cf087c
dynamic batch tuning
netanel-haber May 28, 2025
f0c2427
fix tests
netanel-haber May 28, 2025
7d40928
fix blocks_per_window
netanel-haber May 29, 2025
7119d3a
use windowSizeToLayers for improved clarity instead of managedLayers …
netanel-haber May 29, 2025
08545e8
docs and naming
netanel-haber May 29, 2025
ac5e649
provide free memory to calculateMaxNumBlocks as an argument, so cross…
netanel-haber May 29, 2025
e9ca7b1
remove unused imports
netanel-haber May 29, 2025
0bccc1a
hopefully implement mpi sync
netanel-haber May 29, 2025
6ee3096
only warn when VSWA + config.maxTokens is set
netanel-haber May 29, 2025
93d9816
clamp maxAttentionWindowVec
netanel-haber May 30, 2025
efd35ee
fix test
netanel-haber May 30, 2025
bb6ea84
better logs
netanel-haber May 30, 2025
2cc5566
fix windowSizeToBlocks indexing
netanel-haber May 30, 2025
1945864
Merge branch 'main' into user/nhaber/fix-variable-window-size-underal…
netanel-haber May 30, 2025
ff7388a
fix KVCacheManagerLeafBlockWithDependentTest
netanel-haber Jun 1, 2025
44e6d51
fix WindowSizeMetadata fields ordering
netanel-haber Jun 1, 2025
c885ace
fix KVCacheManagerVariableWindowAttentionWithReuseTest
netanel-haber Jun 1, 2025
002ba6a
Changed type for maxTokens to uint64_t to avoid overflow
netanel-haber Jun 3, 2025
12486a5
*multiply* by crossKvCacheFraction, not divide
netanel-haber Jun 3, 2025
31e4048
fix cross manager window size
netanel-haber Jun 3, 2025
717d9c0
fix test_KvCache_events_binding kvcachemanager init
netanel-haber Jun 3, 2025
a258b7e
fix calculate_max_num_blocks binding
netanel-haber Jun 3, 2025
d33883b
assert freeMemory smaller than totalMemory
netanel-haber Jun 3, 2025
92ed07b
assert freeMemory smaller than totalMemory - after printing them
netanel-haber Jun 3, 2025
755ebf5
metadata.allottedPrimaryBlocks / blockRequirementsPerSequence instead…
netanel-haber Jun 3, 2025
34a7033
fix minor bug
netanel-haber Jun 8, 2025
31f501d
actually use reduced value [blocksWorld] and assign it to blocksPrima…
netanel-haber Jun 8, 2025
8c67618
[Infra] - Update JNLP container config (#5008)
chzblych Jun 8, 2025
f6f030b
Merge branch 'main' into user/nhaber/fix-variable-window-size-underal…
netanel-haber Jun 10, 2025
11402b3
Merge branch 'main' into user/nhaber/fix-variable-window-size-underal…
netanel-haber Jun 10, 2025
892ad3d
pr comments
netanel-haber Jun 10, 2025
13edfc1
make logging quieter
netanel-haber Jun 10, 2025
029d7e6
add ceremony
netanel-haber Jun 10, 2025
ae25a10
Merge branch 'main' into user/nhaber/fix-variable-window-size-underal…
netanel-haber Jun 11, 2025
37c8e56
Merge branch 'main' into user/nhaber/fix-variable-window-size-underal…
netanel-haber Jun 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 74 additions & 36 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ using FreeBlocksQueue = std::list<BlockPtr>;
using UniqueToken = tensorrt_llm::runtime::UniqueToken;
using VecUniqueTokens = tensorrt_llm::runtime::VecUniqueTokens;
using LoraTaskIdType = tensorrt_llm::runtime::LoraTaskIdType;
using BlocksPerWindow = std::map<SizeType32, std::tuple<SizeType32, SizeType32>>;

template <typename T>
using OptionalRef = tensorrt_llm::common::OptionalRef<T>;
Expand All @@ -78,6 +79,8 @@ struct TempAttentionWindowInputs

struct WindowSizeMetadata
{
SizeType32 allottedPrimaryBlocks; // Number of primary blocks allotted to the windowSize
SizeType32 allottedSecondaryBlocks; // Number of secondary blocks allotted to the windowSize
SizeType32 absolutePoolsOffset; // cumulative number of pools up to manager
SizeType32 numPools; // number of managed pools
SizeType32 maxTokenNum; // Maximum token length (including bubble)
Expand All @@ -90,9 +93,10 @@ struct WindowSizeMetadata
std::string toString()
{
return tensorrt_llm::common::fmtstr(
"WindowSizeMetadata{ .absolutePoolsOffset=%d, .numPools=%d, .maxTokenNum=%d, .maxBlocksPerSeq=%d, "
".maxNumBlocks=%d, .temporaryAttentionWindow=%d }",
absolutePoolsOffset, numPools, maxTokenNum, maxBlocksPerSeq, maxNumBlocks, temporaryAttentionWindow);
"WindowSizeMetadata{ .allottedPrimaryBlocks=%d, .allottedSecondaryBlocks=%d, .absolutePoolsOffset=%d, "
".numPools=%d, .maxTokenNum=%d, .maxBlocksPerSeq=%d, .maxNumBlocks=%d, .temporaryAttentionWindow=%d }",
allottedPrimaryBlocks, allottedSecondaryBlocks, absolutePoolsOffset, numPools, maxTokenNum, maxBlocksPerSeq,
maxNumBlocks, temporaryAttentionWindow);
}
};

Expand Down Expand Up @@ -838,22 +842,23 @@ class BlockManager
using BaseEvictionPolicy = tensorrt_llm::batch_manager::eviction_policy::BaseEvictionPolicy;

explicit BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, SizeType32 sizePerHead,
SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool,
SizeType32 maxNumSequences, CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength,
SizeType32 maxBeamWidth, std::vector<SizeType32> const& maxAttentionWindowVec,
SizeType32 tokensPerBlock, BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences,
CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength, SizeType32 maxBeamWidth,
std::vector<SizeType32> const& maxAttentionWindowVec,
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType = CacheType::kSELF,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enableHashKey = false,
bool enablePartialReuse = true, bool copyOnPartialReuse = true);

//! \brief Calculate the number of blocks each window size heap receives of blocksIn{Primary/Secondary}Pool
//! \details Example: (total=16384, uniqueWindowSizeToLayers={1024: [1], 4096: [0, 4, 5], 8192: [2, 3]})
//! Would Return: {1024: 565, 4096: 6780, 8192: 9039} [sums to total].
//! See: TEST_F(KVCacheManagerTest, BlockManagerTestBlocksPerWindowSize).
//! \return Map<windowSize, numBlocks>
static std::map<SizeType32, SizeType32> blocksPerWindowSize(
SizeType32 totalBlocks, std::map<SizeType32, std::vector<SizeType32>> const& uniqueWindowSizeToLayers);
//! \brief Calculate the proportional share each window size receives of the total memory pool
//! \details Example: (uniqueWindowSizeToLayers={1024: [1], 4096: [0, 4, 5], 8192: [2, 3]})
//! Would Return: {1024: 0.0345, 4096: 0.4138, 8192: 0.5517} [sums to 1.0].
//! See: TEST_F(KVCacheManagerTest, BlockManagerTestWindowSizeToShare).
//! \return Map<windowSize, share> where share is a float between 0 and 1. Shares sum to 1.0.
static std::map<SizeType32, float> calculateWindowSizeToShare(
std::map<SizeType32, std::vector<SizeType32>> const& uniqueWindowSizeToLayers,
std::map<SizeType32, SizeType32> const& cacheSizePerTokenPerWindowSize);

void allocatePools(bool useUvm);

Expand Down Expand Up @@ -1279,21 +1284,54 @@ class BaseKVCacheManager
[[nodiscard]] static SizeType32 getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock);

// Sum of numLayers * kvFactor * numKvHeads * sizePerHead for each pool
[[nodiscard]] static SizeType32 calculateCacheSizePerToken(tensorrt_llm::runtime::ModelConfig const& modelConfig,
tensorrt_llm::runtime::WorldConfig const& worldConfig, bool isCrossAttention = false, SizeType32 kvFactor = 2)
[[nodiscard]] static SizeType32 calculateCacheSizePerTokenForSingleWindowSize(
tensorrt_llm::runtime::ModelConfig const& modelConfig, std::vector<SizeType32> const& windowSizeLayers,
bool isCrossAttention, SizeType32 kvFactor)
{
auto const nkvh = modelConfig.getNumKvHeadsForGivenLayers(windowSizeLayers, isCrossAttention);
auto const sumLocalHeads = std::reduce(nkvh.cbegin(), nkvh.cend());
// NOTE: We expect the initialization of modelConfig to have already taken the tp size into account and do not
// address it here
// consider only local layers for the calculation
return modelConfig.getSumLocalKvHeads(
worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank(), isCrossAttention)
* kvFactor * modelConfig.getSizePerHead();
}

[[nodiscard]] static std::tuple<SizeType32, SizeType32> calculateMaxNumBlocks(KvCacheConfig const& config,
return sumLocalHeads * kvFactor * modelConfig.getSizePerHead();
}

/// @brief Groups model layers by their attention window size.
/// @param maxAttentionWindowVec Vector of maximum attention window sizes per layer (may have fewer elements than
/// numLayers, in which case it cycles)
/// @param numLayers Total number of layers in the model
/// @return Map from window size to vector of layer indices that use that window size
[[nodiscard]] static std::map<SizeType32, std::vector<SizeType32>> groupLayersByWindowSize(
std::vector<SizeType32> const& maxAttentionWindowVec, SizeType32 numLayers);

/// @brief Calculate the free memory available for KV cache allocation.
/// @param bufferManager Buffer manager for memory operations
/// @param config KV cache configuration parameters
/// @return Tuple containing the {.freePrimaryMemBytes, .freeSecondaryMemBytes}
[[nodiscard]] static std::tuple<uint64_t, uint64_t> calculateFreeMemBytes(
runtime::BufferManager const& bufferManager, KvCacheConfig const& config);

/// @brief Calculate the maximum number of KV cache blocks that can be allocated based on available GPU memory.
/// @details This function computes how many blocks each WindowBlockManager should receive based on the weighted
/// share
/// of memory requirements. The weighting considers both the window size and the number of
/// layers using each window size, as well as the sum of cache sizes per token for each window.
/// @param config KV cache configuration parameters
/// @param isCrossAttention Whether this is for cross-attention KV cache
/// @param dtype Data type used for KV cache values
/// @param modelConfig Model configuration containing layer and head information
/// @param worldConfig World configuration for multi-GPU setups
/// @param windowSizeToLayers Map from attention window size to vector of layer indices using that window size
/// @param allottedPrimaryMemBytes Allotted primary memory
/// @param allottedSecondaryMemBytes Allotted secondary memory
/// @param extraCostMemory Additional memory cost to account for CacheTransBufferManager::preAllocBufferSize
/// @param kvFactor Factor for KV cache size calculation (typically 2 for key+value)
/// @return Map from window size to tuple of (primary blocks, secondary blocks)
[[nodiscard]] static BlocksPerWindow calculateMaxNumBlocks(KvCacheConfig const& config, bool isCrossAttention,
nvinfer1::DataType dtype, tensorrt_llm::runtime::ModelConfig const& modelConfig,
tensorrt_llm::runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager,
SizeType32 kvFactor = 2, size_t extraCostMemory = 0);
tensorrt_llm::runtime::WorldConfig const& worldConfig,
std::map<SizeType32, std::vector<SizeType32>> const& windowSizeToLayers, uint64_t allottedPrimaryMemBytes,
uint64_t allottedSecondaryMemBytes, size_t extraCostMemory, SizeType32 kvFactor);

/// @brief Calculates the maximum batch size that can fit the kv-cache, given that all sequences in the batch have
/// the provided input and output length.
Expand All @@ -1316,8 +1354,8 @@ class KVCacheManager : public BaseKVCacheManager
using CacheType = tensorrt_llm::batch_manager::kv_cache_manager::CacheType;

KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences,
SizeType32 maxBeamWidth, std::vector<SizeType32> const& maxAttentionWindowVec,
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
std::vector<SizeType32> const& maxAttentionWindowVec,
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength,
bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
Expand All @@ -1326,8 +1364,8 @@ class KVCacheManager : public BaseKVCacheManager
bool enablePartialReuse = true, bool copyOnpartialReuse = true);

KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences,
SizeType32 maxBeamWidth, std::vector<SizeType32> const& maxAttentionWindowVec,
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
std::vector<SizeType32> const& maxAttentionWindowVec,
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
SizeType32 sinkTokenLength, int64_t stream, std::optional<SizeType32> maxSequenceLength,
bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
Expand All @@ -1336,8 +1374,8 @@ class KVCacheManager : public BaseKVCacheManager
bool copyOnpartialReuse = true);

KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences,
SizeType32 maxBeamWidth, std::vector<SizeType32> const& maxAttentionWindowVec,
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
std::vector<SizeType32> const& maxAttentionWindowVec,
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength,
bool enableBlockReuse = true, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
Expand All @@ -1346,8 +1384,8 @@ class KVCacheManager : public BaseKVCacheManager
bool enablePartialReuse = true, bool copyOnpartialReuse = true);

KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences,
SizeType32 maxBeamWidth, std::vector<SizeType32> const& maxAttentionWindowVec,
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
std::vector<SizeType32> const& maxAttentionWindowVec,
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
SizeType32 sinkTokenLength, int64_t stream, std::optional<SizeType32> maxSequenceLength,
bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
Expand Down Expand Up @@ -1538,22 +1576,22 @@ class KVCacheManager : public BaseKVCacheManager
/// @param inputLength The number of input tokens in the sequence.
/// @param outputLength The number of output tokens in the sequence.
/// @param sinkTokenLength The number of sink tokens configured.
/// @param maxAttentionWindow The maximum attention window allowed by the model.
/// @param maxAttentionWindow The attention window size allowed by the model.
/// @param beamWidth The number of beams to consider for the request.
/// @param tokensPerBlock The number of tokens a single kv-cache block contains.,
/// @return SizeType32 A number of blocks.
[[nodiscard]] static SizeType32 calculateMaxBlockRequirements(SizeType32 inputLength, SizeType32 outputLength,
SizeType32 sinkTokenLength, SizeType32 maxAttentionWindow, SizeType32 beamWidth, SizeType32 tokensPerBlock);
SizeType32 sinkTokenLength, SizeType32 windowSize, SizeType32 beamWidth, SizeType32 tokensPerBlock);

/// @brief Calculates the number of kv-cache blocks that a sequence will require, for a single beam.
///
/// @param sequenceLength The total length of the sequence (input and output).
/// @param sinkTokenLength The number of sink tokens configured.
/// @param maxAttentionWindow The maximum attention window allowed by the model.
/// @param windowSize The attention window size
/// @param tokensPerBlock The number of tokens in a single kv-cache block.
/// @return SizeType32 A number of blocks.
[[nodiscard]] static SizeType32 calculateMaxBlockRequirementsPerBeam(SizeType32 sequenceLength,
SizeType32 sinkTokenLength, SizeType32 maxAttentionWindow, SizeType32 tokensPerBlock);
[[nodiscard]] static SizeType32 calculateMaxBlockRequirementsPerBeam(
SizeType32 sequenceLength, SizeType32 sinkTokenLength, SizeType32 windowSize, SizeType32 tokensPerBlock);

std::vector<std::vector<SizeType32>> const& getCacheBlockIds(
LlmRequest::RequestIdType requestId, SizeType32 windowSize) const override;
Expand Down
21 changes: 12 additions & 9 deletions cpp/include/tensorrt_llm/runtime/modelConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,18 @@ class ModelConfig
return mNumKvHeadsPerAttentionLayer;
}

[[nodiscard]] std::vector<SizeType32> getNumKvHeadsForGivenLayers(
std::vector<SizeType32> const& layers, bool isCrossAttention) const
{
std::vector<SizeType32> numKvHeads;
numKvHeads.reserve(layers.size());
auto const numKvHeadsAllLayers
= isCrossAttention ? mNumKvHeadsPerCrossAttentionLayer : mNumKvHeadsPerAttentionLayer;
std::transform(layers.begin(), layers.end(), std::back_inserter(numKvHeads),
[&numKvHeadsAllLayers](SizeType32 layer) { return numKvHeadsAllLayers.at(layer); });
return numKvHeads;
}

[[nodiscard]] std::pair<std::vector<SizeType32>::const_iterator, std::vector<SizeType32>::const_iterator>
getNumKvHeadsPerLayerLocalRange(
SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0, bool isCrossAttention = false) const
Expand Down Expand Up @@ -834,15 +846,6 @@ class ModelConfig
mNumKvHeadsPerCrossAttentionLayer = headsPerLayer;
}

[[nodiscard]] SizeType32 getSumLocalKvHeads(
SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0, bool isCrossAttention = false) const
{
auto [cbegin, cend]
= getNumKvHeadsPerLayerLocalRange(pipelineParallelism, pipelineParallelismRank, isCrossAttention);
auto const sumLocalHeads = std::reduce(cbegin, cend);
return sumLocalHeads;
}

[[nodiscard]] bool constexpr skipCrossAttnBlocks() const noexcept
{
return mSkipCrossAttnBlocks;
Expand Down
9 changes: 2 additions & 7 deletions cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,22 +230,17 @@ CacheTransBufferManager::CacheTransBufferManager(
allocateBuffer();
}

size_t CacheTransBufferManager::preAllocBufferSize(
std::optional<size_t> maxNumTokens, std::optional<size_t> kvCacheSizePerToken)
size_t CacheTransBufferManager::preAllocBufferSize(std::optional<size_t> maxNumTokens)
{
bool to_allocate = common::getEnvUseMPIKvCache() || common::getEnvUseUCXKvCache() || common::getEnvUseNixlKvCache();
if (!to_allocate)
{
return 0;
}
if (maxNumTokens.has_value())
{
TLLM_CHECK(kvCacheSizePerToken.has_value());
}
size_t TransferBufferSize = common::getEnvMemSizeForKVCacheTransferBuffer();
if (maxNumTokens.has_value())
{
TransferBufferSize = maxNumTokens.value() * kvCacheSizePerToken.value();
TransferBufferSize = maxNumTokens.value();
}
bool useFabricMemory = FabricMemory::supportFbaricMemory()
&& (!(common::getEnvKVCacheTransferUseSyncBuffer() || common::getEnvKVCacheTransferUseAsyncBuffer()));
Expand Down
3 changes: 1 addition & 2 deletions cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ class CacheTransBufferManager
CacheTransBufferManager(
KVCacheManager::BaseKVCacheManager* cacheManager, std::optional<size_t> maxNumTokens = std::nullopt);

static size_t preAllocBufferSize(
std::optional<size_t> maxNumTokens = std::nullopt, std::optional<size_t> kvCacheSizePerToken = std::nullopt);
static size_t preAllocBufferSize(std::optional<size_t> maxNumTokens = std::nullopt);

std::optional<int> assignBufferIndexForSend();
void freeBufferIndexForSend(std::optional<int> bufferId);
Expand Down
Loading