File tree Expand file tree Collapse file tree 3 files changed +10
-4
lines changed
include/tensorrt_llm/executor
tensorrt_llm/_torch/pyexecutor Expand file tree Collapse file tree 3 files changed +10
-4
lines changed Original file line number Diff line number Diff line change @@ -1020,7 +1020,7 @@ class KvCacheConfig
10201020 void setEnableBlockReuse(bool enableBlockReuse);
10211021 void setEnablePartialReuse(bool enablePartialReuse);
10221022 void setCopyOnPartialReuse(bool copyOnPartialReuse);
1023- void setMaxTokens(SizeType32 maxTokens);
1023+ void setMaxTokens(std::optional< SizeType32> maxTokens);
10241024 void setMaxAttentionWindowVec(std::vector<SizeType32> maxAttentionWindowVec);
10251025 void setSinkTokenLength(SizeType32 sinkTokenLength);
10261026 void setFreeGpuMemoryFraction(FloatType freeGpuMemoryFraction);
Original file line number Diff line number Diff line change @@ -143,9 +143,12 @@ void KvCacheConfig::setCopyOnPartialReuse(bool copyOnPartialReuse)
143143 mCopyOnPartialReuse = copyOnPartialReuse;
144144}
145145
146- void KvCacheConfig::setMaxTokens(SizeType32 maxTokens)
146+ void KvCacheConfig::setMaxTokens(std::optional< SizeType32> maxTokens)
147147{
148- TLLM_CHECK(maxTokens > 0);
148+ if (maxTokens)
149+ {
150+ TLLM_CHECK(maxTokens.value() > 0);
151+ }
149152 mMaxTokens = maxTokens;
150153}
151154
Original file line number Diff line number Diff line change @@ -206,6 +206,9 @@ def __init__(
206206 assert isinstance(
207207 kv_cache_config, KvCacheConfigCpp
208208 ), "calculate_max_num_blocks_from_cpp only accepts KvCacheConfigCpp"
209+
210+ # overwrite max_tokens in VSWA case
211+ kv_cache_config.max_tokens = None
209212 blocks_per_window = self.calculate_max_num_blocks_from_cpp(
210213 kv_cache_config=kv_cache_config,
211214 model_config=model_config,
@@ -633,7 +636,7 @@ def calculate_max_num_blocks_from_cpp(
633636 logger.debug(f"window_size_to_layers: {window_size_to_layers}")
634637
635638 free_mem, total_mem = torch.cuda.mem_get_info()
636- primary_pool_memory_bytes = free_mem
639+ primary_pool_memory_bytes = int( free_mem * 0.9)
637640 secondary_pool_memory_bytes = 0
638641 logger.debug(
639642 f"primary_pool_memory_bytes is set to {primary_pool_memory_bytes/1024**3}GB, \nsecondary_pool_memory_bytes is set to {secondary_pool_memory_bytes/1024**3}GB"
You can’t perform that action at this time.
0 commit comments