diff --git a/torch_harmonics/attention/attention.py b/torch_harmonics/attention/attention.py index d40fdc9d..3605825a 100644 --- a/torch_harmonics/attention/attention.py +++ b/torch_harmonics/attention/attention.py @@ -47,6 +47,7 @@ class AttentionS2(nn.Module): """ (Global) attention on the 2-sphere. + Parameters ----------- in_channels: int @@ -67,6 +68,13 @@ class AttentionS2(nn.Module): number of dimensions for interior inner product in the attention matrix (corresponds to kdim in MHA in PyTorch) out_channels: int, optional number of dimensions for interior inner product in the attention matrix (corresponds to vdim in MHA in PyTorch) + + Reference + --------- + Bonev, B., Rietmann, M., Paris, A., Carpentieri, A., & Kurth, T. (2025). + "Attention on the Sphere." + Advances in Neural Information Processing Systems (NeurIPS). + https://arxiv.org/abs/2505.11157 """ def __init__( @@ -210,6 +218,13 @@ class NeighborhoodAttentionS2(nn.Module): number of dimensions for interior inner product in the attention matrix (corresponds to vdim in MHA in PyTorch) optimized_kernel: Optional[bool] Whether to use the optimized kernel (if available) + + Reference + --------- + Bonev, B., Rietmann, M., Paris, A., Carpentieri, A., & Kurth, T. (2025). + "Attention on the Sphere." + Advances in Neural Information Processing Systems (NeurIPS). + https://arxiv.org/abs/2505.11157 """ def __init__( diff --git a/torch_harmonics/disco/convolution.py b/torch_harmonics/disco/convolution.py index 30c891e7..f9f7b944 100644 --- a/torch_harmonics/disco/convolution.py +++ b/torch_harmonics/disco/convolution.py @@ -92,10 +92,6 @@ def _normalize_convolution_tensor_s2( If basis_norm_mode is not one of the supported modes. """ - # exit here if no normalization is needed - if basis_norm_mode == "none": - return psi_vals - # reshape the indices implicitly to be ikernel, out_shape[0], in_shape[0], in_shape[1] idx = torch.stack([psi_idx[0], psi_idx[1], psi_idx[2] // in_shape[1], psi_idx[2] % in_shape[1]], dim=0) @@ -117,7 +113,8 @@ def _normalize_convolution_tensor_s2( q = quad_weights[ilat_in].reshape(-1) # buffer to store intermediate values - vnorm = torch.zeros(kernel_size, nlat_out, device=psi_vals.device) + bias = torch.zeros(kernel_size, nlat_out, device=psi_vals.device) + scale = torch.zeros(kernel_size, nlat_out, device=psi_vals.device) support = torch.zeros(kernel_size, nlat_out, device=psi_vals.device) # loop through dimensions to compute the norms @@ -128,8 +125,14 @@ def _normalize_convolution_tensor_s2( iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat)) # compute the 1-norm - # vnorm[ik, ilat] = torch.sqrt(torch.sum(psi_vals[iidx].abs().pow(2) * q[iidx])) - vnorm[ik, ilat] = torch.sum(psi_vals[iidx].abs() * q[iidx]) + # scale[ik, ilat] = torch.sqrt(torch.sum(psi_vals[iidx].abs().pow(2) * q[iidx])) + if basis_norm_mode == "modal": + # if ik != 0: + # bias[ik, ilat] = torch.sum(psi_vals[iidx] * q[iidx]) + # scale[ik, ilat] = torch.sqrt(torch.sum((psi_vals[iidx] - bias[ik, ilat]).abs().pow(2) * q[iidx])) + scale[ik, ilat] = torch.sum((psi_vals[iidx] - bias[ik, ilat]).abs() * q[iidx]) + else: + scale[ik, ilat] = torch.sum((psi_vals[iidx] - bias[ik, ilat]).abs() * q[iidx]) # compute the support support[ik, ilat] = torch.sum(q[iidx]) @@ -140,18 +143,23 @@ def _normalize_convolution_tensor_s2( iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat)) - if basis_norm_mode == "individual": - val = vnorm[ik, ilat] - elif basis_norm_mode == "mean": - val = vnorm[ik, :].mean() + if basis_norm_mode in ["nodal", "individual", "modal"]: + b = bias[ik, ilat] + s = scale[ik, ilat] + elif basis_norm_mode in ["mean"]: + if ilat == 0: + b = bias[ik, :].mean() + s = scale[ik, :].mean() elif basis_norm_mode == "support": - val = support[ik, ilat] + b = 0.0 + s = support[ik, ilat] elif basis_norm_mode == "none": - val = 1.0 + b = 0.0 + s = 1.0 else: raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.") - psi_vals[iidx] = psi_vals[iidx] / (val + eps) + psi_vals[iidx] = (psi_vals[iidx] - b)/ (s + eps) if merge_quadrature: psi_vals[iidx] = psi_vals[iidx] * q[iidx] @@ -167,13 +175,13 @@ def _precompute_convolution_tensor_s2( in_shape: Tuple[int], out_shape: Tuple[int], filter_basis: FilterBasis, - grid_in: Optional[str]="equiangular", - grid_out: Optional[str]="equiangular", - theta_cutoff: Optional[float]=0.01 * math.pi, - theta_eps: Optional[float]=1e-3, - transpose_normalization: Optional[bool]=False, - basis_norm_mode: Optional[str]="mean", - merge_quadrature: Optional[bool]=False, + grid_in: Optional[str] = "equiangular", + grid_out: Optional[str] = "equiangular", + theta_cutoff: Optional[float] = 0.01 * math.pi, + theta_eps: Optional[float] = 1e-3, + transpose_normalization: Optional[bool] = False, + basis_norm_mode: Optional[str] = "nodal", + merge_quadrature: Optional[bool] = False, ): r""" Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$. @@ -439,7 +447,7 @@ def __init__( out_shape: Tuple[int], kernel_shape: Union[int, Tuple[int], Tuple[int, int]], basis_type: Optional[str] = "piecewise linear", - basis_norm_mode: Optional[str] = "mean", + basis_norm_mode: Optional[str] = "nodal", groups: Optional[int] = 1, grid_in: Optional[str] = "equiangular", grid_out: Optional[str] = "equiangular", @@ -578,7 +586,7 @@ def __init__( out_shape: Tuple[int], kernel_shape: Union[int, Tuple[int], Tuple[int, int]], basis_type: Optional[str] = "piecewise linear", - basis_norm_mode: Optional[str] = "mean", + basis_norm_mode: Optional[str] = "nodal", groups: Optional[int] = 1, grid_in: Optional[str] = "equiangular", grid_out: Optional[str] = "equiangular", diff --git a/torch_harmonics/filter_basis.py b/torch_harmonics/filter_basis.py index 67965125..cb31e4d6 100644 --- a/torch_harmonics/filter_basis.py +++ b/torch_harmonics/filter_basis.py @@ -96,10 +96,12 @@ def get_filter_basis(kernel_shape: Union[int, Tuple[int], Tuple[int, int]], basi if basis_type == "piecewise linear": return PiecewiseLinearFilterBasis(kernel_shape=kernel_shape) - elif basis_type == "morlet": - return MorletFilterBasis(kernel_shape=kernel_shape) + elif basis_type == "harmonic": + return HarmonicFilterBasis(kernel_shape=kernel_shape) elif basis_type == "zernike": return ZernikeFilterBasis(kernel_shape=kernel_shape) + elif basis_type == "morlet": + raise NotImplementedError("Morlet basis functions are not supported anymore. Use harmonic basis functions with a Morlet window function instead.") else: raise ValueError(f"Unknown basis_type {basis_type}") @@ -214,7 +216,7 @@ def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: flo return self._compute_support_vals_isotropic(r, phi, r_cutoff=r_cutoff) -class MorletFilterBasis(FilterBasis): +class HarmonicFilterBasis(FilterBasis): """Morlet-style filter basis on the disk. A Gaussian is multiplied with a Fourier basis in x and y directions.""" def __init__( diff --git a/torch_harmonics/sht.py b/torch_harmonics/sht.py index e80cc6ae..739ca20a 100644 --- a/torch_harmonics/sht.py +++ b/torch_harmonics/sht.py @@ -72,7 +72,7 @@ class RealSHT(nn.Module): """ def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True): - + super().__init__() @@ -252,7 +252,7 @@ def forward(self, x: torch.Tensor): x[..., 0].imag = 0.0 if (self.nlon % 2 == 0) and (self.nlon // 2 < self.mmax): x[..., self.nlon // 2].imag = 0.0 - + x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward") return x @@ -293,7 +293,7 @@ class RealVectorSHT(nn.Module): """ def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True): - + super().__init__() @@ -488,7 +488,7 @@ def forward(self, x: torch.Tensor): x[..., 0].imag = 0.0 if (self.nlon % 2 == 0) and (self.nlon // 2 < self.mmax): x[..., self.nlon // 2].imag = 0.0 - + x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward") return x