Skip to content

Commit b4ab81d

Browse files
committed
[None][feat] AutoDeploy: refactor memory usage logging
1. Log model size 2. Fix the logging of memory used during forward, when reconfiguring the kv-cache 3. Catch cache-resize OOM to give users a gentler experience Signed-off-by: Neta Zmora <[email protected]>
1 parent d05079b commit b4ab81d

File tree

3 files changed

+46
-12
lines changed

3 files changed

+46
-12
lines changed

tensorrt_llm/_torch/auto_deploy/models/factory.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch.fx import GraphModule
1313

1414
from ..custom_ops.attention_interface import CacheConfig
15+
from ..utils.cuda_mem_tracker import get_mem_info_in_mb
1516
from ..utils.logger import ad_logger
1617

1718
DynamicShape = Dict[int, Dim] # indicating the dynamic shape in tensor dimension
@@ -273,11 +274,20 @@ def load_or_random_init(self, model: nn.Module, device: DeviceLikeType):
273274
274275
"""
275276
ad_logger.info("Loading and initializing weights.")
277+
free_mem_pre, _ = get_mem_info_in_mb()
278+
ad_logger.info(f"Free memory before loading weights (MB): {free_mem_pre}")
276279
self._to_maybe_random(model, device)
280+
params_size = sum(p.numel() * p.element_size() for p in model.parameters())
281+
total_size_GB = params_size / (1024**3)
282+
ad_logger.info(f"Estimated parameters memory: {total_size_GB:.2f} GB")
283+
277284
if not self.skip_loading_weights:
278285
self.prefetch_checkpoint(force=True)
279286
self._load_checkpoint(model, device)
287+
280288
ad_logger.info("Loading and initializing weights. Done.")
289+
free_mem_post, _ = get_mem_info_in_mb()
290+
ad_logger.info(f"Free memory after loading weights (MB): {free_mem_post}")
281291

282292
@staticmethod
283293
def _to_maybe_random(model: nn.Module, device: DeviceLikeType):

tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from ...models.factory import ModelFactory
1414
from ...shim.interface import CachedSequenceInterface
1515
from ...utils._graph import add_graph_input
16+
from ...utils.cuda_mem_tracker import get_mem_info_in_mb
1617
from ...utils.node_utils import get_all_input_output_nodes, is_op
1718
from ..interface import (
1819
BaseTransform,
@@ -245,16 +246,12 @@ def _apply_to_full_model(
245246
) -> Tuple[nn.Module, TransformInfo]:
246247
free_mem_ratio = self.config.free_mem_ratio
247248

248-
def _get_mem_info_in_mb():
249-
free_mem, total_mem = torch.cuda.mem_get_info()
250-
return free_mem // 1024**2, total_mem // 1024**2
251-
252-
free_mem, total_mem = _get_mem_info_in_mb()
249+
free_mem, total_mem = get_mem_info_in_mb(empty_cache=True)
253250
self._log_info(f"Free memory (MB): {free_mem}, Total memory (MB): {total_mem}")
254251
current_cache_size = cm.current_cache_size_bytes()
255252
current_num_pages = cm.info.num_pages
256253
self._log_info(
257-
f"Current cache size (MB): {current_cache_size // 1024 // 1024}, "
254+
f"Current cache size (MB): {current_cache_size // 1024**2}, "
258255
f"Current num pages (MB): {current_num_pages}"
259256
)
260257

@@ -269,16 +266,33 @@ def _get_mem_info_in_mb():
269266

270267
# Let's run a forward pass to get the memory usage
271268
cm.info.set_max_num_tokens_sample()
272-
free_mem_pre, _ = _get_mem_info_in_mb()
269+
free_mem_pre, _ = get_mem_info_in_mb(empty_cache=True)
273270
self._log_info(f"Free memory before forward pass (MB): {free_mem_pre}")
274271

275-
mod(**cm.named_args)
272+
# Reset peak memory stats to get the extra memory used during the forward pass
273+
torch.cuda.reset_peak_memory_stats()
274+
memory_allocated_before_forward_pass_mb = torch.cuda.memory_allocated() // 1024**2
275+
try:
276+
mod(**cm.named_args)
277+
except torch.OutOfMemoryError as e:
278+
self.ad_logger.error(
279+
f"OutOfMemoryError in forward pass while trying to resize the kv-cache:\n{e}"
280+
)
281+
raise e
276282

277-
free_mem_post, _ = _get_mem_info_in_mb()
278-
self._log_info(f"Free memory after forward pass (MB): {free_mem_post}")
283+
peak_memory_during_forward_pass_mb = torch.cuda.max_memory_allocated() // 1024**2
284+
mem_used_during_forward_pass_mb = (
285+
peak_memory_during_forward_pass_mb - memory_allocated_before_forward_pass_mb
286+
)
287+
self._log_info(
288+
f"Peak memory uasge during forward pass (MB): {peak_memory_during_forward_pass_mb}"
289+
)
290+
self._log_info(
291+
f"Extra memory used during forward pass (MB): {mem_used_during_forward_pass_mb}"
292+
)
279293

280-
memory_for_forward_pass = free_mem_pre - free_mem_post
281-
self._log_info(f"Memory for forward pass (MB): {memory_for_forward_pass}")
294+
free_mem_post, _ = get_mem_info_in_mb(empty_cache=True)
295+
self._log_info(f"Free memory after forward pass (MB): {free_mem_post}")
282296

283297
new_cache_size = free_mem_post * 1024 * 1024 * free_mem_ratio + current_cache_size
284298
new_num_pages = int(new_cache_size // (current_cache_size // current_num_pages))

tensorrt_llm/_torch/auto_deploy/utils/cuda_mem_tracker.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import gc
22
from contextlib import contextmanager
3+
from typing import Tuple
34

45
import torch
56

@@ -24,3 +25,12 @@ def cuda_memory_tracker(logger=ad_logger):
2425
leaked = mem_after - mem_before
2526
if leaked > 0:
2627
logger.warning(f"Potential memory leak detected, leaked memory: {leaked} bytes")
28+
29+
30+
def get_mem_info_in_mb(empty_cache: bool = True) -> Tuple[int, int]:
31+
if empty_cache:
32+
# Clear the memory cache to get the exact free memory
33+
torch.cuda.empty_cache()
34+
free_mem, total_mem = torch.cuda.mem_get_info()
35+
MB = 1024**2
36+
return free_mem // MB, total_mem // MB

0 commit comments

Comments
 (0)