1414from ...models .factory import ModelFactory
1515from ...shim .interface import CachedSequenceInterface
1616from ...utils ._graph import add_graph_input
17+ from ...utils .cuda_mem_tracker import get_mem_info_in_mb
1718from ...utils .node_utils import get_all_input_output_nodes , is_op
1819from ..interface import (
1920 BaseTransform ,
@@ -246,11 +247,7 @@ def _apply_to_full_model(
246247 ) -> Tuple [nn .Module , TransformInfo ]:
247248 free_mem_ratio = self .config .free_mem_ratio
248249
249- def _get_mem_info_in_mb ():
250- free_mem , total_mem = torch .cuda .mem_get_info ()
251- return free_mem // 1024 ** 2 , total_mem // 1024 ** 2
252-
253- free_mem , total_mem = _get_mem_info_in_mb ()
250+ free_mem , total_mem = get_mem_info_in_mb (empty_cache = True )
254251 self ._log_info (f"Free memory (MB): { free_mem } , Total memory (MB): { total_mem } " )
255252 current_cache_size = cm .current_cache_size_bytes ()
256253 current_kv_cache_size = getattr (cm , "current_kv_cache_size_bytes" , None )
@@ -259,8 +256,8 @@ def _get_mem_info_in_mb():
259256 )
260257 current_num_pages = cm .info .num_pages
261258 self ._log_info (
262- f"Current cache size (MB): { current_cache_size // 1024 // 1024 } , "
263- f"Current num pages: { current_num_pages } "
259+ f"Current cache size (MB): { current_cache_size // 1024 ** 2 } , "
260+ f"Current num pages (MB) : { current_num_pages } "
264261 )
265262 if current_kv_cache_size != current_cache_size :
266263 self ._log_info (
@@ -278,12 +275,32 @@ def _get_mem_info_in_mb():
278275
279276 # Let's run a forward pass to get the memory usage
280277 cm .info .set_max_num_tokens_sample ()
281- free_mem_pre , _ = _get_mem_info_in_mb ( )
278+ free_mem_pre , _ = get_mem_info_in_mb ( empty_cache = True )
282279 self ._log_info (f"Free memory before forward pass (MB): { free_mem_pre } " )
283280
284- mod (** cm .named_args )
281+ # Reset peak memory stats to get the extra memory used during the forward pass
282+ torch .cuda .reset_peak_memory_stats ()
283+ memory_allocated_before_forward_pass_mb = torch .cuda .memory_allocated () // 1024 ** 2
284+ try :
285+ mod (** cm .named_args )
286+ except torch .OutOfMemoryError as e :
287+ self .ad_logger .error (
288+ f"OutOfMemoryError in forward pass while trying to resize the kv-cache:\n { e } "
289+ )
290+ raise e
291+
292+ peak_memory_during_forward_pass_mb = torch .cuda .max_memory_allocated () // 1024 ** 2
293+ mem_used_during_forward_pass_mb = (
294+ peak_memory_during_forward_pass_mb - memory_allocated_before_forward_pass_mb
295+ )
296+ self ._log_info (
297+ f"Peak memory uasge during forward pass (MB): { peak_memory_during_forward_pass_mb } "
298+ )
299+ self ._log_info (
300+ f"Extra memory used during forward pass (MB): { mem_used_during_forward_pass_mb } "
301+ )
285302
286- free_mem_post , _ = _get_mem_info_in_mb ( )
303+ free_mem_post , _ = get_mem_info_in_mb ( empty_cache = True )
287304 self ._log_info (f"Free memory after forward pass (MB): { free_mem_post } " )
288305
289306 memory_for_forward_pass = free_mem_pre - free_mem_post
0 commit comments