diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index 30634516bd0..70df824ee8e 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -824,6 +824,9 @@ class WindowBlockManager return mBufferManager; } + //! \brief Sync internal streams used by transfer manager with buffer manager stream + void syncTransferManagerWithBufferManager(); + //! \brief Perform per-request bookkeeping void refreshBlocks(); @@ -1313,6 +1316,9 @@ class BlockManager //! \brief Store newest block for reuse void storeNewBlock(GenerationRequest& sequence, OptionalRef llmRequest); + //! \brief Sync internal streams used by transfer manager with buffer manager stream + void syncTransferManagerWithBufferManager(); + //! \brief Perform per-request bookkeeping void refreshBlocks(); @@ -1584,6 +1590,7 @@ class BaseKVCacheManager [[nodiscard]] virtual runtime::ITensor::SharedPtr getIndexerKCachePool() const = 0; [[nodiscard]] virtual SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const = 0; + virtual void syncTransferManagerWithBufferManager() = 0; virtual void refreshBlocks() = 0; virtual void flushIterationEvents() = 0; virtual void resetReuseState() = 0; @@ -1965,6 +1972,11 @@ class KVCacheManager : public BaseKVCacheManager return mBlockManager.getPoolLayerIdx(layer_idx); } + void syncTransferManagerWithBufferManager() override + { + mBlockManager.syncTransferManagerWithBufferManager(); + } + //! \brief Perform per-iteration bookkeeping void refreshBlocks() override { diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h index 45f615cafe7..00540dc671e 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h @@ -46,7 +46,15 @@ class KVCacheTransferManager int numTokensToCopy = 0, executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = ""); - //! \brief Synchronize the offload/onboard streams with the bufferManager stream. + //! \brief Synchronize internal streams with bufferManager stream. + //! \details The buffer manager uses the same stream as the prefill and decode kernels. This method ensures that the + //! internal kernels used for offloading and onboarding will wait for prefill and decode kernels before performing + //! any block copies. This method must be called before the first call to KVCacheManager::addSequence in every step. + void syncWithBufferManager(); + + //! \brief Synchronize bufferManager stream with internal streams. This method ensures that prefill and decode + //! kernels for next step will wait for offloading and onboarding work that has already been scheduled. This method + //! must be called after last call to KVCacheManager::addSequence in every step. void syncTransfers(); private: @@ -75,8 +83,10 @@ class KVCacheTransferManager runtime::BufferManager mOnboardManager; runtime::BufferManager mOffloadManager; - // Track the block ids offloaded in this iteration. - std::unordered_map mPendingOffloads; + // Track reads and writes for blocks. Note that it is the memory pool index that + // identifies the raw memory blocks involved in I/O, not the block Id. + std::unordered_map mPendingReads; + std::unordered_map mPendingWrites; // Reference to parent loopback agent std::shared_ptr mLoopbackAgent; int mDeviceId; diff --git a/cpp/tensorrt_llm/batch_manager/allocateKvCache.cpp b/cpp/tensorrt_llm/batch_manager/allocateKvCache.cpp index c0482deb554..211abe78186 100644 --- a/cpp/tensorrt_llm/batch_manager/allocateKvCache.cpp +++ b/cpp/tensorrt_llm/batch_manager/allocateKvCache.cpp @@ -26,6 +26,8 @@ void tensorrt_llm::batch_manager::AllocateKvCache::operator()(BaseKVCacheManager TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(allocateKvCache); + kvCacheManager.syncTransferManagerWithBufferManager(); + for (auto const& llmReq : contextRequests) { if (llmReq->isFirstContextChunk()) diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index b69db6d1bcc..0fb5da2c0bc 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -1343,6 +1343,19 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector const& return numMatchedTokens; } +void BlockManager::syncTransferManagerWithBufferManager() +{ + for (auto& [_, manager] : mWindowBlockManagers) + { + manager.syncTransferManagerWithBufferManager(); + } +} + +void WindowBlockManager::syncTransferManagerWithBufferManager() +{ + mTransferManager->syncWithBufferManager(); +} + void BlockManager::refreshBlocks() { for (auto& [_, manager] : mWindowBlockManagers) diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp index fd5758a8368..495f6a3ed34 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp @@ -207,47 +207,140 @@ void KVCacheTransferManager::copyBlock(BlockPtr const& src, BlockPtr const& dst, } } -void KVCacheTransferManager::onboard(BlockPtr const& offloadBlock, BlockPtr const& block, +// +// Note about recording events to wait for cudaMempyAsync calls between blocks: +// The memory copy involves raw memory blocks, which are pointed to by the +// memory pool block index. When recording events, you must use getMemoryPoolBlockIndex() +// as the raw memory block identifier. Using getBlockId() when recording events is wrong. +// getBlockId() returns the logical block id, which has nothing to do with the raw memory +// block pointers involved in a cudaMemcpy. +// + +// +// Notes about need for synchronization: +// +// Relying on decoder syncing GPU with CPU to ensure that blocks are ready +// for offload/onboard/partial copy is dangerous. We have an asynchronous decoder +// that may not synchronize or synchronize at a later point in the execution stream. +// To avoid synchronization issues caused by changes to decoder design we rely on +// KVCacheTransferManager::syncWithBufferManager() that ensures that internal copy streams +// will wait for prefill and decode kernels that have already been scheduled. +// +// Earlier versions of this code did not account for all possible cases where a new block copy +// needed to wait for a previously scheduled copy to finish. For instance, it is possible +// that two primary blocks are offloaded to the same secondary block in a single step, +// scheduling the second offloading without waiting for the first one to finish leads to +// a corrupted block after offloading. It is possible that partial reuse will copy +// from a block that is currently being onboarded, scheduling the partial copy without +// waiting for the onboarding to finish will lead to a corrupted block. To handle all +// possible cases needing synchronization we record separate events for reads and writes +// to a block. When a new block copy is scheduled, we wait for all writes to the source +// block and all reads and writes to a destination block. +// +// As before, syncTransfers() must be called after last call to KVCacheManager::addSequence. +// Failing to do so will lead to corrupted blocks eventually. +// + +void KVCacheTransferManager::onboard(BlockPtr const& offloadedBlock, BlockPtr const& block, std::vector const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode, std::string const& directory) { - if (mode != executor::KvCacheTransferMode::DRAM - && mPendingOffloads.find(offloadBlock->getBlockId()) == mPendingOffloads.end()) + // Wait for any pending writes before reading from offloadedBlock + auto offloadedBlockPendingWriteItr = mPendingWrites.find(offloadedBlock->getMemoryPoolBlockIndex()); + if (offloadedBlockPendingWriteItr != mPendingWrites.end()) { - TLLM_LOG_DEBUG("Skipping onboard for block %d because it was never previously offloaded to disk", - offloadBlock->getBlockId()); - return; + mOnboardManager.getStream().wait(offloadedBlockPendingWriteItr->second); + // Don't erase, we are not changing state of offloadedBlock } - - if (mPendingOffloads.find(offloadBlock->getBlockId()) != mPendingOffloads.end()) + // Wait for any pending reads before overwriting block + auto blockPendingReadItr = mPendingReads.find(block->getMemoryPoolBlockIndex()); + if (blockPendingReadItr != mPendingReads.end()) + { + mOnboardManager.getStream().wait(blockPendingReadItr->second); + mPendingReads.erase(blockPendingReadItr); + } + // Wait for any pending writes before overwriting block + auto blockPendingWriteItr = mPendingWrites.find(block->getMemoryPoolBlockIndex()); + if (blockPendingWriteItr != mPendingWrites.end()) { - mOnboardManager.getStream().wait(mPendingOffloads[offloadBlock->getBlockId()]); + mOnboardManager.getStream().wait(blockPendingWriteItr->second); + mPendingWrites.erase(blockPendingWriteItr); } - copyBlock(offloadBlock, block, pools, false, numTokensToCopy, mode, directory); + + copyBlock(offloadedBlock, block, pools, false, numTokensToCopy, mode, directory); + + // Record new pending read from offloadedBlock + mPendingReads[offloadedBlock->getMemoryPoolBlockIndex()] = tr::CudaEvent(); + mOnboardManager.getStream().record(mPendingReads[offloadedBlock->getMemoryPoolBlockIndex()]); + // Record new pending write to block + mPendingWrites[block->getMemoryPoolBlockIndex()] = tr::CudaEvent(); + mOnboardManager.getStream().record(mPendingWrites[block->getMemoryPoolBlockIndex()]); } void KVCacheTransferManager::offload(BlockPtr const& block, BlockPtr const& offloadBlock, std::vector const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode, std::string const& directory) { - mPendingOffloads[block->getBlockId()] = tr::CudaEvent(); + // Wait for any pending writes before reading from block + auto blockPendingWriteItr = mPendingWrites.find(block->getMemoryPoolBlockIndex()); + if (blockPendingWriteItr != mPendingWrites.end()) + { + mOffloadManager.getStream().wait(blockPendingWriteItr->second); + // Don't erase, we are not changing state of block + } + // Wait for any pending reads before overwriting offloadBlock + auto offloadBlockPendingReadItr = mPendingReads.find(offloadBlock->getMemoryPoolBlockIndex()); + if (offloadBlockPendingReadItr != mPendingReads.end()) + { + mOffloadManager.getStream().wait(offloadBlockPendingReadItr->second); + mPendingReads.erase(offloadBlockPendingReadItr); + } + // Wait for any pending writes before overwriting offloadBlock + auto offloadBlockPendingWriteItr = mPendingWrites.find(offloadBlock->getMemoryPoolBlockIndex()); + if (offloadBlockPendingWriteItr != mPendingWrites.end()) + { + mOffloadManager.getStream().wait(offloadBlockPendingWriteItr->second); + mPendingWrites.erase(offloadBlockPendingWriteItr); + } + copyBlock(block, offloadBlock, pools, true, numTokensToCopy, mode, directory); - mOffloadManager.getStream().record(mPendingOffloads[block->getBlockId()]); + + // Record new pending read from block + mPendingReads[block->getMemoryPoolBlockIndex()] = tr::CudaEvent(); + mOffloadManager.getStream().record(mPendingReads[block->getMemoryPoolBlockIndex()]); + // Record new pending write to offloadBlock + mPendingWrites[offloadBlock->getMemoryPoolBlockIndex()] = tr::CudaEvent(); + mOffloadManager.getStream().record(mPendingWrites[offloadBlock->getMemoryPoolBlockIndex()]); +} + +void KVCacheTransferManager::syncWithBufferManager() +{ + tr::CudaEvent readyForOffloadEvent; + mBufferManager.getStream().record(readyForOffloadEvent); + mOffloadManager.getStream().wait(readyForOffloadEvent); + + tr::CudaEvent readyForOnboardEvent; + mBufferManager.getStream().record(readyForOnboardEvent); + mOnboardManager.getStream().wait(readyForOnboardEvent); + + // Once we synchronize, clear our list of pending thransfers. + mPendingReads.clear(); + mPendingWrites.clear(); } void KVCacheTransferManager::syncTransfers() { tr::CudaEvent offloadEvent; mOffloadManager.getStream().record(offloadEvent); + mBufferManager.getStream().wait(offloadEvent); tr::CudaEvent onboardEvent; mOnboardManager.getStream().record(onboardEvent); - - mBufferManager.getStream().wait(offloadEvent); mBufferManager.getStream().wait(onboardEvent); // Once we synchronize, clear our list of pending thransfers. - mPendingOffloads.clear(); + mPendingReads.clear(); + mPendingWrites.clear(); } } // namespace tensorrt_llm::batch_manager::kv_cache_manager diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index bfa3eaf169e..7a3bcae7cf1 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -235,6 +235,11 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager NB_OVERRIDE_PURE(getPoolLayerIdx, layer_idx); } + void syncTransferManagerWithBufferManager() override + { + NB_OVERRIDE_PURE(syncTransferManagerWithBufferManager); + } + void refreshBlocks() override { NB_OVERRIDE_PURE(refreshBlocks); @@ -481,6 +486,9 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) nb::call_guard()) .def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents, nb::call_guard()) + .def("sync_transfer_manager_with_buffer_manager", &BaseKVCacheManager::syncTransferManagerWithBufferManager, + nb::call_guard()) + .def("refresh_blocks", &BaseKVCacheManager::refreshBlocks, nb::call_guard()) .def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, nb::call_guard()) .def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, nb::call_guard()) .def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, nb::call_guard()); diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index e2ebc3aa451..6ab03315e1a 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -240,6 +240,11 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager PYBIND11_OVERLOAD_PURE(SizeType32, tbk::BaseKVCacheManager, getPoolLayerIdx, layer_idx); } + void syncTransferManagerWithBufferManager() override + { + PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, syncTransferManagerWithBufferManager); + } + void refreshBlocks() override { PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, refreshBlocks); @@ -485,6 +490,9 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) py::call_guard()) .def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents, py::call_guard()) + .def("sync_transfer_manager_with_buffer_manager", &BaseKVCacheManager::syncTransferManagerWithBufferManager, + py::call_guard()) + .def("refresh_blocks", &BaseKVCacheManager::refreshBlocks, py::call_guard()) .def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, py::call_guard()) .def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, py::call_guard()) .def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, py::call_guard()); diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index c2f8f8175e2..fee4745022b 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -353,7 +353,7 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int], # Note that this stream is unused for now. Will be used for copying to host # when that feature is enabled. - self._stream = torch.cuda.Stream() + self._stream = torch.cuda.current_stream() kwargs = { 'num_kv_heads_per_layer': self.num_kv_heads_per_layer, 'size_per_head': head_dim, @@ -434,6 +434,10 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): with request_context(self.is_draft, scheduled_batch): context_batch = scheduled_batch.context_requests generation_batch = scheduled_batch.generation_requests + + # wait for all pending work to finish before launching offload/onboarding/partial copy + self.impl.sync_transfer_manager_with_buffer_manager() + # allocate KV Cache for req in context_batch: req_beam_width = req.sampling_config.beam_width @@ -468,6 +472,9 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): for _ in range(get_draft_token_length(req)): self.impl.add_token(req.py_request_id) + # prefill and generation kernels wait for scheduled offload/onboard/partial copy work before launching + self.impl.refresh_blocks() + if self.kv_connector_manager is not None: self.kv_connector_manager.build_scheduler_output( scheduled_batch, self)