Skip to content

Enable fast attention in nanoPET #454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions src/metatrain/experimental/nanopet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class NanoPET(torch.nn.Module):
"""

__supported_devices__ = ["cuda", "cpu"]
__supported_dtypes__ = [torch.float64, torch.float32]
__supported_dtypes__ = [torch.float32, torch.float64]

component_labels: Dict[str, List[List[Labels]]]

Expand Down Expand Up @@ -77,10 +77,7 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
4 * self.hypers["d_pet"],
self.hypers["num_heads"],
self.hypers["num_attention_layers"],
0.0, # MLP dropout rate
0.0, # attention dropout rate
)
# empirically, the model seems to perform better without dropout

self.num_mp_layers = self.hypers["num_gnn_layers"] - 1
gnn_contractions = []
Expand All @@ -97,8 +94,6 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
4 * self.hypers["d_pet"],
self.hypers["num_heads"],
self.hypers["num_attention_layers"],
0.0, # MLP dropout rate
0.0, # attention dropout rate
)
)
self.gnn_contractions = torch.nn.ModuleList(gnn_contractions)
Expand Down Expand Up @@ -278,6 +273,10 @@ def forward(

edge_vectors = positions[neighbors] - positions[centers] + cell_contributions

# the scaled_dot_product_attention function from torch cannot do
# double backward, so we will use manual attention if needed
use_manual_attention = edge_vectors.requires_grad and self.training

bincount = torch.bincount(centers)
if bincount.numel() == 0: # no edges
max_edges_per_node = 0
Expand Down Expand Up @@ -320,7 +319,7 @@ def forward(
features = self.encoder(features)

# Transformer
features = self.transformer(features, radial_mask)
features = self.transformer(features, radial_mask, use_manual_attention)

# GNN
if self.num_mp_layers > 0:
Expand All @@ -342,7 +341,9 @@ def forward(
)
new_features = contraction(new_features)
new_features = edge_array_to_nef(new_features, nef_indices)
new_features = transformer(new_features, radial_mask)
new_features = transformer(
new_features, radial_mask, use_manual_attention
)
features = (features + new_features) * 0.5**0.5

edge_features = features * radial_mask[:, :, None]
Expand Down
37 changes: 21 additions & 16 deletions src/metatrain/experimental/nanopet/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,19 @@ def __init__(
self,
hidden_size: int,
num_heads: int,
dropout_rate: float,
attention_dropout_rate: float,
):
super().__init__()

self.num_heads = num_heads
self.in_proj = torch.nn.Linear(hidden_size, 3 * hidden_size, bias=False)
self.out_proj = torch.nn.Linear(hidden_size, hidden_size, bias=False)
self.layernorm = torch.nn.LayerNorm(normalized_shape=hidden_size)
self.attention_dropout_rate = attention_dropout_rate

def forward(
self,
inputs: torch.Tensor, # seq_len hidden_size
radial_mask: torch.Tensor, # seq_len
use_manual_attention: bool,
) -> torch.Tensor: # seq_len hidden_size
# Pre-layer normalization
normed_inputs = self.layernorm(inputs)
Expand All @@ -41,21 +39,19 @@ def forward(
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# Attention
attention_weights = torch.matmul(q, k.transpose(-2, -1)) / (k.size(-1) ** 0.5)
attention_weights = attention_weights.softmax(dim=-1)
attention_weights = torch.nn.functional.dropout(
attention_weights, p=self.attention_dropout_rate, training=self.training
)

# Radial mask
attention_weights = attention_weights * radial_mask[:, None, None, :]
attention_weights = attention_weights / (
attention_weights.sum(dim=-1, keepdim=True) + 1e-6
)
# Attention
attn_mask = torch.log(radial_mask[:, None, None, :])
if use_manual_attention:
attention_output = manual_attention(q, k, v, attn_mask)
else:
attention_output = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attn_mask,
)

# Attention output
attention_output = torch.matmul(attention_weights, v)
attention_output = attention_output.transpose(1, 2)
attention_output = attention_output.reshape(
attention_output.size(0),
Expand All @@ -70,3 +66,12 @@ def forward(
outputs = (outputs + inputs) * 0.5**0.5

return outputs


def manual_attention(q, k, v, attn_mask):
attention_weights = (
torch.matmul(q, k.transpose(-2, -1)) / (k.size(-1) ** 0.5)
) + attn_mask
attention_weights = attention_weights.softmax(dim=-1)
attention_output = torch.matmul(attention_weights, v)
return attention_output
6 changes: 0 additions & 6 deletions src/metatrain/experimental/nanopet/modules/feedforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ def __init__(
self,
hidden_size: int,
intermediate_size: int,
dropout_rate: float,
):
super().__init__()

Expand All @@ -18,9 +17,7 @@ def __init__(
self.output = torch.nn.Linear(
in_features=intermediate_size, out_features=hidden_size, bias=False
)

self.layernorm = torch.nn.LayerNorm(normalized_shape=hidden_size)
self.dropout = torch.nn.Dropout(dropout_rate)

def forward(
self,
Expand All @@ -36,9 +33,6 @@ def forward(
# Project back to input size
outputs = self.output(hidden)

# Apply dropout
outputs = self.dropout(outputs)

# Residual connection
outputs = (outputs + inputs) * 0.5**0.5

Expand Down
17 changes: 6 additions & 11 deletions src/metatrain/experimental/nanopet/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,27 @@ def __init__(
hidden_size: int,
intermediate_size: int,
num_heads: int,
dropout_rate: float,
attention_dropout_rate: float,
):
super().__init__()

self.attention_block = AttentionBlock(
hidden_size=hidden_size,
num_heads=num_heads,
dropout_rate=dropout_rate,
attention_dropout_rate=attention_dropout_rate,
)
self.ff_block = FeedForwardBlock(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
dropout_rate=dropout_rate,
)

def forward(
self,
inputs: torch.Tensor,
radial_mask: torch.Tensor,
use_manual_attention: bool,
) -> torch.Tensor:
attention_output = self.attention_block(inputs, radial_mask)
attention_output = self.attention_block(
inputs, radial_mask, use_manual_attention
)
output = self.ff_block(attention_output)

return output
Expand All @@ -49,8 +47,6 @@ def __init__(
intermediate_size: int,
num_heads: int,
num_layers: int,
dropout_rate: float,
attention_dropout_rate: float,
):
super().__init__()

Expand All @@ -60,8 +56,6 @@ def __init__(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_heads=num_heads,
dropout_rate=dropout_rate,
attention_dropout_rate=attention_dropout_rate,
)
for _ in range(num_layers)
]
Expand All @@ -71,8 +65,9 @@ def forward(
self,
inputs,
radial_mask,
use_manual_attention: bool,
):
x = inputs
for layer in self.layers:
x = layer(x, radial_mask)
x = layer(x, radial_mask, use_manual_attention)
return x
20 changes: 20 additions & 0 deletions src/metatrain/experimental/nanopet/tests/test_functionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from omegaconf import OmegaConf

from metatrain.experimental.nanopet.model import NanoPET
from metatrain.experimental.nanopet.modules.attention import AttentionBlock
from metatrain.utils.architectures import check_architecture_options
from metatrain.utils.data import DatasetInfo
from metatrain.utils.data.target_info import (
Expand Down Expand Up @@ -466,3 +467,22 @@ def test_spherical_output_multi_block(per_atom):
{"spherical_tensor": model.outputs["spherical_tensor"]},
)
assert len(outputs["spherical_tensor"]) == 3


def test_consistency():
"""Tests that the two implementations of attention are consistent."""

num_centers = 100
num_neighbors_per_center = 50
hidden_size = 128
num_heads = 4

attention = AttentionBlock(hidden_size, num_heads)

inputs = torch.randn(num_centers, num_neighbors_per_center, hidden_size)
radial_mask = torch.rand(num_centers, num_neighbors_per_center)

attention_output_torch = attention(inputs, radial_mask, use_manual_attention=False)
attention_output_manual = attention(inputs, radial_mask, use_manual_attention=True)

assert torch.allclose(attention_output_torch, attention_output_manual, atol=1e-6)
20 changes: 10 additions & 10 deletions src/metatrain/experimental/nanopet/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ def test_regression_init():

expected_output = torch.tensor(
[
[-0.031101370230],
[0.057578764856],
[0.023473259062],
[-0.040009934455],
[0.019827004522],
[0.030772615224],
[0.014496028423],
[-0.018107019365],
[0.051709491760],
[0.006714724004],
]
)

Expand Down Expand Up @@ -113,11 +113,11 @@ def test_regression_train():

expected_output = torch.tensor(
[
[0.909050107002],
[0.501401424408],
[0.290860712528],
[0.577842593193],
[0.250561147928],
[0.388283729553],
[0.310406684875],
[0.207200437784],
[0.224172845483],
[0.119103781879],
]
)

Expand Down
38 changes: 9 additions & 29 deletions src/metatrain/experimental/nanopet/trainer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import copy
import logging
from pathlib import Path
from typing import Dict, List, Tuple, Union
from typing import List, Union

import torch
import torch.distributed
from metatensor.torch import TensorMap
from metatensor.torch.atomistic import System
from torch.utils.data import DataLoader, DistributedSampler

from ...utils.additive import remove_additive
Expand All @@ -25,6 +23,10 @@
)
from ...utils.per_atom import average_by_num_atoms
from ...utils.scaler import remove_scale
from ...utils.transfer import (
systems_and_targets_to_device,
systems_and_targets_to_dtype,
)
from .model import NanoPET
from .modules.augmentation import RotationalAugmenter

Expand Down Expand Up @@ -238,27 +240,7 @@ def train(
logger.info(f"Initial learning rate: {old_lr}")

start_epoch = 0 if self.epoch is None else self.epoch + 1

@torch.jit.script
def systems_and_targets_to_device(
systems: List[System], targets: Dict[str, TensorMap], device: torch.device
) -> Tuple[List[System], Dict[str, TensorMap]]:
return (
[system.to(device=device) for system in systems],
{key: value.to(device=device) for key, value in targets.items()},
)

@torch.jit.script
def systems_and_targets_to_dtype(
systems: List[System], targets: Dict[str, TensorMap], dtype: torch.dtype
) -> Tuple[List[System], Dict[str, TensorMap]]:
return (
[system.to(dtype=dtype) for system in systems],
{key: value.to(dtype=dtype) for key, value in targets.items()},
)

rotational_augmenter = RotationalAugmenter(train_targets)

# Train the model:
if self.best_metric is None:
self.best_metric = float("inf")
Expand Down Expand Up @@ -340,10 +322,9 @@ def systems_and_targets_to_dtype(
val_loss = 0.0
for batch in val_dataloader:
systems, targets = batch
systems = [system.to(device=device) for system in systems]
targets = {
key: value.to(device=device) for key, value in targets.items()
}
systems, targets = systems_and_targets_to_device(
systems, targets, device
)
for additive_model in (
model.module if is_distributed else model
).additive_models:
Expand All @@ -353,8 +334,7 @@ def systems_and_targets_to_dtype(
targets = remove_scale(
targets, (model.module if is_distributed else model).scaler
)
systems = [system.to(dtype=dtype) for system in systems]
targets = {key: value.to(dtype=dtype) for key, value in targets.items()}
systems, targets = systems_and_targets_to_dtype(systems, targets, dtype)
predictions = evaluate_model(
model,
systems,
Expand Down