Skip to content

Commit a0d0fc5

Browse files
authored
Merge pull request #188 from jcallaham/184-args-in-hypersolver-step
Add `args` kwarg to hypersolvers
2 parents 74bd32b + 9b84dea commit a0d0fc5

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

torchdyn/numerics/solvers/hyper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(self, hypernet, dtype=torch.float32):
2121
self.stepping_class = 'fixed'
2222
self.op1 = self.order + 1
2323

24-
def step(self, f, x, t, dt, k1=None):
24+
def step(self, f, x, t, dt, k1=None, args=None):
2525
_, x_sol, _ = super().step(f, x, t, dt, k1)
2626
return None, x_sol + dt**(self.op1) * self.hypernet(t, x), None
2727

@@ -32,7 +32,7 @@ def __init__(self, hypernet, dtype=torch.float32):
3232
self.stepping_class = 'fixed'
3333
self.op1 = self.order + 1
3434

35-
def step(self, f, x, t, dt, k1=None):
35+
def step(self, f, x, t, dt, k1=None, args=None):
3636
_, x_sol, _ = super().step(f, x, t, dt, k1)
3737
return None, x_sol + dt**(self.op1) * self.hypernet(t, x), None
3838

@@ -42,6 +42,6 @@ def __init__(self, hypernet, dtype=torch.float32):
4242
self.hypernet = hypernet
4343
self.op1 = self.order + 1
4444

45-
def step(self, f, x, t, dt, k1=None):
45+
def step(self, f, x, t, dt, k1=None, args=None):
4646
_, x_sol, _ = super().step(f, x, t, dt, k1)
4747
return None, x_sol + dt**(self.op1) * self.hypernet(t, x), None

0 commit comments

Comments
 (0)