Skip to content

Commit 7432901

Browse files
committed
add memory estimation for attention metadata to solve OOM issue when cuda graph is enabled and max window size is large
Signed-off-by: qixiang-99 <[email protected]>
1 parent d2b1162 commit 7432901

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,59 @@ def _get_free_gpu_memory_fraction(self) -> float:
9696
fraction = 0.9
9797
return fraction
9898

99+
def _get_num_graphs(self) -> int:
100+
return len(self._model_engine._cuda_graph_batch_sizes)
101+
102+
def _get_extra_memory_for_attention_metadata(
103+
self, kv_cache_manager: KVCacheManager) -> int:
104+
"""
105+
`kv_cache_block_offsets` (see `TrtllmAttentionMetadata`) stores the KV-cache
106+
block offsets for every request. Its layout is
107+
[num_pools, max_num_sequences, 2, max_blocks_per_seq].
108+
109+
• Estimation phase: we run a dry-run with requests of length
110+
`max_num_tokens` (e.g. 8192). Consequently, `max_blocks_per_seq` is small
111+
and the tensor's footprint appears modest.
112+
113+
• Real inference: `max_blocks_per_seq` can increase to
114+
`max_seq_len / tokens_per_block`. For long-context models this is
115+
orders of magnitude larger, so the tensor consumes significantly more
116+
GPU memory.
117+
118+
• CUDA graphs: when graph capture is enabled the full
119+
`kv_cache_block_offsets` tensor must be pre-allocated,
120+
making the extra memory grow linearly with the number of graphs.
121+
"""
122+
# get the max_blocks_per_seq in estimation phase
123+
est_phase_max_blocks_per_seq = kv_cache_manager.max_blocks_per_seq
124+
125+
max_batch_size = self._executor_config.max_batch_size
126+
if max_batch_size is None:
127+
logger.warning(f"max_batch_size is not set, using 1")
128+
max_batch_size = 1
129+
max_attention_window_from_config = self._executor_config.kv_cache_config.max_attention_window
130+
max_window_size = max(
131+
max_attention_window_from_config
132+
) if max_attention_window_from_config is not None else self._executor_config.max_seq_len
133+
tokens_per_block = self._executor_config.tokens_per_block
134+
num_pools = kv_cache_manager.num_pools
135+
136+
# calculate the max_blocks_per_seq in real inference phase
137+
real_phase_max_blocks_per_seq = int(
138+
(max_window_size + tokens_per_block - 1) // tokens_per_block)
139+
140+
# calculate the extra memory from kv_cache_block_offsets for each graph
141+
extra_bytes_per_graph = (
142+
real_phase_max_blocks_per_seq -
143+
est_phase_max_blocks_per_seq) * num_pools * max_batch_size * 2 * 4
144+
# get number of graphs
145+
num_graphs = self._get_num_graphs()
146+
total_extra_bytes = int(extra_bytes_per_graph * num_graphs)
147+
logger.info(
148+
f"extra bytes per graph from kv_cache_block_offsets: {extra_bytes_per_graph / (GB):.2f} GiB, total extra bytes: {total_extra_bytes / (GB):.2f} GiB"
149+
)
150+
return total_extra_bytes
151+
99152
def _cal_max_memory(self, peak_memory, total_gpu_memory, fraction,
100153
alloc_kv_tokens: int) -> int:
101154
"""
@@ -256,6 +309,13 @@ def configure_kv_cache_capacity(self, py_executor: PyExecutor) -> None:
256309
logger.info(
257310
f"Memory used outside torch (e.g., NCCL and CUDA graphs) in memory usage profiling: {extra_cost / (GB):.2f} GiB"
258311
)
312+
313+
# get extra memory from attention metadata
314+
extra_memory_for_attention_metadata = self._get_extra_memory_for_attention_metadata(
315+
py_executor.resource_manager.resource_managers.get(
316+
ResourceManagerType.KV_CACHE_MANAGER))
317+
peak_memory += extra_memory_for_attention_metadata
318+
259319
kv_stats = py_executor.resource_manager.resource_managers.get(
260320
ResourceManagerType.KV_CACHE_MANAGER).get_kv_cache_stats()
261321

0 commit comments

Comments
 (0)