Skip to content

Can't replicate O(1) memory with Adjoint method #259

@petercmh01

Description

@petercmh01

Hi, I tried to use the following code to test the effects of using adjoint method

import torch
import torch.nn as nn
from torchdiffeq import odeint_adjoint as odeint
import time

class LargeODEFunc(nn.Module):
    def __init__(self, state_dim, hidden_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, t, y):
        return self.net(y)

state_dim = 1024
hidden_dim = 512
output_dim = 1024

ode_func = LargeODEFunc(state_dim, hidden_dim, output_dim).cuda()

y0 = torch.randn(1, state_dim).cuda()
t_dense = torch.linspace(0, 1, 1000).cuda()
t_sparse = torch.linspace(0, 1, 10).cuda()

optimizer = torch.optim.Adam(ode_func.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

def print_cuda_memory_usage():
    print(f"Memory Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
    print(f"Memory Cached: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")

def train_step(all_time_steps, label):
    optimizer.zero_grad()
    torch.cuda.reset_peak_memory_stats()

    print("\nAfter optimizer.zero_grad() and before odeint()")
    print_cuda_memory_usage()

    y_ode = odeint(ode_func, y0, all_time_steps)[-1]
    
    target = torch.randn_like(y_ode).cuda()
    loss = loss_fn(y_ode, target)
    loss.backward()
    print("\nAfter loss_backward()")
    print_cuda_memory_usage()
    optimizer.step()
    torch.cuda.synchronize()
    optimizer.zero_grad()
    #print(f"Peak Memory = {peak_mem:.2f}MB, Decoded Steps = {len(all_time_steps)}")

print("Training with dense time steps:")
for epoch in range(5):
    train_step(t_dense, "dense")

As comparison, I seen to have lower memory consumption with the normal odeint instead of odeint_adjoint. Can anyone see if there is a problem in my design?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions