Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions torchdiffeq/_impl/fixed_grid_implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,31 +34,31 @@ class ImplicitMidpoint(FixedGridFIRKODESolver):
order = 2
tableau = _IMPLICIT_MIDPOINT_TABLEAU

_GAUSS_LEGENDRE_4_TABLEAU = _ButcherTableau(
alpha=torch.tensor([1 / 2 - _sqrt_3 / 6, 1 / 2 - _sqrt_3 / 6], dtype=torch.float64),
beta=[
torch.tensor([1 / 4, 1 / 4 - _sqrt_3 / 6], dtype=torch.float64),
torch.tensor([1 / 4 + _sqrt_3 / 6, 1 / 4], dtype=torch.float64),
],
c_sol=torch.tensor([1 / 2, 1 / 2], dtype=torch.float64),
c_error=torch.tensor([], dtype=torch.float64),
)

_TRAPEZOID_TABLEAU = _ButcherTableau(
alpha=torch.tensor([0, 1], dtype=torch.float64),
beta=[
torch.tensor([0, 0], dtype=torch.float64),
torch.tensor([0], dtype=torch.float64),
torch.tensor([1 /2, 1 / 2], dtype=torch.float64),
],
c_sol=torch.tensor([1 / 2, 1 / 2], dtype=torch.float64),
c_error=torch.tensor([], dtype=torch.float64),
)

class Trapezoid(FixedGridFIRKODESolver):
class Trapezoid(FixedGridDIRKODESolver):
order = 2
tableau = _TRAPEZOID_TABLEAU


_GAUSS_LEGENDRE_4_TABLEAU = _ButcherTableau(
alpha=torch.tensor([1 / 2 - _sqrt_3 / 6, 1 / 2 + _sqrt_3 / 6], dtype=torch.float64),
beta=[
torch.tensor([1 / 4, 1 / 4 - _sqrt_3 / 6], dtype=torch.float64),
torch.tensor([1 / 4 + _sqrt_3 / 6, 1 / 4], dtype=torch.float64),
],
c_sol=torch.tensor([1 / 2, 1 / 2], dtype=torch.float64),
c_error=torch.tensor([], dtype=torch.float64),
)

class GaussLegendre4(FixedGridFIRKODESolver):
order = 4
tableau = _GAUSS_LEGENDRE_4_TABLEAU
Expand Down
105 changes: 55 additions & 50 deletions torchdiffeq/_impl/rk_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,34 +379,13 @@ class FixedGridFIRKODESolver(FixedGridODESolver):
order: int
tableau: _ButcherTableau

def __init__(self, func, y0, step_size=None, grid_constructor=None, interp='linear', perturb=False, max_iters=100, **unused_kwargs):
def __init__(self, func, y0, rtol=1e-6, atol=0., max_iters=100, **kwargs):
super(FixedGridFIRKODESolver, self).__init__(func, y0, rtol=rtol, atol=rtol, **kwargs)

self.rtol = rtol
self.atol = atol
self.max_iters = max_iters
self.atol = unused_kwargs.pop('atol')
unused_kwargs.pop('rtol', None)
unused_kwargs.pop('norm', None)
_handle_unused_kwargs(self, unused_kwargs)
del unused_kwargs

self.func = func
self.y0 = y0
self.dtype = y0.dtype
self.device = y0.device
self.step_size = step_size
self.interp = interp
self.perturb = perturb

if step_size is None:
if grid_constructor is None:
self.grid_constructor = lambda f, y0, t: t
else:
self.grid_constructor = grid_constructor
else:
if grid_constructor is None:
self.grid_constructor = self._grid_constructor_from_step_size(step_size)
else:
raise ValueError("step_size and grid_constructor are mutually exclusive arguments.")


self.tableau = _ButcherTableau(alpha=self.tableau.alpha.to(device=self.device, dtype=y0.dtype),
beta=[b.to(device=self.device, dtype=y0.dtype) for b in self.tableau.beta],
c_sol=self.tableau.c_sol.to(device=self.device, dtype=y0.dtype),
Expand All @@ -422,11 +401,6 @@ def _step_func(self, func, t0, dt, t1, y0):
f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE)

t_dtype = y0.abs().dtype
tol = 1e-8
if t_dtype == torch.float64:
tol = 1e-8
if t_dtype == torch.float32:
tol = 1e-6

t0 = t0.to(t_dtype)
dt = dt.to(t_dtype)
Expand All @@ -438,25 +412,43 @@ def _step_func(self, func, t0, dt, t1, y0):
# Broyden's Method to solve the system of nonlinear equations
y = torch.matmul(k, beta * dt).add(y0.unsqueeze(-1)).movedim(-1, 0)
f = self._residual(func, k, y, t0, dt, t1)
J = torch.ones_like(f).diag()
# J = torch.ones_like(f).diag()
values = torch.ones_like(f)
indices = torch.arange(values.numel(), dtype=torch.int64, device=values.device)
indices = torch.stack((indices, indices))
size = (f.numel(), f.numel())
Jinv = torch.sparse_coo_tensor(indices, values, size).coalesce() # Sherman-Morrison Good Method

converged = False
dense_update = False
for _ in range(self.max_iters):
if torch.linalg.norm(f, 2) < tol:
if torch.linalg.vector_norm(f) < (self.rtol * torch.linalg.vector_norm(k) + self.atol + torch.finfo(t_dtype).eps):
converged = True
break

# If the matrix becomes singular, just stop and return the last value
try:
s = -torch.linalg.solve(J, f)
except torch._C._LinAlgError:
# s = -torch.linalg.solve(J, f)
s = -torch.sparse.mm(Jinv, f.unsqueeze(1)).squeeze(1)
if not torch.all(torch.isfinite(s)):
break

k = k + s.reshape_as(k)
y = torch.matmul(k, beta * dt).add(y0.unsqueeze(-1)).movedim(-1, 0)
newf = self._residual(func, k, y, t0, dt, t1)
z = newf - f
f = newf
J = J + (torch.outer ((z - torch.linalg.vecdot(J,s)),s)) / (torch.dot(s,s))
# J = J + (torch.outer ((z - torch.linalg.vecdot(J,s)),s)) / (torch.dot(s,s))

sJinv = torch.sparse.mm(Jinv.t(), s.unsqueeze(1)).squeeze(1)

if dense_update:
update = (torch.outer((s - torch.sparse.mm(Jinv, z.unsqueeze(1)).squeeze(1)), sJinv)) / (torch.dot(sJinv, z))
update = update.to_sparse()
else:
# Only update nonzero elements
update = torch.mul((s - torch.sparse.mm(Jinv, z.unsqueeze(1)).squeeze(1))[Jinv.indices()[0,:]], sJinv[Jinv.indices()[1,:]]) / (torch.dot(sJinv, z))
update = torch.sparse_coo_tensor(Jinv.indices(), update, size)

Jinv = (Jinv + update).coalesce()

if not converged:
warnings.warn('Functional iteration did not converge. Solution may be incorrect.')
Expand Down Expand Up @@ -495,11 +487,6 @@ def _step_func(self, func, t0, dt, t1, y0):
f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE)

t_dtype = y0.abs().dtype
tol = 1e-8
if t_dtype == torch.float64:
tol = 1e-8
if t_dtype == torch.float32:
tol = 1e-6

t0 = t0.to(t_dtype)
dt = dt.to(t_dtype)
Expand All @@ -525,17 +512,23 @@ def _step_func(self, func, t0, dt, t1, y0):
# Broyden's Method to solve the system of nonlinear equations
y_i = torch.matmul(k_i, beta_i * dt).add(y0)
f = self._residual(func, k_i, y_i, ti, perturb)
J = torch.ones_like(f).diag()
# J = torch.ones_like(f).diag()
values = torch.ones_like(f)
indices = torch.arange(values.numel(), dtype=torch.int64, device=values.device)
indices = torch.stack((indices, indices))
size = (f.numel(), f.numel())
Jinv = torch.sparse_coo_tensor(indices, values, size).coalesce() # Sherman-Morrison Good Method

converged = False
dense_update = False
for _ in range(self.max_iters):
if torch.linalg.norm(f, 2) < tol:
if torch.linalg.vector_norm(f) < (self.rtol * torch.linalg.vector_norm(k[i].unsqueeze(-1)) + self.atol + torch.finfo(t_dtype).eps):
converged = True
break

# If the matrix becomes singular, just stop and return the last value
try:
s = -torch.linalg.solve(J, f)
except torch._C._LinAlgError:
# s = -torch.linalg.solve(J, f)
s = -torch.sparse.mm(Jinv, f.unsqueeze(1)).squeeze(1)
if not torch.all(torch.isfinite(s)):
break

k[i] = k[i] + s.reshape_as(k[i])
Expand All @@ -544,7 +537,19 @@ def _step_func(self, func, t0, dt, t1, y0):
newf = self._residual(func, k_i, y_i, ti, perturb)
z = newf - f
f = newf
J = J + (torch.outer ((z - torch.linalg.vecdot(J,s)),s)) / (torch.dot(s,s))
# J = J + (torch.outer ((z - torch.linalg.vecdot(J,s)),s)) / (torch.dot(s,s))

sJinv = torch.sparse.mm(Jinv.t(), s.unsqueeze(1)).squeeze(1)

if dense_update:
update = (torch.outer((s - torch.sparse.mm(Jinv, z.unsqueeze(1)).squeeze(1)), sJinv)) / (torch.dot(sJinv, z))
update = update.to_sparse()
else:
# Only update nonzero elements
update = torch.mul((s - torch.sparse.mm(Jinv, z.unsqueeze(1)).squeeze(1))[Jinv.indices()[0,:]], sJinv[Jinv.indices()[1,:]]) / (torch.dot(sJinv, z))
update = torch.sparse_coo_tensor(Jinv.indices(), update, size)

Jinv = (Jinv + update).coalesce()

if not converged:
warnings.warn('Functional iteration did not converge. Solution may be incorrect.')
Expand Down