Skip to content

Commit b165f8b

Browse files
authored
fix/improve kvcache allocation in PyTorch runtime (#5933)
Signed-off-by: qixiang-99 <[email protected]>
1 parent 9257648 commit b165f8b

File tree

15 files changed

+392
-67
lines changed

15 files changed

+392
-67
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ struct KvCacheStats
192192
float cacheHitRate;
193193
// Number of free blocks for every configured attention-window size.
194194
std::map<SizeType32, SizeType32> numFreeBlocksPerWindowSize;
195+
// GPU bytes allocated for KV-cache
196+
std::size_t allocatedBytes{};
195197
};
196198

197199
// Basic building block of a paged KV cache - a single
@@ -1474,6 +1476,7 @@ class KVCacheManager : public BaseKVCacheManager
14741476
: static_cast<float>(kvCacheStats.reusedBlocks)
14751477
/ static_cast<float>(kvCacheStats.reusedBlocks + kvCacheStats.missedBlocks);
14761478
kvCacheStats.numFreeBlocksPerWindowSize = getNumFreeBlocksPerWindowSize();
1479+
kvCacheStats.allocatedBytes = mAllocatedBytes;
14771480
return kvCacheStats;
14781481
}
14791482

@@ -1677,6 +1680,8 @@ class KVCacheManager : public BaseKVCacheManager
16771680
runtime::ITensor::SharedPtr mBlockPoolPointers;
16781681
runtime::ITensor::SharedPtr mLayerToPoolMapping;
16791682
runtime::ITensor::SharedPtr mBlockScalePoolPointers;
1683+
// GPU bytes allocated for KV-cache
1684+
std::size_t mAllocatedBytes{0};
16801685
};
16811686

16821687
} // namespace tensorrt_llm::batch_manager::kv_cache_manager

cpp/include/tensorrt_llm/executor/executor.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,7 +1006,8 @@ class KvCacheConfig
10061006
std::optional<RetentionPriority> secondaryOffloadMinPriority = std::nullopt, size_t eventBufferMaxSize = 0,
10071007
bool enablePartialReuse = true, bool copyOnPartialReuse = true, bool useUvm = false,
10081008
SizeType32 attentionDpEventsGatherPeriodMs = 5,
1009-
std::optional<tensorrt_llm::runtime::RuntimeDefaults> const& runtimeDefaults = std::nullopt);
1009+
std::optional<tensorrt_llm::runtime::RuntimeDefaults> const& runtimeDefaults = std::nullopt,
1010+
uint64_t const& maxGpuTotalBytes = 0);
10101011

10111012
[[nodiscard]] bool getEnableBlockReuse() const;
10121013
[[nodiscard]] bool getEnablePartialReuse() const;
@@ -1022,11 +1023,12 @@ class KvCacheConfig
10221023
[[nodiscard]] size_t getEventBufferMaxSize() const;
10231024
[[nodiscard]] bool getUseUvm() const;
10241025
[[nodiscard]] SizeType32 getAttentionDpEventsGatherPeriodMs() const;
1026+
[[nodiscard]] uint64_t getMaxGpuTotalBytes() const;
10251027

10261028
void setEnableBlockReuse(bool enableBlockReuse);
10271029
void setEnablePartialReuse(bool enablePartialReuse);
10281030
void setCopyOnPartialReuse(bool copyOnPartialReuse);
1029-
void setMaxTokens(SizeType32 maxTokens);
1031+
void setMaxTokens(std::optional<SizeType32> maxTokens);
10301032
void setMaxAttentionWindowVec(std::vector<SizeType32> maxAttentionWindowVec);
10311033
void setSinkTokenLength(SizeType32 sinkTokenLength);
10321034
void setFreeGpuMemoryFraction(FloatType freeGpuMemoryFraction);
@@ -1037,6 +1039,7 @@ class KvCacheConfig
10371039
void setEventBufferMaxSize(size_t eventBufferMaxSize);
10381040
void setUseUvm(bool useUvm);
10391041
void setAttentionDpEventsGatherPeriodMs(SizeType32 attentionDpEventsGatherPeriodMs);
1042+
void setMaxGpuTotalBytes(uint64_t maxGpuTotalBytes);
10401043

10411044
void fillEmptyFieldsFromRuntimeDefaults(tensorrt_llm::runtime::RuntimeDefaults const& runtimeDefaults);
10421045

@@ -1095,6 +1098,10 @@ class KvCacheConfig
10951098

10961099
/// @brief The period in milliseconds to gather attention DP events across ranks
10971100
SizeType32 mAttentionDpEventsGatherPeriodMs;
1101+
/// @brief The maximum size in bytes of GPU memory that can be allocated for the KV cache.
1102+
/// If both mMaxGpuTotalBytes and mFreeGpuMemoryFraction are specified, memory corresponding to the minimum will
1103+
/// be allocated.
1104+
uint64_t mMaxGpuTotalBytes;
10981105
};
10991106

11001107
/// @brief Configuration class for the runtime perf knobs

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1681,27 +1681,29 @@ void KVCacheManager::allocatePools(bool useUvm)
16811681
mBlockManager.allocatePools(useUvm);
16821682
auto const numPools = mBlockManager.getNumPools();
16831683

1684-
if (tc::Logger::getLogger()->getLevel() <= tc::Logger::INFO)
1684+
uint64_t cacheSizeBytes = 0;
1685+
for (SizeType32 poolIdx = 0; poolIdx < numPools; poolIdx++)
16851686
{
1686-
uint64_t cacheSizeBytes = 0;
1687-
for (SizeType32 poolIdx = 0; poolIdx < numPools; poolIdx++)
1688-
{
1689-
auto const cacheShape = mBlockManager.getPrimaryPool(poolIdx)->getShape();
1690-
auto const cacheVolume = ITensor::volume(cacheShape);
1687+
auto const cacheShape = mBlockManager.getPrimaryPool(poolIdx)->getShape();
1688+
auto const cacheVolume = ITensor::volume(cacheShape);
16911689
#ifdef ENABLE_FP4
1692-
auto const isFp4 = mDataType == nvinfer1::DataType::kFP4;
1690+
auto const isFp4 = mDataType == nvinfer1::DataType::kFP4;
16931691
#else
1694-
auto const isFp4 = false;
1692+
auto const isFp4 = false;
16951693
#endif
1696-
if (!isFp4)
1697-
{
1698-
cacheSizeBytes += cacheVolume * BufferDataType(mDataType).getSize();
1699-
}
1700-
else
1701-
{
1702-
cacheSizeBytes += (cacheVolume * 4) / 8;
1703-
}
1694+
if (!isFp4)
1695+
{
1696+
cacheSizeBytes += cacheVolume * BufferDataType(mDataType).getSize();
1697+
}
1698+
else
1699+
{
1700+
cacheSizeBytes += (cacheVolume * 4) / 8;
17041701
}
1702+
}
1703+
// Save the total number of bytes allocated for the KV-cache for KvCacheStats
1704+
mAllocatedBytes = cacheSizeBytes;
1705+
if (tc::Logger::getLogger()->getLevel() <= tc::Logger::INFO)
1706+
{
17051707

17061708
TLLM_LOG_INFO("Number of tokens per block: %d.", mBlockManager.getTokensPerBlock());
17071709
auto const maxNumTokens = mBlockManager.getNumPrimaryBlocks() * mBlockManager.getTokensPerBlock();

cpp/tensorrt_llm/executor/kvCacheConfig.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ KvCacheConfig::KvCacheConfig(bool enableBlockReuse, std::optional<SizeType32> co
2828
std::optional<FloatType> const& crossKvCacheFraction, std::optional<RetentionPriority> secondaryOffloadMinPriority,
2929
size_t eventBufferMaxSize, bool enablePartialReuse, bool copyOnPartialReuse, bool useUvm,
3030
SizeType32 attentionDpEventsGatherPeriodMs,
31-
std::optional<tensorrt_llm::runtime::RuntimeDefaults> const& runtimeDefaults)
31+
std::optional<tensorrt_llm::runtime::RuntimeDefaults> const& runtimeDefaults, uint64_t const& maxGpuTotalBytes)
3232
: mEnableBlockReuse(enableBlockReuse)
3333
, mHostCacheSize(hostCacheSize)
3434
, mOnboardBlocks(onboardBlocks)
@@ -38,6 +38,7 @@ KvCacheConfig::KvCacheConfig(bool enableBlockReuse, std::optional<SizeType32> co
3838
, mCopyOnPartialReuse{copyOnPartialReuse}
3939
, mUseUvm{useUvm}
4040
, mAttentionDpEventsGatherPeriodMs(attentionDpEventsGatherPeriodMs)
41+
, mMaxGpuTotalBytes{maxGpuTotalBytes}
4142
{
4243
if (maxTokens)
4344
{
@@ -63,6 +64,10 @@ KvCacheConfig::KvCacheConfig(bool enableBlockReuse, std::optional<SizeType32> co
6364
{
6465
fillEmptyFieldsFromRuntimeDefaults(runtimeDefaults.value());
6566
}
67+
if (maxGpuTotalBytes)
68+
{
69+
setMaxGpuTotalBytes(maxGpuTotalBytes);
70+
}
6671
TLLM_CHECK_WITH_INFO(
6772
mAttentionDpEventsGatherPeriodMs > 0, "Attention DP events gather period must be greater than 0");
6873
}
@@ -137,6 +142,11 @@ SizeType32 KvCacheConfig::getAttentionDpEventsGatherPeriodMs() const
137142
return mAttentionDpEventsGatherPeriodMs;
138143
}
139144

145+
uint64_t KvCacheConfig::getMaxGpuTotalBytes() const
146+
{
147+
return mMaxGpuTotalBytes;
148+
}
149+
140150
void KvCacheConfig::setEnableBlockReuse(bool enableBlockReuse)
141151
{
142152
mEnableBlockReuse = enableBlockReuse;
@@ -152,9 +162,12 @@ void KvCacheConfig::setCopyOnPartialReuse(bool copyOnPartialReuse)
152162
mCopyOnPartialReuse = copyOnPartialReuse;
153163
}
154164

155-
void KvCacheConfig::setMaxTokens(SizeType32 maxTokens)
165+
void KvCacheConfig::setMaxTokens(std::optional<SizeType32> maxTokens)
156166
{
157-
TLLM_CHECK(maxTokens > 0);
167+
if (maxTokens)
168+
{
169+
TLLM_CHECK(maxTokens.value() > 0);
170+
}
158171
mMaxTokens = maxTokens;
159172
}
160173

@@ -219,6 +232,11 @@ void KvCacheConfig::setAttentionDpEventsGatherPeriodMs(SizeType32 attentionDpEve
219232
mAttentionDpEventsGatherPeriodMs = attentionDpEventsGatherPeriodMs;
220233
}
221234

235+
void KvCacheConfig::setMaxGpuTotalBytes(uint64_t maxGpuTotalBytes)
236+
{
237+
mMaxGpuTotalBytes = maxGpuTotalBytes;
238+
}
239+
222240
void KvCacheConfig::fillEmptyFieldsFromRuntimeDefaults(tensorrt_llm::runtime::RuntimeDefaults const& runtimeDefaults)
223241
{
224242
if (!mMaxAttentionWindowVec && runtimeDefaults.maxAttentionWindowVec)

cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
303303
.def_rw("reused_blocks", &tbk::KvCacheStats::reusedBlocks)
304304
.def_rw("missed_blocks", &tbk::KvCacheStats::missedBlocks)
305305
.def_rw("cache_hit_rate", &tbk::KvCacheStats::cacheHitRate)
306-
.def_rw("num_free_blocks_per_window_size", &tbk::KvCacheStats::numFreeBlocksPerWindowSize);
306+
.def_rw("num_free_blocks_per_window_size", &tbk::KvCacheStats::numFreeBlocksPerWindowSize)
307+
.def_ro("allocated_bytes", &tbk::KvCacheStats::allocatedBytes);
307308

308309
nb::class_<tbk::TempAttentionWindowInputs>(m, "TempAttentionWindowInputs")
309310
.def(nb::init<>())

cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "tensorrt_llm/nanobind/common/customCasters.h"
2222
#include "tensorrt_llm/runtime/cudaStream.h"
2323
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
24+
#include <cstdint>
2425
#include <nanobind/nanobind.h>
2526
#include <nanobind/stl/function.h>
2627
#include <nanobind/stl/map.h>
@@ -111,11 +112,11 @@ void initConfigBindings(nb::module_& m)
111112
self.getSinkTokenLength(), self.getFreeGpuMemoryFraction(), self.getHostCacheSize(),
112113
self.getOnboardBlocks(), self.getCrossKvCacheFraction(), self.getSecondaryOffloadMinPriority(),
113114
self.getEventBufferMaxSize(), self.getEnablePartialReuse(), self.getCopyOnPartialReuse(), self.getUseUvm(),
114-
self.getAttentionDpEventsGatherPeriodMs());
115+
self.getAttentionDpEventsGatherPeriodMs(), self.getMaxGpuTotalBytes());
115116
};
116117
auto kvCacheConfigSetstate = [](tle::KvCacheConfig& self, nb::tuple const& state)
117118
{
118-
if (state.size() != 14)
119+
if (state.size() != 15)
119120
{
120121
throw std::runtime_error("Invalid state!");
121122
}
@@ -125,20 +126,21 @@ void initConfigBindings(nb::module_& m)
125126
nb::cast<bool>(state[6]), nb::cast<std::optional<float>>(state[7]),
126127
nb::cast<std::optional<tle::RetentionPriority>>(state[8]), nb::cast<size_t>(state[9]),
127128
nb::cast<bool>(state[10]), nb::cast<bool>(state[11]), nb::cast<bool>(state[12]),
128-
nb::cast<SizeType32>(state[13]));
129+
nb::cast<SizeType32>(state[13]), std::nullopt, nb::cast<uint64_t>(state[14]));
129130
};
130131
nb::class_<tle::KvCacheConfig>(m, "KvCacheConfig")
131132
.def(nb::init<bool, std::optional<SizeType32> const&, std::optional<std::vector<SizeType32>> const&,
132133
std::optional<SizeType32> const&, std::optional<float> const&, std::optional<size_t> const&, bool,
133134
std::optional<float> const&, std::optional<tle::RetentionPriority>, size_t const&, bool, bool, bool,
134-
SizeType32, std::optional<RuntimeDefaults> const&>(),
135+
SizeType32, std::optional<RuntimeDefaults> const&, uint64_t const&>(),
135136
nb::arg("enable_block_reuse") = true, nb::arg("max_tokens") = nb::none(),
136137
nb::arg("max_attention_window") = nb::none(), nb::arg("sink_token_length") = nb::none(),
137138
nb::arg("free_gpu_memory_fraction") = nb::none(), nb::arg("host_cache_size") = nb::none(),
138139
nb::arg("onboard_blocks") = true, nb::arg("cross_kv_cache_fraction") = nb::none(),
139140
nb::arg("secondary_offload_min_priority") = nb::none(), nb::arg("event_buffer_max_size") = 0, nb::kw_only(),
140141
nb::arg("enable_partial_reuse") = true, nb::arg("copy_on_partial_reuse") = true, nb::arg("use_uvm") = false,
141-
nb::arg("attention_dp_events_gather_period_ms") = 5, nb::arg("runtime_defaults") = nb::none())
142+
nb::arg("attention_dp_events_gather_period_ms") = 5, nb::arg("runtime_defaults") = nb::none(),
143+
nb::arg("max_gpu_total_bytes") = 0)
142144
.def_prop_rw(
143145
"enable_block_reuse", &tle::KvCacheConfig::getEnableBlockReuse, &tle::KvCacheConfig::setEnableBlockReuse)
144146
.def_prop_rw("max_tokens", &tle::KvCacheConfig::getMaxTokens, &tle::KvCacheConfig::setMaxTokens)
@@ -163,6 +165,8 @@ void initConfigBindings(nb::module_& m)
163165
.def_prop_rw("use_uvm", &tle::KvCacheConfig::getUseUvm, &tle::KvCacheConfig::setUseUvm)
164166
.def_prop_rw("attention_dp_events_gather_period_ms", &tle::KvCacheConfig::getAttentionDpEventsGatherPeriodMs,
165167
&tle::KvCacheConfig::setAttentionDpEventsGatherPeriodMs)
168+
.def_prop_rw(
169+
"max_gpu_total_bytes", &tle::KvCacheConfig::getMaxGpuTotalBytes, &tle::KvCacheConfig::setMaxGpuTotalBytes)
166170
.def("fill_empty_fields_from_runtime_defaults", &tle::KvCacheConfig::fillEmptyFieldsFromRuntimeDefaults)
167171
.def("__getstate__", kvCacheConfigGetstate)
168172
.def("__setstate__", kvCacheConfigSetstate);

cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
299299
.def_readwrite("reused_blocks", &tbk::KvCacheStats::reusedBlocks)
300300
.def_readwrite("missed_blocks", &tbk::KvCacheStats::missedBlocks)
301301
.def_readwrite("cache_hit_rate", &tbk::KvCacheStats::cacheHitRate)
302-
.def_readwrite("num_free_blocks_per_window_size", &tbk::KvCacheStats::numFreeBlocksPerWindowSize);
302+
.def_readwrite("num_free_blocks_per_window_size", &tbk::KvCacheStats::numFreeBlocksPerWindowSize)
303+
.def_readonly("allocated_bytes", &tbk::KvCacheStats::allocatedBytes);
303304

304305
py::class_<tbk::TempAttentionWindowInputs>(m, "TempAttentionWindowInputs")
305306
.def(py::init<>())

cpp/tensorrt_llm/pybind/executor/executorConfig.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,11 @@ void initConfigBindings(pybind11::module_& m)
104104
self.getSinkTokenLength(), self.getFreeGpuMemoryFraction(), self.getHostCacheSize(),
105105
self.getOnboardBlocks(), self.getCrossKvCacheFraction(), self.getSecondaryOffloadMinPriority(),
106106
self.getEventBufferMaxSize(), self.getEnablePartialReuse(), self.getCopyOnPartialReuse(), self.getUseUvm(),
107-
self.getAttentionDpEventsGatherPeriodMs());
107+
self.getAttentionDpEventsGatherPeriodMs(), self.getMaxGpuTotalBytes());
108108
};
109109
auto kvCacheConfigSetstate = [](py::tuple const& state)
110110
{
111-
if (state.size() != 14)
111+
if (state.size() != 15)
112112
{
113113
throw std::runtime_error("Invalid state!");
114114
}
@@ -117,20 +117,21 @@ void initConfigBindings(pybind11::module_& m)
117117
state[4].cast<std::optional<float>>(), state[5].cast<std::optional<size_t>>(), state[6].cast<bool>(),
118118
state[7].cast<std::optional<float>>(), state[8].cast<std::optional<tle::RetentionPriority>>(),
119119
state[9].cast<size_t>(), state[10].cast<bool>(), state[11].cast<bool>(), state[12].cast<bool>(),
120-
state[13].cast<SizeType32>());
120+
state[13].cast<SizeType32>(), std::nullopt, state[14].cast<uint64_t>());
121121
};
122122
py::class_<tle::KvCacheConfig>(m, "KvCacheConfig")
123123
.def(py::init<bool, std::optional<SizeType32> const&, std::optional<std::vector<SizeType32>> const&,
124124
std::optional<SizeType32> const&, std::optional<float> const&, std::optional<size_t> const&, bool,
125125
std::optional<float> const&, std::optional<tle::RetentionPriority>, size_t const&, bool, bool, bool,
126-
SizeType32, std::optional<RuntimeDefaults> const&>(),
126+
SizeType32, std::optional<RuntimeDefaults> const&, uint64_t const&>(),
127127
py::arg("enable_block_reuse") = true, py::arg("max_tokens") = py::none(),
128128
py::arg("max_attention_window") = py::none(), py::arg("sink_token_length") = py::none(),
129129
py::arg("free_gpu_memory_fraction") = py::none(), py::arg("host_cache_size") = py::none(),
130130
py::arg("onboard_blocks") = true, py::arg("cross_kv_cache_fraction") = py::none(),
131131
py::arg("secondary_offload_min_priority") = py::none(), py::arg("event_buffer_max_size") = 0, py::kw_only(),
132132
py::arg("enable_partial_reuse") = true, py::arg("copy_on_partial_reuse") = true, py::arg("use_uvm") = false,
133-
py::arg("attention_dp_events_gather_period_ms") = 5, py::arg("runtime_defaults") = py::none())
133+
py::arg("attention_dp_events_gather_period_ms") = 5, py::arg("runtime_defaults") = py::none(),
134+
py::arg("max_gpu_total_bytes") = 0)
134135
.def_property(
135136
"enable_block_reuse", &tle::KvCacheConfig::getEnableBlockReuse, &tle::KvCacheConfig::setEnableBlockReuse)
136137
.def_property("max_tokens", &tle::KvCacheConfig::getMaxTokens, &tle::KvCacheConfig::setMaxTokens)
@@ -140,6 +141,8 @@ void initConfigBindings(pybind11::module_& m)
140141
"sink_token_length", &tle::KvCacheConfig::getSinkTokenLength, &tle::KvCacheConfig::setSinkTokenLength)
141142
.def_property("free_gpu_memory_fraction", &tle::KvCacheConfig::getFreeGpuMemoryFraction,
142143
&tle::KvCacheConfig::setFreeGpuMemoryFraction)
144+
.def_property(
145+
"max_gpu_total_bytes", &tle::KvCacheConfig::getMaxGpuTotalBytes, &tle::KvCacheConfig::setMaxGpuTotalBytes)
143146
.def_property("host_cache_size", &tle::KvCacheConfig::getHostCacheSize, &tle::KvCacheConfig::setHostCacheSize)
144147
.def_property("onboard_blocks", &tle::KvCacheConfig::getOnboardBlocks, &tle::KvCacheConfig::setOnboardBlocks)
145148
.def_property("cross_kv_cache_fraction", &tle::KvCacheConfig::getCrossKvCacheFraction,

0 commit comments

Comments
 (0)