-
Notifications
You must be signed in to change notification settings - Fork 981
Open
Description
Hey, thank you for the work that you put into this package!
I have a general question to understand how the odeint solver works.
Suppose, we have the following example of an oscillator:
from torchdiffeq import odeint
import matplotlib.pyplot as plt
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
t = torch.linspace(0, 10, 101).to(device)
true_y0 = torch.tensor([[np.pi - 0.1, 0.0]]).to(device)
true_A = torch.tensor([[-0.1, 2.0],[-2.0, -0.1]]).to(device)
class Lambda(nn.Module):
def __init__(self):
super(Lambda, self).__init__() # type: ignore
self.index = 0
def forward(self, t : torch.Tensor, y : torch.Tensor):
self.index += 1
b = 0.25
c = 5.0
theta, omega = y[0,0],y[0,1]
dydt = torch.tensor([omega, -b*omega - c*torch.sin(theta)],device = y.device)
dydt = dydt.unsqueeze(0)
return dydt
lambda_object = Lambda()
with torch.no_grad():
true_y = odeint(func = lambda_object, y0 = true_y0, t = t, method='dopri5')
print(f"The forward method was called {lambda_object.index} times whule the length of the t is {len(t)}.")
plt.title("ODE Integration (torchdiffeq)")
plt.plot(t.detach().cpu(),true_y[:,0,0].detach().cpu(), 'b', label='theta(t)')
plt.plot(t.detach().cpu(),true_y[:,0,1].detach().cpu(), 'g', label='omega(t)')
plt.legend(loc='best')
plt.xlabel('t')
plt.grid()
plt.show()
When I count the number of times the forward function is called, I would expect, that it equals the number of time steps in the
Why is this the case?
Ultimatively, I would like to add an external force that is called by an self.index in the forward method.
Metadata
Metadata
Assignees
Labels
No labels