diff --git a/pinn/pinn_1d.py b/pinn/pinn_1d.py index 368673d..24ccdef 100644 --- a/pinn/pinn_1d.py +++ b/pinn/pinn_1d.py @@ -285,6 +285,8 @@ def __init__(self, loss_type, loss_func=nn.MSELoss(), bc_weight=1.0): self.name = "PINN Loss" elif self.type == 1: self.name = "DRM Loss" + elif self.type== 2: + self.name = "DCGD(PINN+DRM) Loss" else: raise ValueError(f"Unknown loss type: {self.type}") self.bc_weight = bc_weight @@ -294,7 +296,7 @@ def super_loss(self, model, mesh, loss_func): x = mesh.x_train u = model.get_solution(x) loss = loss_func(u, mesh.u_ex) - return loss + return loss, () # "PINN" loss def pinn_loss(self, model, mesh, loss_func): @@ -306,7 +308,7 @@ def pinn_loss(self, model, mesh, loss_func): # Internal loss pde = mesh.pde - loss = loss_func(d2u_dx2[1:-1] + mesh.f[1:-1], pde.r * u[1:-1]) + loss = loss_pinn = loss_func(d2u_dx2[1:-1] + mesh.f[1:-1], pde.r * u[1:-1]) # Boundary loss if not model.enforce_bc: u_bc = u[[0, -1]] @@ -314,7 +316,8 @@ def pinn_loss(self, model, mesh, loss_func): loss_b = loss_func(u_bc, u_ex_bc) loss += self.bc_weight * loss_b - return loss + return loss, (loss_pinn, loss_b) + return loss, (loss_pinn,) def drm_loss(self, model, mesh: Mesh): """Deep Ritz Method loss""" @@ -332,7 +335,7 @@ def drm_loss(self, model, mesh: Mesh): fu_prod = f_val * u integrand_values = 0.5 * grad_u_pred_sq[1:-1] + 0.5 * mesh.pde.r * u_pred_sq[1:-1] - fu_prod[1:-1] - loss = torch.mean(integrand_values) + loss = loss_drm = torch.mean(integrand_values) # Boundary loss u_bc = u[[0,-1]] @@ -342,7 +345,7 @@ def drm_loss(self, model, mesh: Mesh): xs.requires_grad_(False) # Disable gradient tracking for x - return loss + return loss, (loss_drm, loss_b) def loss(self, model, mesh): if self.type == -1: @@ -351,6 +354,10 @@ def loss(self, model, mesh): loss_value = self.pinn_loss(model=model, mesh=mesh, loss_func=self.loss_func) elif self.type == 1: loss_value = self.drm_loss(model=model, mesh=mesh) + elif self.type == 2: + loss_p, pinn_losses = self.pinn_loss(model=model, mesh=mesh, loss_func=self.loss_func) + _, drm_losses = self.drm_loss(model=model, mesh=mesh) + loss_value = pinn_losses[0] + drm_losses[0], [pinn_losses[0], drm_losses[0]] else: raise ValueError(f"Unknown loss type: {self.type}") return loss_value @@ -359,8 +366,10 @@ def loss(self, model, mesh): # %% # Define the training loop def train(model, mesh, criterion, iterations, adam_iterations, learning_rate, - num_check, num_plots, sweep_idx, level_idx, frame_dir): + num_check, num_plots, sweep_idx, level_idx, frame_dir, loss_type=0): optimizer = optim.Adam(model.parameters(), lr=learning_rate) + if loss_type ==2: + optimizer = DCGD(optimizer, 1, type="center") # optimizer = SOAP(model.parameters(), lr = 3e-3, betas=(.95, .95), weight_decay=.01, # precondition_frequency=10) scheduler = StepLR(optimizer, step_size=1000, gamma=0.9) @@ -390,12 +399,16 @@ def closure(): # we need to set to zero the gradients of all model parameters (PyTorch accumulates grad by default) optimizer.zero_grad() # compute the loss value for the current batch of data - loss = criterion.loss(model=model, mesh=mesh) + loss, losses = criterion.loss(model=model, mesh=mesh) # backpropagation to compute gradients of model param respect to the loss. computes dloss/dx # for every parameter x which has requires_grad=True. - loss.backward() # update the model param doing an optim step using the computed gradients and learning rate - optimizer.step() + if loss_type != 2: + loss.backward() + optimizer.step() + else: + # Dual Cone GD optimizer + optimizer.step(losses) # scheduler.step() @@ -444,12 +457,13 @@ def main(args=None): # Input and output dimension: x -> u(x) dim_inputs = 1 dim_outputs = 1 + enforce_bc = args.enforce_bc if args.loss_type !=2 else True # Dual Cone GD enforces hard constarint on BC model = MultiLevelNN(mesh=mesh, num_levels=args.levels, dim_inputs=dim_inputs, dim_outputs=dim_outputs, dim_hidden=args.hidden_dims, act=get_activation(args.activation), - enforce_bc=args.enforce_bc) + enforce_bc=enforce_bc) print(model) model.to(device) # Plotting diff --git a/pinn/utils.py b/pinn/utils.py index c7c96fa..3bce07e 100644 --- a/pinn/utils.py +++ b/pinn/utils.py @@ -56,8 +56,8 @@ def parse_args(args=None): help="Learning rate for the optimizer.") parser.add_argument('--levels', type=int, default=4, help="Number of levels in multilevel training.") - parser.add_argument('--loss_type', type=int, default=0, choices=[-1, 0], - help="Loss type: -1 for supervised (true solution), 0 for PINN loss.") + parser.add_argument('--loss_type', type=int, default=0, choices=[-1, 0,1,2], + help="Loss type: -1 for supervised (true solution), 0 for PINN loss. 1 for DRM loss, 2 for DCGD loss") parser.add_argument('--activation', type=str, default='tanh', choices=['tanh', 'silu', 'relu', 'gelu', 'softmax'], help="Activation function to use.")