Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/models/factory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""The model factory interface used by auto-deploy to build custom models."""

import copy
Expand All @@ -12,6 +28,7 @@
from torch.fx import GraphModule

from ..custom_ops.attention_interface import CacheConfig
from ..utils.cuda_mem_tracker import get_mem_info_in_mb
from ..utils.logger import ad_logger

DynamicShape = Dict[int, Dim] # indicating the dynamic shape in tensor dimension
Expand Down Expand Up @@ -290,11 +307,20 @@ def load_or_random_init(self, model: nn.Module, device: DeviceLikeType):

"""
ad_logger.info("Loading and initializing weights.")
free_mem_pre, _ = get_mem_info_in_mb()
ad_logger.info(f"Free memory before loading weights (MB): {free_mem_pre}")
self._to_maybe_random(model, device)
params_size = sum(p.numel() * p.element_size() for p in model.parameters())
total_size_GB = params_size / (1024**3)
ad_logger.info(f"Estimated parameters memory: {total_size_GB:.2f} GB")

if not self.skip_loading_weights:
self.prefetch_checkpoint(force=True)
self._load_checkpoint(model, device)

ad_logger.info("Loading and initializing weights. Done.")
free_mem_post, _ = get_mem_info_in_mb()
ad_logger.info(f"Free memory after loading weights (MB): {free_mem_post}")

@staticmethod
def _to_maybe_random(model: nn.Module, device: DeviceLikeType):
Expand Down
51 changes: 42 additions & 9 deletions tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Graph transformation to automatically add kv cache into fused MHA op."""

import operator
Expand All @@ -19,6 +35,7 @@
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils._graph import add_graph_input
from ...utils.cuda_mem_tracker import get_mem_info_in_mb
from ...utils.node_utils import get_all_input_output_nodes, is_op
from ..interface import (
BaseTransform,
Expand Down Expand Up @@ -256,11 +273,7 @@ def _apply_to_full_model(
) -> Tuple[nn.Module, TransformInfo]:
free_mem_ratio = self.config.free_mem_ratio

def _get_mem_info_in_mb():
free_mem, total_mem = torch.cuda.mem_get_info()
return free_mem // 1024**2, total_mem // 1024**2

free_mem, total_mem = _get_mem_info_in_mb()
free_mem, total_mem = get_mem_info_in_mb(empty_cache=True)
self._log_info(f"Free memory (MB): {free_mem}, Total memory (MB): {total_mem}")
current_cache_size = cm.current_cache_size_bytes()
current_kv_cache_size = getattr(cm, "current_kv_cache_size_bytes", None)
Expand All @@ -269,7 +282,7 @@ def _get_mem_info_in_mb():
)
current_num_pages = cm.info.num_pages
self._log_info(
f"Current cache size (MB): {current_cache_size // 1024 // 1024}, "
f"Current cache size (MB): {current_cache_size // 1024**2}, "
f"Current num pages: {current_num_pages}"
)
if current_kv_cache_size != current_cache_size:
Expand All @@ -288,12 +301,32 @@ def _get_mem_info_in_mb():

# Let's run a forward pass to get the memory usage
cm.info.set_max_num_tokens_sample()
free_mem_pre, _ = _get_mem_info_in_mb()
free_mem_pre, _ = get_mem_info_in_mb(empty_cache=True)
self._log_info(f"Free memory before forward pass (MB): {free_mem_pre}")

mod(**cm.named_args)
# Reset peak memory stats to get the extra memory used during the forward pass
torch.cuda.reset_peak_memory_stats()
memory_allocated_before_forward_pass_mb = torch.cuda.memory_allocated() // 1024**2
try:
mod(**cm.named_args)
except torch.OutOfMemoryError as e:
self.ad_logger.error(
f"OutOfMemoryError in forward pass while trying to resize the kv-cache:\n{e}"
)
raise e

peak_memory_during_forward_pass_mb = torch.cuda.max_memory_allocated() // 1024**2
mem_used_during_forward_pass_mb = (
peak_memory_during_forward_pass_mb - memory_allocated_before_forward_pass_mb
)
self._log_info(
f"Peak memory uasge during forward pass (MB): {peak_memory_during_forward_pass_mb}"
)
self._log_info(
f"Extra memory used during forward pass (MB): {mem_used_during_forward_pass_mb}"
)

free_mem_post, _ = _get_mem_info_in_mb()
free_mem_post, _ = get_mem_info_in_mb(empty_cache=True)
self._log_info(f"Free memory after forward pass (MB): {free_mem_post}")

memory_for_forward_pass = free_mem_pre - free_mem_post
Expand Down
26 changes: 26 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/utils/cuda_mem_tracker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import gc
from contextlib import contextmanager
from typing import Tuple

import torch

Expand All @@ -24,3 +41,12 @@ def cuda_memory_tracker(logger=ad_logger):
leaked = mem_after - mem_before
if leaked > 0:
logger.warning(f"Potential memory leak detected, leaked memory: {leaked} bytes")


def get_mem_info_in_mb(empty_cache: bool = True) -> Tuple[int, int]:
if empty_cache:
# Clear the memory cache to get the exact free memory
torch.cuda.empty_cache()
free_mem, total_mem = torch.cuda.mem_get_info()
MB = 1024**2
return free_mem // MB, total_mem // MB