@@ -44,6 +44,7 @@ def replay(self, input_ids, infer_state):
4444 graph_obj .replay ()
4545 return graph_predict_logics
4646
47+ @torch .no_grad ()
4748 def warmup (self , model ):
4849 logger .info ("Begin capture cudagraph, use the --disable_cudagraph to disable it." )
4950 for batch_size in range (self .max_batch_size , 0 , - 1 ):
@@ -70,6 +71,9 @@ def warmup(self, model):
7071 prob_out = torch .softmax (logics , dim = - 1 )
7172 predict_ids = torch .argmax (prob_out , dim = 1 , keepdim = True )
7273 predict_ids = predict_ids .detach ().cpu ().numpy ()
74+ del logics
75+ del prob_out
76+ torch .cuda .empty_cache ()
7377
7478 # dummy decoding, capture the cudagraph
7579 b_start_loc = b_start_loc + torch .arange (0 , batch_size , dtype = torch .int32 , device = "cuda" )
@@ -87,6 +91,11 @@ def warmup(self, model):
8791 )
8892 model .mem_manager .free_all ()
8993 model .req_manager .free_all ()
94+ # release local tensors
95+ for var_name , var_value in list (locals ().items ()):
96+ if isinstance (var_value , torch .Tensor ):
97+ del locals ()[var_name ]
98+ torch .cuda .empty_cache ()
9099 logger .info (
91100 f"Capture cudagraph success, batch_size <={ self .max_batch_size } "
92101 f"and max_len_in_batch <= { self .graph_max_len_in_batch } will infer with cudagraph."
0 commit comments