diff --git a/tensorrt_llm/_torch/auto_deploy/models/factory.py b/tensorrt_llm/_torch/auto_deploy/models/factory.py index 71b4b8b2c5c..f67251d3d6c 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/factory.py +++ b/tensorrt_llm/_torch/auto_deploy/models/factory.py @@ -1,3 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + """The model factory interface used by auto-deploy to build custom models.""" import copy @@ -12,6 +28,7 @@ from torch.fx import GraphModule from ..custom_ops.attention_interface import CacheConfig +from ..utils.cuda_mem_tracker import get_mem_info_in_mb from ..utils.logger import ad_logger DynamicShape = Dict[int, Dim] # indicating the dynamic shape in tensor dimension @@ -290,11 +307,20 @@ def load_or_random_init(self, model: nn.Module, device: DeviceLikeType): """ ad_logger.info("Loading and initializing weights.") + free_mem_pre, _ = get_mem_info_in_mb() + ad_logger.info(f"Free memory before loading weights (MB): {free_mem_pre}") self._to_maybe_random(model, device) + params_size = sum(p.numel() * p.element_size() for p in model.parameters()) + total_size_GB = params_size / (1024**3) + ad_logger.info(f"Estimated parameters memory: {total_size_GB:.2f} GB") + if not self.skip_loading_weights: self.prefetch_checkpoint(force=True) self._load_checkpoint(model, device) + ad_logger.info("Loading and initializing weights. Done.") + free_mem_post, _ = get_mem_info_in_mb() + ad_logger.info(f"Free memory after loading weights (MB): {free_mem_post}") @staticmethod def _to_maybe_random(model: nn.Module, device: DeviceLikeType): diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py index ecf42d0b238..de657087147 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py @@ -1,3 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + """Graph transformation to automatically add kv cache into fused MHA op.""" import operator @@ -19,6 +35,7 @@ from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface from ...utils._graph import add_graph_input +from ...utils.cuda_mem_tracker import get_mem_info_in_mb from ...utils.node_utils import get_all_input_output_nodes, is_op from ..interface import ( BaseTransform, @@ -256,11 +273,7 @@ def _apply_to_full_model( ) -> Tuple[nn.Module, TransformInfo]: free_mem_ratio = self.config.free_mem_ratio - def _get_mem_info_in_mb(): - free_mem, total_mem = torch.cuda.mem_get_info() - return free_mem // 1024**2, total_mem // 1024**2 - - free_mem, total_mem = _get_mem_info_in_mb() + free_mem, total_mem = get_mem_info_in_mb(empty_cache=True) self._log_info(f"Free memory (MB): {free_mem}, Total memory (MB): {total_mem}") current_cache_size = cm.current_cache_size_bytes() current_kv_cache_size = getattr(cm, "current_kv_cache_size_bytes", None) @@ -269,7 +282,7 @@ def _get_mem_info_in_mb(): ) current_num_pages = cm.info.num_pages self._log_info( - f"Current cache size (MB): {current_cache_size // 1024 // 1024}, " + f"Current cache size (MB): {current_cache_size // 1024**2}, " f"Current num pages: {current_num_pages}" ) if current_kv_cache_size != current_cache_size: @@ -288,12 +301,32 @@ def _get_mem_info_in_mb(): # Let's run a forward pass to get the memory usage cm.info.set_max_num_tokens_sample() - free_mem_pre, _ = _get_mem_info_in_mb() + free_mem_pre, _ = get_mem_info_in_mb(empty_cache=True) self._log_info(f"Free memory before forward pass (MB): {free_mem_pre}") - mod(**cm.named_args) + # Reset peak memory stats to get the extra memory used during the forward pass + torch.cuda.reset_peak_memory_stats() + memory_allocated_before_forward_pass_mb = torch.cuda.memory_allocated() // 1024**2 + try: + mod(**cm.named_args) + except torch.OutOfMemoryError as e: + self.ad_logger.error( + f"OutOfMemoryError in forward pass while trying to resize the kv-cache:\n{e}" + ) + raise e + + peak_memory_during_forward_pass_mb = torch.cuda.max_memory_allocated() // 1024**2 + mem_used_during_forward_pass_mb = ( + peak_memory_during_forward_pass_mb - memory_allocated_before_forward_pass_mb + ) + self._log_info( + f"Peak memory uasge during forward pass (MB): {peak_memory_during_forward_pass_mb}" + ) + self._log_info( + f"Extra memory used during forward pass (MB): {mem_used_during_forward_pass_mb}" + ) - free_mem_post, _ = _get_mem_info_in_mb() + free_mem_post, _ = get_mem_info_in_mb(empty_cache=True) self._log_info(f"Free memory after forward pass (MB): {free_mem_post}") memory_for_forward_pass = free_mem_pre - free_mem_post diff --git a/tensorrt_llm/_torch/auto_deploy/utils/cuda_mem_tracker.py b/tensorrt_llm/_torch/auto_deploy/utils/cuda_mem_tracker.py index ddf57a6e6f7..e73cec39e7c 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/cuda_mem_tracker.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/cuda_mem_tracker.py @@ -1,5 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import gc from contextlib import contextmanager +from typing import Tuple import torch @@ -24,3 +41,12 @@ def cuda_memory_tracker(logger=ad_logger): leaked = mem_after - mem_before if leaked > 0: logger.warning(f"Potential memory leak detected, leaked memory: {leaked} bytes") + + +def get_mem_info_in_mb(empty_cache: bool = True) -> Tuple[int, int]: + if empty_cache: + # Clear the memory cache to get the exact free memory + torch.cuda.empty_cache() + free_mem, total_mem = torch.cuda.mem_get_info() + MB = 1024**2 + return free_mem // MB, total_mem // MB