Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions torch_harmonics/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
class AttentionS2(nn.Module):
"""
(Global) attention on the 2-sphere.

Parameters
-----------
in_channels: int
Expand All @@ -67,6 +68,13 @@ class AttentionS2(nn.Module):
number of dimensions for interior inner product in the attention matrix (corresponds to kdim in MHA in PyTorch)
out_channels: int, optional
number of dimensions for interior inner product in the attention matrix (corresponds to vdim in MHA in PyTorch)

Reference
---------
Bonev, B., Rietmann, M., Paris, A., Carpentieri, A., & Kurth, T. (2025).
"Attention on the Sphere."
Advances in Neural Information Processing Systems (NeurIPS).
https://arxiv.org/abs/2505.11157
"""

def __init__(
Expand Down Expand Up @@ -210,6 +218,13 @@ class NeighborhoodAttentionS2(nn.Module):
number of dimensions for interior inner product in the attention matrix (corresponds to vdim in MHA in PyTorch)
optimized_kernel: Optional[bool]
Whether to use the optimized kernel (if available)

Reference
---------
Bonev, B., Rietmann, M., Paris, A., Carpentieri, A., & Kurth, T. (2025).
"Attention on the Sphere."
Advances in Neural Information Processing Systems (NeurIPS).
https://arxiv.org/abs/2505.11157
"""

def __init__(
Expand Down
54 changes: 31 additions & 23 deletions torch_harmonics/disco/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,6 @@ def _normalize_convolution_tensor_s2(
If basis_norm_mode is not one of the supported modes.
"""

# exit here if no normalization is needed
if basis_norm_mode == "none":
return psi_vals

# reshape the indices implicitly to be ikernel, out_shape[0], in_shape[0], in_shape[1]
idx = torch.stack([psi_idx[0], psi_idx[1], psi_idx[2] // in_shape[1], psi_idx[2] % in_shape[1]], dim=0)

Expand All @@ -117,7 +113,8 @@ def _normalize_convolution_tensor_s2(
q = quad_weights[ilat_in].reshape(-1)

# buffer to store intermediate values
vnorm = torch.zeros(kernel_size, nlat_out, device=psi_vals.device)
bias = torch.zeros(kernel_size, nlat_out, device=psi_vals.device)
scale = torch.zeros(kernel_size, nlat_out, device=psi_vals.device)
support = torch.zeros(kernel_size, nlat_out, device=psi_vals.device)

# loop through dimensions to compute the norms
Expand All @@ -128,8 +125,14 @@ def _normalize_convolution_tensor_s2(
iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat))

# compute the 1-norm
# vnorm[ik, ilat] = torch.sqrt(torch.sum(psi_vals[iidx].abs().pow(2) * q[iidx]))
vnorm[ik, ilat] = torch.sum(psi_vals[iidx].abs() * q[iidx])
# scale[ik, ilat] = torch.sqrt(torch.sum(psi_vals[iidx].abs().pow(2) * q[iidx]))
if basis_norm_mode == "modal":
# if ik != 0:
# bias[ik, ilat] = torch.sum(psi_vals[iidx] * q[iidx])
# scale[ik, ilat] = torch.sqrt(torch.sum((psi_vals[iidx] - bias[ik, ilat]).abs().pow(2) * q[iidx]))
scale[ik, ilat] = torch.sum((psi_vals[iidx] - bias[ik, ilat]).abs() * q[iidx])
else:
scale[ik, ilat] = torch.sum((psi_vals[iidx] - bias[ik, ilat]).abs() * q[iidx])

# compute the support
support[ik, ilat] = torch.sum(q[iidx])
Expand All @@ -140,18 +143,23 @@ def _normalize_convolution_tensor_s2(

iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat))

if basis_norm_mode == "individual":
val = vnorm[ik, ilat]
elif basis_norm_mode == "mean":
val = vnorm[ik, :].mean()
if basis_norm_mode in ["nodal", "individual", "modal"]:
b = bias[ik, ilat]
s = scale[ik, ilat]
elif basis_norm_mode in ["mean"]:
if ilat == 0:
b = bias[ik, :].mean()
s = scale[ik, :].mean()
elif basis_norm_mode == "support":
val = support[ik, ilat]
b = 0.0
s = support[ik, ilat]
elif basis_norm_mode == "none":
val = 1.0
b = 0.0
s = 1.0
else:
raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.")

psi_vals[iidx] = psi_vals[iidx] / (val + eps)
psi_vals[iidx] = (psi_vals[iidx] - b)/ (s + eps)

if merge_quadrature:
psi_vals[iidx] = psi_vals[iidx] * q[iidx]
Expand All @@ -167,13 +175,13 @@ def _precompute_convolution_tensor_s2(
in_shape: Tuple[int],
out_shape: Tuple[int],
filter_basis: FilterBasis,
grid_in: Optional[str]="equiangular",
grid_out: Optional[str]="equiangular",
theta_cutoff: Optional[float]=0.01 * math.pi,
theta_eps: Optional[float]=1e-3,
transpose_normalization: Optional[bool]=False,
basis_norm_mode: Optional[str]="mean",
merge_quadrature: Optional[bool]=False,
grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular",
theta_cutoff: Optional[float] = 0.01 * math.pi,
theta_eps: Optional[float] = 1e-3,
transpose_normalization: Optional[bool] = False,
basis_norm_mode: Optional[str] = "nodal",
merge_quadrature: Optional[bool] = False,
):
r"""
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
Expand Down Expand Up @@ -439,7 +447,7 @@ def __init__(
out_shape: Tuple[int],
kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
basis_type: Optional[str] = "piecewise linear",
basis_norm_mode: Optional[str] = "mean",
basis_norm_mode: Optional[str] = "nodal",
groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular",
Expand Down Expand Up @@ -578,7 +586,7 @@ def __init__(
out_shape: Tuple[int],
kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
basis_type: Optional[str] = "piecewise linear",
basis_norm_mode: Optional[str] = "mean",
basis_norm_mode: Optional[str] = "nodal",
groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular",
Expand Down
8 changes: 5 additions & 3 deletions torch_harmonics/filter_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,12 @@ def get_filter_basis(kernel_shape: Union[int, Tuple[int], Tuple[int, int]], basi

if basis_type == "piecewise linear":
return PiecewiseLinearFilterBasis(kernel_shape=kernel_shape)
elif basis_type == "morlet":
return MorletFilterBasis(kernel_shape=kernel_shape)
elif basis_type == "harmonic":
return HarmonicFilterBasis(kernel_shape=kernel_shape)
elif basis_type == "zernike":
return ZernikeFilterBasis(kernel_shape=kernel_shape)
elif basis_type == "morlet":
raise NotImplementedError("Morlet basis functions are not supported anymore. Use harmonic basis functions with a Morlet window function instead.")
else:
raise ValueError(f"Unknown basis_type {basis_type}")

Expand Down Expand Up @@ -214,7 +216,7 @@ def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: flo
return self._compute_support_vals_isotropic(r, phi, r_cutoff=r_cutoff)


class MorletFilterBasis(FilterBasis):
class HarmonicFilterBasis(FilterBasis):
"""Morlet-style filter basis on the disk. A Gaussian is multiplied with a Fourier basis in x and y directions."""

def __init__(
Expand Down
8 changes: 4 additions & 4 deletions torch_harmonics/sht.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class RealSHT(nn.Module):
"""

def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):


super().__init__()

Expand Down Expand Up @@ -252,7 +252,7 @@ def forward(self, x: torch.Tensor):
x[..., 0].imag = 0.0
if (self.nlon % 2 == 0) and (self.nlon // 2 < self.mmax):
x[..., self.nlon // 2].imag = 0.0

x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")

return x
Expand Down Expand Up @@ -293,7 +293,7 @@ class RealVectorSHT(nn.Module):
"""

def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):


super().__init__()

Expand Down Expand Up @@ -488,7 +488,7 @@ def forward(self, x: torch.Tensor):
x[..., 0].imag = 0.0
if (self.nlon % 2 == 0) and (self.nlon // 2 < self.mmax):
x[..., self.nlon // 2].imag = 0.0

x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")

return x
Loading