Skip to content

Torchscripting not possible with NeuralODE due to function redefinition #163

@StephenHogg

Description

@StephenHogg

Describe the bug

Torchscript is unable to script the NeuralODE class due to a function being redefined. This is a problem because there is control flow present in the code that tracing would not necessarily respect, implying that alternative would produce an incorrect output.

Step to Reproduce

Minimal working example:

f = nn.Sequential(
        nn.Linear(2, 16),
        nn.Tanh(),
        nn.Linear(16, 2)
    )

model = NeuralODE(f, solver='tsit5', solver_adjoint='dopri5')
out = torch.jit.script(model)

produces the following errors:

The first time you run it, the error is:

forward(__torch__.torch.nn.modules.container.Sequential self, Tensor input) -> (Tensor):
Expected at most 2 arguments but found 3 positional arguments.
:
  File "/home/shogg/.pyenv/versions/3.8.12/envs/mldi/lib/python3.8/site-packages/torchdyn/core/defunc.py", line 32
    def forward(self, t:Tensor, x:Tensor) -> Tensor:
        self.nfe += 1
        if self.has_time_arg: return self.vf(t, x)
                                     ~~~~~~~ <--- HERE
        else: return self.vf(x)

on subsequent attempts, the error is:

RuntimeError: Can't redefine method: forward on class: __torch__.torchdyn.core.defunc.DEFuncBase (of Python compilation unit at: 0x560bca2c9a70)

Note that changing the solvers doesn't appear to change anything.

Expected behavior

Torchscript should not break

I'm using the latest available version of torchdyn from pip and torch==1.11.0+cu102, would be very grateful for your advice as to how to torchscript this safely.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions