diff --git a/examples/disk-dataset/dump_to_disk.py b/examples/disk-dataset/dump_to_disk.py new file mode 100644 index 000000000..9d34f0251 --- /dev/null +++ b/examples/disk-dataset/dump_to_disk.py @@ -0,0 +1,34 @@ +import ase.io +import torch +import tqdm +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.atomistic import NeighborListOptions, systems_to_torch + +from metatrain.utils.data import DiskDatasetWriter +from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists + + +disk_dataset_writer = DiskDatasetWriter("qm9_reduced_100.zip") +for i in tqdm.tqdm(range(100)): + frame = ase.io.read("qm9_reduced_100.xyz", index=i) + system = systems_to_torch(frame, dtype=torch.float64) + system = get_system_with_neighbor_lists( + system, + [NeighborListOptions(cutoff=5.0, full_list=True, strict=True)], + ) + energy = TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.tensor([[frame.info["U0"]]], dtype=torch.float64), + samples=Labels( + names=["system"], + values=torch.tensor([[i]]), + ), + components=[], + properties=Labels("energy", torch.tensor([[0]])), + ) + ], + ) + disk_dataset_writer.write_sample(system, {"energy": energy}) +del disk_dataset_writer diff --git a/examples/programmatic/electron_density/README.rst b/examples/programmatic/electron_density/README.rst new file mode 100644 index 000000000..520bc6818 --- /dev/null +++ b/examples/programmatic/electron_density/README.rst @@ -0,0 +1,2 @@ +Learning electron densities +=========================== diff --git a/examples/programmatic/electron_density/electron_density.py b/examples/programmatic/electron_density/electron_density.py new file mode 100644 index 000000000..add3dd170 --- /dev/null +++ b/examples/programmatic/electron_density/electron_density.py @@ -0,0 +1,111 @@ +""" +Learning electron densities +=========================== + +This tutorial demonstrates how to train a model for the electron density of an +atomic system. +""" + +import subprocess + +import ase.io +import numpy as np +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.atomistic import NeighborListOptions, systems_to_torch + +from metatrain.utils.data import DiskDatasetWriter +from metatrain.utils.io import load_model +from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists + + +def _get_fake_electron_density(atoms: ase.Atoms, structure_number: int) -> TensorMap: + # Returns a random electron-density-like TensorMap object + all_densities = {} + for o3_lambda in range(6): + for atomic_number in [1, 6, 7, 8, 9]: + base_n_properties = 5 if atomic_number == 1 else 10 + all_densities[(o3_lambda, atomic_number)] = torch.tensor( + np.random.normal( + size=( + np.sum(atoms.numbers == atomic_number), + 2 * o3_lambda + 1, + base_n_properties - o3_lambda, + ) + ), + dtype=torch.float64, + ) + + return TensorMap( + keys=Labels( + names=["o3_lambda", "o3_sigma", "center_type"], + values=torch.tensor( + [ + [o3_lambda, 1, atomic_number] + for o3_lambda, atomic_number in all_densities.keys() + ] + ), + ), + blocks=[ + TensorBlock( + values=values, + samples=Labels( + names=["system", "atom"], + values=torch.tensor( + [ + [structure_number, i] + for i, is_correct_type in enumerate( + atoms.numbers == atomic_number + ) + if is_correct_type + ] + ).reshape(-1, 2), + ), + components=[ + Labels( + "o3_mu", + torch.arange( + -o3_lambda, o3_lambda + 1, dtype=torch.int + ).reshape(-1, 1), + ) + ], + properties=Labels.range("properties", values.shape[2]), + ) + for (o3_lambda, atomic_number), values in all_densities.items() + ], + ) + + +disk_dataset_writer = DiskDatasetWriter("qm9_reduced_100.zip") +for i in range(100): + frame = ase.io.read("qm9_reduced_100.xyz", index=i) + system = systems_to_torch(frame, dtype=torch.float64) + system = get_system_with_neighbor_lists( + system, + [NeighborListOptions(cutoff=5.0, full_list=True, strict=True)], + ) + electron_density = _get_fake_electron_density(frame, i) + disk_dataset_writer.write_sample( + system, {"mtt::electron_density": electron_density} + ) +del disk_dataset_writer + +# %% +# +# Now that the dataset has been saved to disk, we can train a model on it. +# The model will be trained using the following training options. +# +# .. literalinclude:: options.yaml +# :language: yaml + +# Launch `mtt train options.yaml` from this script +subprocess.run(["mtt", "train", "options.yaml"]) + +# Once the model is trained, we can load it and use it: +load_model("model.pt", extensions_directory="extensions/") + +# %% +# +# Analysis and plotting (@Joe) + +... diff --git a/examples/programmatic/electron_density/options.yaml b/examples/programmatic/electron_density/options.yaml new file mode 100644 index 000000000..7065cc503 --- /dev/null +++ b/examples/programmatic/electron_density/options.yaml @@ -0,0 +1,17 @@ +architecture: + name: experimental.nanopet + training: + batch_size: 16 + num_epochs: 100 + +# Section defining the parameters for system and target data +training_set: + systems: "qm9_reduced_100.zip" + targets: + mtt::electron_density: + read_from: "qm9_reduced_100.zip" + quantity: "electron_density" + metatensor_target_disable_checks: true + +validation_set: 0.1 +test_set: 0.1 diff --git a/examples/programmatic/electron_density/qm9_reduced_100.xyz b/examples/programmatic/electron_density/qm9_reduced_100.xyz new file mode 120000 index 000000000..a98d02872 --- /dev/null +++ b/examples/programmatic/electron_density/qm9_reduced_100.xyz @@ -0,0 +1 @@ +../../../tests/resources/qm9_reduced_100.xyz \ No newline at end of file diff --git a/examples/programmatic/electron_density/qm9_reduced_100.zip b/examples/programmatic/electron_density/qm9_reduced_100.zip new file mode 100644 index 000000000..87eadafcf Binary files /dev/null and b/examples/programmatic/electron_density/qm9_reduced_100.zip differ diff --git a/examples/programmatic/electron_density/train.sh b/examples/programmatic/electron_density/train.sh new file mode 100644 index 000000000..f73b048e6 --- /dev/null +++ b/examples/programmatic/electron_density/train.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +mtt train options.yaml diff --git a/src/metatrain/experimental/nanopet/model.py b/src/metatrain/experimental/nanopet/model.py index c02f9f5b2..d847aedd3 100644 --- a/src/metatrain/experimental/nanopet/model.py +++ b/src/metatrain/experimental/nanopet/model.py @@ -447,28 +447,47 @@ def forward( for output_name, last_layer in self.last_layers.items(): if output_name in outputs: atomic_features = atomic_features_dict[output_name] - atomic_properties_by_block = [] - for last_layer_by_block in last_layer.values(): - atomic_properties_by_block.append( - last_layer_by_block(atomic_features) + atomic_properties_tensors = [] + sample_values_list = [] + for last_layer_key, last_layer_by_block in last_layer.items(): + shape = self.output_shapes[output_name][last_layer_key] + if "center_type" in last_layer_key: + center_type = int(last_layer_key.split("center_type_")[-1]) + center_type_mask = species == center_type + relevant_atomic_features = atomic_features[center_type_mask] + sample_values_list.append(sample_values[center_type_mask]) + else: + relevant_atomic_features = atomic_features + sample_values_list.append(sample_values) + atomic_property = last_layer_by_block(relevant_atomic_features) + atomic_properties_tensors.append( + atomic_property.reshape([atomic_property.shape[0]] + shape) ) - blocks = [ - TensorBlock( - values=atomic_property.reshape([-1] + shape), - samples=sample_labels, - components=components, - properties=properties, - ) - for atomic_property, shape, components, properties in zip( - atomic_properties_by_block, - self.output_shapes[output_name].values(), - self.component_labels[output_name], - self.property_labels[output_name], + atomic_properties_blocks: List[TensorBlock] = [] + for tensor, sv, components, properties in zip( + atomic_properties_tensors, + sample_values_list, + self.component_labels[output_name], + self.property_labels[output_name], + ): + atomic_properties_blocks.append( + TensorBlock( + values=tensor, + samples=( + Labels( + names=["system", "atom"], + values=sv, + ) + if "center_type" in last_layer_key + else sample_labels + ), + components=components, + properties=properties, + ) ) - ] atomic_properties_tmap_dict[output_name] = TensorMap( keys=self.key_labels[output_name], - blocks=blocks, + blocks=atomic_properties_blocks, ) if selected_atoms is not None: diff --git a/src/metatrain/experimental/nanopet/modules/augmentation.py b/src/metatrain/experimental/nanopet/modules/augmentation.py index 41421535e..8c19fd6ee 100644 --- a/src/metatrain/experimental/nanopet/modules/augmentation.py +++ b/src/metatrain/experimental/nanopet/modules/augmentation.py @@ -123,10 +123,13 @@ def _apply_wigner_D_matrices( for key, block in target_tmap.items(): ell, sigma = int(key[0]), int(key[1]) values = block.values + split_indices = ( + [int(torch.sum(system.types == key["center_type"])) for system in systems] + if "center_type" in key.names + else [len(system.positions) for system in systems] + ) if "atom" in block.samples.names: - split_values = torch.split( - values, [len(system.positions) for system in systems] - ) + split_values = torch.split(values, split_indices) else: split_values = torch.split(values, [1 for _ in systems]) new_values = [] diff --git a/src/metatrain/share/schema-dataset.json b/src/metatrain/share/schema-dataset.json index ac42264c9..49adb6af5 100644 --- a/src/metatrain/share/schema-dataset.json +++ b/src/metatrain/share/schema-dataset.json @@ -195,6 +195,9 @@ }, "virial": { "$ref": "#/$defs/gradient_section" + }, + "metatensor_target_disable_checks": { + "type": "boolean" } }, "additionalProperties": false diff --git a/src/metatrain/utils/additive/composition.py b/src/metatrain/utils/additive/composition.py index 3e8319c76..98f660ab2 100644 --- a/src/metatrain/utils/additive/composition.py +++ b/src/metatrain/utils/additive/composition.py @@ -3,7 +3,7 @@ import metatensor.torch import torch -from metatensor.torch import Labels, LabelsEntry, TensorBlock, TensorMap +from metatensor.torch import Labels, TensorBlock, TensorMap from metatensor.torch.atomistic import ModelOutput, System from ..data import Dataset, DatasetInfo, TargetInfo, get_all_targets, get_atomic_types @@ -25,6 +25,7 @@ class CompositionModel(torch.nn.Module): target quantities and atomic types. """ + all_layouts = Dict[str, TensorMap] weights: Dict[str, TensorMap] outputs: Dict[str, ModelOutput] @@ -53,6 +54,7 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo): for target_name, target_info in dataset_info.targets.items() } + self.all_layouts = {} self.weights = {} self.outputs: Dict[str, ModelOutput] = {} for target_name, target_info in self.dataset_info.targets.items(): @@ -165,9 +167,6 @@ def train_model( ) continue - total_num_structures = sum( - [len(dataset) for dataset in datasets_with_target] - ) dtype = datasets[0][0]["system"].positions.dtype if dtype != torch.float64: raise ValueError( @@ -175,168 +174,210 @@ def train_model( f"Got dtype: {dtype}." ) - composition_features = torch.zeros( - (total_num_structures, len(self.atomic_types)), - dtype=dtype, - device=device, - ) - system_index = 0 - per_block_targets_list: Dict[LabelsEntry, List[TensorBlock]] = {} - for dataset in datasets_with_target: - for sample in dataset: - systems = [sample["system"]] - targets = {target_key: sample[target_key]} - systems, targets = systems_and_targets_to_device( - systems, targets, device - ) - for additive_model in additive_models: - target_info_dict = { - target_key: self.new_targets[target_key] - } - targets = remove_additive( # remove other additive models - systems, - targets, - additive_model, - target_info_dict, + is_spherical = self.dataset_info.targets[target_key].is_spherical + is_per_atom = self.dataset_info.targets[target_key].per_atom + + if is_spherical: + if is_per_atom: + self.weights[target_key] = ( + self._get_composition_spherical_per_atom( + datasets_with_target, + target_key, + additive_models, + device, + dtype, ) - for j, t in enumerate(self.atomic_types): - composition_features[system_index, j] = torch.sum( - systems[0].types == t + ) + else: + self.weights[target_key] = ( + self._get_composition_spherical_per_structure( + datasets_with_target, + target_key, + additive_models, + device, + dtype, ) - system_index += 1 - if self.dataset_info.targets[target_key].per_atom: - # we need the center type in the samples to do - # mean_over_samples - if "center_type" in targets[target_key].keys.names: - # it's in the keys: move it to the samples - targets[target_key] = targets[ - target_key - ].keys_to_samples("center_type") - if targets[target_key].block(0).samples.names == [ - "system", - "atom", - ]: - # there is no center type, we need to add it - # and we will rely on the fact that per-atom targets - # should be in the same order as the atoms in the system - targets[target_key] = metatensor.torch.append_dimension( - targets[target_key], - "samples", - "center_type", - systems[0].types, - ) - # TODO: abstract even more for more complex targets? - for key, block in targets[target_key].items(): - # `if key not in per_block_targets_list` doesn't work, so: - matching_keys = [ - k for k in per_block_targets_list if k == key - ] - assert len(matching_keys) <= 1 - if len(matching_keys) == 0: - per_block_targets_list[key] = [block] - else: - per_block_targets_list[matching_keys[0]].append(block) - - weight_blocks = [] - for key, block_list in per_block_targets_list.items(): - # distinguish between spherical and scalar targets - needs_unsqueeze = False - if self.dataset_info.targets[target_key].is_spherical: # spherical - is_invariant = ( - int(key["o3_lambda"]) == 0 and int(key["o3_sigma"]) == 1 ) - if is_invariant: - # squeeze components dimension - tensor_list = [b.values.squeeze(1) for b in block_list] - needs_unsqueeze = True - else: - # we don't need the targets as we will set the composition - # to zero - tensor_list = None - else: # scalar - tensor_list = [b.values for b in block_list] - - metadata_block = self.dataset_info.targets[target_key].layout.block( - key - ) - if tensor_list is None: # spherical non-invariant - weights_tensor = torch.zeros( - ( - len(self.atomic_types), - len(metadata_block.components[0]), - len(metadata_block.properties), - ), - dtype=dtype, - device=device, + else: + if is_per_atom: + self.weights[target_key] = ( + self._get_composition_scalar_per_atom( + datasets_with_target, + target_key, + additive_models, + device, + dtype, + ) ) else: - if self.dataset_info.targets[target_key].per_atom: - # hack: metatensor.join doesn't work on single blocks; - # create TensorMaps, join, and then extract the joined block - joined_blocks = metatensor.torch.join( - [ - TensorMap( - keys=Labels.single(), - blocks=[b], - ) - for b in block_list - ], - axis="samples", - remove_tensor_name=True, - ).block() - # This code doesn't work because mean_over_samples_block - # actually does a sum... - # weights_tensor = ( - # metatensor.torch.sort_block( - # metatensor.torch.mean_over_samples_block( - # joined_blocks, - # [ - # n - # for n in joined_blocks.samples.names - # if n != "center_type" - # ], - # ) - # ) - # .values - # ) - weights_tensor = torch.empty( - len(self.atomic_types), len(metadata_block.properties) + self.weights[target_key] = ( + self._get_composition_scalar_per_structure( + datasets_with_target, + target_key, + additive_models, + device, + dtype, ) - for i_type, atomic_type in enumerate(self.atomic_types): - mask = ( - joined_blocks.samples.column("center_type") - == atomic_type - ) - weights_tensor[i_type] = joined_blocks.values[ - mask - ].mean(dim=0) - else: - # concatenate samples, for each block - all_targets = torch.concatenate(tensor_list) - weights_tensor = _solve_linear_system( - composition_features, all_targets - ) - if needs_unsqueeze: # scalar invariant, needs extra dimension - weights_tensor = weights_tensor.unsqueeze(1) - weight_blocks.append( - TensorBlock( - values=weights_tensor, - samples=Labels( - ["center_type"], - values=torch.tensor( - self.atomic_types, dtype=torch.int, device=device - ).reshape(-1, 1), - ), - components=[ - c.to(device) for c in metadata_block.components - ], - properties=metadata_block.properties.to(device), ) - ) - self.weights[target_key] = TensorMap( - keys=self.dataset_info.targets[target_key].layout.keys.to(device), - blocks=weight_blocks, - ) + + # for dataset in datasets_with_target: + # for sample in dataset: + # systems = [sample["system"]] + # targets = {target_key: sample[target_key]} + # systems, targets = systems_and_targets_to_device( + # systems, targets, device + # ) + # for additive_model in additive_models: + # target_info_dict = { + # target_key: self.new_targets[target_key] + # } + # targets = remove_additive( # remove other additive models + # systems, + # targets, + # additive_model, + # target_info_dict, + # ) + # for j, t in enumerate(self.atomic_types): + # composition_features[system_index, j] = torch.sum( + # systems[0].types == t + # ) + # system_index += 1 + # if self.dataset_info.targets[target_key].per_atom: + # if "center_type" not in targets[ + # target_key + # ].keys.names and targets[target_key].block( + # 0 + # ).samples.names == [ + # "system", + # "atom", + # ]: + # # there is no center type, we need to add it + # # and we will rely on the fact that per-atom targets + # # should be in the same order as the atoms in the system + # targets[target_key] = metatensor.torch.append_dimension( + # targets[target_key], + # "samples", + # "center_type", + # systems[0].types, + # ) + # # TODO: abstract even more for more complex targets? + # for key, block in targets[target_key].items(): + # # `if key not in per_block_targets_list` doesn't work, so: + # matching_keys = [ + # k for k in per_block_targets_list if k == key + # ] + # assert len(matching_keys) <= 1 + # if len(matching_keys) == 0: + # per_block_targets_list[key] = [block] + # else: + # per_block_targets_list[matching_keys[0]].append(block) + + # weight_blocks = [] + # for key, block_list in per_block_targets_list.items(): + # # distinguish between spherical and scalar targets + # is_spherical = self.dataset_info.targets[target_key].is_spherical + # is_spherical_and_invariant = False + # if is_spherical: + # is_spherical_and_invariant = ( + # int(key["o3_lambda"]) == 0 and int(key["o3_sigma"]) == 1 + # ) + # needs_unsqueeze = False + # if self.dataset_info.targets[target_key].is_spherical: # spherical + # is_invariant = ( + # int(key["o3_lambda"]) == 0 and int(key["o3_sigma"]) == 1 + # ) + # if is_invariant: + # # squeeze components dimension + # tensor_list = [b.values.squeeze(1) for b in block_list] + # needs_unsqueeze = True + # else: + # # we don't need the targets as we will set the composition + # # to zero + # tensor_list = None + # else: # scalar + # tensor_list = [b.values for b in block_list] + + # metadata_block = self.dataset_info.targets[target_key].layout.block( + # key + # ) + # if is_spherical and not is_spherical_and_invariant: + # weights_tensor = torch.zeros( + # ( + # len(self.atomic_types), + # len(metadata_block.components[0]), + # len(metadata_block.properties), + # ), + # dtype=dtype, + # device=device, + # ) + # else: + # if self.dataset_info.targets[target_key].per_atom: + # # HACK: metatensor.join doesn't work on single blocks; + # # create TensorMaps, join, and then extract the joined block + # joined_blocks = metatensor.torch.join( + # [ + # TensorMap( + # keys=Labels.single(), + # blocks=[b], + # ) + # for b in block_list + # ], + # axis="samples", + # remove_tensor_name=True, + # ).block() + # # This code doesn't work because mean_over_samples_block + # # actually does a sum... TODO: change for next release + # # weights_tensor = ( + # # metatensor.torch.sort_block( + # # metatensor.torch.mean_over_samples_block( + # # joined_blocks, + # # [ + # # n + # # for n in joined_blocks.samples.names + # # if n != "center_type" + # # ], + # # ) + # # ) + # # .values + # # ) + # weights_tensor = torch.empty( + # len(self.atomic_types), len(metadata_block.properties) + # ) + # for i_type, atomic_type in enumerate(self.atomic_types): + # mask = ( + # joined_blocks.samples.column("center_type") + # == atomic_type + # ) + # weights_tensor[i_type] = joined_blocks.values[ + # mask + # ].mean(dim=0) + # else: + # # concatenate samples, for each block + # all_targets = torch.concatenate(tensor_list) + # weights_tensor = _solve_linear_system( + # composition_features, all_targets + # ) + # if needs_unsqueeze: # scalar invariant, needs extra dimension + # weights_tensor = weights_tensor.unsqueeze(1) + # weight_blocks.append( + # TensorBlock( + # values=weights_tensor, + # samples=Labels( + # ["center_type"], + # values=torch.tensor( + # self.atomic_types, dtype=torch.int, device=device + # ).reshape(-1, 1), + # ), + # components=[ + # c.to(device) for c in metadata_block.components + # ], + # properties=metadata_block.properties.to(device), + # ) + # ) + # self.weights[target_key] = TensorMap( + # keys=self.dataset_info.targets[target_key].layout.keys.to(device), + # blocks=weight_blocks, + # ) def restart(self, dataset_info: DatasetInfo) -> "CompositionModel": """Restart the model with a new dataset info. @@ -433,27 +474,118 @@ def forward( composition_result_dict: Dict[str, TensorMap] = {} for output_name, output_options in outputs.items(): blocks: List[TensorBlock] = [] - for weight_key, weight_block in self.weights[output_name].items(): - weights_tensor = self.weights[output_name].block(weight_key).values - composition_values_per_atom = torch.empty( - [len(concatenated_types)] + weight_block.shape[1:], - dtype=dtype, - device=device, - ) - for i_type, atomic_type in enumerate(self.atomic_types): - composition_values_per_atom[concatenated_types == atomic_type] = ( - weights_tensor[i_type] - ) - blocks.append( - TensorBlock( - values=composition_values_per_atom, - samples=sample_labels, - components=weight_block.components, - properties=weight_block.properties, - ) + if "center_type" in self.weights[output_name].keys.names: + # weird stuff going on here because iterating a Labels to get a LabelsEntry + # apparently doesn't work in torchscript + center_type_position = self.weights[output_name].keys.names.index( + "center_type" ) + for v in self.all_layouts[output_name].keys.values: + key = { + name: int(value) + for name, value in zip(self.weights[output_name].keys.names, v) + } + if torch.any( + torch.all( + torch.eq(self.weights[output_name].keys.values, v), dim=1 + ) + ): + weight_block = self.weights[output_name].block(key) + center_type = int(v[center_type_position]) + center_type_mask = concatenated_types == center_type + weights_tensor = weight_block.values + composition_values_per_atom = weights_tensor.expand( + [int(torch.sum(center_type_mask))] + + [-1 for _ in weight_block.shape[1:]] + ) + blocks.append( + TensorBlock( + values=composition_values_per_atom, + samples=Labels( + sample_labels.names, + sample_labels.values[center_type_mask], + ), + components=weight_block.components, + properties=weight_block.properties, + ) + ) + else: + center_type = int(v[center_type_position]) + center_type_mask = concatenated_types == center_type + blocks.append( + TensorBlock( + values=torch.zeros( + [int(torch.sum(center_type_mask))] + + self.all_layouts[output_name] + .block(key) + .shape[1:], + dtype=dtype, + device=device, + ), + samples=Labels( + sample_labels.names, + sample_labels.values[center_type_mask], + ), + components=self.all_layouts[output_name] + .block(key) + .components, + properties=self.all_layouts[output_name] + .block(key) + .properties, + ) + ) + else: + for v in self.all_layouts[output_name].keys.values: + key = { + name: int(value) + for name, value in zip(self.weights[output_name].keys.names, v) + } + if torch.any( + torch.all( + torch.eq(self.weights[output_name].keys.values, v), dim=1 + ) + ): + weight_block = self.weights[output_name].block(key) + weights_tensor = weight_block.values + composition_values_per_atom = torch.empty( + [len(concatenated_types)] + weight_block.shape[1:], + dtype=dtype, + device=device, + ) + for i_type, atomic_type in enumerate(self.atomic_types): + composition_values_per_atom[ + concatenated_types == atomic_type + ] = weights_tensor[i_type] + blocks.append( + TensorBlock( + values=composition_values_per_atom, + samples=sample_labels, + components=weight_block.components, + properties=weight_block.properties, + ) + ) + else: # spherical non-invariant target + blocks.append( + TensorBlock( + values=torch.zeros( + [len(concatenated_types)] + + self.all_layouts[output_name] + .block(key) + .shape[1:], + dtype=dtype, + device=device, + ), + samples=sample_labels, + components=self.all_layouts[output_name] + .block(key) + .components, + properties=self.all_layouts[output_name] + .block(key) + .properties, + ) + ) composition_result_dict[output_name] = TensorMap( - keys=self.weights[output_name].keys, + keys=self.all_layouts[output_name].keys, blocks=blocks, ) @@ -473,31 +605,64 @@ def forward( return composition_result_dict def _add_output(self, target_name: str, target_info: TargetInfo) -> None: + self.all_layouts[target_name] = target_info.layout self.outputs[target_name] = ModelOutput( quantity=target_info.quantity, unit=target_info.unit, per_atom=True, ) - self.weights[target_name] = TensorMap( - keys=target_info.layout.keys, - blocks=[ - TensorBlock( - values=torch.zeros( - ([len(self.atomic_types)] + b.shape[1:]), - dtype=torch.float64, - ), - samples=Labels( - names=["center_type"], - values=torch.tensor(self.atomic_types, dtype=torch.int).reshape( - -1, 1 - ), - ), - components=b.components, - properties=b.properties, - ) - for b in target_info.layout.blocks() - ], + center_type_in_keys = "center_type" in target_info.layout.keys.names + new_keys = ( + Labels( + target_info.layout.keys.names, + target_info.layout.keys.values[ + target_info.layout.keys.select( + Labels(["o3_lambda", "o3_sigma"], torch.tensor([[0, 1]])) + ) + ], + ) + if target_info.is_spherical + else target_info.layout.keys ) + if center_type_in_keys: + self.weights[target_name] = TensorMap( + keys=new_keys, + blocks=[ + TensorBlock( + values=torch.zeros( + ([1] + block.shape[1:]), + dtype=torch.float64, + ), + samples=Labels.single(), + components=block.components, + properties=block.properties, + ) + for key, block in target_info.layout.items() + if (key["o3_lambda"] == 0 and key["o3_sigma"] == 1) + ], + ) + else: + self.weights[target_name] = TensorMap( + keys=new_keys, + blocks=[ + TensorBlock( + values=torch.zeros( + ([len(self.atomic_types)] + block.shape[1:]), + dtype=torch.float64, + ), + samples=Labels( + names=["center_type"], + values=torch.tensor( + self.atomic_types, dtype=torch.int + ).reshape(-1, 1), + ), + components=block.components, + properties=block.properties, + ) + for key, block in target_info.layout.items() + if (key["o3_lambda"] == 0 and key["o3_sigma"] == 1) + ], + ) def _move_weights_to_device_and_dtype( self, device: torch.device, dtype: torch.dtype @@ -507,6 +672,166 @@ def _move_weights_to_device_and_dtype( self.weights = {k: v.to(device) for k, v in self.weights.items()} if self.weights[list(self.weights.keys())[0]].dtype != dtype: self.weights = {k: v.to(dtype) for k, v in self.weights.items()} + if len(self.all_layouts) != 0: + if self.all_layouts[list(self.all_layouts.keys())[0]].device != device: + self.all_layouts = { + k: v.to(device) for k, v in self.all_layouts.items() + } + if self.all_layouts[list(self.all_layouts.keys())[0]].dtype != dtype: + self.all_layouts = {k: v.to(dtype) for k, v in self.all_layouts.items()} + + def _get_composition_spherical_per_atom( + self, + datasets_with_target: List[Union[Dataset, torch.utils.data.Subset]], + target_key: str, + additive_models: List[torch.nn.Module], + device: torch.device, + dtype: torch.dtype, + ): + metadata_tensor_map = self.dataset_info.targets[target_key].layout + center_type_in_keys = "center_type" in metadata_tensor_map.keys.names + + # Initialize one accumulator per block (only invariant blocks) + if center_type_in_keys: + mean_accumulators = { + tuple(int(k) for k in key.values): _MeanAccumulator( + shape=metadata_tensor_map.block(key).values.shape[1:], + device=device, + dtype=dtype, + ) + for key in metadata_tensor_map.keys + if (key["o3_lambda"] == 0 and key["o3_sigma"] == 1) + } + else: + mean_accumulators = { + tuple(int(k) for k in key.values) + (center_type): _MeanAccumulator( + shape=metadata_tensor_map.block(key).values.shape[1:], + device=device, + dtype=dtype, + ) + for center_type in self.atomic_types + for key in metadata_tensor_map.keys + if (key["o3_lambda"] == 0 and key["o3_sigma"] == 1) + } + + for dataset in datasets_with_target: + for sample in dataset: + systems = [sample["system"]] + targets = {target_key: sample[target_key]} + systems, targets = systems_and_targets_to_device( + systems, targets, device + ) + for additive_model in additive_models: + target_info_dict = {target_key: self.new_targets[target_key]} + targets = remove_additive( + systems, targets, additive_model, target_info_dict + ) + for key, block in targets[target_key].items(): + if key["o3_lambda"] == 0 and key["o3_sigma"] == 1: + # Two cases: with and without center_type + if center_type_in_keys: + mean_accumulators[tuple(int(k) for k in key.values)].add( + block.values + ) + else: + for center_type in self.atomic_types: + mask = systems[0].types == center_type + mean_accumulators[ + tuple(int(k) for k in key.values) + (center_type,) + ].add(block.values[mask]) + + composition_tensor_map = TensorMap( + keys=Labels( + names=metadata_tensor_map.keys.names, + values=torch.stack( + [ + k.values + for k in metadata_tensor_map.keys + if (k["o3_lambda"] == 0 and k["o3_sigma"] == 1) + ] + ).to(device), + ), + blocks=( + [ + TensorBlock( + values=mean_accumulators[tuple(int(k) for k in key.values)] + .return_result() + .reshape( + (1,) + metadata_tensor_map.block(key).values.shape[1:] + ), + samples=Labels.single().to(device), + components=[ + c.to(device) + for c in metadata_tensor_map.block(key).components + ], + properties=self.dataset_info.targets[target_key] + .layout.block(key) + .properties.to(device), + ) + for key in metadata_tensor_map.keys + if (key["o3_lambda"] == 0 and key["o3_sigma"] == 1) + ] + if center_type_in_keys + else [ + TensorBlock( + values=torch.stack( + [ + mean_accumulators[ + tuple(int(k) for k in key.values) + (center_type,) + ].return_result() + for center_type in self.atomic_types + ] + ), + samples=Labels( + names=["center_type"], + values=torch.tensor( + self.atomic_types, dtype=torch.int, device=device + ).reshape(-1, 1), + ), + components=[ + c.to(device) + for c in metadata_tensor_map.block(key).components + ], + properties=self.dataset_info.targets[target_key] + .layout.block(key) + .properties.to(device), + ) + for key in metadata_tensor_map.keys + if (key["o3_lambda"] == 0 and key["o3_sigma"] == 1) + ] + ), + ) + return composition_tensor_map + + def _get_composition_spherical_per_structure( + self, + datasets_with_target: List[Union[Dataset, torch.utils.data.Subset]], + target_key: str, + additive_models: List[torch.nn.Module], + device: torch.device, + dtype: torch.dtype, + ): + raise NotImplementedError() + + def _get_composition_scalar_per_atom( + self, + datasets_with_target: List[Union[Dataset, torch.utils.data.Subset]], + target_key: str, + additive_models: List[torch.nn.Module], + device: torch.device, + dtype: torch.dtype, + ): + raise NotImplementedError() + + def _get_composition_scalar_per_structure( + self, + datasets_with_target: List[Union[Dataset, torch.utils.data.Subset]], + target_key: str, + additive_models: List[torch.nn.Module], + device: torch.device, + dtype: torch.dtype, + ): + raise NotImplementedError() @staticmethod def is_valid_target(target_name: str, target_info: TargetInfo) -> bool: @@ -533,18 +858,48 @@ def is_valid_target(target_name: str, target_info: TargetInfo) -> bool: return True -def _solve_linear_system(composition_features, all_targets) -> torch.Tensor: - compf_t_at_compf = composition_features.T @ composition_features - compf_t_at_targets = composition_features.T @ all_targets - trace_magnitude = float(torch.diag(compf_t_at_compf).abs().mean()) - regularizer = 1e-14 * trace_magnitude - return torch.linalg.solve( - compf_t_at_compf - + regularizer - * torch.eye( - composition_features.shape[1], - dtype=composition_features.dtype, - device=composition_features.device, - ), - compf_t_at_targets, - ) +class _MeanAccumulator: + def __init__(self, shape: List[int], device: torch.device, dtype: torch.dtype): + self.sum = torch.zeros(shape, dtype=dtype, device=device) + self.count = 0 + + def add(self, tensor: float): + self.sum += torch.sum(tensor, dim=0) + self.count += tensor.numel() + + def return_result(self) -> torch.Tensor: + return self.sum / self.count + + +class _LinearSystemAccumulator: + def __init__( + self, + feature_size: int, + target_size: int, + dtype: torch.dtype, + device: torch.device, + ): + self.feat_t_at_feat = torch.zeros( + feature_size, feature_size, dtype=dtype, device=device + ) + self.feat_t_at_targets = torch.zeros( + feature_size, target_size, dtype=dtype, device=device + ) + + def add(self, features: torch.Tensor, targets: torch.Tensor): + self.feat_t_at_feat += features.T @ features + self.feat_t_at_targets += features.T @ targets + + def return_result(self) -> torch.Tensor: + trace_magnitude = float(torch.diag(self.compf_t_at_compf).abs().mean()) + regularizer = 1e-14 * trace_magnitude + return torch.linalg.solve( + self.feat_t_at_feat + + regularizer + * torch.eye( + self.feat_t_at_feat.shape[0], + dtype=self.feat_t_at_feat.dtype, + device=self.feat_t_at_feat.device, + ), + self.feat_t_at_targets, + ) diff --git a/src/metatrain/utils/data/dataset.py b/src/metatrain/utils/data/dataset.py index 218a5f23f..b414217b1 100644 --- a/src/metatrain/utils/data/dataset.py +++ b/src/metatrain/utils/data/dataset.py @@ -417,24 +417,36 @@ def get_target_info(self, target_config: DictConfig) -> Dict[str, TargetInfo]: and (not target["per_atom"]) and target["num_subtargets"] == 1 and target["type"] == "scalar" + and target["metatensor_target_disable_checks"] is False ) tensor_map = self[0][target_key] # always > 0 samples, see above if is_energy: if len(tensor_map) != 1: raise ValueError("Energy TensorMaps should have exactly one block.") - add_position_gradients = tensor_map.block().has_gradient("positions") - add_strain_gradients = tensor_map.block().has_gradient("strain") + add_position_gradients = target["forces"] + add_strain_gradients = target["stress"] or target["virial"] target_info = get_energy_target_info( target, add_position_gradients, add_strain_gradients ) - _check_tensor_map_metadata(tensor_map, target_info.layout) + if not target["metatensor_target_disable_checks"]: + # the check here will be skipped if the flag is set to True + _check_tensor_map_metadata(tensor_map, target_info.layout) target_info_dict[target_key] = target_info else: - target_info = get_generic_target_info(target) - _check_tensor_map_metadata(tensor_map, target_info.layout) - # make sure that the properties of the target_info.layout also match the - # actual properties of the tensor maps - target_info.layout = _empty_tensor_map_like(tensor_map) + # TODO!!!!: do the same as here in the metatensor reader + if not target["metatensor_target_disable_checks"]: + # the check here will be skipped if the flag is set to True + target_info = get_generic_target_info(target) + _check_tensor_map_metadata(tensor_map, target_info.layout) + # make sure that the properties of the target_info.layout also match the + # actual properties of the tensor maps + target_info.layout = _empty_tensor_map_like(tensor_map) + else: + target_info = TargetInfo( + quantity=target["quantity"], + unit=target["unit"], + layout=_empty_tensor_map_like(tensor_map), + ) target_info_dict[target_key] = target_info return target_info_dict diff --git a/src/metatrain/utils/data/readers/metatensor.py b/src/metatrain/utils/data/readers/metatensor.py index 2c35f4708..9da0f836c 100644 --- a/src/metatrain/utils/data/readers/metatensor.py +++ b/src/metatrain/utils/data/readers/metatensor.py @@ -70,7 +70,9 @@ def read_generic(target: DictConfig) -> Tuple[List[TensorMap], TargetInfo]: raise ValueError("Only energy targets can have gradient blocks.") target_info = get_generic_target_info(target) - _check_tensor_map_metadata(tensor_map, target_info.layout) + if not target["metatensor_target_disable_checks"]: + # the check here will be skipped if the flag is set to True + _check_tensor_map_metadata(tensor_map, target_info.layout) # make sure that the properties of the target_info.layout also match the # actual properties of the tensor maps diff --git a/src/metatrain/utils/data/target_info.py b/src/metatrain/utils/data/target_info.py index bcf3e51a3..f1e80379e 100644 --- a/src/metatrain/utils/data/target_info.py +++ b/src/metatrain/utils/data/target_info.py @@ -147,10 +147,15 @@ def _check_layout(self, layout: TensorMap) -> None: ) if self.is_spherical: - if layout.keys.names != ["o3_lambda", "o3_sigma"]: + if layout.keys.names != ["o3_lambda", "o3_sigma"] and layout.keys.names != [ + "o3_lambda", + "o3_sigma", + "center_type", + ]: raise ValueError( "The layout ``TensorMap`` of a spherical tensor target " - "should have two keys named 'o3_lambda' and 'o3_sigma'." + "should have keys named 'o3_lambda', 'o3_sigma', and, optionally, " + "'center_type'." f"Found '{layout.keys.names}' instead." ) for key, block in layout.items(): diff --git a/src/metatrain/utils/devices.py b/src/metatrain/utils/devices.py index 04442b9ea..edea113c5 100644 --- a/src/metatrain/utils/devices.py +++ b/src/metatrain/utils/devices.py @@ -83,7 +83,7 @@ def pick_devices( if possible_devices.index(desired_device) > 0: warnings.warn( f"Device {desired_device!r} requested, but {possible_devices[0]!r} is " - "prefferred by the architecture and available on current system.", + "preferred by the architecture and available on current system.", stacklevel=2, ) diff --git a/src/metatrain/utils/logging.py b/src/metatrain/utils/logging.py index 4e78ba44f..f602790ec 100644 --- a/src/metatrain/utils/logging.py +++ b/src/metatrain/utils/logging.py @@ -177,7 +177,10 @@ def _get_digits(value: float) -> Tuple[int, int]: """ # Get order of magnitude of the value: - order = int(np.floor(np.log10(value))) + if not np.isfinite(value): + return 5, 2 + else: + order = int(np.floor(np.log10(value))) # Get the number of digits before the decimal point: if order < 0: diff --git a/src/metatrain/utils/loss.py b/src/metatrain/utils/loss.py index 0dd0ebeea..105b6bfa7 100644 --- a/src/metatrain/utils/loss.py +++ b/src/metatrain/utils/loss.py @@ -122,7 +122,8 @@ def __call__( for block_1, block_2 in zip(tensor_map_1.blocks(), tensor_map_2.blocks()): values_1 = block_1.values values_2 = block_2.values - loss += self.weight * self.losses["values"](values_1, values_2) + if values_1.numel() != 0: + loss += self.weight * self.losses["values"](values_1, values_2) for gradient_name, gradient_weight in self.gradient_weights.items(): values_1 = block_1.gradient(gradient_name).values values_2 = block_2.gradient(gradient_name).values diff --git a/src/metatrain/utils/metrics.py b/src/metatrain/utils/metrics.py index e45e5eea1..fc94e3091 100644 --- a/src/metatrain/utils/metrics.py +++ b/src/metatrain/utils/metrics.py @@ -105,7 +105,9 @@ def finalize( out_key = f"{key} RMSE" else: out_key = f"{key} RMSE (per atom)" - finalized_info[out_key] = (value[0] / value[1]) ** 0.5 + finalized_info[out_key] = ( + (value[0] / value[1]) ** 0.5 if value[1] > 0 else 0.0 + ) return finalized_info @@ -212,7 +214,7 @@ def finalize( out_key = f"{key} MAE" else: out_key = f"{key} MAE (per atom)" - finalized_info[out_key] = value[0] / value[1] + finalized_info[out_key] = value[0] / value[1] if value[1] > 0 else 0.0 return finalized_info diff --git a/src/metatrain/utils/omegaconf.py b/src/metatrain/utils/omegaconf.py index f0c2f0a9a..efd3c4082 100644 --- a/src/metatrain/utils/omegaconf.py +++ b/src/metatrain/utils/omegaconf.py @@ -99,6 +99,7 @@ def _resolve_single_str(config: str) -> DictConfig: "per_atom": False, "type": "scalar", "num_subtargets": 1, + "metatensor_target_disable_checks": False, } )