@@ -21,7 +21,7 @@ def __init__(self, hypernet, dtype=torch.float32):
21
21
self .stepping_class = 'fixed'
22
22
self .op1 = self .order + 1
23
23
24
- def step (self , f , x , t , dt , k1 = None ):
24
+ def step (self , f , x , t , dt , k1 = None , args = None ):
25
25
_ , x_sol , _ = super ().step (f , x , t , dt , k1 )
26
26
return None , x_sol + dt ** (self .op1 ) * self .hypernet (t , x ), None
27
27
@@ -32,7 +32,7 @@ def __init__(self, hypernet, dtype=torch.float32):
32
32
self .stepping_class = 'fixed'
33
33
self .op1 = self .order + 1
34
34
35
- def step (self , f , x , t , dt , k1 = None ):
35
+ def step (self , f , x , t , dt , k1 = None , args = None ):
36
36
_ , x_sol , _ = super ().step (f , x , t , dt , k1 )
37
37
return None , x_sol + dt ** (self .op1 ) * self .hypernet (t , x ), None
38
38
@@ -42,6 +42,6 @@ def __init__(self, hypernet, dtype=torch.float32):
42
42
self .hypernet = hypernet
43
43
self .op1 = self .order + 1
44
44
45
- def step (self , f , x , t , dt , k1 = None ):
45
+ def step (self , f , x , t , dt , k1 = None , args = None ):
46
46
_ , x_sol , _ = super ().step (f , x , t , dt , k1 )
47
47
return None , x_sol + dt ** (self .op1 ) * self .hypernet (t , x ), None
0 commit comments