-
Notifications
You must be signed in to change notification settings - Fork 981
Open
Description
Hi, I tried to use the following code to test the effects of using adjoint method
import torch
import torch.nn as nn
from torchdiffeq import odeint_adjoint as odeint
import time
class LargeODEFunc(nn.Module):
def __init__(self, state_dim, hidden_dim, output_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
)
def forward(self, t, y):
return self.net(y)
state_dim = 1024
hidden_dim = 512
output_dim = 1024
ode_func = LargeODEFunc(state_dim, hidden_dim, output_dim).cuda()
y0 = torch.randn(1, state_dim).cuda()
t_dense = torch.linspace(0, 1, 1000).cuda()
t_sparse = torch.linspace(0, 1, 10).cuda()
optimizer = torch.optim.Adam(ode_func.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()
def print_cuda_memory_usage():
print(f"Memory Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
print(f"Memory Cached: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
def train_step(all_time_steps, label):
optimizer.zero_grad()
torch.cuda.reset_peak_memory_stats()
print("\nAfter optimizer.zero_grad() and before odeint()")
print_cuda_memory_usage()
y_ode = odeint(ode_func, y0, all_time_steps)[-1]
target = torch.randn_like(y_ode).cuda()
loss = loss_fn(y_ode, target)
loss.backward()
print("\nAfter loss_backward()")
print_cuda_memory_usage()
optimizer.step()
torch.cuda.synchronize()
optimizer.zero_grad()
#print(f"Peak Memory = {peak_mem:.2f}MB, Decoded Steps = {len(all_time_steps)}")
print("Training with dense time steps:")
for epoch in range(5):
train_step(t_dense, "dense")
As comparison, I seen to have lower memory consumption with the normal odeint instead of odeint_adjoint. Can anyone see if there is a problem in my design?
Metadata
Metadata
Assignees
Labels
No labels