Skip to content

[BUG]memory leak when using torch.cat #1475

@zwc163

Description

@zwc163

memory leak when using torch.cat

result = torch.cat(td_list, dim=0)', ' File "/usr/local/lib/python3.10/dist-packages/tensordict/base.py", line 673', ' return TD_HANDLED_FUNCTIONS[func](*args, **kwargs)', ' File "/usr/local/lib/python3.10/dist-packages/tensordict/_torch_func.py", line 371', ' out[key] = torch.cat(items, dim)']...

To Reproduce

Steps to reproduce the behavior.

import torch
import tracemalloc
from tensordict import TensorDict
import gc
import os

def test_torch_cat_memory_leak():
    """测试 torch.cat 拼接 TensorDict 是否存在内存泄漏"""
    
    # 启动内存跟踪
    tracemalloc.start(10)
    
    # 记录初始内存
    snapshot1 = tracemalloc.take_snapshot()
    
    def create_tensordict_batch(batch_size=32, feature_dim=34816):
        """创建一个 TensorDict 批次"""
        return TensorDict({
            'observation': torch.randn(batch_size, feature_dim),
            'action': torch.randint(0, 10, (batch_size, feature_dim)),
            'reward': torch.randn(batch_size, 1)
        }, batch_size=[batch_size])
    
    # 模拟多次操作,观察内存增长
    num_iterations = 100
    all_results = []
    
    print(f"开始测试,迭代次数: {num_iterations}")
    print("=" * 50)
    
    for i in range(num_iterations):
        # 创建多个 TensorDict 进行拼接
        td_list = []
        for j in range(5):  # 每次创建5个 TensorDict
            td = create_tensordict_batch()
            td_list.append(td)
        
        result = torch.cat(td_list, dim=0)

        del result
        gc.collect()

        # 每10次迭代检查一次内存
        if (i + 1) % 10 == 0:
            snapshot2 = tracemalloc.take_snapshot()
            top_stats = snapshot2.compare_to(snapshot1, 'traceback')
            
            print(f"\n{i+1} 次迭代后内存变化:")
            print("-" * 30)
            
            # 显示内存增长最多的前5个位置
            for stat in top_stats[:5]:
                print(f"{stat.traceback.format()[:100]}...")
                print(f"    内存增长: {stat.size_diff / 1024:.2f} KB")
                print(f"    总内存: {stat.size / 1024:.2f} KB")
                print()
            
            # 强制垃圾回收,看是否能释放内存
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    # 最终内存统计
    snapshot_final = tracemalloc.take_snapshot()
    top_stats_final = snapshot_final.compare_to(snapshot1, 'lineno')
    
    print("=" * 50)
    print("最终内存统计:")
    print("=" * 50)
    
    total_leak = 0
    for stat in top_stats_final[:10]:
        if stat.size_diff > 0:  # 只关注内存增长
            total_leak += stat.size_diff
            #print(f"行号: {stat.traceback._frames[0]._lineno}")
            print(f"内存泄漏: {stat.size_diff / 1024:.2f} KB")
            print(f"调用栈: {stat.traceback.format()[:200]}...")
            print("-" * 30)
    
    print(f"\n总内存泄漏: {total_leak / 1024:.2f} KB")
    
    # 判断是否存在明显的内存泄漏
    if total_leak > 1024 * 1024:  # 如果泄漏超过 1MB
        print("🚨 检测到严重内存泄漏!")
    elif total_leak > 100 * 1024:  # 如果泄漏超过 100KB
        print("⚠️  检测到明显内存泄漏")
    else:
        print("✅ 未检测到明显内存泄漏")
    
    tracemalloc.stop()
    return total_leak


if __name__ == "__main__":
    leak_cat = test_torch_cat_memory_leak()
 
result = torch.cat(td_list, dim=0)', '  File "/usr/local/lib/python3.10/dist-packages/tensordict/base.py", line 673', '    return TD_HANDLED_FUNCTIONS[func](*args, **kwargs)', '  File "/usr/local/lib/python3.10/dist-packages/tensordict/_torch_func.py", line 371', '    out[key] = torch.cat(items, dim)']...

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions