-
Notifications
You must be signed in to change notification settings - Fork 134
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
Torchscript is unable to script the NeuralODE
class due to a function being redefined. This is a problem because there is control flow present in the code that tracing would not necessarily respect, implying that alternative would produce an incorrect output.
Step to Reproduce
Minimal working example:
f = nn.Sequential(
nn.Linear(2, 16),
nn.Tanh(),
nn.Linear(16, 2)
)
model = NeuralODE(f, solver='tsit5', solver_adjoint='dopri5')
out = torch.jit.script(model)
produces the following errors:
The first time you run it, the error is:
forward(__torch__.torch.nn.modules.container.Sequential self, Tensor input) -> (Tensor):
Expected at most 2 arguments but found 3 positional arguments.
:
File "/home/shogg/.pyenv/versions/3.8.12/envs/mldi/lib/python3.8/site-packages/torchdyn/core/defunc.py", line 32
def forward(self, t:Tensor, x:Tensor) -> Tensor:
self.nfe += 1
if self.has_time_arg: return self.vf(t, x)
~~~~~~~ <--- HERE
else: return self.vf(x)
on subsequent attempts, the error is:
RuntimeError: Can't redefine method: forward on class: __torch__.torchdyn.core.defunc.DEFuncBase (of Python compilation unit at: 0x560bca2c9a70)
Note that changing the solvers doesn't appear to change anything.
Expected behavior
Torchscript should not break
I'm using the latest available version of torchdyn from pip and torch==1.11.0+cu102
, would be very grateful for your advice as to how to torchscript this safely.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working