-
Notifications
You must be signed in to change notification settings - Fork 107
Open
Labels
bugSomething isn't workingSomething isn't working
Description
memory leak when using torch.cat
result = torch.cat(td_list, dim=0)', ' File "/usr/local/lib/python3.10/dist-packages/tensordict/base.py", line 673', ' return TD_HANDLED_FUNCTIONS[func](*args, **kwargs)', ' File "/usr/local/lib/python3.10/dist-packages/tensordict/_torch_func.py", line 371', ' out[key] = torch.cat(items, dim)']...
To Reproduce
Steps to reproduce the behavior.
import torch
import tracemalloc
from tensordict import TensorDict
import gc
import os
def test_torch_cat_memory_leak():
"""测试 torch.cat 拼接 TensorDict 是否存在内存泄漏"""
# 启动内存跟踪
tracemalloc.start(10)
# 记录初始内存
snapshot1 = tracemalloc.take_snapshot()
def create_tensordict_batch(batch_size=32, feature_dim=34816):
"""创建一个 TensorDict 批次"""
return TensorDict({
'observation': torch.randn(batch_size, feature_dim),
'action': torch.randint(0, 10, (batch_size, feature_dim)),
'reward': torch.randn(batch_size, 1)
}, batch_size=[batch_size])
# 模拟多次操作,观察内存增长
num_iterations = 100
all_results = []
print(f"开始测试,迭代次数: {num_iterations}")
print("=" * 50)
for i in range(num_iterations):
# 创建多个 TensorDict 进行拼接
td_list = []
for j in range(5): # 每次创建5个 TensorDict
td = create_tensordict_batch()
td_list.append(td)
result = torch.cat(td_list, dim=0)
del result
gc.collect()
# 每10次迭代检查一次内存
if (i + 1) % 10 == 0:
snapshot2 = tracemalloc.take_snapshot()
top_stats = snapshot2.compare_to(snapshot1, 'traceback')
print(f"\n第 {i+1} 次迭代后内存变化:")
print("-" * 30)
# 显示内存增长最多的前5个位置
for stat in top_stats[:5]:
print(f"{stat.traceback.format()[:100]}...")
print(f" 内存增长: {stat.size_diff / 1024:.2f} KB")
print(f" 总内存: {stat.size / 1024:.2f} KB")
print()
# 强制垃圾回收,看是否能释放内存
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# 最终内存统计
snapshot_final = tracemalloc.take_snapshot()
top_stats_final = snapshot_final.compare_to(snapshot1, 'lineno')
print("=" * 50)
print("最终内存统计:")
print("=" * 50)
total_leak = 0
for stat in top_stats_final[:10]:
if stat.size_diff > 0: # 只关注内存增长
total_leak += stat.size_diff
#print(f"行号: {stat.traceback._frames[0]._lineno}")
print(f"内存泄漏: {stat.size_diff / 1024:.2f} KB")
print(f"调用栈: {stat.traceback.format()[:200]}...")
print("-" * 30)
print(f"\n总内存泄漏: {total_leak / 1024:.2f} KB")
# 判断是否存在明显的内存泄漏
if total_leak > 1024 * 1024: # 如果泄漏超过 1MB
print("🚨 检测到严重内存泄漏!")
elif total_leak > 100 * 1024: # 如果泄漏超过 100KB
print("⚠️ 检测到明显内存泄漏")
else:
print("✅ 未检测到明显内存泄漏")
tracemalloc.stop()
return total_leak
if __name__ == "__main__":
leak_cat = test_torch_cat_memory_leak()
result = torch.cat(td_list, dim=0)', ' File "/usr/local/lib/python3.10/dist-packages/tensordict/base.py", line 673', ' return TD_HANDLED_FUNCTIONS[func](*args, **kwargs)', ' File "/usr/local/lib/python3.10/dist-packages/tensordict/_torch_func.py", line 371', ' out[key] = torch.cat(items, dim)']...Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working