Skip to content

Commit c191b38

Browse files
committed
fix: WAR for under-allocation in torch VSWA kvcachemanager
Signed-off-by: qixiang-99 <[email protected]>
1 parent c198402 commit c191b38

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

cpp/include/tensorrt_llm/executor/executor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1020,7 +1020,7 @@ class KvCacheConfig
10201020
void setEnableBlockReuse(bool enableBlockReuse);
10211021
void setEnablePartialReuse(bool enablePartialReuse);
10221022
void setCopyOnPartialReuse(bool copyOnPartialReuse);
1023-
void setMaxTokens(SizeType32 maxTokens);
1023+
void setMaxTokens(std::optional<SizeType32> maxTokens);
10241024
void setMaxAttentionWindowVec(std::vector<SizeType32> maxAttentionWindowVec);
10251025
void setSinkTokenLength(SizeType32 sinkTokenLength);
10261026
void setFreeGpuMemoryFraction(FloatType freeGpuMemoryFraction);

cpp/tensorrt_llm/executor/kvCacheConfig.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,12 @@ void KvCacheConfig::setCopyOnPartialReuse(bool copyOnPartialReuse)
143143
mCopyOnPartialReuse = copyOnPartialReuse;
144144
}
145145

146-
void KvCacheConfig::setMaxTokens(SizeType32 maxTokens)
146+
void KvCacheConfig::setMaxTokens(std::optional<SizeType32> maxTokens)
147147
{
148-
TLLM_CHECK(maxTokens > 0);
148+
if (maxTokens)
149+
{
150+
TLLM_CHECK(maxTokens.value() > 0);
151+
}
149152
mMaxTokens = maxTokens;
150153
}
151154

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,9 @@ def __init__(
206206
assert isinstance(
207207
kv_cache_config, KvCacheConfigCpp
208208
), "calculate_max_num_blocks_from_cpp only accepts KvCacheConfigCpp"
209+
210+
# overwrite max_tokens in VSWA case
211+
kv_cache_config.max_tokens = None
209212
blocks_per_window = self.calculate_max_num_blocks_from_cpp(
210213
kv_cache_config=kv_cache_config,
211214
model_config=model_config,
@@ -633,7 +636,7 @@ def calculate_max_num_blocks_from_cpp(
633636
logger.debug(f"window_size_to_layers: {window_size_to_layers}")
634637

635638
free_mem, total_mem = torch.cuda.mem_get_info()
636-
primary_pool_memory_bytes = free_mem
639+
primary_pool_memory_bytes = int(free_mem * 0.9)
637640
secondary_pool_memory_bytes = 0
638641
logger.debug(
639642
f"primary_pool_memory_bytes is set to {primary_pool_memory_bytes/1024**3}GB, \nsecondary_pool_memory_bytes is set to {secondary_pool_memory_bytes/1024**3}GB"

0 commit comments

Comments
 (0)