diff --git a/pinn/pinn_1d.py b/pinn/pinn_1d.py index 2599d8d..4fd5f86 100644 --- a/pinn/pinn_1d.py +++ b/pinn/pinn_1d.py @@ -47,6 +47,7 @@ import torch.optim as optim import numpy as np from enum import Enum +import warnings from utils import parse_args, get_activation, print_args, save_frame, make_video_from_frames from utils import is_notebook, cleanfiles, fourier_analysis, get_scheduler_generator, scheduler_step # from SOAP.soap import SOAP @@ -126,10 +127,38 @@ def set_pde(self, pde: PDE): # %% +class FourierEmbedding(nn.Module): + def __init__(self, dim_inputs:int, half_dim_output:int, sigma:int=5): + """ + Fourier Features Embedding for the input data. This can help learning high frequency functions. + Args: + dim_inputs: Dimension of the input data + half_dim_outputs: The output dimension is 2*half_dim_output. + sigma: Scaling factor for the frequencies. Recommended is [1, 10] + Ref: https://arxiv.org/abs/2006.10739 + """ + super().__init__() + self.sigma = sigma + m = half_dim_output # number of frequencies pairs (cos, sin) + B = torch.rand(m, dim_inputs) * sigma + self.B = nn.Parameter(B, requires_grad=False) # fixed frequencies coefficients + + def forward(self, x): + """ + x: Tensor of shape [batch_size, dim_inputs] + """ + if len(x.size()) == 1: + x = x.unsqueeze(0) # X mush be 2D + Bxs = torch.einsum("BD, MD-> BM", x, self.B) + sin_xs = torch.sin(Bxs) # (B, M) + cos_xs = torch.cos(Bxs) # (B, M) + feature = torch.cat([cos_xs, sin_xs], dim=-1) # (B, 2M) + return feature + # Define one level NN class Level(nn.Module): def __init__(self, dim_inputs, dim_outputs, dim_hidden: list, - act: nn.Module = nn.Tanh()) -> None: + act: nn.Module = nn.Tanh(), fourier_embedding_sigma:int=None) -> None: """Simple neural network with linear layers and non-linear activation function This class is used as universal function approximate for the solution of partial differential equations using PINNs @@ -137,10 +166,21 @@ def __init__(self, dim_inputs, dim_outputs, dim_hidden: list, super().__init__() self.dim_inputs = dim_inputs self.dim_outputs = dim_outputs - # multi-layer MLP - layer_dim = [dim_inputs] + dim_hidden + [dim_outputs] - self.linear = nn.ModuleList([nn.Linear(layer_dim[i], layer_dim[i + 1]) - for i in range(len(layer_dim) - 1)]) + if fourier_embedding_sigma is not None: + # Check output dim is divisible by 2 + if dim_hidden[0] != 2 * dim_hidden[0]//2: + dim_hidden[0] = 2 * dim_hidden[0]//2 + warnings.warn(f"dim_hidden[0] is changed to {dim_hidden[0]} to be divisible by 2 for Fourier embedding.") + # Fourier embedding + layer_dim = dim_hidden + [dim_outputs] + layer_fourier = FourierEmbedding(dim_inputs=dim_inputs, half_dim_output=dim_hidden[0]//2,sigma=fourier_embedding_sigma) + layers = [layer_fourier] + layers.extend([nn.Linear(layer_dim[i], layer_dim[i + 1]) for i in range(len(layer_dim) - 1)]) + else: + layer_dim = [dim_inputs] + dim_hidden + [dim_outputs] + # multi-layer MLP + layers = [nn.Linear(layer_dim[i], layer_dim[i + 1]) for i in range(len(layer_dim) - 1)] + self.linear = nn.ModuleList(layers) # activation function self.act = act @@ -165,12 +205,12 @@ class LevelStatus(Enum): # Define multilevel NN class MultiLevelNN(nn.Module): def __init__(self, mesh: Mesh, num_levels: int, dim_inputs, dim_outputs, dim_hidden: list, - act: nn.Module = nn.ReLU(), enforce_bc: bool = False) -> None: + act: nn.Module = nn.ReLU(), enforce_bc: bool = False, fourier_embedding_sigma:int=None) -> None: super().__init__() self.mesh = mesh # currently the same model on each level self.models = nn.ModuleList([ - Level(dim_inputs=dim_inputs, dim_outputs=dim_outputs, dim_hidden=dim_hidden, act=act) + Level(dim_inputs=dim_inputs, dim_outputs=dim_outputs, dim_hidden=dim_hidden, act=act, fourier_embedding_sigma= fourier_embedding_sigma) for _ in range(num_levels) ]) self.dim_inputs = dim_inputs @@ -449,7 +489,7 @@ def main(args=None): dim_inputs=dim_inputs, dim_outputs=dim_outputs, dim_hidden=args.hidden_dims, act=get_activation(args.activation), - enforce_bc=args.enforce_bc) + enforce_bc=args.enforce_bc, fourier_embedding_sigma=args.fourier_embedding_sigma) print(model) model.to(device) # Plotting diff --git a/pinn/utils.py b/pinn/utils.py index b7720b8..c443935 100644 --- a/pinn/utils.py +++ b/pinn/utils.py @@ -101,6 +101,7 @@ def parse_args(args=None): help="Configuration for learning rate scheduler. " "Follow https://docs.pytorch.org/docs/stable/optim.html for full list of schedulers. " "The setting is corresponding to `--scheduler` setting.") + parser.add_argument("--fourier_embedding_sigma", type=float, default=-1, help="Sigma for Fourier embedding. Recommended [1,10] ") args = parser.parse_args(args)