Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class HandleContextLogits : Algorithm
runtime::SizeType32 operator()(DecoderInputBuffers& inputBuffers, RequestVector const& contextRequests,
runtime::ITensor::SharedPtr const& logits, std::vector<runtime::SizeType32> const& numContextLogitsVec,
runtime::ModelConfig const& modelConfig, runtime::BufferManager const& manager,
OptionalRef<MedusaBuffers> medusaBuffers) const;
OptionalRef<MedusaBuffers> medusaBuffers, runtime::SizeType32 vocabId = 0) const;
};

} // namespace tensorrt_llm::batch_manager
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ class HandleGenerationLogits : Algorithm
void operator()(DecoderInputBuffers& inputBuffers, RequestVector const& generationRequests,
runtime::ITensor::SharedPtr const& logits, runtime::SizeType32 logitsIndex,
runtime::ModelConfig const& modelConfig, runtime::BufferManager const& manager,
OptionalRef<RuntimeBuffers> genRuntimeBuffers, OptionalRef<MedusaBuffers> medusaBuffers) const;
OptionalRef<RuntimeBuffers> genRuntimeBuffers, OptionalRef<MedusaBuffers> medusaBuffers,
runtime::SizeType32 vocabId = 0) const;
};

} // namespace tensorrt_llm::batch_manager
16 changes: 10 additions & 6 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ class WindowBlockManager
SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream,
bool onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager, SizeType32 numVocabs = 1,
std::shared_ptr<kvc::BaseLoopbackAgent> loopbackAgent = nullptr);

~WindowBlockManager();
Expand Down Expand Up @@ -888,6 +888,9 @@ class WindowBlockManager
bool mEnablePartialReuse;
// Whether partially matched blocks that are already in use should be copied and reused.
bool mCopyOnPartialReuse;

SizeType32 mNumVocabs;

// The kv cache connector manager
std::shared_ptr<kv_connector::KvCacheConnectorManager> mKvCacheConnectorManager;
};
Expand All @@ -907,7 +910,7 @@ class BlockManager
SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType = CacheType::kSELF,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
bool copyOnPartialReuse = true,
bool copyOnPartialReuse = true, SizeType32 numVocabs = 1,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr,
std::optional<kvc::BaseAgentConfig> agentConfig = std::nullopt);

Expand Down Expand Up @@ -1212,6 +1215,7 @@ class BlockManager
std::vector<SizeType32> mLayerToWindowSize;
std::vector<SizeType32> mAbsolutePoolToWindowSize;
std::vector<SizeType32> mAbsolutePoolToRelativePoolIndex;
SizeType32 mNumVocabs;
};

struct OffsetTableDimensions
Expand Down Expand Up @@ -1433,7 +1437,7 @@ class KVCacheManager : public BaseKVCacheManager
bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
bool copyOnpartialReuse = true,
bool copyOnpartialReuse = true, SizeType32 numVocabs = 1,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr);

KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
Expand All @@ -1444,7 +1448,7 @@ class KVCacheManager : public BaseKVCacheManager
bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
bool copyOnpartialReuse = true,
bool copyOnpartialReuse = true, SizeType32 numVocabs = 1,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr);

KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
Expand All @@ -1455,7 +1459,7 @@ class KVCacheManager : public BaseKVCacheManager
bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
bool copyOnpartialReuse = true,
bool copyOnpartialReuse = true, SizeType32 numVocabs = 1,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr);

KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
Expand All @@ -1464,7 +1468,7 @@ class KVCacheManager : public BaseKVCacheManager
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
SizeType32 sinkTokenLength, int64_t stream, SizeType32 maxSequenceLength, bool enableBlockReuse = false,
bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF, bool enablePartialReuse = true,
bool copyOnpartialReuse = true);
bool copyOnpartialReuse = true, SizeType32 numVocabs = 1);

~KVCacheManager() override = default;

Expand Down
30 changes: 27 additions & 3 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,35 @@ class BlockRange
return BlockRange(cacheManager, blockIds, requestId);
}

static BlockRange fromNewlyAllocatedBlockIds(
BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId)
static BlockRange fromAllBlockIds(
BaseKVCacheManager const& cacheManager, LlmRequest const& llmRequest, SizeType32 beam = kFIRST_AND_ONLY_BEAM)
{
assert(kFIRST_AND_ONLY_BEAM == beam);

const LlmRequest::RequestIdType requestId = llmRequest.getSeqSlotId(0);
std::vector<SizeType32> blockIds;
auto const windowSize = firstWindowSize(cacheManager);

for (int i = 0; i < llmRequest.getNumSequences(); i++)
{
auto const& thisBlockIds = cacheManager.getSequence(llmRequest.getSeqSlotId(i))
.getCacheBlockIds(windowSize)
.at(kFIRST_AND_ONLY_BEAM);
blockIds.insert(blockIds.end(), thisBlockIds.begin(), thisBlockIds.end());
}
return BlockRange(cacheManager, blockIds, requestId);
}

static BlockRange fromNewlyAllocatedBlockIds(BaseKVCacheManager const& cacheManager, LlmRequest const& llmRequest)
{
const LlmRequest::RequestIdType requestId = llmRequest.getSeqSlotId(0);
auto const windowSize = firstWindowSize(cacheManager);
auto const blockIds = cacheManager.getNewlyAllocatedBlockIds(requestId, windowSize);
std::vector<SizeType32> blockIds;
for (int i = 0; i < llmRequest.getNumSequences(); i++)
{
auto const& thisBlockIds = cacheManager.getNewlyAllocatedBlockIds(llmRequest.getSeqSlotId(i), windowSize);
blockIds.insert(blockIds.end(), thisBlockIds.begin(), thisBlockIds.end());
}
return BlockRange(cacheManager, blockIds, requestId);
}

Expand Down
Loading
Loading