diff --git a/docs/src/advanced-concepts/index.rst b/docs/src/advanced-concepts/index.rst index 22b3ff3f6..40d08a5d1 100644 --- a/docs/src/advanced-concepts/index.rst +++ b/docs/src/advanced-concepts/index.rst @@ -9,6 +9,7 @@ such as output naming, auxiliary outputs, and wrapper models. :maxdepth: 1 output-naming + loss-functions auxiliary-outputs multi-gpu auto-restarting diff --git a/docs/src/advanced-concepts/loss-functions.rst b/docs/src/advanced-concepts/loss-functions.rst new file mode 100644 index 000000000..9bddba9ea --- /dev/null +++ b/docs/src/advanced-concepts/loss-functions.rst @@ -0,0 +1,123 @@ +.. _loss-functions: + +Loss functions +============== + +``metatrain`` supports a variety of loss functions, which can be configured +in the ``loss`` subsection of the ``training`` section for each ``architecture`` +in the options file. The loss functions are designed to be flexible and can be +tailored to the specific needs of the dataset and the targets being predicted. + +The ``loss`` subsection describes the loss functions to be used. The most basic +configuration is + +.. code-block:: yaml + + loss: mse + +which sets the loss function to mean squared error (MSE) for all targets. +When training a potential energy surface on energy, forces, and virial, +for example, this configuration is internally expanded to + +.. code-block:: yaml + + loss: + energy: + type: mse + weight: 1.0 + reduction: mean + forces: + type: mse + weight: 1.0 + reduction: mean + virial: + type: mse + weight: 1.0 + reduction: mean + +This internal, more detailed configuration can be used in the options file +to specify different loss functions for each target, or to override default +values for the parameters. The parameters accepted by each loss function are + +1. ``type``. This controls the type of loss to be used. The default value is ``mse``, + and other standard options are ``mae`` and ``huber``, which implement the equivalent + PyTorch loss functions + `MSELoss `_, + `L1Loss `_, + and + `HuberLoss `_, + respectively. + There are also "masked" versions of these losses, which are useful when using + padded targets with values that should be masked before computing the loss. The + masked losses are named ``masked_mse``, ``masked_mae``, and ``masked_huber``. + +2. ``weight``. This controls the weighting of different contributions to the loss + (e.g., energy, forces, virial, etc.). The default value of 1.0 for all targets + works well for most datasets, but can be adjusted if required. + +3. ``reduction``. This controls how the overall loss is computed across batches. + The default for this is to use the ``mean`` of the batch losses. The ``sum`` + function is also supported. + +Some losses, like ``huber``, require additional parameters to be specified. Below is +a table summarizing losses that require or allow additional parameters: + +.. list-table:: Loss Functions and Parameters + :header-rows: 1 + :widths: 20 30 50 + + * - Loss Type + - Description + - Additional Parameters + * - ``mse`` + - Mean squared error + - N/A + * - ``mae`` + - Mean absolute error + - N/A + * - ``mse_masked`` + - Masked mean squared error + - N/A + * - ``mae_masked`` + - Masked mean absolute error + - N/A + * - ``huber`` + - Huber loss + - ``delta``: Threshold at which to switch from squared error to absolute error. + + +Masked loss functions +--------------------- + +Masked loss functions are particularly useful when dealing with datasets that contain +padded targets. In such cases, the loss function can be configured to ignore the padded +values during the loss computation. This is done by using the ``masked_`` prefix in +the loss type. For example, if the target contains padded values, you can use +``masked_mse`` or ``masked_mae`` to ensure that the loss is computed only on the +valid (non-padded) values. The values of the masks must be passed as ``extra_data`` +in the training set, and the loss function will automatically apply the mask to +the target values. An example configuration for a masked loss is as follows: + + .. code-block:: yaml + + loss: + energy: + type: masked_mse + weight: 1.0 + reduction: sum + forces: + type: masked_mae + weight: 0.1 + reduction: sum + ... + + training_set: + systems: + ... + targets: + mtt::my_target: + ... + ... + extra_data: + mtt::my_target_mask: + read_from: my_target_mask.mts diff --git a/docs/src/architectures/nanopet.rst b/docs/src/architectures/nanopet.rst index f4ec15949..55d6d140b 100644 --- a/docs/src/architectures/nanopet.rst +++ b/docs/src/architectures/nanopet.rst @@ -60,19 +60,9 @@ hyperparameters to tune are (in decreasing order of importance): - ``num_attention_layers``: The number of attention layers in each layer of the graph neural network. Depending on the dataset, increasing this hyperparameter might lead to better accuracy, at the cost of increased training and evaluation time. -- ``loss``: This section describes the loss function to be used, and it has three - subsections. 1. ``weights``. This controls the weighting of different contributions - to the loss (e.g., energy, forces, virial, etc.). The default values of 1.0 for all - targets work well for most datasets, but they might need to be adjusted. For example, - to set a weight of 1.0 for the energy and 0.1 for the forces, you can set the - following in the ``options.yaml`` file under ``loss``: - ``weights: {"energy": 1.0, "forces": 0.1}``. 2. ``type``. This controls the type of - loss to be used. The default value is ``mse``, and other options are ``mae`` and - ``huber``. ``huber`` is a subsection of its own, and it requires the user to specify - the ``deltas`` parameters in a similar way to how the ``weights`` are specified (e.g., - ``deltas: {"energy": 0.1, "forces": 0.01}``). 3. ``reduction``. This controls how the - loss is reduced over batches. The default value is ``mean``, and the other allowed - option is ``sum``. +- ``loss``: This section describes the loss function to be used. See the + :doc:`dedicated documentation page <../advanced-concepts/loss-functions>` for more + details. - ``long_range``: In some systems and datasets, enabling long-range Coulomb interactions might be beneficial for the accuracy of the model and/or its physical correctness. See below for a breakdown of the long-range section of the model hyperparameters. diff --git a/docs/src/architectures/pet.rst b/docs/src/architectures/pet.rst index 6b2e09449..3f893d80f 100644 --- a/docs/src/architectures/pet.rst +++ b/docs/src/architectures/pet.rst @@ -58,19 +58,9 @@ hyperparameters to tune are (in decreasing order of importance): - ``num_attention_layers``: The number of attention layers in each layer of the graph neural network. Depending on the dataset, increasing this hyperparameter might lead to better accuracy, at the cost of increased training and evaluation time. -- ``loss``: This section describes the loss function to be used, and it has three - subsections. 1. ``weights``. This controls the weighting of different contributions - to the loss (e.g., energy, forces, virial, etc.). The default values of 1.0 for all - targets work well for most datasets, but they might need to be adjusted. For example, - to set a weight of 1.0 for the energy and 0.1 for the forces, you can set the - following in the ``options.yaml`` file under ``loss``: - ``weights: {"energy": 1.0, "forces": 0.1}``. 2. ``type``. This controls the type of - loss to be used. The default value is ``mse``, and other options are ``mae`` and - ``huber``. ``huber`` is a subsection of its own, and it requires the user to specify - the ``deltas`` parameters in a similar way to how the ``weights`` are specified (e.g., - ``deltas: {"energy": 0.1, "forces": 0.01}``). 3. ``reduction``. This controls how the - loss is reduced over batches. The default value is ``mean``, and the other allowed - option is ``sum``. +- ``loss``: This section describes the loss function to be used. See the + :doc:`dedicated documentation page <../advanced-concepts/loss-functions>` for more + details. - ``long_range``: In some systems and datasets, enabling long-range Coulomb interactions might be beneficial for the accuracy of the model and/or its physical correctness. See below for a breakdown of the long-range section of the model hyperparameters. diff --git a/docs/src/architectures/soap-bpnn.rst b/docs/src/architectures/soap-bpnn.rst index 9ee96426e..88bee0194 100644 --- a/docs/src/architectures/soap-bpnn.rst +++ b/docs/src/architectures/soap-bpnn.rst @@ -54,19 +54,9 @@ hyperparameters to tune are (in decreasing order of importance): - ``layernorm``: Whether to use layer normalization before the neural network. Setting this hyperparameter to ``false`` will lead to slower convergence of training, but might lead to better generalization outside of the training set distribution. -- ``loss``: This section describes the loss function to be used, and it has three - subsections. 1. ``weights``. This controls the weighting of different contributions - to the loss (e.g., energy, forces, virial, etc.). The default values of 1.0 for all - targets work well for most datasets, but they might need to be adjusted. For example, - to set a weight of 1.0 for the energy and 0.1 for the forces, you can set the - following in the ``options.yaml`` file under ``loss``: - ``weights: {"energy": 1.0, "forces": 0.1}``. 2. ``type``. This controls the type of - loss to be used. The default value is ``mse``, and other options are ``mae`` and - ``huber``. ``huber`` is a subsection of its own, and it requires the user to specify - the ``deltas`` parameters in a similar way to how the ``weights`` are specified (e.g., - ``deltas: {"energy": 0.1, "forces": 0.01}``). 3. ``reduction``. This controls how the - loss is reduced over batches. The default value is ``mean``, and the other allowed - option is ``sum``. +- ``loss``: This section describes the loss function to be used. See the + :doc:`dedicated documentation page <../advanced-concepts/loss-functions>` for more + details. - ``long_range``: In some systems and datasets, enabling long-range Coulomb interactions might be beneficial for the accuracy of the model and/or its physical correctness. See below for a breakdown of the long-range section of the model hyperparameters. diff --git a/docs/src/dev-docs/changelog.rst b/docs/src/dev-docs/changelog.rst index 53a2ea93b..335ede281 100644 --- a/docs/src/dev-docs/changelog.rst +++ b/docs/src/dev-docs/changelog.rst @@ -15,8 +15,12 @@ changelog `_ format. This project follows .. Added .. ##### -.. Changed -.. ####### +Changed +####### + +- Refactored the ``loss.py`` module to provide an easier to extend interface for custom + loss functions. +- Updated the trainer checkpoints to account for changes in the loss-related hypers. .. Removed .. ####### diff --git a/docs/src/dev-docs/index.rst b/docs/src/dev-docs/index.rst index 5dff20c61..a91847fa7 100644 --- a/docs/src/dev-docs/index.rst +++ b/docs/src/dev-docs/index.rst @@ -13,6 +13,7 @@ module. architecture-life-cycle new-architecture dataset-information + new-loss cli/index utils/index changelog diff --git a/docs/src/dev-docs/new-loss.rst b/docs/src/dev-docs/new-loss.rst new file mode 100644 index 000000000..dc49cc1d7 --- /dev/null +++ b/docs/src/dev-docs/new-loss.rst @@ -0,0 +1,61 @@ +.. _adding-new-loss: + +Adding a new loss function +========================== + +This page describes the required classes and files necessary for adding a new +loss function to ``metatrain``. Defining a new loss can be useful in case some extra +data has to be used to compute the loss. + +Loss functions in ``metatrain`` are implemented as subclasses of +:py:class:`metatrain.utils.loss.LossInterface`. This interface defines the +required method :py:meth:`compute`, which takes the model predictions and +the ground truth values as input and returns the computed loss value. The +:py:meth:`compute` method accepts an additional argument ``extra_data`` on top of +``predictions`` and ``targets``, that can be used to pass any extra information needed +for the loss computation. + +.. code-block:: python + + from typing import Dict, Optional + import torch + from metatrain.utils.loss import LossInterface + from metatensor.torch import TensorMap + + class NewLoss(LossInterface): + def __init__( + self, + name: str, + gradient: Optional[str], + weight: float, + reduction: str, + ) -> None: + ... + + def compute( + self, + predictions: Dict[str, TensorMap], + targets: Dict[str, TensorMap], + extra_data: Dict[str, TensorMap] + ) -> torch.Tensor: + ... + + +Examples of loss functions already implemented in ``metatrain`` are +:py:class:`metatrain.utils.loss.TensorMapMSELoss` and +:py:class:`metatrain.utils.loss.TensorMapMAELoss`. They both inherit from the +:py:class:`metatrain.utils.loss.BaseTensorMapLoss` class, which implements pointwise +losses for :py:class:`metatensor.torch.TensorMap` objects. + + +Loss weight scheduling +---------------------- + +Currently, only one loss weight scheduler is implemented in ``metatrain``, which is +:py:class:`metatrain.utils.loss.EMAScheduler`. This class is used to schedule the weight +of a loss function based on the Exponential Moving Average (EMA) of the loss value. +The EMA scheduler is useful to adapt the loss weight during training, allowing for a +more dynamic adjustment of the loss contribution based on the training progress. +New schedulers can be implemented by inheriting from the +:py:class:`metatrain.utils.loss.WeightScheduler` abstract class, which defines the +:py:meth:`initialize` and :py:meth:`update` methods that need to be implemented. diff --git a/src/metatrain/cli/train.py b/src/metatrain/cli/train.py index 6055c9196..59e481182 100644 --- a/src/metatrain/cli/train.py +++ b/src/metatrain/cli/train.py @@ -39,7 +39,12 @@ ) from ..utils.jsonschema import validate from ..utils.logging import ROOT_LOGGER, WandbHandler, human_readable -from ..utils.omegaconf import BASE_OPTIONS, check_units, expand_dataset_config +from ..utils.omegaconf import ( + BASE_OPTIONS, + check_units, + expand_dataset_config, + expand_loss_config, +) from .eval import _eval_targets from .export import _has_extensions from .formatter import CustomHelpFormatter @@ -205,7 +210,6 @@ def train_model( {"architecture": get_default_hypers(architecture_name)}, options, ) - hypers = OmegaConf.to_container(options["architecture"]) ########################### # PROCESS BASE PARAMETERS # @@ -386,6 +390,10 @@ def train_model( test_datasets.append(dataset) test_indices.append(None) + # Expand loss options and finalize the hypers + options = expand_loss_config(options) + hypers = OmegaConf.to_container(options["architecture"]) + ############################################ # SAVE TRAIN, VALIDATION, TEST INDICES ##### ############################################ diff --git a/src/metatrain/experimental/nanopet/checkpoints.py b/src/metatrain/experimental/nanopet/checkpoints.py new file mode 100644 index 000000000..4a83f08f3 --- /dev/null +++ b/src/metatrain/experimental/nanopet/checkpoints.py @@ -0,0 +1,20 @@ +# ===== Model checkpoint updates ===== + +# ... + +# ===== Trainer checkpoint updates ===== + + +def trainer_update_v1_v2(checkpoint): + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + new_loss_hypers[target_name] = { + "type": old_loss_hypers["type"], + "weight": old_loss_hypers["weights"].get(target_name, 1.0), + "reduction": old_loss_hypers["reduction"], + "sliding_factor": old_loss_hypers.get("sliding_factor", None), + } + checkpoint["train_hypers"]["loss"] = new_loss_hypers diff --git a/src/metatrain/experimental/nanopet/default-hypers.yaml b/src/metatrain/experimental/nanopet/default-hypers.yaml index 38f01381b..ea9202133 100644 --- a/src/metatrain/experimental/nanopet/default-hypers.yaml +++ b/src/metatrain/experimental/nanopet/default-hypers.yaml @@ -31,7 +31,4 @@ architecture: log_mae: false log_separate_blocks: false best_model_metric: rmse_prod - loss: - type: mse - weights: {} - reduction: mean + loss: mse diff --git a/src/metatrain/experimental/nanopet/schema-hypers.json b/src/metatrain/experimental/nanopet/schema-hypers.json index 0b389739a..7f23b0388 100644 --- a/src/metatrain/experimental/nanopet/schema-hypers.json +++ b/src/metatrain/experimental/nanopet/schema-hypers.json @@ -137,59 +137,13 @@ ] }, "loss": { - "type": "object", - "properties": { - "weights": { - "type": "object", - "patternProperties": { - ".*": { - "type": "number" - } - }, - "additionalProperties": false - }, - "reduction": { - "type": "string", - "enum": [ - "sum", - "mean", - "none" - ] - }, - "type": { - "oneOf": [ - { - "type": "string", - "enum": [ - "mse", - "mae" - ] - }, - { - "type": "object", - "properties": { - "huber": { - "type": "object", - "properties": { - "deltas": { - "type": "object", - "patternProperties": { - ".*": { - "type": "number" - } - }, - "additionalProperties": false - } - }, - "required": [ - "deltas" - ], - "additionalProperties": false - } - }, - "additionalProperties": false - } - ] + "type": [ + "object", + "string" + ], + "patternProperties": { + "^.+$": { + "$ref": "#/definitions/lossTerm" } }, "additionalProperties": false @@ -206,5 +160,47 @@ "uniqueItems": true } }, - "additionalProperties": false + "additionalProperties": false, + "definitions": { + "lossTerm": { + "type": [ + "object", + "string" + ], + "properties": { + "type": { + "type": "string" + }, + "weight": { + "type": "number", + "minimum": 0.0 + }, + "reduction": { + "type": "string", + "enum": [ + "none", + "mean", + "sum" + ] + }, + "sliding_factor": { + "type": [ + "number", + "null" + ], + "minimum": 0.0 + }, + "gradients": { + "type": "object", + "patternProperties": { + "^.+$": { + "$ref": "#/definitions/lossTerm" + } + }, + "additionalProperties": false + } + }, + "additionalProperties": true + } + } } diff --git a/src/metatrain/experimental/nanopet/tests/checkpoints/model-v1_trainer-v2.ckpt.gz b/src/metatrain/experimental/nanopet/tests/checkpoints/model-v1_trainer-v2.ckpt.gz new file mode 100644 index 000000000..d17bb06d3 Binary files /dev/null and b/src/metatrain/experimental/nanopet/tests/checkpoints/model-v1_trainer-v2.ckpt.gz differ diff --git a/src/metatrain/experimental/nanopet/tests/test_checkpoints.py b/src/metatrain/experimental/nanopet/tests/test_checkpoints.py index 2df080223..6e790e90f 100644 --- a/src/metatrain/experimental/nanopet/tests/test_checkpoints.py +++ b/src/metatrain/experimental/nanopet/tests/test_checkpoints.py @@ -2,10 +2,12 @@ import pytest import torch +from omegaconf import OmegaConf from metatrain.experimental.nanopet import NanoPET, Trainer from metatrain.utils.data import DatasetInfo, get_atomic_types, get_dataset from metatrain.utils.data.target_info import get_energy_target_info +from metatrain.utils.omegaconf import CONF_LOSS from metatrain.utils.testing.checkpoints import ( checkpoint_did_not_change, make_checkpoint_load_tests, @@ -59,6 +61,10 @@ def model_trainer(): hypers = copy.deepcopy(DEFAULT_HYPERS) hypers["training"]["num_epochs"] = 1 + loss_hypers = OmegaConf.create({"energy": CONF_LOSS.copy()}) + loss_hypers = OmegaConf.to_container(loss_hypers, resolve=True) + hypers["training"]["loss"] = loss_hypers + trainer = Trainer(hypers["training"]) trainer.train( diff --git a/src/metatrain/experimental/nanopet/tests/test_continue.py b/src/metatrain/experimental/nanopet/tests/test_continue.py index 28812398c..ec67c3280 100644 --- a/src/metatrain/experimental/nanopet/tests/test_continue.py +++ b/src/metatrain/experimental/nanopet/tests/test_continue.py @@ -10,6 +10,7 @@ from metatrain.utils.data.target_info import get_energy_target_info from metatrain.utils.io import model_from_checkpoint from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.omegaconf import CONF_LOSS from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS @@ -58,6 +59,10 @@ def test_continue(monkeypatch, tmp_path): hypers = DEFAULT_HYPERS.copy() hypers["training"]["num_epochs"] = 0 + loss_conf = OmegaConf.create({"mtt::U0": CONF_LOSS.copy()}) + OmegaConf.resolve(loss_conf) + hypers["training"]["loss"] = loss_conf + trainer = Trainer(hypers["training"]) trainer.train( model=model, diff --git a/src/metatrain/experimental/nanopet/tests/test_regression.py b/src/metatrain/experimental/nanopet/tests/test_regression.py index dd7ee6d2d..cddb4a72a 100644 --- a/src/metatrain/experimental/nanopet/tests/test_regression.py +++ b/src/metatrain/experimental/nanopet/tests/test_regression.py @@ -10,6 +10,7 @@ from metatrain.utils.data.readers import read_systems, read_targets from metatrain.utils.data.target_info import get_energy_target_info from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.omegaconf import CONF_LOSS from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS @@ -84,6 +85,9 @@ def test_regression_train(): hypers = DEFAULT_HYPERS.copy() hypers["training"]["num_epochs"] = 2 + loss_conf = OmegaConf.create({"mtt::U0": CONF_LOSS.copy()}) + OmegaConf.resolve(loss_conf) + hypers["training"]["loss"] = loss_conf dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict diff --git a/src/metatrain/experimental/nanopet/trainer.py b/src/metatrain/experimental/nanopet/trainer.py index 8f3719a3e..00a123f39 100644 --- a/src/metatrain/experimental/nanopet/trainer.py +++ b/src/metatrain/experimental/nanopet/trainer.py @@ -21,10 +21,9 @@ ) from metatrain.utils.distributed.slurm import DistributedEnvironment from metatrain.utils.evaluate_model import evaluate_model -from metatrain.utils.external_naming import to_external_name from metatrain.utils.io import check_file_extension from metatrain.utils.logging import ROOT_LOGGER, MetricLogger -from metatrain.utils.loss import TensorMapDictLoss +from metatrain.utils.loss import LossAggregator from metatrain.utils.metrics import MAEAccumulator, RMSEAccumulator, get_selected_metric from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, @@ -34,11 +33,12 @@ from metatrain.utils.scaler import remove_scale from metatrain.utils.transfer import batch_to +from . import checkpoints from .model import NanoPET class Trainer(TrainerInterface): - __checkpoint_version__ = 1 + __checkpoint_version__ = 2 def __init__(self, hypers): super().__init__(hypers) @@ -232,29 +232,23 @@ def train( outputs_list.append(target_name) for gradient_name in target_info.gradients: outputs_list.append(f"{target_name}_{gradient_name}_gradients") - # Create a loss weight dict: - loss_weights_dict = {} - for output_name in outputs_list: - loss_weights_dict[output_name] = ( - self.hypers["loss"]["weights"][ - to_external_name(output_name, train_targets) - ] - if to_external_name(output_name, train_targets) - in self.hypers["loss"]["weights"] - else 1.0 - ) - loss_weights_dict_external = { - to_external_name(key, train_targets): value - for key, value in loss_weights_dict.items() - } - loss_hypers = copy.deepcopy(self.hypers["loss"]) - loss_hypers["weights"] = loss_weights_dict - logging.info(f"Training with loss weights: {loss_weights_dict_external}") # Create a loss function: - loss_fn = TensorMapDictLoss( - **loss_hypers, + loss_hypers = self.hypers["loss"] + loss_fn = LossAggregator( + targets=train_targets, + config=loss_hypers, ) + logging.info("Using the following loss functions:") + for name, info in loss_fn.metadata.items(): + logging.info(f"{name}:") + main = {k: v for k, v in info.items() if k != "gradients"} + logging.info(main) + if "gradients" not in info or len(info["gradients"]) == 0: + continue + logging.info("With gradients:") + for grad, ginfo in info["gradients"].items(): + logging.info(f"\t{name}::{grad}: {ginfo}") # Create an optimizer: optimizer = torch.optim.Adam( @@ -346,7 +340,7 @@ def train( ) targets = average_by_num_atoms(targets, systems, per_structure_targets) - train_loss_batch = loss_fn(predictions, targets) + train_loss_batch = loss_fn(predictions, targets, extra_data) train_loss_batch.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() @@ -408,7 +402,7 @@ def train( ) targets = average_by_num_atoms(targets, systems, per_structure_targets) - val_loss_batch = loss_fn(predictions, targets) + val_loss_batch = loss_fn(predictions, targets, extra_data) if is_distributed: # sum the loss over all processes @@ -562,6 +556,16 @@ def load_checkpoint( return trainer - @staticmethod - def upgrade_checkpoint(checkpoint: Dict) -> Dict: - raise NotImplementedError("checkpoint upgrade is not implemented for NanoPET") + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + if checkpoint["trainer_ckpt_version"] == 1: + checkpoints.trainer_update_v1_v2(checkpoint) + checkpoint["trainer_ckpt_version"] = 2 + + if checkpoint["trainer_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using " + f"trainer version {checkpoint['trainer_ckpt_version']}, while the " + f"current trainer version is {cls.__checkpoint_version__}." + ) + return checkpoint diff --git a/src/metatrain/gap/tests/test_torchscript.py b/src/metatrain/gap/tests/test_torchscript.py index 2705ed831..869ac5d2d 100644 --- a/src/metatrain/gap/tests/test_torchscript.py +++ b/src/metatrain/gap/tests/test_torchscript.py @@ -93,8 +93,6 @@ def test_torchscript_integers(): new_hypers["soap"]["density"]["scaling"]["scale"] = 2 new_hypers["soap"]["density"]["scaling"]["exponent"] = 7 - # print(new_hypers) - target_info_dict = {} target_info_dict["mtt::U0"] = get_energy_target_info({"unit": "eV"}) diff --git a/src/metatrain/pet/checkpoints.py b/src/metatrain/pet/checkpoints.py index a4fbec538..132204d19 100644 --- a/src/metatrain/pet/checkpoints.py +++ b/src/metatrain/pet/checkpoints.py @@ -47,8 +47,25 @@ def model_update_v5_v6(checkpoint): def trainer_update_v1_v2(checkpoint): - checkpoint["train_hypers"] = checkpoint["train_hypers"].get("scheduler_factor", 0.5) + checkpoint["train_hypers"]["scheduler_factor"] = checkpoint["train_hypers"].get( + "scheduler_factor", 0.5 + ) def trainer_update_v2_v3(checkpoint): checkpoint["best_epoch"] = checkpoint.get("best_epoch") + + +def trainer_update_v3_v4(checkpoint): + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + new_loss_hypers[target_name] = { + "type": old_loss_hypers["type"], + "weight": old_loss_hypers["weights"].get(target_name, 1.0), + "reduction": old_loss_hypers["reduction"], + "sliding_factor": old_loss_hypers.get("sliding_factor", None), + } + checkpoint["train_hypers"]["loss"] = new_loss_hypers diff --git a/src/metatrain/pet/default-hypers.yaml b/src/metatrain/pet/default-hypers.yaml index 646490c80..efdce89d1 100644 --- a/src/metatrain/pet/default-hypers.yaml +++ b/src/metatrain/pet/default-hypers.yaml @@ -35,8 +35,4 @@ architecture: log_separate_blocks: false best_model_metric: rmse_prod grad_clip_norm: .inf - loss: - type: mse - weights: {} - reduction: mean - sliding_factor: null + loss: mse diff --git a/src/metatrain/pet/schema-hypers.json b/src/metatrain/pet/schema-hypers.json index 885f7f520..c190f48d0 100644 --- a/src/metatrain/pet/schema-hypers.json +++ b/src/metatrain/pet/schema-hypers.json @@ -201,65 +201,13 @@ "type": "number" }, "loss": { - "type": "object", - "properties": { - "weights": { - "type": "object", - "patternProperties": { - ".*": { - "type": "number" - } - }, - "additionalProperties": false - }, - "reduction": { - "type": "string", - "enum": [ - "sum", - "mean", - "none" - ] - }, - "type": { - "oneOf": [ - { - "type": "string", - "enum": [ - "mse", - "mae" - ] - }, - { - "type": "object", - "properties": { - "huber": { - "type": "object", - "properties": { - "deltas": { - "type": "object", - "patternProperties": { - ".*": { - "type": "number" - } - }, - "additionalProperties": false - } - }, - "required": [ - "deltas" - ], - "additionalProperties": false - } - }, - "additionalProperties": false - } - ] - }, - "sliding_factor": { - "type": [ - "number", - "null" - ] + "type": [ + "object", + "string" + ], + "patternProperties": { + "^.+$": { + "$ref": "#/definitions/lossTerm" } }, "additionalProperties": false @@ -276,5 +224,47 @@ "uniqueItems": true } }, - "additionalProperties": false + "additionalProperties": false, + "definitions": { + "lossTerm": { + "type": [ + "object", + "string" + ], + "properties": { + "type": { + "type": "string" + }, + "weight": { + "type": "number", + "minimum": 0.0 + }, + "reduction": { + "type": "string", + "enum": [ + "none", + "mean", + "sum" + ] + }, + "sliding_factor": { + "type": [ + "number", + "null" + ], + "minimum": 0.0 + }, + "gradients": { + "type": "object", + "patternProperties": { + "^.+$": { + "$ref": "#/definitions/lossTerm" + } + }, + "additionalProperties": false + } + }, + "additionalProperties": true + } + } } diff --git a/src/metatrain/pet/tests/checkpoints/model-v4_trainer-v4.ckpt.gz b/src/metatrain/pet/tests/checkpoints/model-v4_trainer-v4.ckpt.gz new file mode 100644 index 000000000..4bcfcbd04 Binary files /dev/null and b/src/metatrain/pet/tests/checkpoints/model-v4_trainer-v4.ckpt.gz differ diff --git a/src/metatrain/pet/tests/checkpoints/model-v6_trainer-v4.ckpt.gz b/src/metatrain/pet/tests/checkpoints/model-v6_trainer-v4.ckpt.gz new file mode 100644 index 000000000..93fe64ca2 Binary files /dev/null and b/src/metatrain/pet/tests/checkpoints/model-v6_trainer-v4.ckpt.gz differ diff --git a/src/metatrain/pet/tests/test_checkpoints.py b/src/metatrain/pet/tests/test_checkpoints.py index 8c36fda68..24381d342 100644 --- a/src/metatrain/pet/tests/test_checkpoints.py +++ b/src/metatrain/pet/tests/test_checkpoints.py @@ -3,10 +3,12 @@ import pytest import torch +from omegaconf import OmegaConf from metatrain.pet import PET, Trainer from metatrain.utils.data import DatasetInfo, get_atomic_types, get_dataset from metatrain.utils.data.target_info import get_energy_target_info +from metatrain.utils.omegaconf import CONF_LOSS from metatrain.utils.testing.checkpoints import ( checkpoint_did_not_change, make_checkpoint_load_tests, @@ -62,6 +64,10 @@ def model_trainer(): hypers = copy.deepcopy(DEFAULT_HYPERS) hypers["training"]["num_epochs"] = 1 + loss_hypers = OmegaConf.create({"energy": CONF_LOSS.copy()}) + loss_hypers = OmegaConf.to_container(loss_hypers, resolve=True) + hypers["training"]["loss"] = loss_hypers + trainer = Trainer(hypers["training"]) trainer.train( diff --git a/src/metatrain/pet/tests/test_continue.py b/src/metatrain/pet/tests/test_continue.py index 661161481..1d8e7e6f5 100644 --- a/src/metatrain/pet/tests/test_continue.py +++ b/src/metatrain/pet/tests/test_continue.py @@ -10,6 +10,7 @@ from metatrain.utils.data.target_info import get_energy_target_info from metatrain.utils.io import model_from_checkpoint from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.omegaconf import CONF_LOSS from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS @@ -57,6 +58,10 @@ def test_continue(monkeypatch, tmp_path): hypers = DEFAULT_HYPERS.copy() hypers["training"]["num_epochs"] = 0 + loss_conf = OmegaConf.create({"mtt::U0": CONF_LOSS.copy()}) + OmegaConf.resolve(loss_conf) + hypers["training"]["loss"] = loss_conf + trainer = Trainer(hypers["training"]) trainer.train( model=model, diff --git a/src/metatrain/pet/tests/test_finetuning.py b/src/metatrain/pet/tests/test_finetuning.py index eba2885f1..3235c5659 100644 --- a/src/metatrain/pet/tests/test_finetuning.py +++ b/src/metatrain/pet/tests/test_finetuning.py @@ -12,6 +12,7 @@ from metatrain.utils.data.readers import read_systems, read_targets from metatrain.utils.data.target_info import get_energy_target_info from metatrain.utils.io import model_from_checkpoint +from metatrain.utils.omegaconf import CONF_LOSS from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS @@ -143,6 +144,10 @@ def test_finetuning_restart(monkeypatch, tmp_path): hypers["training"]["num_epochs"] = 1 + loss_conf = OmegaConf.create({"mtt::U0": CONF_LOSS.copy()}) + OmegaConf.resolve(loss_conf) + hypers["training"]["loss"] = loss_conf + # Pre-training trainer = Trainer(hypers["training"]) trainer.train( diff --git a/src/metatrain/pet/tests/test_regression.py b/src/metatrain/pet/tests/test_regression.py index 4b74e04d2..98656fbe8 100644 --- a/src/metatrain/pet/tests/test_regression.py +++ b/src/metatrain/pet/tests/test_regression.py @@ -15,6 +15,7 @@ from metatrain.utils.data.target_info import get_energy_target_info from metatrain.utils.evaluate_model import evaluate_model from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.omegaconf import CONF_LOSS from . import DATASET_PATH, DATASET_WITH_FORCES_PATH, DEFAULT_HYPERS, MODEL_HYPERS @@ -96,6 +97,11 @@ def test_regression_energies_forces_train(device): hypers["training"]["num_epochs"] = 2 hypers["training"]["scheduler_patience"] = 1 hypers["training"]["fixed_composition_weights"] = {} + loss_conf = {"energy": CONF_LOSS.copy()} + loss_conf["energy"]["gradients"] = {"positions": CONF_LOSS.copy()} + loss_conf = OmegaConf.create(loss_conf) + OmegaConf.resolve(loss_conf) + hypers["training"]["loss"] = loss_conf dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[6], targets=target_info_dict @@ -135,7 +141,7 @@ def test_regression_energies_forces_train(device): [0.208536088467, -0.117365449667, -0.278660595417], device=device ) - # if you need to change the hardcoded values: + # # if you need to change the hardcoded values: # torch.set_printoptions(precision=12) # print(output["energy"].block().values) # print(output["energy"].block().gradient("positions").values.squeeze(-1)[0]) diff --git a/src/metatrain/pet/trainer.py b/src/metatrain/pet/trainer.py index 0b72edfd2..1f95c983e 100644 --- a/src/metatrain/pet/trainer.py +++ b/src/metatrain/pet/trainer.py @@ -21,10 +21,9 @@ ) from metatrain.utils.distributed.slurm import DistributedEnvironment from metatrain.utils.evaluate_model import evaluate_model -from metatrain.utils.external_naming import to_external_name from metatrain.utils.io import check_file_extension from metatrain.utils.logging import ROOT_LOGGER, MetricLogger -from metatrain.utils.loss import TensorMapDictLoss +from metatrain.utils.loss import LossAggregator from metatrain.utils.metrics import MAEAccumulator, RMSEAccumulator, get_selected_metric from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, @@ -52,7 +51,7 @@ def func_lr_scheduler(epoch): class Trainer(TrainerInterface): - __checkpoint_version__ = 3 + __checkpoint_version__ = 4 def __init__(self, hypers): super().__init__(hypers) @@ -254,29 +253,22 @@ def train( for gradient_name in target_info.gradients: outputs_list.append(f"{target_name}_{gradient_name}_gradients") - # Create a loss weight dict: - loss_weights_dict = {} - for output_name in outputs_list: - loss_weights_dict[output_name] = ( - self.hypers["loss"]["weights"][ - to_external_name(output_name, train_targets) - ] - if to_external_name(output_name, train_targets) - in self.hypers["loss"]["weights"] - else 1.0 - ) - loss_weights_dict_external = { - to_external_name(key, train_targets): value - for key, value in loss_weights_dict.items() - } - loss_hypers = copy.deepcopy(self.hypers["loss"]) - loss_hypers["weights"] = loss_weights_dict - logging.info(f"Training with loss weights: {loss_weights_dict_external}") - # Create a loss function: - loss_fn = TensorMapDictLoss( - **loss_hypers, + loss_hypers = self.hypers["loss"] + loss_fn = LossAggregator( + targets=train_targets, + config=loss_hypers, ) + logging.info("Using the following loss functions:") + for name, info in loss_fn.metadata.items(): + logging.info(f"{name}:") + main = {k: v for k, v in info.items() if k != "gradients"} + logging.info(main) + if "gradients" not in info or len(info["gradients"]) == 0: + continue + logging.info("With gradients:") + for grad, ginfo in info["gradients"].items(): + logging.info(f"\t{name}::{grad}: {ginfo}") if self.hypers["weight_decay"] is not None: optimizer = torch.optim.AdamW( @@ -370,7 +362,7 @@ def train( predictions, systems, per_structure_targets ) targets = average_by_num_atoms(targets, systems, per_structure_targets) - train_loss_batch = loss_fn(predictions, targets) + train_loss_batch = loss_fn(predictions, targets, extra_data) train_loss_batch.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), self.hypers["grad_clip_norm"] @@ -429,8 +421,7 @@ def train( predictions, systems, per_structure_targets ) targets = average_by_num_atoms(targets, systems, per_structure_targets) - - val_loss_batch = loss_fn(predictions, targets) + val_loss_batch = loss_fn(predictions, targets, extra_data) if is_distributed: # sum the loss over all processes @@ -588,6 +579,7 @@ def load_checkpoint( def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: for v in range(1, cls.__checkpoint_version__): if checkpoint["trainer_ckpt_version"] == v: + print(v, checkpoint["train_hypers"]) update = getattr(checkpoints, f"trainer_update_v{v}_v{v + 1}") update(checkpoint) checkpoint["trainer_ckpt_version"] = v + 1 diff --git a/src/metatrain/soap_bpnn/checkpoints.py b/src/metatrain/soap_bpnn/checkpoints.py index fcb1c8844..8cb290ec0 100644 --- a/src/metatrain/soap_bpnn/checkpoints.py +++ b/src/metatrain/soap_bpnn/checkpoints.py @@ -26,3 +26,18 @@ def model_update_v2_v3(checkpoint): def trainer_update_v1_v2(checkpoint): checkpoint["best_epoch"] = checkpoint.get("best_epoch") + + +def trainer_update_v2_v3(checkpoint): + old_loss_hypers = checkpoint["train_hypers"]["loss"].copy() + dataset_info = checkpoint["model_data"]["dataset_info"] + new_loss_hypers = {} + + for target_name in dataset_info.targets.keys(): + new_loss_hypers[target_name] = { + "type": old_loss_hypers["type"], + "weight": old_loss_hypers["weights"].get(target_name, 1.0), + "reduction": old_loss_hypers["reduction"], + "sliding_factor": old_loss_hypers.get("sliding_factor", None), + } + checkpoint["train_hypers"]["loss"] = new_loss_hypers diff --git a/src/metatrain/soap_bpnn/default-hypers.yaml b/src/metatrain/soap_bpnn/default-hypers.yaml index 5b5082232..3b6105a00 100644 --- a/src/metatrain/soap_bpnn/default-hypers.yaml +++ b/src/metatrain/soap_bpnn/default-hypers.yaml @@ -37,7 +37,4 @@ architecture: log_mae: false log_separate_blocks: false best_model_metric: rmse_prod - loss: - type: mse - weights: {} - reduction: mean + loss: mse diff --git a/src/metatrain/soap_bpnn/schema-hypers.json b/src/metatrain/soap_bpnn/schema-hypers.json index 5534cac93..8872ca26c 100644 --- a/src/metatrain/soap_bpnn/schema-hypers.json +++ b/src/metatrain/soap_bpnn/schema-hypers.json @@ -164,56 +164,13 @@ ] }, "loss": { - "type": "object", - "properties": { - "weights": { - "type": "object", - "patternProperties": { - ".*": { - "type": "number" - } - }, - "additionalProperties": false - }, - "reduction": { - "type": "string", - "enum": [ - "sum", - "mean", - "none" - ] - }, - "type": { - "oneOf": [ - { - "type": "string", - "enum": [ - "mse", - "mae" - ] - }, - { - "type": "object", - "properties": { - "huber": { - "type": "object", - "properties": { - "deltas": { - "type": "object", - "patternProperties": { - ".*": { - "type": "number" - } - }, - "additionalProperties": false - } - }, - "additionalProperties": false - } - }, - "additionalProperties": false - } - ] + "type": [ + "object", + "string" + ], + "patternProperties": { + "^.+$": { + "$ref": "#/definitions/lossTerm" } }, "additionalProperties": false @@ -230,5 +187,47 @@ "uniqueItems": true } }, - "additionalProperties": false + "additionalProperties": false, + "definitions": { + "lossTerm": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "weight": { + "type": "number", + "minimum": 0.0 + }, + "reduction": { + "type": "string", + "enum": [ + "none", + "mean", + "sum" + ] + }, + "sliding_factor": { + "type": [ + "number", + "null" + ], + "minimum": 0.0 + }, + "gradients": { + "type": [ + "object", + "string" + ], + "patternProperties": { + "^.+$": { + "$ref": "#/definitions/lossTerm" + } + }, + "additionalProperties": false + } + }, + "additionalProperties": true + } + } } diff --git a/src/metatrain/soap_bpnn/tests/checkpoints/model-v1-trainer-v1.ckpt.gz b/src/metatrain/soap_bpnn/tests/checkpoints/model-v1-trainer-v1.ckpt.gz new file mode 100644 index 000000000..dbf983525 Binary files /dev/null and b/src/metatrain/soap_bpnn/tests/checkpoints/model-v1-trainer-v1.ckpt.gz differ diff --git a/src/metatrain/soap_bpnn/tests/checkpoints/model-v3_trainer-v3.ckpt.gz b/src/metatrain/soap_bpnn/tests/checkpoints/model-v3_trainer-v3.ckpt.gz new file mode 100644 index 000000000..5ecc76db0 Binary files /dev/null and b/src/metatrain/soap_bpnn/tests/checkpoints/model-v3_trainer-v3.ckpt.gz differ diff --git a/src/metatrain/soap_bpnn/tests/test_checkpoints.py b/src/metatrain/soap_bpnn/tests/test_checkpoints.py index 24821dcc0..4807b89ea 100644 --- a/src/metatrain/soap_bpnn/tests/test_checkpoints.py +++ b/src/metatrain/soap_bpnn/tests/test_checkpoints.py @@ -3,6 +3,7 @@ import pytest import torch +from omegaconf import OmegaConf from metatrain.soap_bpnn import SoapBpnn, Trainer from metatrain.utils.data import ( @@ -11,6 +12,7 @@ get_dataset, ) from metatrain.utils.data.target_info import get_energy_target_info +from metatrain.utils.omegaconf import CONF_LOSS from metatrain.utils.testing.checkpoints import ( checkpoint_did_not_change, make_checkpoint_load_tests, @@ -64,6 +66,10 @@ def model_trainer(): hypers = copy.deepcopy(DEFAULT_HYPERS) hypers["training"]["num_epochs"] = 1 + loss_hypers = OmegaConf.create({"energy": CONF_LOSS.copy()}) + loss_hypers = OmegaConf.to_container(loss_hypers, resolve=True) + hypers["training"]["loss"] = loss_hypers + trainer = Trainer(hypers["training"]) trainer.train( diff --git a/src/metatrain/soap_bpnn/tests/test_continue.py b/src/metatrain/soap_bpnn/tests/test_continue.py index cc437c131..f454cd05f 100644 --- a/src/metatrain/soap_bpnn/tests/test_continue.py +++ b/src/metatrain/soap_bpnn/tests/test_continue.py @@ -13,6 +13,7 @@ get_requested_neighbor_lists, get_system_with_neighbor_lists, ) +from metatrain.utils.omegaconf import CONF_LOSS from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS @@ -65,6 +66,10 @@ def test_continue(monkeypatch, tmp_path): hypers = DEFAULT_HYPERS.copy() hypers["training"]["num_epochs"] = 0 + loss_conf = OmegaConf.create({"mtt::U0": CONF_LOSS.copy()}) + OmegaConf.resolve(loss_conf) + hypers["training"]["loss"] = loss_conf + trainer = Trainer(hypers["training"]) trainer.train( model=model, diff --git a/src/metatrain/soap_bpnn/tests/test_regression.py b/src/metatrain/soap_bpnn/tests/test_regression.py index 7fc8c2c2a..fd1b87d69 100644 --- a/src/metatrain/soap_bpnn/tests/test_regression.py +++ b/src/metatrain/soap_bpnn/tests/test_regression.py @@ -14,6 +14,7 @@ get_requested_neighbor_lists, get_system_with_neighbor_lists, ) +from metatrain.utils.omegaconf import CONF_LOSS from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS, SPHERICAL_DISK_DATASET_PATH @@ -95,6 +96,9 @@ def test_regression_train(device): hypers = DEFAULT_HYPERS.copy() hypers["training"]["num_epochs"] = 2 + loss_conf = OmegaConf.create({"mtt::U0": CONF_LOSS.copy()}) + OmegaConf.resolve(loss_conf) + hypers["training"]["loss"] = loss_conf dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict @@ -175,14 +179,15 @@ def test_regression_train_spherical(device): }, }, } - # targets, target_info_dict = read_targets(OmegaConf.create(conf)) - # dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) dataset, target_info_dict, _ = get_dataset(conf) hypers = DEFAULT_HYPERS.copy() hypers["training"]["num_epochs"] = 2 hypers["training"]["batch_size"] = 1 + hypers["training"]["loss"]["mtt::electron_density_basis"] = hypers["training"][ + "loss" + ].pop("mtt::U0") dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict @@ -212,7 +217,7 @@ def test_regression_train_spherical(device): systems, { "mtt::electron_density_basis": ModelOutput( - quantity="energy", unit="", per_atom=True + quantity="", unit="", per_atom=True ) }, ) @@ -220,13 +225,13 @@ def test_regression_train_spherical(device): expected_output = torch.tensor( [ [ - -0.038565825671, - 0.000463733566, - 0.000264365954, - 0.023815866560, - 0.018959790468, - -0.000692606962, - 0.020604602993, + -0.057801067829, + -0.022922841832, + -0.013157465495, + 0.003876133356, + 0.008559819311, + 0.039749406278, + 0.013140974566, ], [ 0.000000000000, @@ -238,13 +243,13 @@ def test_regression_train_spherical(device): 0.000000000000, ], [ - 0.024628046900, - -0.001363838091, - 0.003145742230, - -0.024710856378, - -0.010125328787, - -0.015510082245, - -0.014338681474, + 0.020274644718, + 0.005657686852, + 0.001842519501, + -0.014781145379, + 0.003011089284, + -0.011008090340, + -0.012492593378, ], ], device=device, diff --git a/src/metatrain/soap_bpnn/trainer.py b/src/metatrain/soap_bpnn/trainer.py index 59608d173..66d073c54 100644 --- a/src/metatrain/soap_bpnn/trainer.py +++ b/src/metatrain/soap_bpnn/trainer.py @@ -20,10 +20,9 @@ ) from metatrain.utils.distributed.slurm import DistributedEnvironment from metatrain.utils.evaluate_model import evaluate_model -from metatrain.utils.external_naming import to_external_name from metatrain.utils.io import check_file_extension from metatrain.utils.logging import ROOT_LOGGER, MetricLogger -from metatrain.utils.loss import TensorMapDictLoss +from metatrain.utils.loss import LossAggregator from metatrain.utils.metrics import MAEAccumulator, RMSEAccumulator, get_selected_metric from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, @@ -40,7 +39,7 @@ class Trainer(TrainerInterface): - __checkpoint_version__ = 2 + __checkpoint_version__ = 3 def __init__(self, hypers): super().__init__(hypers) @@ -235,29 +234,23 @@ def train( outputs_list.append(target_name) for gradient_name in target_info.gradients: outputs_list.append(f"{target_name}_{gradient_name}_gradients") - # Create a loss weight dict: - loss_weights_dict = {} - for output_name in outputs_list: - loss_weights_dict[output_name] = ( - self.hypers["loss"]["weights"][ - to_external_name(output_name, train_targets) - ] - if to_external_name(output_name, train_targets) - in self.hypers["loss"]["weights"] - else 1.0 - ) - loss_weights_dict_external = { - to_external_name(key, train_targets): value - for key, value in loss_weights_dict.items() - } - loss_hypers = copy.deepcopy(self.hypers["loss"]) - loss_hypers["weights"] = loss_weights_dict - logging.info(f"Training with loss weights: {loss_weights_dict_external}") # Create a loss function: - loss_fn = TensorMapDictLoss( - **loss_hypers, + loss_hypers = self.hypers["loss"] + loss_fn = LossAggregator( + targets=train_targets, + config=loss_hypers, ) + logging.info("Using the following loss functions:") + for name, info in loss_fn.metadata.items(): + logging.info(f"{name}:") + main = {k: v for k, v in info.items() if k != "gradients"} + logging.info(main) + if "gradients" not in info or len(info["gradients"]) == 0: + continue + logging.info("With gradients:") + for grad, ginfo in info["gradients"].items(): + logging.info(f"\t{name}::{grad}: {ginfo}") # Create an optimizer: optimizer = torch.optim.Adam( @@ -343,7 +336,7 @@ def train( ) targets = average_by_num_atoms(targets, systems, per_structure_targets) - train_loss_batch = loss_fn(predictions, targets) + train_loss_batch = loss_fn(predictions, targets, extra_data) train_loss_batch.backward() optimizer.step() @@ -402,7 +395,7 @@ def train( ) targets = average_by_num_atoms(targets, systems, per_structure_targets) - val_loss_batch = loss_fn(predictions, targets) + val_loss_batch = loss_fn(predictions, targets, extra_data) if is_distributed: # sum the loss over all processes diff --git a/src/metatrain/utils/augmentation.py b/src/metatrain/utils/augmentation.py index beb8a57cb..293f37a69 100644 --- a/src/metatrain/utils/augmentation.py +++ b/src/metatrain/utils/augmentation.py @@ -133,6 +133,9 @@ def apply_random_augmentations( ) for tensormap_dict, info_dict in zip(tensormap_dicts, info_dicts): for name in tensormap_dict.keys(): + if name.endswith("_mask"): + # skip loss masks + continue tensormap_info = info_dict[name] if tensormap_info.is_spherical: for block in tensormap_info.layout.blocks(): @@ -246,6 +249,15 @@ def _apply_random_augmentations( new_targets: Dict[str, TensorMap] = {} new_extra_data: Dict[str, TensorMap] = {} + # Do not transform any masks present in extra_data + if extra_data is not None: + mask_keys: List[str] = [] + for key in extra_data.keys(): + if key.endswith("_mask"): + mask_keys.append(key) + for key in mask_keys: + new_extra_data[key] = extra_data.pop(key) + for tensormap_dict, new_dict in zip( [targets, extra_data], [new_targets, new_extra_data] ): diff --git a/src/metatrain/utils/io.py b/src/metatrain/utils/io.py index 6fe56b417..284837bd2 100644 --- a/src/metatrain/utils/io.py +++ b/src/metatrain/utils/io.py @@ -181,15 +181,15 @@ def model_from_checkpoint( architecture = import_architecture(architecture_name) model_ckpt_version = checkpoint.get("model_ckpt_version") - ckpt_before_versionning = model_ckpt_version is None - if ckpt_before_versionning: + ckpt_before_versioning = model_ckpt_version is None + if ckpt_before_versioning: # assume version 1 and try our best model_ckpt_version = 1 checkpoint["model_ckpt_version"] = model_ckpt_version if model_ckpt_version != architecture.__model__.__checkpoint_version__: try: - if ckpt_before_versionning: + if ckpt_before_versioning: warnings.warn( "trying to upgrade an old model checkpoint with unknown " "version, this might fail and require manual modifications", @@ -229,15 +229,16 @@ def trainer_from_checkpoint( architecture = import_architecture(architecture_name) trainer_ckpt_version = checkpoint.get("trainer_ckpt_version") - ckpt_before_versionning = trainer_ckpt_version is None - if ckpt_before_versionning: + + ckpt_before_versioning = trainer_ckpt_version is None + if ckpt_before_versioning: # assume version 1 and try our best trainer_ckpt_version = 1 checkpoint["trainer_ckpt_version"] = trainer_ckpt_version if trainer_ckpt_version != architecture.__trainer__.__checkpoint_version__: try: - if ckpt_before_versionning: + if ckpt_before_versioning: warnings.warn( "trying to upgrade an old trainer checkpoint with unknown " "version, this might fail and require manual modifications", diff --git a/src/metatrain/utils/loss.py b/src/metatrain/utils/loss.py index 726e3bab4..da33e3509 100644 --- a/src/metatrain/utils/loss.py +++ b/src/metatrain/utils/loss.py @@ -1,348 +1,748 @@ -from typing import Dict, Optional, Tuple, Union +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, Dict, Optional, Type +import metatensor.torch as mts import torch from metatensor.torch import TensorMap -from omegaconf import DictConfig from torch.nn.modules.loss import _Loss -from metatrain.utils.external_naming import to_internal_name +from metatrain.utils.data import TargetInfo -class TensorMapLoss: - """A loss function that operates on two ``metatensor.torch.TensorMap``. - - The loss is computed as the sum of the loss on the block values and - the loss on the gradients, with weights specified at initialization. - - At the moment, this loss function assumes that all the gradients - declared at initialization are present in both TensorMaps. - - :param reduction: The reduction to apply to the loss. - See :py:class:`torch.nn.MSELoss`. - :param weight: The weight to apply to the loss on the block values. - :param gradient_weights: The weights to apply to the loss on the gradients. - :param sliding_factor: The factor to apply to the exponential moving average - of the "sliding" weights. These are weights that act on different components of - the loss (for example, energies and forces), based on their individual recent - history. If ``None``, no sliding weights are used in the computation of the - loss. - :param type: The type of loss to use. This can be either "mse" or "mae". - A Huber loss can also be requested as a dictionary with the key "huber" and - the value must be a dictionary with the key "deltas" and the value - must be a dictionary with the keys "values" and the gradient keys. - The values of the dictionary must be the deltas to use for the - Huber loss. +class LossInterface(ABC): + """ + Abstract base for all loss functions. - :returns: The loss as a zero-dimensional :py:class:`torch.Tensor` - (with one entry). + Subclasses must implement the ``compute`` method. """ + weight: float + reduction: str + loss_kwargs: Dict[str, Any] + target: str + gradient: Optional[str] + def __init__( self, - reduction: str = "mean", - weight: float = 1.0, - gradient_weights: Optional[Dict[str, float]] = None, - sliding_factor: Optional[float] = None, - type: Union[str, dict] = "mse", - ): - if gradient_weights is None: - gradient_weights = {} - - losses = {} - if type == "mse": - losses["values"] = torch.nn.MSELoss(reduction=reduction) - for key in gradient_weights.keys(): - losses[key] = torch.nn.MSELoss(reduction=reduction) - elif type == "mae": - losses["values"] = torch.nn.L1Loss(reduction=reduction) - for key in gradient_weights.keys(): - losses[key] = torch.nn.L1Loss(reduction=reduction) - elif isinstance(type, dict) and "huber" in type: - # Huber loss - deltas = type["huber"]["deltas"] - losses["values"] = torch.nn.HuberLoss( - reduction=reduction, delta=deltas["values"] - ) - for key in gradient_weights.keys(): - losses[key] = torch.nn.HuberLoss(reduction=reduction, delta=deltas[key]) - else: - raise ValueError(f"Unknown loss type: {type}") - - self.losses = losses + name: str, + gradient: Optional[str], + weight: float, + reduction: str, + ) -> None: + """ + :param name: key in the predictions/targets dict to select the TensorMap. + :param gradient: optional name of a gradient field to extract. + :param weight: multiplicative weight (used by ScheduledLoss). + :param reduction: reduction mode for torch losses ("mean", "sum", etc.). + """ + self.target = name + self.gradient = gradient self.weight = weight - self.gradient_weights = gradient_weights - self.sliding_factor = sliding_factor - self.sliding_weights: Optional[Dict[str, TensorMap]] = None + self.reduction = reduction + self.loss_kwargs = {} + super().__init__() + + @abstractmethod + def compute( + self, + predictions: Dict[str, TensorMap], + targets: Dict[str, TensorMap], + extra_data: Optional[Any] = None, + ) -> torch.Tensor: + """ + Compute the loss. + + :param predictions: mapping from target names to :py:class:`TensorMap`. + :param targets: mapping from target names to :py:class:`TensorMap`. + :param extra_data: optional additional data (e.g., masks). + :return: scalar torch.Tensor representing the loss. + """ + ... def __call__( self, - predictions_tensor_map: TensorMap, - targets_tensor_map: TensorMap, - ) -> Tuple[torch.Tensor, Dict[str, Tuple[float, int]]]: - # Check that the two have the same metadata, except for the samples, - # which can be different due to batching, but must have the same size: - if predictions_tensor_map.keys != targets_tensor_map.keys: - raise ValueError( - "TensorMapLoss requires the two TensorMaps to have the same keys." - ) - for block_1, block_2 in zip( - predictions_tensor_map.blocks(), targets_tensor_map.blocks() - ): - if block_1.properties != block_2.properties: - raise ValueError( - "TensorMapLoss requires the two TensorMaps to have the same " - "properties." - ) - if block_1.components != block_2.components: - raise ValueError( - "TensorMapLoss requires the two TensorMaps to have the same " - "components." + predictions: Dict[str, TensorMap], + targets: Dict[str, TensorMap], + extra_data: Optional[Any] = None, + ) -> torch.Tensor: + """ + Alias to compute() for direct invocation. + """ + return self.compute(predictions, targets, extra_data) + + @classmethod + def from_config(cls, cfg: Dict[str, Any]) -> "LossInterface": + """ + Instantiate a loss from a config dict. + + :param cfg: keyword args matching the loss constructor. + :return: instance of a LossInterface subclass. + """ + return cls(**cfg) + + +# --- scheduler interface and implementations ------------------------------------------ + + +class WeightScheduler(ABC): + """ + Abstract interface for scheduling a weight for a :py:class:`LossInterface`. + """ + + initialized: bool = False + + @abstractmethod + def initialize( + self, loss_fn: LossInterface, targets: Dict[str, TensorMap] + ) -> float: + """ + Compute and return the initial weight. + + :param loss_fn: the base loss to initialize. + :param targets: mapping of target names to :py:class:`TensorMap`. + :return: initial weight as a float. + """ + + @abstractmethod + def update( + self, + loss_fn: LossInterface, + predictions: Dict[str, TensorMap], + targets: Dict[str, TensorMap], + ) -> float: + """ + Update and return the new weight after a batch. + + :param loss_fn: the base loss. + :param predictions: mapping of target names to :py:class:`TensorMap`. + :param targets: mapping of target names to :py:class:`TensorMap`. + :return: updated weight as a float. + """ + + +class EMAScheduler(WeightScheduler): + """ + Exponential moving average scheduler for loss weights. + """ + + EPSILON = 1e-6 + + def __init__(self, sliding_factor: Optional[float]) -> None: + """ + :param sliding_factor: factor in [0,1] for EMA (0 disables scheduling). + """ + self.sliding_factor = float(sliding_factor or 0.0) + self.current_weight = 1.0 + self.initialized = False + + def initialize( + self, loss_fn: LossInterface, targets: Dict[str, TensorMap] + ) -> float: + # If scheduling disabled, keep weight = 1.0 + if self.sliding_factor <= 0.0: + self.current_weight = 1.0 + else: + # Compute a baseline loss against a constant mean or zero-gradient map + target_name = loss_fn.target + gradient_name = getattr(loss_fn, "gradient", None) + tensor_map_for_target = targets[target_name] + + if gradient_name is None: + # Create a baseline TensorMap with all values = mean over samples + mean_tensor_map = mts.mean_over_samples( + tensor_map_for_target, tensor_map_for_target.sample_names ) - if len(block_1.samples) != len(block_2.samples): - raise ValueError( - "TensorMapLoss requires the two TensorMaps " - "to have the same number of samples." + baseline_tensor_map = TensorMap( + keys=tensor_map_for_target.keys, + blocks=[ + mts.TensorBlock( + samples=block.samples, + components=block.components, + properties=block.properties, + values=torch.ones_like(block.values) * mean_block.values, + ) + for block, mean_block in zip( + tensor_map_for_target, mean_tensor_map + ) + ], ) - for gradient_name in block_2.gradients_list(): - if len(block_1.gradient(gradient_name).samples) != len( - block_2.gradient(gradient_name).samples - ): - raise ValueError( - "TensorMapLoss requires the two TensorMaps " - "to have the same number of gradient samples." - ) - if ( - block_1.gradient(gradient_name).properties - != block_2.gradient(gradient_name).properties - ): - raise ValueError( - "TensorMapLoss requires the two TensorMaps " - "to have the same gradient properties." - ) - if ( - block_1.gradient(gradient_name).components - != block_2.gradient(gradient_name).components - ): - raise ValueError( - "TensorMapLoss requires the two TensorMaps " - "to have the same gradient components." - ) + else: + # Zero baseline for gradient-based losses + baseline_tensor_map = mts.zeros_like(tensor_map_for_target) - # First time the function is called: compute the sliding weights only - # from the targets (if they are enabled) - if self.sliding_factor is not None and self.sliding_weights is None: - self.sliding_weights = get_sliding_weights( - self.losses, - self.sliding_factor, - targets_tensor_map, + initial_loss_value = loss_fn.compute( + {target_name: tensor_map_for_target}, {target_name: baseline_tensor_map} ) + self.current_weight = float(initial_loss_value.clamp_min(self.EPSILON)) + + self.initialized = True + return self.current_weight - # Compute the loss: - loss = torch.zeros( - (), - dtype=predictions_tensor_map.block(0).values.dtype, - device=predictions_tensor_map.block(0).values.device, + def update( + self, + loss_fn: LossInterface, + predictions: Dict[str, TensorMap], + targets: Dict[str, TensorMap], + ) -> float: + # If scheduling disabled, return fixed weight + if self.sliding_factor <= 0.0: + return self.current_weight + + # Compute the instantaneous error + instantaneous_error = loss_fn.compute(predictions, targets).detach().item() + # EMA update + new_weight = ( + self.sliding_factor * self.current_weight + + (1.0 - self.sliding_factor) * instantaneous_error ) - for key in targets_tensor_map.keys: - block_1 = predictions_tensor_map.block(key) - block_2 = targets_tensor_map.block(key) - values_1 = block_1.values - values_2 = block_2.values - # sliding weights: default to 1.0 if not used/provided for this target - sliding_weight = ( - 1.0 - if self.sliding_weights is None - else self.sliding_weights.get("values", 1.0) - ) - loss += ( - self.weight * self.losses["values"](values_1, values_2) / sliding_weight - ) - for gradient_name in block_2.gradients_list(): - gradient_weight = self.gradient_weights[gradient_name] - values_1 = block_1.gradient(gradient_name).values - values_2 = block_2.gradient(gradient_name).values - # sliding weights: default to 1.0 if not used/provided for this target - sliding_weigths_value = ( - 1.0 - if self.sliding_weights is None - else self.sliding_weights.get(gradient_name, 1.0) - ) - loss += ( - gradient_weight - * self.losses[gradient_name](values_1, values_2) - / sliding_weigths_value - ) - if self.sliding_factor is not None: - self.sliding_weights = get_sliding_weights( - self.losses, - self.sliding_factor, - targets_tensor_map, - predictions_tensor_map, - self.sliding_weights, + self.current_weight = max(new_weight, self.EPSILON) + return self.current_weight + + +class ScheduledLoss(LossInterface): + """ + Wrap a base :py:class:`LossInterface` with a :py:class:`WeightScheduler`. + After each compute, the scheduler updates the loss weight. + """ + + def __init__(self, base_loss: LossInterface, weight_scheduler: WeightScheduler): + """ + :param base_loss: underlying LossInterface to wrap. + :param weight_scheduler: scheduler that controls the multiplier. + """ + super().__init__( + base_loss.target, + base_loss.gradient, + base_loss.weight, + base_loss.reduction, + ) + self.base_loss = base_loss + self.scheduler = weight_scheduler + self.loss_kwargs = getattr(base_loss, "loss_kwargs", {}) + + def compute( + self, + predictions: Dict[str, TensorMap], + targets: Dict[str, TensorMap], + extra_data: Optional[Any] = None, + ) -> torch.Tensor: + # Initialize scheduler on first call + if not self.scheduler.initialized: + self.normalization_factor = self.scheduler.initialize( + self.base_loss, targets ) - return loss + # compute the raw loss using the base loss function + raw_loss_value = self.base_loss.compute(predictions, targets, extra_data) + + # scale by the fixed weight and divide by the sliding weight + weighted_loss_value = raw_loss_value * ( + self.base_loss.weight / self.normalization_factor + ) + + # update the sliding weight + self.normalization_factor = self.scheduler.update( + self.base_loss, predictions, targets + ) -class TensorMapDictLoss: - """A loss function that operates on two ``Dict[str, metatensor.torch.TensorMap]``. + return weighted_loss_value - At initialization, the user specifies a list of keys to use for the loss, - along with a weight for each key. - The loss is then computed as a weighted sum. Any keys that are not present - in the dictionaries are ignored. +# --- specific losses ------------------------------------------------------------------ - :param weights: A dictionary mapping keys to weights. This might contain - gradient keys, in the form ``__gradients``. - :param sliding_factor: The factor to apply to the exponential moving average - of the "sliding" weights. These are weights that act on different components of - the loss (for example, energies and forces), based on their individual recent - history. If ``None``, no sliding weights are used in the computation of the - loss. - :param reduction: The reduction to apply to the loss. - See :py:class:`torch.nn.MSELoss`. - :returns: The loss as a zero-dimensional :py:class:`torch.Tensor` - (with one entry). +class BaseTensorMapLoss(LossInterface): + """ + Backbone for pointwise losses on :py:class:`TensorMap` entries. + + Provides a compute_flattened() helper that extracts values or gradients, + flattens them, applies an optional mask, and computes the torch loss. """ def __init__( self, - weights: Dict[str, float], - sliding_factor: Optional[float] = None, - reduction: str = "mean", - type: Union[str, dict] = "mse", + name: str, + gradient: Optional[str], + weight: float, + reduction: str, + *, + loss_fn: _Loss, ): - outputs = [key for key in weights.keys() if "gradients" not in key] - self.losses = {} - for output in outputs: - value_weight = weights[output] - gradient_weights = {} - for key, weight in weights.items(): - if key.startswith(output) and key.endswith("_gradients"): - gradient_name = key.replace(f"{output}_", "").replace( - "_gradients", "" - ) - gradient_weights[gradient_name] = weight - type_output = _process_type(type, output) - if output == "energy" and sliding_factor is not None: - self.losses[output] = TensorMapLoss( - reduction=reduction, - weight=value_weight, - gradient_weights=gradient_weights, - sliding_factor=sliding_factor, - type=type_output, - ) + """ + :param name: key in the predictions/targets dict. + :param gradient: optional gradient field name. + :param weight: dummy here; real weighting in ScheduledLoss. + :param reduction: reduction mode for torch loss. + :param loss_fn: pre-instantiated torch.nn loss (e.g. MSELoss). + """ + super().__init__(name, gradient, weight, reduction) + self.torch_loss = loss_fn + + def compute_flattened( + self, + tensor_map_predictions_for_target: TensorMap, + tensor_map_targets_for_target: TensorMap, + tensor_map_mask_for_target: Optional[TensorMap] = None, + ) -> torch.Tensor: + """ + Flatten prediction and target blocks (and optional mask), then + apply the torch loss. + + :param tensor_map_predictions_for_target: predicted :py:class:`TensorMap`. + :param tensor_map_targets_for_target: target :py:class:`TensorMap`. + :param tensor_map_mask_for_target: optional mask :py:class:`TensorMap`. + :return: scalar torch.Tensor of the computed loss. + """ + list_of_prediction_segments = [] + list_of_target_segments = [] + + def extract_flattened_values_from_block( + tensor_block: mts.TensorBlock, + ) -> torch.Tensor: + """ + Extract values or gradients from a block, flatten to 1D. + """ + if self.gradient is not None: + values = tensor_block.gradient(self.gradient).values else: - self.losses[output] = TensorMapLoss( - reduction=reduction, - weight=value_weight, - gradient_weights=gradient_weights, - type=type_output, - ) + values = tensor_block.values + return values.reshape(-1) - def __call__( + # Loop over each key in the TensorMap + for single_key in tensor_map_predictions_for_target.keys: + block_for_prediction = tensor_map_predictions_for_target.block(single_key) + block_for_target = tensor_map_targets_for_target.block(single_key) + + flattened_prediction = extract_flattened_values_from_block( + block_for_prediction + ) + flattened_target = extract_flattened_values_from_block(block_for_target) + + if tensor_map_mask_for_target is not None: + # Apply boolean mask if provided + block_for_mask = tensor_map_mask_for_target.block(single_key) + flattened_mask = extract_flattened_values_from_block( + block_for_mask + ).bool() + flattened_prediction = flattened_prediction[flattened_mask] + flattened_target = flattened_target[flattened_mask] + + list_of_prediction_segments.append(flattened_prediction) + list_of_target_segments.append(flattened_target) + + # Concatenate all segments and apply the torch loss + all_predictions_flattened = torch.cat(list_of_prediction_segments) + all_targets_flattened = torch.cat(list_of_target_segments) + return self.torch_loss(all_predictions_flattened, all_targets_flattened) + + def compute( self, - tensor_map_dict_1: Dict[str, TensorMap], - tensor_map_dict_2: Dict[str, TensorMap], + predictions: Dict[str, TensorMap], + targets: Dict[str, TensorMap], + extra_data: Optional[Any] = None, ) -> torch.Tensor: - # Assert that the two have the keys: - assert set(tensor_map_dict_1.keys()) == set(tensor_map_dict_2.keys()) + """ + Compute the unmasked pointwise loss. + + :param predictions: mapping of names to :py:class:`TensorMap`. + :param targets: mapping of names to :py:class:`TensorMap`. + :param extra_data: ignored for unmasked losses. + :return: scalar torch.Tensor loss. + """ + tensor_map_pred = predictions[self.target] + tensor_map_targ = targets[self.target] + + # Check gradients are present in the target TensorMap + if self.gradient is not None: + if self.gradient not in tensor_map_targ[0].gradients_list(): + # Skip loss computation if block gradient is missing in the dataset + # Tensor gradients are not tracked + return torch.zeros( + (), dtype=torch.float, device=tensor_map_targ[0].values.device + ) + return self.compute_flattened(tensor_map_pred, tensor_map_targ) + - # Initialize the loss: - first_values = next(iter(tensor_map_dict_1.values())).block(0).values - loss = torch.zeros((), dtype=first_values.dtype, device=first_values.device) +class MaskedTensorMapLoss(BaseTensorMapLoss): + """ + Pointwise masked loss on :py:class:`TensorMap` entries. + + Inherits flattening and torch-loss logic from BaseTensorMapLoss. + """ - # Compute the loss: - for target in tensor_map_dict_1.keys(): - target_loss = self.losses[target]( - tensor_map_dict_1[target], tensor_map_dict_2[target] + def compute( + self, + predictions: Dict[str, TensorMap], + targets: Dict[str, TensorMap], + extra_data: Optional[Dict[str, TensorMap]] = None, + ) -> torch.Tensor: + """ + Gather and flatten target and prediction blocks, then compute loss. + + :param predictions: Mapping from target names to TensorMaps. + :param targets: Mapping from target names to TensorMaps. + :param extra_data: Additional data for loss computation. Assumes that, for the + target ``name`` used in the constructor, there is a corresponding data field + ``name + "_mask"`` that contains the tensor to be used for masking. It + should have the same metadata as the target and prediction tensors. + :return: Scalar loss tensor. + """ + mask_key = f"{self.target}_mask" + if extra_data is None or mask_key not in extra_data: + raise ValueError( + f"Expected extra_data to contain TensorMap under '{mask_key}'" ) - loss += target_loss + tensor_map_pred = predictions[self.target] + tensor_map_targ = targets[self.target] + tensor_map_mask = extra_data[mask_key] + return self.compute_flattened(tensor_map_pred, tensor_map_targ, tensor_map_mask) + + +# ------------------------------------------------------------------------ +# Simple explicit subclasses for common pointwise losses +# ------------------------------------------------------------------------ + + +class TensorMapMSELoss(BaseTensorMapLoss): + """ + Unmasked mean-squared error on :py:class:`TensorMap` entries. + """ + + def __init__( + self, + name: str, + gradient: Optional[str], + weight: float, + reduction: str, + ): + super().__init__( + name, + gradient, + weight, + reduction, + loss_fn=torch.nn.MSELoss(reduction=reduction), + ) + + +class TensorMapMAELoss(BaseTensorMapLoss): + """ + Unmasked mean-absolute error on :py:class:`TensorMap` entries. + """ + + def __init__( + self, + name: str, + gradient: Optional[str], + weight: float, + reduction: str, + ): + super().__init__( + name, + gradient, + weight, + reduction, + loss_fn=torch.nn.L1Loss(reduction=reduction), + ) + + +class TensorMapHuberLoss(BaseTensorMapLoss): + """ + Unmasked Huber loss on :py:class:`TensorMap` entries. + + :param delta: threshold parameter for HuberLoss. + """ + + def __init__( + self, + name: str, + gradient: Optional[str], + weight: float, + reduction: str, + delta: float, + ): + super().__init__( + name, + gradient, + weight, + reduction, + loss_fn=torch.nn.HuberLoss(reduction=reduction, delta=delta), + ) + + +class TensorMapMaskedMSELoss(MaskedTensorMapLoss): + """ + Masked mean-squared error on :py:class:`TensorMap` entries. + """ + + def __init__( + self, + name: str, + gradient: Optional[str], + weight: float, + reduction: str, + ): + super().__init__( + name, + gradient, + weight, + reduction, + loss_fn=torch.nn.MSELoss(reduction=reduction), + ) - return loss +class TensorMapMaskedMAELoss(MaskedTensorMapLoss): + """ + Masked mean-absolute error on :py:class:`TensorMap` entries. + """ + + def __init__( + self, + name: str, + gradient: Optional[str], + weight: float, + reduction: str, + ): + super().__init__( + name, + gradient, + weight, + reduction, + loss_fn=torch.nn.L1Loss(reduction=reduction), + ) + + +class TensorMapMaskedHuberLoss(MaskedTensorMapLoss): + """ + Masked Huber loss on :py:class:`TensorMap` entries. -def get_sliding_weights( - losses: Dict[str, _Loss], - sliding_factor: float, - targets: TensorMap, - predictions: Optional[TensorMap] = None, - previous_sliding_weights: Optional[Dict[str, float]] = None, -) -> Dict[str, float]: + :param delta: threshold parameter for HuberLoss. """ - Compute the sliding weights for the loss function. - The sliding weights are computed as the absolute difference between the - predictions and the targets. + def __init__( + self, + name: str, + gradient: Optional[str], + weight: float, + reduction: str, + delta: float, + ): + super().__init__( + name, + gradient, + weight, + reduction, + loss_fn=torch.nn.HuberLoss(reduction=reduction, delta=delta), + ) + - :param predictions: The predictions. - :param targets: The targets. +# --- aggregator ----------------------------------------------------------------------- - :return: The sliding weights. + +class LossAggregator(LossInterface): + """ + Aggregate multiple :py:class:`LossInterface` terms with scheduled weights and + metadata. """ - sliding_weights = {} - if predictions is None: - for block in targets.blocks(): - values = block.values - sliding_weights["values"] = ( - losses["values"](values, values.mean() * torch.ones_like(values)) + 1e-6 + + def __init__( + self, targets: Dict[str, TargetInfo], config: Dict[str, Dict[str, Any]] + ): + """ + :param targets: mapping from target names to :py:class:`TargetInfo`. + :param config: per-target configuration dict. + """ + super().__init__(name="", gradient=None, weight=0.0, reduction="mean") + self.scheduled_losses: Dict[str, ScheduledLoss] = {} + self.metadata: Dict[str, Dict[str, Any]] = {} + + for target_name, target_info in targets.items(): + target_config = config[target_name] + + # Create main loss and its scheduler + base_loss = create_loss( + target_config["type"], + name=target_name, + gradient=None, + weight=target_config["weight"], + reduction=target_config["reduction"], + **{ + pname: pval + for pname, pval in target_config.items() + if pname + not in ( + "type", + "weight", + "reduction", + "sliding_factor", + "gradients", + ) + }, ) - for gradient_name, gradient_block in block.gradients(): - values = gradient_block.values - sliding_weights[gradient_name] = losses[gradient_name]( - values, torch.zeros_like(values) + ema_scheduler = EMAScheduler(target_config["sliding_factor"]) + scheduled_main_loss = ScheduledLoss(base_loss, ema_scheduler) + self.scheduled_losses[target_name] = scheduled_main_loss + self.metadata[target_name] = { + "type": target_config["type"], + "weight": base_loss.weight, + "reduction": base_loss.reduction, + "sliding_factor": target_config["sliding_factor"], + "gradients": {}, + } + for pname, pval in target_config.items(): + if pname not in ( + "type", + "weight", + "reduction", + "sliding_factor", + "gradients", + ): + self.metadata[target_name][pname] = pval + + # Create gradient-based losses + gradient_config = target_config["gradients"] + for gradient_name in target_info.layout[0].gradients_list(): + gradient_key = f"{target_name}_grad_{gradient_name}" + + gradient_specific_config = gradient_config[gradient_name] + + grad_loss = create_loss( + gradient_specific_config["type"], + name=target_name, + gradient=gradient_name, + weight=gradient_specific_config["weight"], + reduction=gradient_specific_config["reduction"], + **{ + pname: pval + for pname, pval in gradient_specific_config.items() + if pname + not in ( + "type", + "weight", + "reduction", + "sliding_factor", + "gradients", + ) + }, ) - elif predictions is not None: - if previous_sliding_weights is None: - raise RuntimeError( - "previous_sliding_weights must be provided if predictions is not None" + ema_scheduler_for_grad = EMAScheduler(target_config["sliding_factor"]) + scheduled_grad_loss = ScheduledLoss(grad_loss, ema_scheduler_for_grad) + self.scheduled_losses[gradient_key] = scheduled_grad_loss + self.metadata[target_name]["gradients"][gradient_name] = { + "type": gradient_specific_config["type"], + "weight": grad_loss.weight, + "reduction": grad_loss.reduction, + "sliding_factor": target_config["sliding_factor"], + } + for pname, pval in gradient_specific_config.items(): + if pname not in ( + "type", + "weight", + "reduction", + "sliding_factor", + "gradients", + ): + self.metadata[target_name]["gradients"][gradient_name][ + pname + ] = pval + + def compute( + self, + predictions: Dict[str, TensorMap], + targets: Dict[str, TensorMap], + extra_data: Optional[Any] = None, + ) -> torch.Tensor: + """ + Sum over all scheduled losses present in the predictions. + """ + # Initialize a zero tensor matching the dtype and device of the first block + first_tensor_map = next(iter(predictions.values())) + first_block = first_tensor_map.block(first_tensor_map.keys[0]) + total_loss = torch.zeros( + (), dtype=first_block.values.dtype, device=first_block.values.device + ) + + # Sum each scheduled term that has a matching prediction + for scheduled_term in self.scheduled_losses.values(): + if scheduled_term.target not in predictions: + continue + total_loss = total_loss + scheduled_term.compute( + predictions, targets, extra_data ) - else: - for predictions_block, target_block in zip( - predictions.blocks(), targets.blocks() - ): - target_values = target_block.values - predictions_values = predictions_block.values - sliding_weights["values"] = ( - sliding_factor * previous_sliding_weights["values"] - + (1 - sliding_factor) - * losses["values"](predictions_values, target_values).detach() - ) - for gradient_name, gradient_block in target_block.gradients(): - target_values = gradient_block.values - predictions_values = predictions_block.gradient( - gradient_name - ).values - sliding_weights[gradient_name] = ( - sliding_factor * previous_sliding_weights[gradient_name] - + (1 - sliding_factor) - * losses[gradient_name]( - predictions_values, target_values - ).detach() - ) - return sliding_weights - - -def _process_type(type: Union[str, DictConfig], output: str) -> Union[str, dict]: - if not isinstance(type, str): - assert "huber" in type - # we process the Huber loss delta dict to make it similar to the - # `weights` dict - type_output = {"huber": {"deltas": {}}} # type: ignore - for key, delta in type["huber"]["deltas"].items(): - key_internal = to_internal_name(key) - if key_internal == output: - type_output["huber"]["deltas"]["values"] = delta - elif key_internal.startswith(output) and key_internal.endswith( - "_gradients" - ): - gradient_name = key_internal.replace(f"{output}_", "").replace( - "_gradients", "" - ) - type_output["huber"]["deltas"][gradient_name] = delta - else: - pass - else: - type_output = type # type: ignore - return type_output + + return total_loss + + +class LossType(Enum): + """ + Enumeration of available loss types and their implementing classes. + """ + + MSE = ("mse", TensorMapMSELoss) + MAE = ("mae", TensorMapMAELoss) + HUBER = ("huber", TensorMapHuberLoss) + MASKED_MSE = ("masked_mse", TensorMapMaskedMSELoss) + MASKED_MAE = ("masked_mae", TensorMapMaskedMAELoss) + MASKED_HUBER = ("masked_huber", TensorMapMaskedHuberLoss) + POINTWISE = ("pointwise", BaseTensorMapLoss) + MASKED_POINTWISE = ("masked_pointwise", MaskedTensorMapLoss) + + def __init__(self, key: str, cls: Type[LossInterface]): + self._key = key + self._cls = cls + + @property + def key(self) -> str: + """String key for this loss type.""" + return self._key + + @property + def cls(self) -> Type[LossInterface]: + """Class implementing this loss type.""" + return self._cls + + @classmethod + def from_key(cls, key: str) -> "LossType": + """ + Look up a LossType by its string key. + + :raises ValueError: if the key is not valid. + """ + for loss_type in cls: + if loss_type.key == key: + return loss_type + valid_keys = ", ".join(loss_type.key for loss_type in cls) + raise ValueError(f"Unknown loss '{key}'. Valid types: {valid_keys}") + + +def create_loss( + loss_type: str, + *, + name: str, + gradient: Optional[str], + weight: float, + reduction: str, + **extra_kwargs: Any, +) -> LossInterface: + """ + Factory to instantiate a concrete :py:class:`LossInterface` given its string key. + + :param loss_type: string key matching one of the members of :py:class:`LossType`. + :param name: target name for the loss. + :param gradient: gradient name, if present. + :param weight: weight for the loss contribution. + :param reduction: reduction mode for the torch loss. + :param extra_kwargs: additional hyperparameters specific to the loss type. + :return: instance of the selected loss. + """ + loss_type_entry = LossType.from_key(loss_type) + try: + return loss_type_entry.cls( + name=name, + gradient=gradient, + weight=weight, + reduction=reduction, + **extra_kwargs, + ) + except TypeError as e: + raise TypeError(f"Error constructing loss '{loss_type}': {e}") from e diff --git a/src/metatrain/utils/old_loss.py b/src/metatrain/utils/old_loss.py new file mode 100644 index 000000000..726e3bab4 --- /dev/null +++ b/src/metatrain/utils/old_loss.py @@ -0,0 +1,348 @@ +from typing import Dict, Optional, Tuple, Union + +import torch +from metatensor.torch import TensorMap +from omegaconf import DictConfig +from torch.nn.modules.loss import _Loss + +from metatrain.utils.external_naming import to_internal_name + + +class TensorMapLoss: + """A loss function that operates on two ``metatensor.torch.TensorMap``. + + The loss is computed as the sum of the loss on the block values and + the loss on the gradients, with weights specified at initialization. + + At the moment, this loss function assumes that all the gradients + declared at initialization are present in both TensorMaps. + + :param reduction: The reduction to apply to the loss. + See :py:class:`torch.nn.MSELoss`. + :param weight: The weight to apply to the loss on the block values. + :param gradient_weights: The weights to apply to the loss on the gradients. + :param sliding_factor: The factor to apply to the exponential moving average + of the "sliding" weights. These are weights that act on different components of + the loss (for example, energies and forces), based on their individual recent + history. If ``None``, no sliding weights are used in the computation of the + loss. + :param type: The type of loss to use. This can be either "mse" or "mae". + A Huber loss can also be requested as a dictionary with the key "huber" and + the value must be a dictionary with the key "deltas" and the value + must be a dictionary with the keys "values" and the gradient keys. + The values of the dictionary must be the deltas to use for the + Huber loss. + + :returns: The loss as a zero-dimensional :py:class:`torch.Tensor` + (with one entry). + """ + + def __init__( + self, + reduction: str = "mean", + weight: float = 1.0, + gradient_weights: Optional[Dict[str, float]] = None, + sliding_factor: Optional[float] = None, + type: Union[str, dict] = "mse", + ): + if gradient_weights is None: + gradient_weights = {} + + losses = {} + if type == "mse": + losses["values"] = torch.nn.MSELoss(reduction=reduction) + for key in gradient_weights.keys(): + losses[key] = torch.nn.MSELoss(reduction=reduction) + elif type == "mae": + losses["values"] = torch.nn.L1Loss(reduction=reduction) + for key in gradient_weights.keys(): + losses[key] = torch.nn.L1Loss(reduction=reduction) + elif isinstance(type, dict) and "huber" in type: + # Huber loss + deltas = type["huber"]["deltas"] + losses["values"] = torch.nn.HuberLoss( + reduction=reduction, delta=deltas["values"] + ) + for key in gradient_weights.keys(): + losses[key] = torch.nn.HuberLoss(reduction=reduction, delta=deltas[key]) + else: + raise ValueError(f"Unknown loss type: {type}") + + self.losses = losses + self.weight = weight + self.gradient_weights = gradient_weights + self.sliding_factor = sliding_factor + self.sliding_weights: Optional[Dict[str, TensorMap]] = None + + def __call__( + self, + predictions_tensor_map: TensorMap, + targets_tensor_map: TensorMap, + ) -> Tuple[torch.Tensor, Dict[str, Tuple[float, int]]]: + # Check that the two have the same metadata, except for the samples, + # which can be different due to batching, but must have the same size: + if predictions_tensor_map.keys != targets_tensor_map.keys: + raise ValueError( + "TensorMapLoss requires the two TensorMaps to have the same keys." + ) + for block_1, block_2 in zip( + predictions_tensor_map.blocks(), targets_tensor_map.blocks() + ): + if block_1.properties != block_2.properties: + raise ValueError( + "TensorMapLoss requires the two TensorMaps to have the same " + "properties." + ) + if block_1.components != block_2.components: + raise ValueError( + "TensorMapLoss requires the two TensorMaps to have the same " + "components." + ) + if len(block_1.samples) != len(block_2.samples): + raise ValueError( + "TensorMapLoss requires the two TensorMaps " + "to have the same number of samples." + ) + for gradient_name in block_2.gradients_list(): + if len(block_1.gradient(gradient_name).samples) != len( + block_2.gradient(gradient_name).samples + ): + raise ValueError( + "TensorMapLoss requires the two TensorMaps " + "to have the same number of gradient samples." + ) + if ( + block_1.gradient(gradient_name).properties + != block_2.gradient(gradient_name).properties + ): + raise ValueError( + "TensorMapLoss requires the two TensorMaps " + "to have the same gradient properties." + ) + if ( + block_1.gradient(gradient_name).components + != block_2.gradient(gradient_name).components + ): + raise ValueError( + "TensorMapLoss requires the two TensorMaps " + "to have the same gradient components." + ) + + # First time the function is called: compute the sliding weights only + # from the targets (if they are enabled) + if self.sliding_factor is not None and self.sliding_weights is None: + self.sliding_weights = get_sliding_weights( + self.losses, + self.sliding_factor, + targets_tensor_map, + ) + + # Compute the loss: + loss = torch.zeros( + (), + dtype=predictions_tensor_map.block(0).values.dtype, + device=predictions_tensor_map.block(0).values.device, + ) + for key in targets_tensor_map.keys: + block_1 = predictions_tensor_map.block(key) + block_2 = targets_tensor_map.block(key) + values_1 = block_1.values + values_2 = block_2.values + # sliding weights: default to 1.0 if not used/provided for this target + sliding_weight = ( + 1.0 + if self.sliding_weights is None + else self.sliding_weights.get("values", 1.0) + ) + loss += ( + self.weight * self.losses["values"](values_1, values_2) / sliding_weight + ) + for gradient_name in block_2.gradients_list(): + gradient_weight = self.gradient_weights[gradient_name] + values_1 = block_1.gradient(gradient_name).values + values_2 = block_2.gradient(gradient_name).values + # sliding weights: default to 1.0 if not used/provided for this target + sliding_weigths_value = ( + 1.0 + if self.sliding_weights is None + else self.sliding_weights.get(gradient_name, 1.0) + ) + loss += ( + gradient_weight + * self.losses[gradient_name](values_1, values_2) + / sliding_weigths_value + ) + if self.sliding_factor is not None: + self.sliding_weights = get_sliding_weights( + self.losses, + self.sliding_factor, + targets_tensor_map, + predictions_tensor_map, + self.sliding_weights, + ) + return loss + + +class TensorMapDictLoss: + """A loss function that operates on two ``Dict[str, metatensor.torch.TensorMap]``. + + At initialization, the user specifies a list of keys to use for the loss, + along with a weight for each key. + + The loss is then computed as a weighted sum. Any keys that are not present + in the dictionaries are ignored. + + :param weights: A dictionary mapping keys to weights. This might contain + gradient keys, in the form ``__gradients``. + :param sliding_factor: The factor to apply to the exponential moving average + of the "sliding" weights. These are weights that act on different components of + the loss (for example, energies and forces), based on their individual recent + history. If ``None``, no sliding weights are used in the computation of the + loss. + :param reduction: The reduction to apply to the loss. + See :py:class:`torch.nn.MSELoss`. + + :returns: The loss as a zero-dimensional :py:class:`torch.Tensor` + (with one entry). + """ + + def __init__( + self, + weights: Dict[str, float], + sliding_factor: Optional[float] = None, + reduction: str = "mean", + type: Union[str, dict] = "mse", + ): + outputs = [key for key in weights.keys() if "gradients" not in key] + self.losses = {} + for output in outputs: + value_weight = weights[output] + gradient_weights = {} + for key, weight in weights.items(): + if key.startswith(output) and key.endswith("_gradients"): + gradient_name = key.replace(f"{output}_", "").replace( + "_gradients", "" + ) + gradient_weights[gradient_name] = weight + type_output = _process_type(type, output) + if output == "energy" and sliding_factor is not None: + self.losses[output] = TensorMapLoss( + reduction=reduction, + weight=value_weight, + gradient_weights=gradient_weights, + sliding_factor=sliding_factor, + type=type_output, + ) + else: + self.losses[output] = TensorMapLoss( + reduction=reduction, + weight=value_weight, + gradient_weights=gradient_weights, + type=type_output, + ) + + def __call__( + self, + tensor_map_dict_1: Dict[str, TensorMap], + tensor_map_dict_2: Dict[str, TensorMap], + ) -> torch.Tensor: + # Assert that the two have the keys: + assert set(tensor_map_dict_1.keys()) == set(tensor_map_dict_2.keys()) + + # Initialize the loss: + first_values = next(iter(tensor_map_dict_1.values())).block(0).values + loss = torch.zeros((), dtype=first_values.dtype, device=first_values.device) + + # Compute the loss: + for target in tensor_map_dict_1.keys(): + target_loss = self.losses[target]( + tensor_map_dict_1[target], tensor_map_dict_2[target] + ) + loss += target_loss + + return loss + + +def get_sliding_weights( + losses: Dict[str, _Loss], + sliding_factor: float, + targets: TensorMap, + predictions: Optional[TensorMap] = None, + previous_sliding_weights: Optional[Dict[str, float]] = None, +) -> Dict[str, float]: + """ + Compute the sliding weights for the loss function. + + The sliding weights are computed as the absolute difference between the + predictions and the targets. + + :param predictions: The predictions. + :param targets: The targets. + + :return: The sliding weights. + """ + sliding_weights = {} + if predictions is None: + for block in targets.blocks(): + values = block.values + sliding_weights["values"] = ( + losses["values"](values, values.mean() * torch.ones_like(values)) + 1e-6 + ) + for gradient_name, gradient_block in block.gradients(): + values = gradient_block.values + sliding_weights[gradient_name] = losses[gradient_name]( + values, torch.zeros_like(values) + ) + elif predictions is not None: + if previous_sliding_weights is None: + raise RuntimeError( + "previous_sliding_weights must be provided if predictions is not None" + ) + else: + for predictions_block, target_block in zip( + predictions.blocks(), targets.blocks() + ): + target_values = target_block.values + predictions_values = predictions_block.values + sliding_weights["values"] = ( + sliding_factor * previous_sliding_weights["values"] + + (1 - sliding_factor) + * losses["values"](predictions_values, target_values).detach() + ) + for gradient_name, gradient_block in target_block.gradients(): + target_values = gradient_block.values + predictions_values = predictions_block.gradient( + gradient_name + ).values + sliding_weights[gradient_name] = ( + sliding_factor * previous_sliding_weights[gradient_name] + + (1 - sliding_factor) + * losses[gradient_name]( + predictions_values, target_values + ).detach() + ) + return sliding_weights + + +def _process_type(type: Union[str, DictConfig], output: str) -> Union[str, dict]: + if not isinstance(type, str): + assert "huber" in type + # we process the Huber loss delta dict to make it similar to the + # `weights` dict + type_output = {"huber": {"deltas": {}}} # type: ignore + for key, delta in type["huber"]["deltas"].items(): + key_internal = to_internal_name(key) + if key_internal == output: + type_output["huber"]["deltas"]["values"] = delta + elif key_internal.startswith(output) and key_internal.endswith( + "_gradients" + ): + gradient_name = key_internal.replace(f"{output}_", "").replace( + "_gradients", "" + ) + type_output["huber"]["deltas"][gradient_name] = delta + else: + pass + else: + type_output = type # type: ignore + return type_output diff --git a/src/metatrain/utils/omegaconf.py b/src/metatrain/utils/omegaconf.py index 4c6d72db0..1f5b282bc 100644 --- a/src/metatrain/utils/omegaconf.py +++ b/src/metatrain/utils/omegaconf.py @@ -56,15 +56,19 @@ def default_precision(_root_: BaseContainer) -> int: ) -def default_random_seed() -> int: - """Return session seed in the range [0, 2**32).""" - return RANDOM_SEED +def default_huber_loss_delta() -> float: + """Return the default delta for the huber loss.""" + return 1.0 # Register custom resolvers OmegaConf.register_new_resolver("default_device", default_device) OmegaConf.register_new_resolver("default_precision", default_precision) -OmegaConf.register_new_resolver("default_random_seed", default_random_seed) +OmegaConf.register_new_resolver("default_random_seed", lambda: RANDOM_SEED) +OmegaConf.register_new_resolver("default_loss_type", lambda: "mse") +OmegaConf.register_new_resolver("default_loss_reduction", lambda: "mean") +OmegaConf.register_new_resolver("default_loss_sliding_factor", lambda: None) +OmegaConf.register_new_resolver("default_loss_weight", lambda: 1.0) def _resolve_single_str(config: str) -> DictConfig: @@ -124,6 +128,16 @@ def _resolve_single_str(config: str) -> DictConfig: } ) +CONF_LOSS = OmegaConf.create( + { + "type": "${default_loss_type:}", + "weight": "${default_loss_weight:}", + "reduction": "${default_loss_reduction:}", + "sliding_factor": "${default_loss_sliding_factor:}", + "gradients": {}, + } +) + KNOWN_GRADIENTS = list(CONF_GRADIENTS.keys()) # Merge configs to get default configs for energies and other targets @@ -363,6 +377,153 @@ def expand_dataset_config(conf: Union[str, DictConfig, ListConfig]) -> ListConfi return conf +def expand_loss_config(conf: DictConfig) -> DictConfig: + """Expand the loss configuration to a list of configurations. + + :param conf: The loss configuration to expand. + :returns: A list of expanded loss configurations. + """ + + training_confs = conf["training_set"] + + if not isinstance(training_confs, ListConfig): + training_confs = OmegaConf.create([training_confs]) + + # initialize + loss_dict: dict = {} + conf_loss = CONF_LOSS.copy() + OmegaConf.resolve(conf_loss) + train_on_forces = False + train_on_stress_or_virial = False + + # fill loss_dict with default values + for tc in training_confs: + for target_name, opts in tc["targets"].items(): + if target_name == "energy": + f, s = _process_energy(loss_dict, opts, conf_loss) + train_on_forces |= f + train_on_stress_or_virial |= s + else: + loss_dict[target_name] = conf_loss.copy() + + train_hypers = conf["architecture"]["training"] + if "loss" not in train_hypers: + # Use default loss configuration + train_hypers["loss"] = OmegaConf.create(loss_dict) + else: + # Expand str -> DictConfig + if isinstance(train_hypers["loss"], str): + # TODO: add test + # the string must be the loss type, which is going to be used + # for all targets + for t in loss_dict.keys(): + loss_dict[t]["type"] = train_hypers["loss"] + if train_hypers["loss"] == "huber": + loss_dict[t]["delta"] = default_huber_loss_delta() + train_hypers["loss"] = OmegaConf.create(loss_dict) + + else: + # Expand per-target str loss configurations + for t in loss_dict.keys(): + if t in train_hypers["loss"]: + if isinstance(train_hypers["loss"][t], str): + train_hypers["loss"][t] = {"type": train_hypers["loss"][t]} + if train_hypers["loss"][t]["type"] == "huber": + train_hypers["loss"][t]["delta"] = ( + default_huber_loss_delta() + ) + + # Adapt the loss configuration to the internal structure + if train_on_forces: + _migrate_gradient_key(train_hypers["loss"], "forces", "positions") + else: + if "forces" in train_hypers["loss"]: + del train_hypers["loss"]["forces"] + + if train_on_stress_or_virial: + for legacy in ["stress", "virial"]: + _migrate_gradient_key(train_hypers["loss"], legacy, "strain") + else: + if "stress" in train_hypers["loss"]: + del train_hypers["loss"]["stress"] + if "virial" in train_hypers["loss"]: + del train_hypers["loss"]["virial"] + + # Add default delta for huber loss if not present + for t in train_hypers["loss"].keys(): + if "type" in train_hypers["loss"][t]: + if train_hypers["loss"][t]["type"] == "huber": + if "delta" not in train_hypers["loss"][t]: + train_hypers["loss"][t]["delta"] = ( + default_huber_loss_delta() + ) + if "gradients" in train_hypers["loss"][t]: + for grad_key in train_hypers["loss"][t]["gradients"].keys(): + if "type" in train_hypers["loss"][t]["gradients"][grad_key]: + if ( + train_hypers["loss"][t]["gradients"][grad_key]["type"] + == "huber" + ): + if ( + "delta" + not in train_hypers["loss"][t]["gradients"][ + grad_key + ] + ): + train_hypers["loss"][t]["gradients"][grad_key][ + "delta" + ] = default_huber_loss_delta() + + train_hypers["loss"] = OmegaConf.merge(loss_dict, train_hypers["loss"]) + + conf["architecture"]["training"] = train_hypers + return conf + + +def _migrate_gradient_key(loss_dict: dict, old_key: str, grad_key: str): + """ + If `old_key` exists in `loss_dict`, move it under + loss_dict['energy']['gradients'][grad_key], creating the necessary nested dicts + along the way. + """ + if old_key in loss_dict: + if "energy" not in loss_dict: + loss_dict["energy"] = {} + if "gradients" not in loss_dict["energy"]: + loss_dict["energy"]["gradients"] = {} + loss_dict["energy"]["gradients"][grad_key] = loss_dict[old_key] + del loss_dict[old_key] + + +def _process_energy( + loss_dict: dict, + opts: dict, + template: dict, +) -> tuple[bool, bool]: + """ + Ensure `loss_dict["energy"]` exists, reset its gradients, and add 'positions' / + 'strain' entries if requested by opts. + Returns (added_forces, added_strain) bools. + """ + if "energy" not in loss_dict: + loss_dict["energy"] = template.copy() + # start with an empty gradients dict each time + loss_dict["energy"]["gradients"] = {} + + added_forces = False + added_strain = False + + if opts.get("forces", False): + loss_dict["energy"]["gradients"]["positions"] = template.copy() + added_forces = True + + if opts.get("stress", False) or opts.get("virial", False): + loss_dict["energy"]["gradients"]["strain"] = template.copy() + added_strain = True + + return added_forces, added_strain + + def check_units( actual_options: Union[DictConfig, ListConfig], desired_options: Union[DictConfig, ListConfig], diff --git a/src/metatrain/utils/testing/checkpoints.py b/src/metatrain/utils/testing/checkpoints.py index 6e866bdb4..0720ea8d3 100644 --- a/src/metatrain/utils/testing/checkpoints.py +++ b/src/metatrain/utils/testing/checkpoints.py @@ -81,6 +81,7 @@ def test_loading_old_checkpoints(model_trainer, context): if context != "export": if checkpoint["trainer_ckpt_version"] != trainer.__checkpoint_version__: + print(context) checkpoint = trainer.__class__.upgrade_checkpoint(checkpoint) trainer.load_checkpoint(checkpoint, DEFAULT_HYPERS, context) diff --git a/src/metatrain/utils/transfer.py b/src/metatrain/utils/transfer.py index 91b96dce6..8b2e6d008 100644 --- a/src/metatrain/utils/transfer.py +++ b/src/metatrain/utils/transfer.py @@ -28,9 +28,15 @@ def batch_to( key: value.to(dtype=dtype, device=device) for key, value in targets.items() } if extra_data is not None: + new_dtypes: List[Optional[int]] = [] + for key in extra_data.keys(): + if key.endswith("_mask"): # masks should always be boolean + new_dtypes.append(torch.bool) + else: + new_dtypes.append(dtype) extra_data = { - key: value.to(dtype=dtype, device=device) - for key, value in extra_data.items() + key: value.to(dtype=_dtype, device=device) + for (key, value), _dtype in zip(extra_data.items(), new_dtypes) } return systems, targets, extra_data diff --git a/tests/utils/test_io.py b/tests/utils/test_io.py index 102f0d4ae..de190aff9 100644 --- a/tests/utils/test_io.py +++ b/tests/utils/test_io.py @@ -90,6 +90,7 @@ def test_load_trainer_checkpoint_wrong_version(monkeypatch, tmp_path): r"checkpoint is using version 5000000, while the current version is \d+; " "and trying to upgrade the checkpoint failed." ) + with pytest.raises(RuntimeError, match=message): checkpoint = torch.load(file, weights_only=False, map_location="cpu") trainer_from_checkpoint(checkpoint, context="restart", hypers={}) diff --git a/tests/utils/test_llpr.py b/tests/utils/test_llpr.py index 88b4e8eb4..416302c3a 100644 --- a/tests/utils/test_llpr.py +++ b/tests/utils/test_llpr.py @@ -25,6 +25,7 @@ get_requested_neighbor_lists, get_system_with_neighbor_lists, ) +from metatrain.utils.omegaconf import CONF_LOSS from . import RESOURCES_PATH @@ -297,6 +298,9 @@ def test_llpr_finetuning(tmpdir): }, } + hypers["training"]["loss"] = {"energy": CONF_LOSS} + hypers["training"]["loss"]["energy"]["gradients"] = {"positions": CONF_LOSS} + trainer = Trainer(hypers["training"]) trainer.train( model=model, diff --git a/tests/utils/test_loss.py b/tests/utils/test_loss.py index 3bb572207..5561e9524 100644 --- a/tests/utils/test_loss.py +++ b/tests/utils/test_loss.py @@ -1,10 +1,26 @@ +# tests/test_losses.py + from pathlib import Path +import metatensor.torch as mts import pytest import torch from metatensor.torch import Labels, TensorBlock, TensorMap -from metatrain.utils.loss import TensorMapDictLoss, TensorMapLoss +from metatrain.utils.data import TargetInfo +from metatrain.utils.loss import ( + EMAScheduler, + LossAggregator, + LossType, + TensorMapHuberLoss, + TensorMapMAELoss, + TensorMapMaskedHuberLoss, + TensorMapMaskedMAELoss, + TensorMapMaskedMSELoss, + TensorMapMSELoss, + create_loss, +) +from metatrain.utils.old_loss import TensorMapLoss RESOURCES_PATH = Path(__file__).parents[1] / "resources" @@ -14,12 +30,12 @@ def tensor_map_with_grad_1(): block = TensorBlock( values=torch.tensor([[1.0], [2.0], [3.0]]), - samples=Labels.range("samples", 3), + samples=Labels.range("sample", 3), components=[], properties=Labels("energy", torch.tensor([[0]])), ) block.add_gradient( - "gradient", + "positions", TensorBlock( values=torch.tensor([[1.0], [2.0], [3.0]]), samples=Labels.range("sample", 3), @@ -35,12 +51,12 @@ def tensor_map_with_grad_1(): def tensor_map_with_grad_2(): block = TensorBlock( values=torch.tensor([[1.0], [1.0], [3.0]]), - samples=Labels.range("samples", 3), + samples=Labels.range("sample", 3), components=[], properties=Labels("energy", torch.tensor([[0]])), ) block.add_gradient( - "gradient", + "positions", TensorBlock( values=torch.tensor([[1.0], [0.0], [3.0]]), samples=Labels.range("sample", 3), @@ -56,12 +72,12 @@ def tensor_map_with_grad_2(): def tensor_map_with_grad_3(): block = TensorBlock( values=torch.tensor([[0.0], [1.0], [3.0]]), - samples=Labels.range("samples", 3), + samples=Labels.range("sample", 3), components=[], properties=Labels("energy", torch.tensor([[0]])), ) block.add_gradient( - "gradient", + "positions", TensorBlock( values=torch.tensor([[1.0], [0.0], [3.0]]), samples=Labels.range("sample", 3), @@ -77,12 +93,12 @@ def tensor_map_with_grad_3(): def tensor_map_with_grad_4(): block = TensorBlock( values=torch.tensor([[0.0], [1.0], [3.0]]), - samples=Labels.range("samples", 3), + samples=Labels.range("sample", 3), components=[], properties=Labels("energy", torch.tensor([[0]])), ) block.add_gradient( - "gradient", + "positions", TensorBlock( values=torch.tensor([[1.0], [0.0], [2.0]]), samples=Labels.range("sample", 3), @@ -94,113 +110,379 @@ def tensor_map_with_grad_4(): return tensor_map -@pytest.mark.parametrize("type", ["mse", {"huber": {"deltas": {"values": 3.0}}}]) -def test_tmap_loss_no_gradients(type): - """Test that the loss is computed correctly when there are no gradients.""" - loss = TensorMapLoss(type=type, reduction="sum") - - tensor_map_1 = TensorMap( - keys=Labels.single(), - blocks=[ - TensorBlock( - values=torch.tensor([[1.0], [2.0], [3.0]]), - samples=Labels.range("samples", 3), - components=[], - properties=Labels("energy", torch.tensor([[0]])), - ) - ], +@pytest.fixture +def tensor_map_with_grad_1_with_strain(): + block = TensorBlock( + values=torch.tensor([[0.0], [1.0], [3.0]]), + samples=Labels.range("sample", 3), + components=[], + properties=Labels("energy", torch.tensor([[0]])), + ) + block.add_gradient( + "positions", + TensorBlock( + values=torch.tensor([[1.0], [0.0], [3.0]]), + samples=Labels.range("sample", 3), + components=[], + properties=Labels("energy", torch.tensor([[0]])), + ), ) - tensor_map_2 = TensorMap( - keys=Labels.single(), - blocks=[ - TensorBlock( - values=torch.tensor([[0.0], [2.0], [3.0]]), - samples=Labels.range("samples", 3), - components=[], - properties=Labels("energy", torch.tensor([[0]])), - ) - ], + block.add_gradient( + "strain", + TensorBlock( + values=torch.tensor([[1.0], [0.0], [3.0]]), + samples=Labels.range("sample", 3), + components=[], + properties=Labels("energy", torch.tensor([[0]])), + ), ) + tensor_map = TensorMap(keys=Labels.single(), blocks=[block]) + return tensor_map - loss_value = loss(tensor_map_1, tensor_map_1) - torch.testing.assert_close(loss_value, torch.tensor(0.0)) - # Expected result: 1.0 - loss_value = loss(tensor_map_1, tensor_map_2) - # Huber loss is scaled by 0.5 due to torch implementation - torch.testing.assert_close( - loss_value, (1.0 if type == "mse" else 0.5) * torch.tensor(1.0) +@pytest.fixture +def tensor_map_with_grad_3_with_strain(): + block = TensorBlock( + values=torch.tensor([[0.0], [1.0], [3.0]]), + samples=Labels.range("sample", 3), + components=[], + properties=Labels("energy", torch.tensor([[0]])), ) + block.add_gradient( + "positions", + TensorBlock( + values=torch.tensor([[1.0], [0.0], [3.0]]), + samples=Labels.range("sample", 3), + components=[], + properties=Labels("energy", torch.tensor([[0]])), + ), + ) + block.add_gradient( + "strain", + TensorBlock( + values=torch.tensor([[1.0], [0.0], [3.0]]), + samples=Labels.range("sample", 3), + components=[], + properties=Labels("energy", torch.tensor([[0]])), + ), + ) + tensor_map = TensorMap(keys=Labels.single(), blocks=[block]) + return tensor_map +# Pointwise losses must return zero when predictions == targets @pytest.mark.parametrize( - "type", ["mse", {"huber": {"deltas": {"values": 3.0, "gradient": 3.0}}}] + "LossCls", + [ + TensorMapMSELoss, + TensorMapMAELoss, + TensorMapHuberLoss, + ], ) -def test_tmap_loss_with_gradients(tensor_map_with_grad_1, tensor_map_with_grad_2, type): - """Test that the loss is computed correctly when there are gradients.""" - loss = TensorMapLoss(type=type, gradient_weights={"gradient": 0.5}, reduction="sum") - - loss_value = loss(tensor_map_with_grad_1, tensor_map_with_grad_1) - torch.testing.assert_close(loss_value, torch.tensor(0.0)) - - # Expected result: 1.0 + 0.5 * 4.0 - loss_value = loss(tensor_map_with_grad_1, tensor_map_with_grad_2) - torch.testing.assert_close( - loss_value, - # Huber loss is scaled by 0.5 due to torch implementation - (1.0 if type == "mse" else 0.5) * torch.tensor(1.0 + 0.5 * 4.0), +def test_pointwise_zero_loss(tensor_map_with_grad_1, LossCls): + tm = tensor_map_with_grad_1 + key = tm.keys.names[0] + if LossCls == TensorMapHuberLoss: + loss = LossCls(name=key, gradient=None, weight=1.0, reduction="mean", delta=1.0) + else: + loss = LossCls(name=key, gradient=None, weight=1.0, reduction="mean") + pred = {key: tm} + targ = {key: tm} + assert loss(pred, targ).item() == pytest.approx(0.0) + + +# Check consistency between old and new loss implementations +def test_check_old_and_new_loss_consistency(tensor_map_with_grad_2): + tensor_1 = mts.remove_gradients(tensor_map_with_grad_2) + tensor_2 = mts.random_uniform_like(tensor_1) + loss_fn_1 = TensorMapLoss() + loss_fn_2 = TensorMapMSELoss( + name="", + gradient=None, + weight=1.0, + reduction="mean", ) + assert loss_fn_1(tensor_1, tensor_2) == loss_fn_2({"": tensor_1}, {"": tensor_2}) -def test_tmap_dict_loss( - tensor_map_with_grad_1, - tensor_map_with_grad_2, - tensor_map_with_grad_3, - tensor_map_with_grad_4, +# Masked losses must error if no mask is supplied +@pytest.mark.parametrize( + "MaskedCls", + [ + TensorMapMaskedMSELoss, + TensorMapMaskedMAELoss, + TensorMapMaskedHuberLoss, + ], +) +def test_masked_loss_error_on_missing_mask(tensor_map_with_grad_1, MaskedCls): + tm = tensor_map_with_grad_1 + key = tm.keys.names[0] + if MaskedCls == TensorMapMaskedHuberLoss: + loss = MaskedCls( + name=key, gradient=None, weight=1.0, reduction="mean", delta=1.0 + ) + else: + loss = MaskedCls(name=key, gradient=None, weight=1.0, reduction="mean") + with pytest.raises(ValueError): + loss({key: tm}, {key: tm}) + + +# Functional test for masked MSE: only unmasked element contributes +def test_masked_mse_behavior(tensor_map_with_grad_1, tensor_map_with_grad_2): + tm1 = tensor_map_with_grad_1 + tm2 = tensor_map_with_grad_2 + key = tm1.keys.names[0] + + # Construct a mask TensorMap: only index 1 is True + mask_vals = torch.tensor([[False], [True], [False]], dtype=torch.bool) + mask_block = TensorBlock( + values=mask_vals, + samples=tm1.block(0).samples, + components=tm1.block(0).components, + properties=tm1.block(0).properties, + ) + mask_map = TensorMap(keys=tm1.keys, blocks=[mask_block]) + extra_data = {f"{key}_mask": mask_map} + + loss = TensorMapMaskedMSELoss(name=key, gradient=None, weight=1.0, reduction="mean") + # Only element 1 contributes: (1-2)^2 = 1 + result = loss({key: tm2}, {key: tm1}, extra_data) + assert result.item() == pytest.approx(1.0) + + +# EMA scheduler: test both no-sliding and sliding-factor cases +@pytest.mark.parametrize( + "sf, expected_init, expected_update", + [ + (0.0, 1.0, 1.0), + (0.5, 2 / 3, (2 / 3) * 0.5), + ], +) +def test_ema_scheduler( + tensor_map_with_grad_1, tensor_map_with_grad_2, sf, expected_init, expected_update ): - """Test that the dict loss is computed correctly.""" - - loss_rmse = TensorMapDictLoss( - weights={ - "output_1": 0.6, - "output_2": 1.0, - "output_1_gradient_gradients": 0.5, - "output_2_gradient_gradients": 0.5, - }, - reduction="sum", + tm1 = tensor_map_with_grad_1 + tm2 = tensor_map_with_grad_2 + key = tm1.keys.names[0] + loss = TensorMapMSELoss(name=key, gradient=None, weight=1.0, reduction="mean") + sched = EMAScheduler(sliding_factor=sf) + + init_w = sched.initialize(loss, {key: tm1}) + assert init_w == pytest.approx(expected_init) + + new_w = sched.update(loss, {key: tm2}, {key: tm2}) + assert new_w == pytest.approx(expected_update) + + +# Factory and enum resolution +def test_loss_type_and_factory(): + mapping = { + "mse": TensorMapMSELoss, + "mae": TensorMapMAELoss, + "huber": TensorMapHuberLoss, + "masked_mse": TensorMapMaskedMSELoss, + "masked_mae": TensorMapMaskedMAELoss, + "masked_huber": TensorMapMaskedHuberLoss, + } + for key, cls in mapping.items(): + # LossType.from_key should return enum with .key + lt = LossType.from_key(key) + assert lt.key == key + # Factory should produce correct class + extra_kwargs = {} + if key == "huber" or key == "masked_huber": + extra_kwargs = {"delta": 1.0} + loss = create_loss( + key, + name="dummy", + gradient=None, + weight=1.0, + reduction="mean", + **extra_kwargs, + ) + assert isinstance(loss, cls) + + # Invalid keys raise ValueError + with pytest.raises(ValueError): + LossType.from_key("invalid_key") + with pytest.raises(ValueError): + create_loss( + "invalid_key", + name="dummy", + gradient=None, + weight=1.0, + reduction="mean", + ) + + +# Point-wise gradient-only +@pytest.mark.parametrize( + "LossCls, expected", + [ + (TensorMapMSELoss, 1 / 3), # MSEGradient: one error squared -> 1/3 + (TensorMapMAELoss, 1 / 3), # MAEGradient: one abs error -> 1/3 + (TensorMapHuberLoss, 1 / 6), # HuberGradient: 0.5*1^2 /3 = 1/6 + ], +) +def test_pointwise_gradient_loss( + tensor_map_with_grad_3, tensor_map_with_grad_4, LossCls, expected +): + tm3 = tensor_map_with_grad_3 + tm4 = tensor_map_with_grad_4 + key = tm3.keys.names[0] + # instantiate with gradient extraction + if LossCls == TensorMapHuberLoss: + loss = LossCls( + name=key, gradient="positions", weight=1.0, reduction="mean", delta=1.0 + ) + else: + loss = LossCls(name=key, gradient="positions", weight=1.0, reduction="mean") + val = loss({key: tm3}, {key: tm4}).item() + assert val == pytest.approx(expected) + + +def test_create_loss_invalid_kwargs(): + # Passing `foo` into an MSELoss constructor will cause + # a TypeError inside create_loss, which should be caught + # and re-raised with our custom message. + with pytest.raises(TypeError) as exc: + create_loss( + "mse", name="dummy", gradient=None, weight=1.0, reduction="mean", foo=123 + ) + msg = str(exc.value) + assert "Error constructing loss 'mse'" in msg + assert ( + "foo" in msg + ) # original constructor error should mention the unexpected 'foo' + + +def test_masked_pointwise_gradient_branch( + tensor_map_with_grad_3, tensor_map_with_grad_4 +): + tm3 = tensor_map_with_grad_3 + tm4 = tensor_map_with_grad_4 + key = tm3.keys.names[0] + + # Build a mask that selects all entries + mask_vals = torch.tensor([[True], [True], [True]], dtype=torch.bool) + mask_block = TensorBlock( + values=mask_vals, + samples=tm3.block(0).samples, + components=tm3.block(0).components, + properties=tm3.block(0).properties, + ) + + # Add a gradient-block to the mask, so grab(mask_block, "gradient") works + grad_block_for_mask = TensorBlock( + values=mask_vals, + samples=tm3.block(0).samples, + components=tm3.block(0).components, + properties=tm3.block(0).properties, ) - loss_huber = TensorMapDictLoss( - weights={ - "output_1": 0.6, - "output_2": 1.0, - "output_1_gradient_gradients": 0.5, - "output_2_gradient_gradients": 0.5, + mask_block.add_gradient("positions", grad_block_for_mask) + + mask_map = TensorMap(keys=tm3.keys, blocks=[mask_block]) + extra = {f"{key}_mask": mask_map} + + # Create the masked-pointwise loss on the 'positions' channel + loss = TensorMapMaskedMSELoss( + name=key, gradient="positions", weight=1.0, reduction="mean" + ) + + # The gradient values in tm3: [1, 0, 3]; in tm4: [1, 0, 2] + # Only one difference of 1 -> MSE mean = 1/3 + result = loss({key: tm3}, {key: tm4}, extra).item() + assert result == pytest.approx(1 / 3) + + +def test_ema_initialize_gradient_branch(tensor_map_with_grad_1): + tm = tensor_map_with_grad_1 + key = tm.keys.names[0] + + # gradient block values [1,2,3], zero baseline -> MSE = (1+4+9)/3 + loss = TensorMapMSELoss( + name=key, gradient="positions", weight=1.0, reduction="mean" + ) + sched = EMAScheduler(sliding_factor=0.5) + init_w = sched.initialize(loss, {key: tm}) + + assert init_w == pytest.approx((1 + 4 + 9) / 3) + + +def test_tmap_loss_subset(tensor_map_with_grad_1, tensor_map_with_grad_3): + """Test that the loss is computed correctly when only a subset + of the possible targets is present both in outputs and targets.""" + + block = TensorBlock( + values=torch.empty(0, 1), + samples=Labels( + names=["system"], + values=torch.empty((0, 1), dtype=torch.int32), + ), + components=[], + properties=Labels.range("property", 1), + ) + block.add_gradient( + "positions", + TensorBlock( + values=torch.empty(0, 1), + samples=Labels( + names=["sample"], + values=torch.empty((0, 1), dtype=torch.int32), + ), + components=[], + properties=Labels.range("property", 1), + ), + ) + layout = TensorMap(keys=Labels.single(), blocks=[block]) + + target_info = TargetInfo(quantity="energy", unit="eV", layout=layout) + loss_hypers = { + "output_1": { + "type": "mse", + "weight": 1.0, + "reduction": "sum", + "sliding_factor": None, + "gradients": { + "positions": { + "type": "mse", + "weight": 0.5, + "reduction": "sum", + "sliding_factor": None, + }, + }, }, - type={ - "huber": { - "deltas": { - "output_1": 0.1, - "output_2": 0.1, - "output_1_gradient_gradients": 0.1, - "output_2_gradient_gradients": 0.1, - } - } + "output_2": { + "type": "mse", + "weight": 1.0, + "reduction": "sum", + "sliding_factor": None, + "gradients": { + "positions": { + "type": "mse", + "weight": 0.5, + "reduction": "sum", + "sliding_factor": None, + }, + }, }, - reduction="sum", + } + + loss = LossAggregator( + targets={"output_1": target_info, "output_2": target_info}, + config=loss_hypers, ) output_dict = { "output_1": tensor_map_with_grad_1, - "output_2": tensor_map_with_grad_2, } target_dict = { "output_1": tensor_map_with_grad_3, - "output_2": tensor_map_with_grad_4, } expected_result = ( - 0.6 + 1.0 * ( tensor_map_with_grad_1.block().values - tensor_map_with_grad_3.block().values @@ -209,71 +491,111 @@ def test_tmap_dict_loss( .sum() + 0.5 * ( - tensor_map_with_grad_1.block().gradient("gradient").values - - tensor_map_with_grad_3.block().gradient("gradient").values - ) - .pow(2) - .sum() - + 1.0 - * ( - tensor_map_with_grad_2.block().values - - tensor_map_with_grad_4.block().values - ) - .pow(2) - .sum() - + 0.5 - * ( - tensor_map_with_grad_2.block().gradient("gradient").values - - tensor_map_with_grad_4.block().gradient("gradient").values + tensor_map_with_grad_1.block().gradient("positions").values + - tensor_map_with_grad_3.block().gradient("positions").values ) .pow(2) .sum() ) - loss_value = loss_rmse(output_dict, target_dict) + loss_value = loss(output_dict, target_dict) torch.testing.assert_close(loss_value, expected_result) - # Huber loss should be lower than RMSE - # (scaled by 0.5 due to torch implementation of Huber) - assert loss_huber(output_dict, target_dict) < 0.5 * loss_rmse( - output_dict, target_dict - ) - -def test_tmap_dict_loss_subset(tensor_map_with_grad_1, tensor_map_with_grad_3): - """Test that the dict loss is computed correctly when only a subset - of the possible targets is present both in outputs and targets.""" +def test_tmap_loss_multiple_datasets_same_target_different_gradients( + tensor_map_with_grad_1, + tensor_map_with_grad_1_with_strain, + tensor_map_with_grad_3_with_strain, +): + """Test that the loss is computed correctly when two datasets have the same target, + but different gradients.""" - loss = TensorMapDictLoss( - weights={ - "output_1": 1.0, - "output_2": 1.0, - "output_1_gradient_gradients": 0.5, - "output_2_gradient_gradients": 0.5, + block = TensorBlock( + values=torch.empty(0, 1), + samples=Labels( + names=["system"], + values=torch.empty((0, 1), dtype=torch.int32), + ), + components=[], + properties=Labels.range("property", 1), + ) + block.add_gradient( + "positions", + TensorBlock( + values=torch.empty(0, 1), + samples=Labels( + names=["sample"], + values=torch.empty((0, 1), dtype=torch.int32), + ), + components=[], + properties=Labels.range("property", 1), + ), + ) + block.add_gradient( + "strain", + TensorBlock( + values=torch.empty(0, 1), + samples=Labels( + names=["sample"], + values=torch.empty((0, 1), dtype=torch.int32), + ), + components=[], + properties=Labels.range("property", 1), + ), + ) + layout = TensorMap(keys=Labels.single(), blocks=[block]) + + target_info = TargetInfo(quantity="energy", unit="eV", layout=layout) + loss_hypers = { + "output": { + "type": "mse", + "weight": 1.0, + "reduction": "sum", + "sliding_factor": None, + "gradients": { + "positions": { + "type": "mse", + "weight": 0.5, + "reduction": "sum", + "sliding_factor": None, + }, + "strain": { + "type": "mse", + "weight": 0.3, + "reduction": "sum", + "sliding_factor": None, + }, + }, }, - reduction="sum", + } + + loss = LossAggregator( + targets={"output": target_info}, + config=loss_hypers, ) + # Test a case where the target has only one gradient output_dict = { - "output_1": tensor_map_with_grad_1, + "output": tensor_map_with_grad_3_with_strain, } + # The target has no `other_gradient` target_dict = { - "output_1": tensor_map_with_grad_3, + "output": tensor_map_with_grad_1, } expected_result = ( 1.0 * ( tensor_map_with_grad_1.block().values - - tensor_map_with_grad_3.block().values + - tensor_map_with_grad_3_with_strain.block().values ) .pow(2) .sum() + 0.5 * ( - tensor_map_with_grad_1.block().gradient("gradient").values - - tensor_map_with_grad_3.block().gradient("gradient").values + tensor_map_with_grad_1.block().gradient("positions").values + - tensor_map_with_grad_3_with_strain.block().gradient("positions").values ) .pow(2) .sum() @@ -282,108 +604,39 @@ def test_tmap_dict_loss_subset(tensor_map_with_grad_1, tensor_map_with_grad_3): loss_value = loss(output_dict, target_dict) torch.testing.assert_close(loss_value, expected_result) + # Test a case where the target has both gradients + output_dict = { + "output": tensor_map_with_grad_3_with_strain, + } -def test_tmap_loss_mae(): - """Test that the MAE loss is computed correctly.""" - loss = TensorMapLoss(type="mae", reduction="mean") - - tensor_map_1 = TensorMap( - keys=Labels.single(), - blocks=[ - TensorBlock( - values=torch.tensor([[2.0], [2.0], [3.0]]), - samples=Labels.range("samples", 3), - components=[], - properties=Labels("energy", torch.tensor([[0]])), - ) - ], - ) - tensor_map_2 = TensorMap( - keys=Labels.single(), - blocks=[ - TensorBlock( - values=torch.tensor([[0.0], [3.0], [3.0]]), - samples=Labels.range("samples", 3), - components=[], - properties=Labels("energy", torch.tensor([[0]])), - ) - ], - ) - - loss_value = loss(tensor_map_1, tensor_map_1) - torch.testing.assert_close(loss_value, torch.tensor(0.0)) - - # Expected result: 1.0 - loss_value = loss(tensor_map_1, tensor_map_2) - torch.testing.assert_close(loss_value, torch.tensor(1.0)) - - -def test_tmap_loss_huber(): - """Test that the Huber loss is computed correctly.""" - loss_mse = TensorMapLoss(type="mse", reduction="mean") - loss_huber = TensorMapLoss( - type={"huber": {"deltas": {"values": 3.0}}}, reduction="mean" - ) - - tensor_map_1 = TensorMap( - keys=Labels.single(), - blocks=[ - TensorBlock( - values=torch.tensor([[2.0], [2.0], [3.0]]), - samples=Labels.range("samples", 3), - components=[], - properties=Labels("energy", torch.tensor([[0]])), - ) - ], - ) - tensor_map_2 = TensorMap( - keys=Labels.single(), - blocks=[ - TensorBlock( - values=torch.tensor([[0.0], [3.0], [3.0]]), - samples=Labels.range("samples", 3), - components=[], - properties=Labels("energy", torch.tensor([[0]])), - ) - ], - ) - - loss_value = loss_huber(tensor_map_1, tensor_map_1) - torch.testing.assert_close(loss_value, torch.tensor(0.0)) - - # No outliers, should be equal to MSE (scaled by 0.5 due to torch implementation) - loss_value_huber = loss_huber(tensor_map_1, tensor_map_2) - loss_value_mse = loss_mse(tensor_map_1, tensor_map_2) - torch.testing.assert_close(loss_value_huber, 0.5 * loss_value_mse) - - tensor_map_with_outlier = TensorMap( - keys=Labels.single(), - blocks=[ - TensorBlock( - values=torch.tensor([[0.0], [100.0], [3.0]]), - samples=Labels.range("samples", 3), - components=[], - properties=Labels("energy", torch.tensor([[0]])), - ) - ], - ) - - loss_value_huber = loss_huber(tensor_map_1, tensor_map_with_outlier) - loss_value_mse = loss_mse(tensor_map_1, tensor_map_with_outlier) - # Huber loss is lower due to the outlier - assert loss_value_huber < 0.5 * loss_value_mse - + # The target has no `other_gradient` + target_dict = { + "output": tensor_map_with_grad_1_with_strain, + } -def test_tmap_loss_with_sliding_weights(tensor_map_with_grad_1, tensor_map_with_grad_2): - """Test that the loss behaves as expected with sliding weights.""" - loss = TensorMapLoss( - type="mse", gradient_weights={"gradient": 1.0}, sliding_factor=0.7 + expected_result = ( + 1.0 + * ( + tensor_map_with_grad_1_with_strain.block().values + - tensor_map_with_grad_3_with_strain.block().values + ) + .pow(2) + .sum() + + 0.5 + * ( + tensor_map_with_grad_1_with_strain.block().gradient("positions").values + - tensor_map_with_grad_3_with_strain.block().gradient("positions").values + ) + .pow(2) + .sum() + + 0.3 + * ( + tensor_map_with_grad_1_with_strain.block().gradient("strain").values + - tensor_map_with_grad_3_with_strain.block().gradient("strain").values + ) + .pow(2) + .sum() ) - for _ in range(5): - loss(tensor_map_with_grad_1, tensor_map_with_grad_2) - - # in the two TensorMaps above, the loss on the gradients is larger than the loss on - # the values, therefore we should expect a larger sliding weight for the gradients - - assert loss.sliding_weights["gradient"] > loss.sliding_weights["values"] + loss_value = loss(output_dict, target_dict) + torch.testing.assert_close(loss_value, expected_result) diff --git a/tests/utils/test_old_loss.py b/tests/utils/test_old_loss.py new file mode 100644 index 000000000..46df7958a --- /dev/null +++ b/tests/utils/test_old_loss.py @@ -0,0 +1,389 @@ +from pathlib import Path + +import pytest +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap + +from metatrain.utils.old_loss import TensorMapDictLoss, TensorMapLoss + + +RESOURCES_PATH = Path(__file__).parents[1] / "resources" + + +@pytest.fixture +def tensor_map_with_grad_1(): + block = TensorBlock( + values=torch.tensor([[1.0], [2.0], [3.0]]), + samples=Labels.range("samples", 3), + components=[], + properties=Labels("energy", torch.tensor([[0]])), + ) + block.add_gradient( + "gradient", + TensorBlock( + values=torch.tensor([[1.0], [2.0], [3.0]]), + samples=Labels.range("sample", 3), + components=[], + properties=Labels("energy", torch.tensor([[0]])), + ), + ) + tensor_map = TensorMap(keys=Labels.single(), blocks=[block]) + return tensor_map + + +@pytest.fixture +def tensor_map_with_grad_2(): + block = TensorBlock( + values=torch.tensor([[1.0], [1.0], [3.0]]), + samples=Labels.range("samples", 3), + components=[], + properties=Labels("energy", torch.tensor([[0]])), + ) + block.add_gradient( + "gradient", + TensorBlock( + values=torch.tensor([[1.0], [0.0], [3.0]]), + samples=Labels.range("sample", 3), + components=[], + properties=Labels("energy", torch.tensor([[0]])), + ), + ) + tensor_map = TensorMap(keys=Labels.single(), blocks=[block]) + return tensor_map + + +@pytest.fixture +def tensor_map_with_grad_3(): + block = TensorBlock( + values=torch.tensor([[0.0], [1.0], [3.0]]), + samples=Labels.range("samples", 3), + components=[], + properties=Labels("energy", torch.tensor([[0]])), + ) + block.add_gradient( + "gradient", + TensorBlock( + values=torch.tensor([[1.0], [0.0], [3.0]]), + samples=Labels.range("sample", 3), + components=[], + properties=Labels("energy", torch.tensor([[0]])), + ), + ) + tensor_map = TensorMap(keys=Labels.single(), blocks=[block]) + return tensor_map + + +@pytest.fixture +def tensor_map_with_grad_4(): + block = TensorBlock( + values=torch.tensor([[0.0], [1.0], [3.0]]), + samples=Labels.range("samples", 3), + components=[], + properties=Labels("energy", torch.tensor([[0]])), + ) + block.add_gradient( + "gradient", + TensorBlock( + values=torch.tensor([[1.0], [0.0], [2.0]]), + samples=Labels.range("sample", 3), + components=[], + properties=Labels("energy", torch.tensor([[0]])), + ), + ) + tensor_map = TensorMap(keys=Labels.single(), blocks=[block]) + return tensor_map + + +@pytest.mark.parametrize("type", ["mse", {"huber": {"deltas": {"values": 3.0}}}]) +def test_tmap_loss_no_gradients(type): + """Test that the loss is computed correctly when there are no gradients.""" + loss = TensorMapLoss(type=type, reduction="sum") + + tensor_map_1 = TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.tensor([[1.0], [2.0], [3.0]]), + samples=Labels.range("samples", 3), + components=[], + properties=Labels("energy", torch.tensor([[0]])), + ) + ], + ) + tensor_map_2 = TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.tensor([[0.0], [2.0], [3.0]]), + samples=Labels.range("samples", 3), + components=[], + properties=Labels("energy", torch.tensor([[0]])), + ) + ], + ) + + loss_value = loss(tensor_map_1, tensor_map_1) + torch.testing.assert_close(loss_value, torch.tensor(0.0)) + + # Expected result: 1.0 + loss_value = loss(tensor_map_1, tensor_map_2) + # Huber loss is scaled by 0.5 due to torch implementation + torch.testing.assert_close( + loss_value, (1.0 if type == "mse" else 0.5) * torch.tensor(1.0) + ) + + +@pytest.mark.parametrize( + "type", ["mse", {"huber": {"deltas": {"values": 3.0, "gradient": 3.0}}}] +) +def test_tmap_loss_with_gradients(tensor_map_with_grad_1, tensor_map_with_grad_2, type): + """Test that the loss is computed correctly when there are gradients.""" + loss = TensorMapLoss(type=type, gradient_weights={"gradient": 0.5}, reduction="sum") + + loss_value = loss(tensor_map_with_grad_1, tensor_map_with_grad_1) + torch.testing.assert_close(loss_value, torch.tensor(0.0)) + + # Expected result: 1.0 + 0.5 * 4.0 + loss_value = loss(tensor_map_with_grad_1, tensor_map_with_grad_2) + torch.testing.assert_close( + loss_value, + # Huber loss is scaled by 0.5 due to torch implementation + (1.0 if type == "mse" else 0.5) * torch.tensor(1.0 + 0.5 * 4.0), + ) + + +def test_tmap_dict_loss( + tensor_map_with_grad_1, + tensor_map_with_grad_2, + tensor_map_with_grad_3, + tensor_map_with_grad_4, +): + """Test that the dict loss is computed correctly.""" + + loss_rmse = TensorMapDictLoss( + weights={ + "output_1": 0.6, + "output_2": 1.0, + "output_1_gradient_gradients": 0.5, + "output_2_gradient_gradients": 0.5, + }, + reduction="sum", + ) + loss_huber = TensorMapDictLoss( + weights={ + "output_1": 0.6, + "output_2": 1.0, + "output_1_gradient_gradients": 0.5, + "output_2_gradient_gradients": 0.5, + }, + type={ + "huber": { + "deltas": { + "output_1": 0.1, + "output_2": 0.1, + "output_1_gradient_gradients": 0.1, + "output_2_gradient_gradients": 0.1, + } + } + }, + reduction="sum", + ) + + output_dict = { + "output_1": tensor_map_with_grad_1, + "output_2": tensor_map_with_grad_2, + } + + target_dict = { + "output_1": tensor_map_with_grad_3, + "output_2": tensor_map_with_grad_4, + } + + expected_result = ( + 0.6 + * ( + tensor_map_with_grad_1.block().values + - tensor_map_with_grad_3.block().values + ) + .pow(2) + .sum() + + 0.5 + * ( + tensor_map_with_grad_1.block().gradient("gradient").values + - tensor_map_with_grad_3.block().gradient("gradient").values + ) + .pow(2) + .sum() + + 1.0 + * ( + tensor_map_with_grad_2.block().values + - tensor_map_with_grad_4.block().values + ) + .pow(2) + .sum() + + 0.5 + * ( + tensor_map_with_grad_2.block().gradient("gradient").values + - tensor_map_with_grad_4.block().gradient("gradient").values + ) + .pow(2) + .sum() + ) + + loss_value = loss_rmse(output_dict, target_dict) + torch.testing.assert_close(loss_value, expected_result) + + # Huber loss should be lower than RMSE + # (scaled by 0.5 due to torch implementation of Huber) + assert loss_huber(output_dict, target_dict) < 0.5 * loss_rmse( + output_dict, target_dict + ) + + +def test_tmap_dict_loss_subset(tensor_map_with_grad_1, tensor_map_with_grad_3): + """Test that the dict loss is computed correctly when only a subset + of the possible targets is present both in outputs and targets.""" + + loss = TensorMapDictLoss( + weights={ + "output_1": 1.0, + "output_2": 1.0, + "output_1_gradient_gradients": 0.5, + "output_2_gradient_gradients": 0.5, + }, + reduction="sum", + ) + + output_dict = { + "output_1": tensor_map_with_grad_1, + } + + target_dict = { + "output_1": tensor_map_with_grad_3, + } + + expected_result = ( + 1.0 + * ( + tensor_map_with_grad_1.block().values + - tensor_map_with_grad_3.block().values + ) + .pow(2) + .sum() + + 0.5 + * ( + tensor_map_with_grad_1.block().gradient("gradient").values + - tensor_map_with_grad_3.block().gradient("gradient").values + ) + .pow(2) + .sum() + ) + + loss_value = loss(output_dict, target_dict) + torch.testing.assert_close(loss_value, expected_result) + + +def test_tmap_loss_mae(): + """Test that the MAE loss is computed correctly.""" + loss = TensorMapLoss(type="mae", reduction="mean") + + tensor_map_1 = TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.tensor([[2.0], [2.0], [3.0]]), + samples=Labels.range("samples", 3), + components=[], + properties=Labels("energy", torch.tensor([[0]])), + ) + ], + ) + tensor_map_2 = TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.tensor([[0.0], [3.0], [3.0]]), + samples=Labels.range("samples", 3), + components=[], + properties=Labels("energy", torch.tensor([[0]])), + ) + ], + ) + + loss_value = loss(tensor_map_1, tensor_map_1) + torch.testing.assert_close(loss_value, torch.tensor(0.0)) + + # Expected result: 1.0 + loss_value = loss(tensor_map_1, tensor_map_2) + torch.testing.assert_close(loss_value, torch.tensor(1.0)) + + +def test_tmap_loss_huber(): + """Test that the Huber loss is computed correctly.""" + loss_mse = TensorMapLoss(type="mse", reduction="mean") + loss_huber = TensorMapLoss( + type={"huber": {"deltas": {"values": 3.0}}}, reduction="mean" + ) + + tensor_map_1 = TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.tensor([[2.0], [2.0], [3.0]]), + samples=Labels.range("samples", 3), + components=[], + properties=Labels("energy", torch.tensor([[0]])), + ) + ], + ) + tensor_map_2 = TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.tensor([[0.0], [3.0], [3.0]]), + samples=Labels.range("samples", 3), + components=[], + properties=Labels("energy", torch.tensor([[0]])), + ) + ], + ) + + loss_value = loss_huber(tensor_map_1, tensor_map_1) + torch.testing.assert_close(loss_value, torch.tensor(0.0)) + + # No outliers, should be equal to MSE (scaled by 0.5 due to torch implementation) + loss_value_huber = loss_huber(tensor_map_1, tensor_map_2) + loss_value_mse = loss_mse(tensor_map_1, tensor_map_2) + torch.testing.assert_close(loss_value_huber, 0.5 * loss_value_mse) + + tensor_map_with_outlier = TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.tensor([[0.0], [100.0], [3.0]]), + samples=Labels.range("samples", 3), + components=[], + properties=Labels("energy", torch.tensor([[0]])), + ) + ], + ) + + loss_value_huber = loss_huber(tensor_map_1, tensor_map_with_outlier) + loss_value_mse = loss_mse(tensor_map_1, tensor_map_with_outlier) + # Huber loss is lower due to the outlier + assert loss_value_huber < 0.5 * loss_value_mse + + +def test_tmap_loss_with_sliding_weights(tensor_map_with_grad_1, tensor_map_with_grad_2): + """Test that the loss behaves as expected with sliding weights.""" + loss = TensorMapLoss( + type="mse", gradient_weights={"gradient": 1.0}, sliding_factor=0.7 + ) + + for _ in range(5): + loss(tensor_map_with_grad_1, tensor_map_with_grad_2) + + # in the two TensorMaps above, the loss on the gradients is larger than the loss on + # the values, therefore we should expect a larger sliding weight for the gradients + + assert loss.sliding_weights["gradient"] > loss.sliding_weights["values"] diff --git a/tests/utils/test_omegaconf.py b/tests/utils/test_omegaconf.py index 3f02d4478..9c85b0577 100644 --- a/tests/utils/test_omegaconf.py +++ b/tests/utils/test_omegaconf.py @@ -2,7 +2,7 @@ import pytest import torch -from omegaconf import ListConfig, OmegaConf +from omegaconf import DictConfig, ListConfig, OmegaConf from metatrain import soap_bpnn from metatrain.utils import omegaconf @@ -10,6 +10,7 @@ check_dataset_options, check_units, expand_dataset_config, + expand_loss_config, ) @@ -233,6 +234,238 @@ def test_expand_dataset_gradient(): conf_expanded["targets"]["my_energy"]["virial"]["read_from"] +def test_expand_loss_config_default(): + """ + When no custom loss is provided, architecture.training.loss + should be created from the default template for each target in training_set. + """ + conf = OmegaConf.create( + { + "training_set": { + "targets": { + "energy": {}, # no gradients requested + "dipole": {}, # non-energy target + } + }, + "architecture": {"training": {}}, + } + ) + expanded = expand_loss_config(conf) + loss = expanded["architecture"]["training"]["loss"] + # top-level loss must be a DictConfig with exactly the two keys + assert isinstance(loss, DictConfig) + assert set(loss.keys()) == {"energy", "dipole"} + + # energy should have an empty gradients dict + assert isinstance(loss["energy"]["gradients"], DictConfig) + assert len(loss["energy"]["gradients"]) == 0 + + # non-energy target gets the default loss template + assert isinstance(loss["dipole"], DictConfig) + + +def test_expand_loss_config_migrates_forces(monkeypatch): + """ + If the training hyperparams include a top-level 'forces' block, + and the dataset requests forces, it should be moved into + energy.gradients.positions. + """ + conf = OmegaConf.create( + { + "training_set": {"targets": {"energy": {"forces": True}}}, + "architecture": { + "training": { + "loss": { + "forces": {"weight": 2.0}, + # also supply some default energy block to be merged + "energy": {"weight": 3.0}, + } + } + }, + } + ) + expanded = expand_loss_config(conf) + loss = expanded["architecture"]["training"]["loss"] + + # no top-level 'forces' or 'stress' or 'virial' + assert "forces" not in loss + assert "stress" not in loss + assert "virial" not in loss + + # custom 'scale' should appear under energy.gradients.positions + pos = loss["energy"]["gradients"]["positions"] + assert isinstance(pos, DictConfig) + assert pos["weight"] == 2.0 + + # custom energy.weight should have been merged + assert loss["energy"]["weight"] == 3.0 + + +def test_expand_loss_config_migrates_virial_to_strain(): + """ + Legacy 'virial' in loss hyperparams should migrate to + energy.gradients.strain when the dataset requests virials. + """ + conf = OmegaConf.create( + { + "training_set": {"targets": {"energy": {"virial": True}}}, + "architecture": { + "training": { + "loss": {"virial": {"weight": 0.5}, "energy": {"type": "huber"}} + } + }, + } + ) + expanded = expand_loss_config(conf) + loss = expanded["architecture"]["training"]["loss"] + + # no top-level 'virial' or 'stress' + assert "virial" not in loss + assert "stress" not in loss + + # migrated into energy.gradients.strain + strain = loss["energy"]["gradients"]["strain"] + assert isinstance(strain, DictConfig) + assert strain["weight"] == 0.5 + + # original energy.type preserved + assert loss["energy"]["type"] == "huber" + + +def test_expand_loss_config_removes_unused_legacy_keys(): + """ + If the dataset does not request a given gradient, any legacy key + (forces, stress, virial) in the loss hyperparams must be deleted. + """ + conf = OmegaConf.create( + { + "training_set": { + "targets": {"energy": {}} # no forces, stress, nor virial + }, + "architecture": { + "training": { + "loss": { + "forces": {"scale": 9.9}, + "stress": {"scale": 8.8}, + "virial": {"scale": 7.7}, + } + } + }, + } + ) + expanded = expand_loss_config(conf) + loss = expanded["architecture"]["training"]["loss"] + + # none of the legacy keys should survive at top level + for legacy in ("forces", "stress", "virial"): + assert legacy not in loss + + # and energy.gradients remains empty + assert loss["energy"]["gradients"] == {} + + +def test_expand_loss_config_non_energy_only(): + """ + If the training_set contains only non-energy targets, no 'energy' + block should appear in the final loss, only the non-energy ones. + """ + conf = OmegaConf.create( + { + "training_set": {"targets": {"dipole": {}, "foo": {}}}, + "architecture": {"training": {}}, + } + ) + expanded = expand_loss_config(conf) + loss = expanded["architecture"]["training"]["loss"] + + # energy should not appear + assert "energy" not in loss + # both non-energy targets must appear, with default template + assert set(loss.keys()) == {"dipole", "foo"} + for target in ("dipole", "foo"): + assert isinstance(loss[target], DictConfig) + + +def test_expand_loss_config_single_string(): + """ + When the loss is given as a single string, it should be expanded into a DictConfig + with the default template for all targets. + """ + conf = OmegaConf.create( + { + "training_set": { + "targets": { + "energy": {}, # no gradients requested + "dipole": {}, # non-energy target + } + }, + "architecture": {"training": {"loss": "mae"}}, + } + ) + expanded = expand_loss_config(conf) + loss = expanded["architecture"]["training"]["loss"] + # top-level loss must be a DictConfig with exactly the two keys + assert isinstance(loss, DictConfig) + assert set(loss.keys()) == {"energy", "dipole"} + + # energy should have an empty gradients dict + assert isinstance(loss["energy"]["gradients"], DictConfig) + assert len(loss["energy"]["gradients"]) == 0 + + # the type of the energy loss should be 'mae' + assert loss["energy"]["type"] == "mae" + + # non-energy target gets the default loss template + assert isinstance(loss["dipole"], DictConfig) + + # the type of the dipole loss should be 'mae' + assert loss["dipole"]["type"] == "mae" + + +def test_expand_loss_config_per_target_string(): + """ + When the loss is given as a string per target, it should be expanded into a + DictConfig with the default template for each target, but with the type + set to the given string. + """ + conf = OmegaConf.create( + { + "training_set": { + "targets": { + "energy": {}, # no gradients requested + "dipole": {}, # non-energy target + } + }, + "architecture": { + "training": {"loss": {"energy": "mse", "dipole": "huber"}} + }, + } + ) + expanded = expand_loss_config(conf) + loss = expanded["architecture"]["training"]["loss"] + # top-level loss must be a DictConfig with exactly the two keys + assert isinstance(loss, DictConfig) + assert set(loss.keys()) == {"energy", "dipole"} + + # energy should have an empty gradients dict + assert isinstance(loss["energy"]["gradients"], DictConfig) + assert len(loss["energy"]["gradients"]) == 0 + + # the type of the energy loss should be 'mse' + assert loss["energy"]["type"] == "mse" + assert loss["energy"]["weight"] == 1.0 + assert loss["energy"]["reduction"] == "mean" + + # non-energy target gets the default loss template + assert isinstance(loss["dipole"], DictConfig) + + # the type of the dipole loss should be 'huber' + assert loss["dipole"]["type"] == "huber" + assert loss["dipole"]["weight"] == 1.0 + assert loss["dipole"]["reduction"] == "mean" + assert loss["dipole"]["delta"] == 1.0 + + def test_check_units(): file_name = "foo.xyz" system_section = {"read_from": file_name, "length_unit": "angstrom"} diff --git a/tests/utils/test_transfer.py b/tests/utils/test_transfer.py index a1021a0c1..4c792d5a4 100644 --- a/tests/utils/test_transfer.py +++ b/tests/utils/test_transfer.py @@ -1,4 +1,5 @@ import metatensor.torch as mts +import pytest import torch from metatensor.torch import Labels, TensorMap from metatomic.torch import System @@ -6,20 +7,27 @@ from metatrain.utils.transfer import batch_to -def test_batch_to_dtype(): - system = System( +@pytest.fixture +def simple_tensormap(): + return TensorMap( + keys=Labels.single(), + blocks=[mts.block_from_array(torch.tensor([[1.0]]))], + ) + + +@pytest.fixture +def simple_system(): + return System( positions=torch.tensor([[1.0, 1.0, 1.0]]), cell=torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]), types=torch.tensor([1]), pbc=torch.tensor([True, True, True]), ) - targets = TensorMap( - keys=Labels.single(), - blocks=[mts.block_from_array(torch.tensor([[1.0]]))], - ) - systems = [system] - targets = {"energy": targets} + +def test_batch_to_dtype(simple_system, simple_tensormap): + systems = [simple_system] + targets = {"energy": simple_tensormap} assert systems[0].positions.dtype == torch.float32 assert systems[0].cell.dtype == torch.float32 @@ -32,20 +40,9 @@ def test_batch_to_dtype(): assert targets["energy"].block().values.dtype == torch.float64 -def test_batch_to_device(): - system = System( - positions=torch.tensor([[1.0, 1.0, 1.0]]), - cell=torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]), - types=torch.tensor([1]), - pbc=torch.tensor([True, True, True]), - ) - targets = TensorMap( - keys=Labels.single(), - blocks=[mts.block_from_array(torch.tensor([[1.0]]))], - ) - - systems = [system] - targets = {"energy": targets} +def test_batch_to_device(simple_system, simple_tensormap): + systems = [simple_system] + targets = {"energy": simple_tensormap} assert systems[0].positions.device == torch.device("cpu") assert systems[0].types.device == torch.device("cpu") @@ -56,3 +53,30 @@ def test_batch_to_device(): assert systems[0].positions.device == torch.device("meta") assert systems[0].types.device == torch.device("meta") assert targets["energy"].block().values.device == torch.device("meta") + + +def test_batch_to_extra_data_mask_branch(simple_system, simple_tensormap): + system = simple_system + targets = {"energy": simple_tensormap} + + # extra_data with one normal key and one mask key + extra_data = { + "feature": simple_tensormap, + "feature_mask": TensorMap( + keys=Labels.single(), + blocks=[mts.block_from_array(torch.tensor([[1]], dtype=torch.int64))], + ), + } + + # Apply batch_to requesting float64 + _, _, extra_out = batch_to( + [system], targets, extra_data=extra_data, dtype=torch.float64 + ) + + # The non-mask TensorMap should be float64 + feat_tm = extra_out["feature"] + assert feat_tm.block().values.dtype == torch.float64 + + # The mask TensorMap should be bool, despite original dtype and requested dtype + mask_tm = extra_out["feature_mask"] + assert mask_tm.block().values.dtype == torch.bool