Skip to content

Commit 71d1208

Browse files
authored
release local tensors alloc during cudagraph warmup (#568)
1 parent 9bf04ae commit 71d1208

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

lightllm/common/basemodel/cuda_graph.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def replay(self, input_ids, infer_state):
4444
graph_obj.replay()
4545
return graph_predict_logics
4646

47+
@torch.no_grad()
4748
def warmup(self, model):
4849
logger.info("Begin capture cudagraph, use the --disable_cudagraph to disable it.")
4950
for batch_size in range(self.max_batch_size, 0, -1):
@@ -70,6 +71,9 @@ def warmup(self, model):
7071
prob_out = torch.softmax(logics, dim=-1)
7172
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
7273
predict_ids = predict_ids.detach().cpu().numpy()
74+
del logics
75+
del prob_out
76+
torch.cuda.empty_cache()
7377

7478
# dummy decoding, capture the cudagraph
7579
b_start_loc = b_start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
@@ -87,6 +91,11 @@ def warmup(self, model):
8791
)
8892
model.mem_manager.free_all()
8993
model.req_manager.free_all()
94+
# release local tensors
95+
for var_name, var_value in list(locals().items()):
96+
if isinstance(var_value, torch.Tensor):
97+
del locals()[var_name]
98+
torch.cuda.empty_cache()
9099
logger.info(
91100
f"Capture cudagraph success, batch_size <={self.max_batch_size} "
92101
f"and max_len_in_batch <= {self.graph_max_len_in_batch} will infer with cudagraph."

0 commit comments

Comments
 (0)