Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
1d484eb
Add refresh blocks
Tabrizian Nov 3, 2025
307479b
Fix transfer manager synchronization issues
thorjohnsen Nov 6, 2025
c4dc529
Fix merge issues
thorjohnsen Nov 11, 2025
8720584
Bug fix
thorjohnsen Nov 11, 2025
99653bf
Another fix
thorjohnsen Nov 11, 2025
ffb90a4
Merge branch 'main' into user/tjohnsen/fix_5627710
thorjohnsen Nov 11, 2025
b471cdf
Merge branch 'main' into user/tjohnsen/fix_5627710
thorjohnsen Nov 11, 2025
05255ff
precommit run
thorjohnsen Nov 11, 2025
9ac9534
Merge branch 'user/tjohnsen/fix_5627710' of github.com:thorjohnsen/Te…
thorjohnsen Nov 11, 2025
9aef49c
Merge branch 'main' into user/tjohnsen/fix_5627710
thorjohnsen Nov 11, 2025
1369f0f
Fix broken pybind
thorjohnsen Nov 12, 2025
559beb0
Merge branch 'user/tjohnsen/fix_5627710' of github.com:thorjohnsen/Te…
thorjohnsen Nov 12, 2025
1cfa88c
Move refreshBlocks call to account for addToken calls
thorjohnsen Nov 12, 2025
03d4aac
precommit run
thorjohnsen Nov 12, 2025
262d34c
Merge branch 'main' into user/tjohnsen/fix_5627710
thorjohnsen Nov 13, 2025
f4ae208
Merge branch 'main' into user/tjohnsen/fix_5627710
thorjohnsen Nov 13, 2025
9f04777
Update cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp
thorjohnsen Nov 13, 2025
ba32c3b
Merge branch 'main' into user/tjohnsen/fix_5627710
thorjohnsen Nov 13, 2025
748cbaf
Merge branch 'main' into user/tjohnsen/fix_5627710
thorjohnsen Nov 14, 2025
279b038
Merge remote-tracking branch 'upstream/main' into user/tjohnsen/fix_5…
thorjohnsen Nov 17, 2025
c6ba7b6
Merge branch 'main' into user/tjohnsen/fix_5627710
thorjohnsen Nov 17, 2025
b34ea28
Merge remote-tracking branch 'upstream/main' into user/tjohnsen/fix_5…
thorjohnsen Nov 20, 2025
c248977
Merge branch 'user/tjohnsen/fix_5627710' of github.com:thorjohnsen/Te…
thorjohnsen Nov 20, 2025
55a274c
Merge branch 'main' into user/tjohnsen/fix_5627710
thorjohnsen Nov 20, 2025
0505ef3
Merge branch 'main' into user/tjohnsen/fix_5627710
thorjohnsen Nov 20, 2025
aab02d2
Merge remote-tracking branch 'upstream/main' into user/tjohnsen/fix_5…
thorjohnsen Nov 21, 2025
f717528
Merge remote-tracking branch 'thor/user/tjohnsen/fix_5627710' into tr…
eopXD Nov 27, 2025
7214d22
Use current stream for onboard/offload to depend on
eopXD Nov 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,9 @@ class WindowBlockManager
return mBufferManager;
}

//! \brief Sync internal streams used by transfer manager with buffer manager stream
void syncTransferManagerWithBufferManager();

//! \brief Perform per-request bookkeeping
void refreshBlocks();

Expand Down Expand Up @@ -1313,6 +1316,9 @@ class BlockManager
//! \brief Store newest block for reuse
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);

//! \brief Sync internal streams used by transfer manager with buffer manager stream
void syncTransferManagerWithBufferManager();

//! \brief Perform per-request bookkeeping
void refreshBlocks();

Expand Down Expand Up @@ -1584,6 +1590,7 @@ class BaseKVCacheManager
[[nodiscard]] virtual runtime::ITensor::SharedPtr getIndexerKCachePool() const = 0;
[[nodiscard]] virtual SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const = 0;

virtual void syncTransferManagerWithBufferManager() = 0;
virtual void refreshBlocks() = 0;
virtual void flushIterationEvents() = 0;
virtual void resetReuseState() = 0;
Expand Down Expand Up @@ -1965,6 +1972,11 @@ class KVCacheManager : public BaseKVCacheManager
return mBlockManager.getPoolLayerIdx(layer_idx);
}

void syncTransferManagerWithBufferManager() override
{
mBlockManager.syncTransferManagerWithBufferManager();
}

//! \brief Perform per-iteration bookkeeping
void refreshBlocks() override
{
Expand Down
16 changes: 13 additions & 3 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,15 @@ class KVCacheTransferManager
int numTokensToCopy = 0, executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM,
std::string const& directory = "");

//! \brief Synchronize the offload/onboard streams with the bufferManager stream.
//! \brief Synchronize internal streams with bufferManager stream.
//! \details The buffer manager uses the same stream as the prefill and decode kernels. This method ensures that the
//! internal kernels used for offloading and onboarding will wait for prefill and decode kernels before performing
//! any block copies. This method must be called before the first call to KVCacheManager::addSequence in every step.
void syncWithBufferManager();

//! \brief Synchronize bufferManager stream with internal streams. This method ensures that prefill and decode
//! kernels for next step will wait for offloading and onboarding work that has already been scheduled. This method
//! must be called after last call to KVCacheManager::addSequence in every step.
void syncTransfers();

private:
Expand Down Expand Up @@ -75,8 +83,10 @@ class KVCacheTransferManager
runtime::BufferManager mOnboardManager;
runtime::BufferManager mOffloadManager;

// Track the block ids offloaded in this iteration.
std::unordered_map<int32_t, tr::CudaEvent> mPendingOffloads;
// Track reads and writes for blocks. Note that it is the memory pool index that
// identifies the raw memory blocks involved in I/O, not the block Id.
std::unordered_map<kernels::KVCacheIndex::UnderlyingType, tr::CudaEvent> mPendingReads;
std::unordered_map<kernels::KVCacheIndex::UnderlyingType, tr::CudaEvent> mPendingWrites;
// Reference to parent loopback agent
std::shared_ptr<kvc::BaseLoopbackAgent> mLoopbackAgent;
int mDeviceId;
Expand Down
2 changes: 2 additions & 0 deletions cpp/tensorrt_llm/batch_manager/allocateKvCache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ void tensorrt_llm::batch_manager::AllocateKvCache::operator()(BaseKVCacheManager
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(allocateKvCache);

kvCacheManager.syncTransferManagerWithBufferManager();

for (auto const& llmReq : contextRequests)
{
if (llmReq->isFirstContextChunk())
Expand Down
13 changes: 13 additions & 0 deletions cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1343,6 +1343,19 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
return numMatchedTokens;
}

void BlockManager::syncTransferManagerWithBufferManager()
{
for (auto& [_, manager] : mWindowBlockManagers)
{
manager.syncTransferManagerWithBufferManager();
}
}

void WindowBlockManager::syncTransferManagerWithBufferManager()
{
mTransferManager->syncWithBufferManager();
}

void BlockManager::refreshBlocks()
{
for (auto& [_, manager] : mWindowBlockManagers)
Expand Down
123 changes: 108 additions & 15 deletions cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,47 +207,140 @@ void KVCacheTransferManager::copyBlock(BlockPtr const& src, BlockPtr const& dst,
}
}

void KVCacheTransferManager::onboard(BlockPtr const& offloadBlock, BlockPtr const& block,
//
// Note about recording events to wait for cudaMempyAsync calls between blocks:
// The memory copy involves raw memory blocks, which are pointed to by the
// memory pool block index. When recording events, you must use getMemoryPoolBlockIndex()
// as the raw memory block identifier. Using getBlockId() when recording events is wrong.
// getBlockId() returns the logical block id, which has nothing to do with the raw memory
// block pointers involved in a cudaMemcpy.
//

//
// Notes about need for synchronization:
//
// Relying on decoder syncing GPU with CPU to ensure that blocks are ready
// for offload/onboard/partial copy is dangerous. We have an asynchronous decoder
// that may not synchronize or synchronize at a later point in the execution stream.
// To avoid synchronization issues caused by changes to decoder design we rely on
// KVCacheTransferManager::syncWithBufferManager() that ensures that internal copy streams
// will wait for prefill and decode kernels that have already been scheduled.
//
// Earlier versions of this code did not account for all possible cases where a new block copy
// needed to wait for a previously scheduled copy to finish. For instance, it is possible
// that two primary blocks are offloaded to the same secondary block in a single step,
// scheduling the second offloading without waiting for the first one to finish leads to
// a corrupted block after offloading. It is possible that partial reuse will copy
// from a block that is currently being onboarded, scheduling the partial copy without
// waiting for the onboarding to finish will lead to a corrupted block. To handle all
// possible cases needing synchronization we record separate events for reads and writes
// to a block. When a new block copy is scheduled, we wait for all writes to the source
// block and all reads and writes to a destination block.
//
// As before, syncTransfers() must be called after last call to KVCacheManager::addSequence.
// Failing to do so will lead to corrupted blocks eventually.
//

void KVCacheTransferManager::onboard(BlockPtr const& offloadedBlock, BlockPtr const& block,
std::vector<KVCacheBlockPool> const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode,
std::string const& directory)
{
if (mode != executor::KvCacheTransferMode::DRAM
&& mPendingOffloads.find(offloadBlock->getBlockId()) == mPendingOffloads.end())
// Wait for any pending writes before reading from offloadedBlock
auto offloadedBlockPendingWriteItr = mPendingWrites.find(offloadedBlock->getMemoryPoolBlockIndex());
if (offloadedBlockPendingWriteItr != mPendingWrites.end())
{
TLLM_LOG_DEBUG("Skipping onboard for block %d because it was never previously offloaded to disk",
offloadBlock->getBlockId());
return;
mOnboardManager.getStream().wait(offloadedBlockPendingWriteItr->second);
// Don't erase, we are not changing state of offloadedBlock
}

if (mPendingOffloads.find(offloadBlock->getBlockId()) != mPendingOffloads.end())
// Wait for any pending reads before overwriting block
auto blockPendingReadItr = mPendingReads.find(block->getMemoryPoolBlockIndex());
if (blockPendingReadItr != mPendingReads.end())
{
mOnboardManager.getStream().wait(blockPendingReadItr->second);
mPendingReads.erase(blockPendingReadItr);
}
// Wait for any pending writes before overwriting block
auto blockPendingWriteItr = mPendingWrites.find(block->getMemoryPoolBlockIndex());
if (blockPendingWriteItr != mPendingWrites.end())
{
mOnboardManager.getStream().wait(mPendingOffloads[offloadBlock->getBlockId()]);
mOnboardManager.getStream().wait(blockPendingWriteItr->second);
mPendingWrites.erase(blockPendingWriteItr);
}
copyBlock(offloadBlock, block, pools, false, numTokensToCopy, mode, directory);

copyBlock(offloadedBlock, block, pools, false, numTokensToCopy, mode, directory);

// Record new pending read from offloadedBlock
mPendingReads[offloadedBlock->getMemoryPoolBlockIndex()] = tr::CudaEvent();
mOnboardManager.getStream().record(mPendingReads[offloadedBlock->getMemoryPoolBlockIndex()]);
// Record new pending write to block
mPendingWrites[block->getMemoryPoolBlockIndex()] = tr::CudaEvent();
mOnboardManager.getStream().record(mPendingWrites[block->getMemoryPoolBlockIndex()]);
}

void KVCacheTransferManager::offload(BlockPtr const& block, BlockPtr const& offloadBlock,
std::vector<KVCacheBlockPool> const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode,
std::string const& directory)
{
mPendingOffloads[block->getBlockId()] = tr::CudaEvent();
// Wait for any pending writes before reading from block
auto blockPendingWriteItr = mPendingWrites.find(block->getMemoryPoolBlockIndex());
if (blockPendingWriteItr != mPendingWrites.end())
{
mOffloadManager.getStream().wait(blockPendingWriteItr->second);
// Don't erase, we are not changing state of block
}
// Wait for any pending reads before overwriting offloadBlock
auto offloadBlockPendingReadItr = mPendingReads.find(offloadBlock->getMemoryPoolBlockIndex());
if (offloadBlockPendingReadItr != mPendingReads.end())
{
mOffloadManager.getStream().wait(offloadBlockPendingReadItr->second);
mPendingReads.erase(offloadBlockPendingReadItr);
}
// Wait for any pending writes before overwriting offloadBlock
auto offloadBlockPendingWriteItr = mPendingWrites.find(offloadBlock->getMemoryPoolBlockIndex());
if (offloadBlockPendingWriteItr != mPendingWrites.end())
{
mOffloadManager.getStream().wait(offloadBlockPendingWriteItr->second);
mPendingWrites.erase(offloadBlockPendingWriteItr);
}

copyBlock(block, offloadBlock, pools, true, numTokensToCopy, mode, directory);
mOffloadManager.getStream().record(mPendingOffloads[block->getBlockId()]);

// Record new pending read from block
mPendingReads[block->getMemoryPoolBlockIndex()] = tr::CudaEvent();
mOffloadManager.getStream().record(mPendingReads[block->getMemoryPoolBlockIndex()]);
// Record new pending write to offloadBlock
mPendingWrites[offloadBlock->getMemoryPoolBlockIndex()] = tr::CudaEvent();
mOffloadManager.getStream().record(mPendingWrites[offloadBlock->getMemoryPoolBlockIndex()]);
}

void KVCacheTransferManager::syncWithBufferManager()
{
tr::CudaEvent readyForOffloadEvent;
mBufferManager.getStream().record(readyForOffloadEvent);
mOffloadManager.getStream().wait(readyForOffloadEvent);

tr::CudaEvent readyForOnboardEvent;
mBufferManager.getStream().record(readyForOnboardEvent);
mOnboardManager.getStream().wait(readyForOnboardEvent);

// Once we synchronize, clear our list of pending thransfers.
mPendingReads.clear();
mPendingWrites.clear();
}

void KVCacheTransferManager::syncTransfers()
{
tr::CudaEvent offloadEvent;
mOffloadManager.getStream().record(offloadEvent);
mBufferManager.getStream().wait(offloadEvent);

tr::CudaEvent onboardEvent;
mOnboardManager.getStream().record(onboardEvent);

mBufferManager.getStream().wait(offloadEvent);
mBufferManager.getStream().wait(onboardEvent);

// Once we synchronize, clear our list of pending thransfers.
mPendingOffloads.clear();
mPendingReads.clear();
mPendingWrites.clear();
}

} // namespace tensorrt_llm::batch_manager::kv_cache_manager
8 changes: 8 additions & 0 deletions cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,11 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager
NB_OVERRIDE_PURE(getPoolLayerIdx, layer_idx);
}

void syncTransferManagerWithBufferManager() override
{
NB_OVERRIDE_PURE(syncTransferManagerWithBufferManager);
}

void refreshBlocks() override
{
NB_OVERRIDE_PURE(refreshBlocks);
Expand Down Expand Up @@ -481,6 +486,9 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
nb::call_guard<nb::gil_scoped_release>())
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents,
nb::call_guard<nb::gil_scoped_release>())
.def("sync_transfer_manager_with_buffer_manager", &BaseKVCacheManager::syncTransferManagerWithBufferManager,
nb::call_guard<nb::gil_scoped_release>())
.def("refresh_blocks", &BaseKVCacheManager::refreshBlocks, nb::call_guard<nb::gil_scoped_release>())
.def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, nb::call_guard<nb::gil_scoped_release>())
.def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, nb::call_guard<nb::gil_scoped_release>())
.def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, nb::call_guard<nb::gil_scoped_release>());
Expand Down
8 changes: 8 additions & 0 deletions cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,11 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager
PYBIND11_OVERLOAD_PURE(SizeType32, tbk::BaseKVCacheManager, getPoolLayerIdx, layer_idx);
}

void syncTransferManagerWithBufferManager() override
{
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, syncTransferManagerWithBufferManager);
}

void refreshBlocks() override
{
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, refreshBlocks);
Expand Down Expand Up @@ -485,6 +490,9 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
py::call_guard<py::gil_scoped_release>())
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents,
py::call_guard<py::gil_scoped_release>())
.def("sync_transfer_manager_with_buffer_manager", &BaseKVCacheManager::syncTransferManagerWithBufferManager,
py::call_guard<py::gil_scoped_release>())
.def("refresh_blocks", &BaseKVCacheManager::refreshBlocks, py::call_guard<py::gil_scoped_release>())
.def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, py::call_guard<py::gil_scoped_release>())
.def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, py::call_guard<py::gil_scoped_release>())
.def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, py::call_guard<py::gil_scoped_release>());
Expand Down
9 changes: 8 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int],

# Note that this stream is unused for now. Will be used for copying to host
# when that feature is enabled.
self._stream = torch.cuda.Stream()
self._stream = torch.cuda.current_stream()
kwargs = {
'num_kv_heads_per_layer': self.num_kv_heads_per_layer,
'size_per_head': head_dim,
Expand Down Expand Up @@ -434,6 +434,10 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests):
with request_context(self.is_draft, scheduled_batch):
context_batch = scheduled_batch.context_requests
generation_batch = scheduled_batch.generation_requests

# wait for all pending work to finish before launching offload/onboarding/partial copy
self.impl.sync_transfer_manager_with_buffer_manager()

# allocate KV Cache
for req in context_batch:
req_beam_width = req.sampling_config.beam_width
Expand Down Expand Up @@ -468,6 +472,9 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests):
for _ in range(get_draft_token_length(req)):
self.impl.add_token(req.py_request_id)

# prefill and generation kernels wait for scheduled offload/onboard/partial copy work before launching
self.impl.refresh_blocks()

if self.kv_connector_manager is not None:
self.kv_connector_manager.build_scheduler_output(
scheduled_batch, self)
Expand Down
Loading