Skip to content

Enable electron density learning #491

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions examples/disk-dataset/dump_to_disk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import ase.io
import torch
import tqdm
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatensor.torch.atomistic import NeighborListOptions, systems_to_torch

from metatrain.utils.data import DiskDatasetWriter
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists


disk_dataset_writer = DiskDatasetWriter("qm9_reduced_100.zip")
for i in tqdm.tqdm(range(100)):
frame = ase.io.read("qm9_reduced_100.xyz", index=i)
system = systems_to_torch(frame, dtype=torch.float64)
system = get_system_with_neighbor_lists(
system,
[NeighborListOptions(cutoff=5.0, full_list=True, strict=True)],
)
energy = TensorMap(
keys=Labels.single(),
blocks=[
TensorBlock(
values=torch.tensor([[frame.info["U0"]]], dtype=torch.float64),
samples=Labels(
names=["system"],
values=torch.tensor([[i]]),
),
components=[],
properties=Labels("energy", torch.tensor([[0]])),
)
],
)
disk_dataset_writer.write_sample(system, {"energy": energy})
del disk_dataset_writer
2 changes: 2 additions & 0 deletions examples/programmatic/electron_density/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Learning electron densities
===========================
111 changes: 111 additions & 0 deletions examples/programmatic/electron_density/electron_density.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""
Learning electron densities
===========================

This tutorial demonstrates how to train a model for the electron density of an
atomic system.
"""

import subprocess

import ase.io
import numpy as np
import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatensor.torch.atomistic import NeighborListOptions, systems_to_torch

from metatrain.utils.data import DiskDatasetWriter
from metatrain.utils.io import load_model
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists


def _get_fake_electron_density(atoms: ase.Atoms, structure_number: int) -> TensorMap:
# Returns a random electron-density-like TensorMap object
all_densities = {}
for o3_lambda in range(6):
for atomic_number in [1, 6, 7, 8, 9]:
base_n_properties = 5 if atomic_number == 1 else 10
all_densities[(o3_lambda, atomic_number)] = torch.tensor(
np.random.normal(
size=(
np.sum(atoms.numbers == atomic_number),
2 * o3_lambda + 1,
base_n_properties - o3_lambda,
)
),
dtype=torch.float64,
)

return TensorMap(
keys=Labels(
names=["o3_lambda", "o3_sigma", "center_type"],
values=torch.tensor(
[
[o3_lambda, 1, atomic_number]
for o3_lambda, atomic_number in all_densities.keys()
]
),
),
blocks=[
TensorBlock(
values=values,
samples=Labels(
names=["system", "atom"],
values=torch.tensor(
[
[structure_number, i]
for i, is_correct_type in enumerate(
atoms.numbers == atomic_number
)
if is_correct_type
]
).reshape(-1, 2),
),
components=[
Labels(
"o3_mu",
torch.arange(
-o3_lambda, o3_lambda + 1, dtype=torch.int
).reshape(-1, 1),
)
],
properties=Labels.range("properties", values.shape[2]),
)
for (o3_lambda, atomic_number), values in all_densities.items()
],
)


disk_dataset_writer = DiskDatasetWriter("qm9_reduced_100.zip")
for i in range(100):
frame = ase.io.read("qm9_reduced_100.xyz", index=i)
system = systems_to_torch(frame, dtype=torch.float64)
system = get_system_with_neighbor_lists(
system,
[NeighborListOptions(cutoff=5.0, full_list=True, strict=True)],
)
electron_density = _get_fake_electron_density(frame, i)
disk_dataset_writer.write_sample(
system, {"mtt::electron_density": electron_density}
)
del disk_dataset_writer

# %%
#
# Now that the dataset has been saved to disk, we can train a model on it.
# The model will be trained using the following training options.
#
# .. literalinclude:: options.yaml
# :language: yaml

# Launch `mtt train options.yaml` from this script
subprocess.run(["mtt", "train", "options.yaml"])

# Once the model is trained, we can load it and use it:
load_model("model.pt", extensions_directory="extensions/")

# %%
#
# Analysis and plotting (@Joe)

...
17 changes: 17 additions & 0 deletions examples/programmatic/electron_density/options.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
architecture:
name: experimental.nanopet
training:
batch_size: 16
num_epochs: 100

# Section defining the parameters for system and target data
training_set:
systems: "qm9_reduced_100.zip"
targets:
mtt::electron_density:
read_from: "qm9_reduced_100.zip"
quantity: "electron_density"
metatensor_target_disable_checks: true

validation_set: 0.1
test_set: 0.1
1 change: 1 addition & 0 deletions examples/programmatic/electron_density/qm9_reduced_100.xyz
Binary file not shown.
3 changes: 3 additions & 0 deletions examples/programmatic/electron_density/train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

mtt train options.yaml
55 changes: 37 additions & 18 deletions src/metatrain/experimental/nanopet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,28 +447,47 @@ def forward(
for output_name, last_layer in self.last_layers.items():
if output_name in outputs:
atomic_features = atomic_features_dict[output_name]
atomic_properties_by_block = []
for last_layer_by_block in last_layer.values():
atomic_properties_by_block.append(
last_layer_by_block(atomic_features)
atomic_properties_tensors = []
sample_values_list = []
for last_layer_key, last_layer_by_block in last_layer.items():
shape = self.output_shapes[output_name][last_layer_key]
if "center_type" in last_layer_key:
center_type = int(last_layer_key.split("center_type_")[-1])
center_type_mask = species == center_type
relevant_atomic_features = atomic_features[center_type_mask]
sample_values_list.append(sample_values[center_type_mask])
else:
relevant_atomic_features = atomic_features
sample_values_list.append(sample_values)
atomic_property = last_layer_by_block(relevant_atomic_features)
atomic_properties_tensors.append(
atomic_property.reshape([atomic_property.shape[0]] + shape)
)
blocks = [
TensorBlock(
values=atomic_property.reshape([-1] + shape),
samples=sample_labels,
components=components,
properties=properties,
)
for atomic_property, shape, components, properties in zip(
atomic_properties_by_block,
self.output_shapes[output_name].values(),
self.component_labels[output_name],
self.property_labels[output_name],
atomic_properties_blocks: List[TensorBlock] = []
for tensor, sv, components, properties in zip(
atomic_properties_tensors,
sample_values_list,
self.component_labels[output_name],
self.property_labels[output_name],
):
atomic_properties_blocks.append(
TensorBlock(
values=tensor,
samples=(
Labels(
names=["system", "atom"],
values=sv,
)
if "center_type" in last_layer_key
else sample_labels
),
components=components,
properties=properties,
)
)
]
atomic_properties_tmap_dict[output_name] = TensorMap(
keys=self.key_labels[output_name],
blocks=blocks,
blocks=atomic_properties_blocks,
)

if selected_atoms is not None:
Expand Down
9 changes: 6 additions & 3 deletions src/metatrain/experimental/nanopet/modules/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,13 @@ def _apply_wigner_D_matrices(
for key, block in target_tmap.items():
ell, sigma = int(key[0]), int(key[1])
values = block.values
split_indices = (
[int(torch.sum(system.types == key["center_type"])) for system in systems]
if "center_type" in key.names
else [len(system.positions) for system in systems]
)
if "atom" in block.samples.names:
split_values = torch.split(
values, [len(system.positions) for system in systems]
)
split_values = torch.split(values, split_indices)
else:
split_values = torch.split(values, [1 for _ in systems])
new_values = []
Expand Down
3 changes: 3 additions & 0 deletions src/metatrain/share/schema-dataset.json
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@
},
"virial": {
"$ref": "#/$defs/gradient_section"
},
"metatensor_target_disable_checks": {
"type": "boolean"
}
},
"additionalProperties": false
Expand Down
Loading
Loading