Skip to content

Commit b124f8e

Browse files
committed
[None][feat] AutoDeploy: refactor memory usage logging
1. Log model size 2. Fix the logging of memory used during forward, when reconfiguring the kv-cache 3. Catch cache-resize OOM to give users a gentler experience Signed-off-by: Neta Zmora <[email protected]>
1 parent d626d13 commit b124f8e

File tree

3 files changed

+47
-10
lines changed

3 files changed

+47
-10
lines changed

tensorrt_llm/_torch/auto_deploy/models/factory.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch.fx import GraphModule
1313

1414
from ..custom_ops.attention_interface import CacheConfig
15+
from ..utils.cuda_mem_tracker import get_mem_info_in_mb
1516
from ..utils.logger import ad_logger
1617

1718
DynamicShape = Dict[int, Dim] # indicating the dynamic shape in tensor dimension
@@ -285,11 +286,20 @@ def load_or_random_init(self, model: nn.Module, device: DeviceLikeType):
285286
286287
"""
287288
ad_logger.info("Loading and initializing weights.")
289+
free_mem_pre, _ = get_mem_info_in_mb()
290+
ad_logger.info(f"Free memory before loading weights (MB): {free_mem_pre}")
288291
self._to_maybe_random(model, device)
292+
params_size = sum(p.numel() * p.element_size() for p in model.parameters())
293+
total_size_GB = params_size / (1024**3)
294+
ad_logger.info(f"Estimated parameters memory: {total_size_GB:.2f} GB")
295+
289296
if not self.skip_loading_weights:
290297
self.prefetch_checkpoint(force=True)
291298
self._load_checkpoint(model, device)
299+
292300
ad_logger.info("Loading and initializing weights. Done.")
301+
free_mem_post, _ = get_mem_info_in_mb()
302+
ad_logger.info(f"Free memory after loading weights (MB): {free_mem_post}")
293303

294304
@staticmethod
295305
def _to_maybe_random(model: nn.Module, device: DeviceLikeType):

tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ...models.factory import ModelFactory
1515
from ...shim.interface import CachedSequenceInterface
1616
from ...utils._graph import add_graph_input
17+
from ...utils.cuda_mem_tracker import get_mem_info_in_mb
1718
from ...utils.node_utils import get_all_input_output_nodes, is_op
1819
from ..interface import (
1920
BaseTransform,
@@ -246,11 +247,7 @@ def _apply_to_full_model(
246247
) -> Tuple[nn.Module, TransformInfo]:
247248
free_mem_ratio = self.config.free_mem_ratio
248249

249-
def _get_mem_info_in_mb():
250-
free_mem, total_mem = torch.cuda.mem_get_info()
251-
return free_mem // 1024**2, total_mem // 1024**2
252-
253-
free_mem, total_mem = _get_mem_info_in_mb()
250+
free_mem, total_mem = get_mem_info_in_mb(empty_cache=True)
254251
self._log_info(f"Free memory (MB): {free_mem}, Total memory (MB): {total_mem}")
255252
current_cache_size = cm.current_cache_size_bytes()
256253
current_kv_cache_size = getattr(cm, "current_kv_cache_size_bytes", None)
@@ -259,8 +256,8 @@ def _get_mem_info_in_mb():
259256
)
260257
current_num_pages = cm.info.num_pages
261258
self._log_info(
262-
f"Current cache size (MB): {current_cache_size // 1024 // 1024}, "
263-
f"Current num pages: {current_num_pages}"
259+
f"Current cache size (MB): {current_cache_size // 1024**2}, "
260+
f"Current num pages (MB): {current_num_pages}"
264261
)
265262
if current_kv_cache_size != current_cache_size:
266263
self._log_info(
@@ -278,12 +275,32 @@ def _get_mem_info_in_mb():
278275

279276
# Let's run a forward pass to get the memory usage
280277
cm.info.set_max_num_tokens_sample()
281-
free_mem_pre, _ = _get_mem_info_in_mb()
278+
free_mem_pre, _ = get_mem_info_in_mb(empty_cache=True)
282279
self._log_info(f"Free memory before forward pass (MB): {free_mem_pre}")
283280

284-
mod(**cm.named_args)
281+
# Reset peak memory stats to get the extra memory used during the forward pass
282+
torch.cuda.reset_peak_memory_stats()
283+
memory_allocated_before_forward_pass_mb = torch.cuda.memory_allocated() // 1024**2
284+
try:
285+
mod(**cm.named_args)
286+
except torch.OutOfMemoryError as e:
287+
self.ad_logger.error(
288+
f"OutOfMemoryError in forward pass while trying to resize the kv-cache:\n{e}"
289+
)
290+
raise e
291+
292+
peak_memory_during_forward_pass_mb = torch.cuda.max_memory_allocated() // 1024**2
293+
mem_used_during_forward_pass_mb = (
294+
peak_memory_during_forward_pass_mb - memory_allocated_before_forward_pass_mb
295+
)
296+
self._log_info(
297+
f"Peak memory uasge during forward pass (MB): {peak_memory_during_forward_pass_mb}"
298+
)
299+
self._log_info(
300+
f"Extra memory used during forward pass (MB): {mem_used_during_forward_pass_mb}"
301+
)
285302

286-
free_mem_post, _ = _get_mem_info_in_mb()
303+
free_mem_post, _ = get_mem_info_in_mb(empty_cache=True)
287304
self._log_info(f"Free memory after forward pass (MB): {free_mem_post}")
288305

289306
memory_for_forward_pass = free_mem_pre - free_mem_post

tensorrt_llm/_torch/auto_deploy/utils/cuda_mem_tracker.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import gc
22
from contextlib import contextmanager
3+
from typing import Tuple
34

45
import torch
56

@@ -24,3 +25,12 @@ def cuda_memory_tracker(logger=ad_logger):
2425
leaked = mem_after - mem_before
2526
if leaked > 0:
2627
logger.warning(f"Potential memory leak detected, leaked memory: {leaked} bytes")
28+
29+
30+
def get_mem_info_in_mb(empty_cache: bool = True) -> Tuple[int, int]:
31+
if empty_cache:
32+
# Clear the memory cache to get the exact free memory
33+
torch.cuda.empty_cache()
34+
free_mem, total_mem = torch.cuda.mem_get_info()
35+
MB = 1024**2
36+
return free_mem // MB, total_mem // MB

0 commit comments

Comments
 (0)