1313from ...models .factory import ModelFactory
1414from ...shim .interface import CachedSequenceInterface
1515from ...utils ._graph import add_graph_input
16+ from ...utils .cuda_mem_tracker import get_mem_info_in_mb
1617from ...utils .node_utils import get_all_input_output_nodes , is_op
1718from ..interface import (
1819 BaseTransform ,
@@ -245,16 +246,12 @@ def _apply_to_full_model(
245246 ) -> Tuple [nn .Module , TransformInfo ]:
246247 free_mem_ratio = self .config .free_mem_ratio
247248
248- def _get_mem_info_in_mb ():
249- free_mem , total_mem = torch .cuda .mem_get_info ()
250- return free_mem // 1024 ** 2 , total_mem // 1024 ** 2
251-
252- free_mem , total_mem = _get_mem_info_in_mb ()
249+ free_mem , total_mem = get_mem_info_in_mb (empty_cache = True )
253250 self ._log_info (f"Free memory (MB): { free_mem } , Total memory (MB): { total_mem } " )
254251 current_cache_size = cm .current_cache_size_bytes ()
255252 current_num_pages = cm .info .num_pages
256253 self ._log_info (
257- f"Current cache size (MB): { current_cache_size // 1024 // 1024 } , "
254+ f"Current cache size (MB): { current_cache_size // 1024 ** 2 } , "
258255 f"Current num pages (MB): { current_num_pages } "
259256 )
260257
@@ -269,16 +266,33 @@ def _get_mem_info_in_mb():
269266
270267 # Let's run a forward pass to get the memory usage
271268 cm .info .set_max_num_tokens_sample ()
272- free_mem_pre , _ = _get_mem_info_in_mb ( )
269+ free_mem_pre , _ = get_mem_info_in_mb ( empty_cache = True )
273270 self ._log_info (f"Free memory before forward pass (MB): { free_mem_pre } " )
274271
275- mod (** cm .named_args )
272+ # Reset peak memory stats to get the extra memory used during the forward pass
273+ torch .cuda .reset_peak_memory_stats ()
274+ memory_allocated_before_forward_pass_mb = torch .cuda .memory_allocated () // 1024 ** 2
275+ try :
276+ mod (** cm .named_args )
277+ except torch .OutOfMemoryError as e :
278+ self .ad_logger .error (
279+ f"OutOfMemoryError in forward pass while trying to resize the kv-cache:\n { e } "
280+ )
281+ raise e
276282
277- free_mem_post , _ = _get_mem_info_in_mb ()
278- self ._log_info (f"Free memory after forward pass (MB): { free_mem_post } " )
283+ peak_memory_during_forward_pass_mb = torch .cuda .max_memory_allocated () // 1024 ** 2
284+ mem_used_during_forward_pass_mb = (
285+ peak_memory_during_forward_pass_mb - memory_allocated_before_forward_pass_mb
286+ )
287+ self ._log_info (
288+ f"Peak memory uasge during forward pass (MB): { peak_memory_during_forward_pass_mb } "
289+ )
290+ self ._log_info (
291+ f"Extra memory used during forward pass (MB): { mem_used_during_forward_pass_mb } "
292+ )
279293
280- memory_for_forward_pass = free_mem_pre - free_mem_post
281- self ._log_info (f"Memory for forward pass (MB): { memory_for_forward_pass } " )
294+ free_mem_post , _ = get_mem_info_in_mb ( empty_cache = True )
295+ self ._log_info (f"Free memory after forward pass (MB): { free_mem_post } " )
282296
283297 new_cache_size = free_mem_post * 1024 * 1024 * free_mem_ratio + current_cache_size
284298 new_num_pages = int (new_cache_size // (current_cache_size // current_num_pages ))
0 commit comments