diff --git a/torchdiffeq/_impl/fixed_grid_implicit.py b/torchdiffeq/_impl/fixed_grid_implicit.py index 7519efc8..682964a3 100644 --- a/torchdiffeq/_impl/fixed_grid_implicit.py +++ b/torchdiffeq/_impl/fixed_grid_implicit.py @@ -34,31 +34,31 @@ class ImplicitMidpoint(FixedGridFIRKODESolver): order = 2 tableau = _IMPLICIT_MIDPOINT_TABLEAU -_GAUSS_LEGENDRE_4_TABLEAU = _ButcherTableau( - alpha=torch.tensor([1 / 2 - _sqrt_3 / 6, 1 / 2 - _sqrt_3 / 6], dtype=torch.float64), - beta=[ - torch.tensor([1 / 4, 1 / 4 - _sqrt_3 / 6], dtype=torch.float64), - torch.tensor([1 / 4 + _sqrt_3 / 6, 1 / 4], dtype=torch.float64), - ], - c_sol=torch.tensor([1 / 2, 1 / 2], dtype=torch.float64), - c_error=torch.tensor([], dtype=torch.float64), -) - _TRAPEZOID_TABLEAU = _ButcherTableau( alpha=torch.tensor([0, 1], dtype=torch.float64), beta=[ - torch.tensor([0, 0], dtype=torch.float64), + torch.tensor([0], dtype=torch.float64), torch.tensor([1 /2, 1 / 2], dtype=torch.float64), ], c_sol=torch.tensor([1 / 2, 1 / 2], dtype=torch.float64), c_error=torch.tensor([], dtype=torch.float64), ) -class Trapezoid(FixedGridFIRKODESolver): +class Trapezoid(FixedGridDIRKODESolver): order = 2 tableau = _TRAPEZOID_TABLEAU +_GAUSS_LEGENDRE_4_TABLEAU = _ButcherTableau( + alpha=torch.tensor([1 / 2 - _sqrt_3 / 6, 1 / 2 + _sqrt_3 / 6], dtype=torch.float64), + beta=[ + torch.tensor([1 / 4, 1 / 4 - _sqrt_3 / 6], dtype=torch.float64), + torch.tensor([1 / 4 + _sqrt_3 / 6, 1 / 4], dtype=torch.float64), + ], + c_sol=torch.tensor([1 / 2, 1 / 2], dtype=torch.float64), + c_error=torch.tensor([], dtype=torch.float64), +) + class GaussLegendre4(FixedGridFIRKODESolver): order = 4 tableau = _GAUSS_LEGENDRE_4_TABLEAU diff --git a/torchdiffeq/_impl/rk_common.py b/torchdiffeq/_impl/rk_common.py index f0050dbf..c242a7d3 100644 --- a/torchdiffeq/_impl/rk_common.py +++ b/torchdiffeq/_impl/rk_common.py @@ -379,34 +379,13 @@ class FixedGridFIRKODESolver(FixedGridODESolver): order: int tableau: _ButcherTableau - def __init__(self, func, y0, step_size=None, grid_constructor=None, interp='linear', perturb=False, max_iters=100, **unused_kwargs): + def __init__(self, func, y0, rtol=1e-6, atol=0., max_iters=100, **kwargs): + super(FixedGridFIRKODESolver, self).__init__(func, y0, rtol=rtol, atol=rtol, **kwargs) + self.rtol = rtol + self.atol = atol self.max_iters = max_iters - self.atol = unused_kwargs.pop('atol') - unused_kwargs.pop('rtol', None) - unused_kwargs.pop('norm', None) - _handle_unused_kwargs(self, unused_kwargs) - del unused_kwargs - - self.func = func - self.y0 = y0 - self.dtype = y0.dtype - self.device = y0.device - self.step_size = step_size - self.interp = interp - self.perturb = perturb - - if step_size is None: - if grid_constructor is None: - self.grid_constructor = lambda f, y0, t: t - else: - self.grid_constructor = grid_constructor - else: - if grid_constructor is None: - self.grid_constructor = self._grid_constructor_from_step_size(step_size) - else: - raise ValueError("step_size and grid_constructor are mutually exclusive arguments.") - + self.tableau = _ButcherTableau(alpha=self.tableau.alpha.to(device=self.device, dtype=y0.dtype), beta=[b.to(device=self.device, dtype=y0.dtype) for b in self.tableau.beta], c_sol=self.tableau.c_sol.to(device=self.device, dtype=y0.dtype), @@ -422,11 +401,6 @@ def _step_func(self, func, t0, dt, t1, y0): f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE) t_dtype = y0.abs().dtype - tol = 1e-8 - if t_dtype == torch.float64: - tol = 1e-8 - if t_dtype == torch.float32: - tol = 1e-6 t0 = t0.to(t_dtype) dt = dt.to(t_dtype) @@ -438,17 +412,23 @@ def _step_func(self, func, t0, dt, t1, y0): # Broyden's Method to solve the system of nonlinear equations y = torch.matmul(k, beta * dt).add(y0.unsqueeze(-1)).movedim(-1, 0) f = self._residual(func, k, y, t0, dt, t1) - J = torch.ones_like(f).diag() + # J = torch.ones_like(f).diag() + values = torch.ones_like(f) + indices = torch.arange(values.numel(), dtype=torch.int64, device=values.device) + indices = torch.stack((indices, indices)) + size = (f.numel(), f.numel()) + Jinv = torch.sparse_coo_tensor(indices, values, size).coalesce() # Sherman-Morrison Good Method + converged = False + dense_update = False for _ in range(self.max_iters): - if torch.linalg.norm(f, 2) < tol: + if torch.linalg.vector_norm(f) < (self.rtol * torch.linalg.vector_norm(k) + self.atol + torch.finfo(t_dtype).eps): converged = True break - # If the matrix becomes singular, just stop and return the last value - try: - s = -torch.linalg.solve(J, f) - except torch._C._LinAlgError: + # s = -torch.linalg.solve(J, f) + s = -torch.sparse.mm(Jinv, f.unsqueeze(1)).squeeze(1) + if not torch.all(torch.isfinite(s)): break k = k + s.reshape_as(k) @@ -456,7 +436,19 @@ def _step_func(self, func, t0, dt, t1, y0): newf = self._residual(func, k, y, t0, dt, t1) z = newf - f f = newf - J = J + (torch.outer ((z - torch.linalg.vecdot(J,s)),s)) / (torch.dot(s,s)) + # J = J + (torch.outer ((z - torch.linalg.vecdot(J,s)),s)) / (torch.dot(s,s)) + + sJinv = torch.sparse.mm(Jinv.t(), s.unsqueeze(1)).squeeze(1) + + if dense_update: + update = (torch.outer((s - torch.sparse.mm(Jinv, z.unsqueeze(1)).squeeze(1)), sJinv)) / (torch.dot(sJinv, z)) + update = update.to_sparse() + else: + # Only update nonzero elements + update = torch.mul((s - torch.sparse.mm(Jinv, z.unsqueeze(1)).squeeze(1))[Jinv.indices()[0,:]], sJinv[Jinv.indices()[1,:]]) / (torch.dot(sJinv, z)) + update = torch.sparse_coo_tensor(Jinv.indices(), update, size) + + Jinv = (Jinv + update).coalesce() if not converged: warnings.warn('Functional iteration did not converge. Solution may be incorrect.') @@ -495,11 +487,6 @@ def _step_func(self, func, t0, dt, t1, y0): f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE) t_dtype = y0.abs().dtype - tol = 1e-8 - if t_dtype == torch.float64: - tol = 1e-8 - if t_dtype == torch.float32: - tol = 1e-6 t0 = t0.to(t_dtype) dt = dt.to(t_dtype) @@ -525,17 +512,23 @@ def _step_func(self, func, t0, dt, t1, y0): # Broyden's Method to solve the system of nonlinear equations y_i = torch.matmul(k_i, beta_i * dt).add(y0) f = self._residual(func, k_i, y_i, ti, perturb) - J = torch.ones_like(f).diag() + # J = torch.ones_like(f).diag() + values = torch.ones_like(f) + indices = torch.arange(values.numel(), dtype=torch.int64, device=values.device) + indices = torch.stack((indices, indices)) + size = (f.numel(), f.numel()) + Jinv = torch.sparse_coo_tensor(indices, values, size).coalesce() # Sherman-Morrison Good Method + converged = False + dense_update = False for _ in range(self.max_iters): - if torch.linalg.norm(f, 2) < tol: + if torch.linalg.vector_norm(f) < (self.rtol * torch.linalg.vector_norm(k[i].unsqueeze(-1)) + self.atol + torch.finfo(t_dtype).eps): converged = True break - # If the matrix becomes singular, just stop and return the last value - try: - s = -torch.linalg.solve(J, f) - except torch._C._LinAlgError: + # s = -torch.linalg.solve(J, f) + s = -torch.sparse.mm(Jinv, f.unsqueeze(1)).squeeze(1) + if not torch.all(torch.isfinite(s)): break k[i] = k[i] + s.reshape_as(k[i]) @@ -544,7 +537,19 @@ def _step_func(self, func, t0, dt, t1, y0): newf = self._residual(func, k_i, y_i, ti, perturb) z = newf - f f = newf - J = J + (torch.outer ((z - torch.linalg.vecdot(J,s)),s)) / (torch.dot(s,s)) + # J = J + (torch.outer ((z - torch.linalg.vecdot(J,s)),s)) / (torch.dot(s,s)) + + sJinv = torch.sparse.mm(Jinv.t(), s.unsqueeze(1)).squeeze(1) + + if dense_update: + update = (torch.outer((s - torch.sparse.mm(Jinv, z.unsqueeze(1)).squeeze(1)), sJinv)) / (torch.dot(sJinv, z)) + update = update.to_sparse() + else: + # Only update nonzero elements + update = torch.mul((s - torch.sparse.mm(Jinv, z.unsqueeze(1)).squeeze(1))[Jinv.indices()[0,:]], sJinv[Jinv.indices()[1,:]]) / (torch.dot(sJinv, z)) + update = torch.sparse_coo_tensor(Jinv.indices(), update, size) + + Jinv = (Jinv + update).coalesce() if not converged: warnings.warn('Functional iteration did not converge. Solution may be incorrect.')