Skip to content

Refactor and generalize loss.py #635

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/advanced-concepts/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ such as output naming, auxiliary outputs, and wrapper models.
:maxdepth: 1

output-naming
loss-functions
auxiliary-outputs
multi-gpu
auto-restarting
Expand Down
123 changes: 123 additions & 0 deletions docs/src/advanced-concepts/loss-functions.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
.. _loss-functions:

Loss functions
==============

``metatrain`` supports a variety of loss functions, which can be configured
in the ``loss`` subsection of the ``training`` section for each ``architecture``
in the options file. The loss functions are designed to be flexible and can be
tailored to the specific needs of the dataset and the targets being predicted.

The ``loss`` subsection describes the loss functions to be used. The most basic
configuration is

.. code-block:: yaml

loss: mse

which sets the loss function to mean squared error (MSE) for all targets.
When training a potential energy surface on energy, forces, and virial,
for example, this configuration is internally expanded to

.. code-block:: yaml

loss:
energy:
type: mse
weight: 1.0
reduction: mean
forces:
type: mse
weight: 1.0
reduction: mean
virial:
type: mse
weight: 1.0
reduction: mean

This internal, more detailed configuration can be used in the options file
to specify different loss functions for each target, or to override default
values for the parameters. The parameters accepted by each loss function are

1. ``type``. This controls the type of loss to be used. The default value is ``mse``,
and other standard options are ``mae`` and ``huber``, which implement the equivalent
PyTorch loss functions
`MSELoss <https://docs.pytorch.org/docs/stable/generated/torch.nn.MSELoss.html>`_,
`L1Loss <https://docs.pytorch.org/docs/stable/generated/torch.nn.L1Loss.html>`_,
and
`HuberLoss <https://docs.pytorch.org/docs/stable/generated/torch.nn.HuberLoss.html>`_,
respectively.
There are also "masked" versions of these losses, which are useful when using
padded targets with values that should be masked before computing the loss. The
masked losses are named ``masked_mse``, ``masked_mae``, and ``masked_huber``.

2. ``weight``. This controls the weighting of different contributions to the loss
(e.g., energy, forces, virial, etc.). The default value of 1.0 for all targets
works well for most datasets, but can be adjusted if required.

3. ``reduction``. This controls how the overall loss is computed across batches.
The default for this is to use the ``mean`` of the batch losses. The ``sum``
function is also supported.

Some losses, like ``huber``, require additional parameters to be specified. Below is
a table summarizing losses that require or allow additional parameters:

.. list-table:: Loss Functions and Parameters
:header-rows: 1
:widths: 20 30 50

* - Loss Type
- Description
- Additional Parameters
* - ``mse``
- Mean squared error
- N/A
* - ``mae``
- Mean absolute error
- N/A
* - ``mse_masked``
- Masked mean squared error
- N/A
* - ``mae_masked``
- Masked mean absolute error
- N/A
* - ``huber``
- Huber loss
- ``delta``: Threshold at which to switch from squared error to absolute error.


Masked loss functions
---------------------

Masked loss functions are particularly useful when dealing with datasets that contain
padded targets. In such cases, the loss function can be configured to ignore the padded
values during the loss computation. This is done by using the ``masked_`` prefix in
the loss type. For example, if the target contains padded values, you can use
``masked_mse`` or ``masked_mae`` to ensure that the loss is computed only on the
valid (non-padded) values. The values of the masks must be passed as ``extra_data``
in the training set, and the loss function will automatically apply the mask to
the target values. An example configuration for a masked loss is as follows:

.. code-block:: yaml

loss:
energy:
type: masked_mse
weight: 1.0
reduction: sum
forces:
type: masked_mae
weight: 0.1
reduction: sum
...

training_set:
systems:
...
targets:
mtt::my_target:
...
...
extra_data:
mtt::my_target_mask:
read_from: my_target_mask.mts
16 changes: 3 additions & 13 deletions docs/src/architectures/nanopet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,9 @@ hyperparameters to tune are (in decreasing order of importance):
- ``num_attention_layers``: The number of attention layers in each layer of the graph
neural network. Depending on the dataset, increasing this hyperparameter might lead to
better accuracy, at the cost of increased training and evaluation time.
- ``loss``: This section describes the loss function to be used, and it has three
subsections. 1. ``weights``. This controls the weighting of different contributions
to the loss (e.g., energy, forces, virial, etc.). The default values of 1.0 for all
targets work well for most datasets, but they might need to be adjusted. For example,
to set a weight of 1.0 for the energy and 0.1 for the forces, you can set the
following in the ``options.yaml`` file under ``loss``:
``weights: {"energy": 1.0, "forces": 0.1}``. 2. ``type``. This controls the type of
loss to be used. The default value is ``mse``, and other options are ``mae`` and
``huber``. ``huber`` is a subsection of its own, and it requires the user to specify
the ``deltas`` parameters in a similar way to how the ``weights`` are specified (e.g.,
``deltas: {"energy": 0.1, "forces": 0.01}``). 3. ``reduction``. This controls how the
loss is reduced over batches. The default value is ``mean``, and the other allowed
option is ``sum``.
- ``loss``: This section describes the loss function to be used. See the
:doc:`dedicated documentation page <../advanced-concepts/loss-functions>` for more
details.
- ``long_range``: In some systems and datasets, enabling long-range Coulomb interactions
might be beneficial for the accuracy of the model and/or its physical correctness.
See below for a breakdown of the long-range section of the model hyperparameters.
Expand Down
16 changes: 3 additions & 13 deletions docs/src/architectures/pet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,9 @@ hyperparameters to tune are (in decreasing order of importance):
- ``num_attention_layers``: The number of attention layers in each layer of the graph
neural network. Depending on the dataset, increasing this hyperparameter might lead to
better accuracy, at the cost of increased training and evaluation time.
- ``loss``: This section describes the loss function to be used, and it has three
subsections. 1. ``weights``. This controls the weighting of different contributions
to the loss (e.g., energy, forces, virial, etc.). The default values of 1.0 for all
targets work well for most datasets, but they might need to be adjusted. For example,
to set a weight of 1.0 for the energy and 0.1 for the forces, you can set the
following in the ``options.yaml`` file under ``loss``:
``weights: {"energy": 1.0, "forces": 0.1}``. 2. ``type``. This controls the type of
loss to be used. The default value is ``mse``, and other options are ``mae`` and
``huber``. ``huber`` is a subsection of its own, and it requires the user to specify
the ``deltas`` parameters in a similar way to how the ``weights`` are specified (e.g.,
``deltas: {"energy": 0.1, "forces": 0.01}``). 3. ``reduction``. This controls how the
loss is reduced over batches. The default value is ``mean``, and the other allowed
option is ``sum``.
- ``loss``: This section describes the loss function to be used. See the
:doc:`dedicated documentation page <../advanced-concepts/loss-functions>` for more
details.
- ``long_range``: In some systems and datasets, enabling long-range Coulomb interactions
might be beneficial for the accuracy of the model and/or its physical correctness.
See below for a breakdown of the long-range section of the model hyperparameters.
Expand Down
16 changes: 3 additions & 13 deletions docs/src/architectures/soap-bpnn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,9 @@ hyperparameters to tune are (in decreasing order of importance):
- ``layernorm``: Whether to use layer normalization before the neural network. Setting
this hyperparameter to ``false`` will lead to slower convergence of training, but
might lead to better generalization outside of the training set distribution.
- ``loss``: This section describes the loss function to be used, and it has three
subsections. 1. ``weights``. This controls the weighting of different contributions
to the loss (e.g., energy, forces, virial, etc.). The default values of 1.0 for all
targets work well for most datasets, but they might need to be adjusted. For example,
to set a weight of 1.0 for the energy and 0.1 for the forces, you can set the
following in the ``options.yaml`` file under ``loss``:
``weights: {"energy": 1.0, "forces": 0.1}``. 2. ``type``. This controls the type of
loss to be used. The default value is ``mse``, and other options are ``mae`` and
``huber``. ``huber`` is a subsection of its own, and it requires the user to specify
the ``deltas`` parameters in a similar way to how the ``weights`` are specified (e.g.,
``deltas: {"energy": 0.1, "forces": 0.01}``). 3. ``reduction``. This controls how the
loss is reduced over batches. The default value is ``mean``, and the other allowed
option is ``sum``.
- ``loss``: This section describes the loss function to be used. See the
:doc:`dedicated documentation page <../advanced-concepts/loss-functions>` for more
details.
- ``long_range``: In some systems and datasets, enabling long-range Coulomb interactions
might be beneficial for the accuracy of the model and/or its physical correctness.
See below for a breakdown of the long-range section of the model hyperparameters.
Expand Down
8 changes: 6 additions & 2 deletions docs/src/dev-docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@ changelog <https://keepachangelog.com/en/1.1.0/>`_ format. This project follows
.. Added
.. #####

.. Changed
.. #######
Changed
#######

- Refactored the ``loss.py`` module to provide an easier to extend interface for custom
loss functions.
- Updated the trainer checkpoints to account for changes in the loss-related hypers.

.. Removed
.. #######
Expand Down
1 change: 1 addition & 0 deletions docs/src/dev-docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ module.
architecture-life-cycle
new-architecture
dataset-information
new-loss
cli/index
utils/index
changelog
61 changes: 61 additions & 0 deletions docs/src/dev-docs/new-loss.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
.. _adding-new-loss:

Adding a new loss function
==========================

This page describes the required classes and files necessary for adding a new
loss function to ``metatrain``. Defining a new loss can be useful in case some extra
data has to be used to compute the loss.

Loss functions in ``metatrain`` are implemented as subclasses of
:py:class:`metatrain.utils.loss.LossInterface`. This interface defines the
required method :py:meth:`compute`, which takes the model predictions and
the ground truth values as input and returns the computed loss value. The
:py:meth:`compute` method accepts an additional argument ``extra_data`` on top of
``predictions`` and ``targets``, that can be used to pass any extra information needed
for the loss computation.

.. code-block:: python

from typing import Dict, Optional
import torch
from metatrain.utils.loss import LossInterface
from metatensor.torch import TensorMap

class NewLoss(LossInterface):
def __init__(
self,
name: str,
gradient: Optional[str],
weight: float,
reduction: str,
) -> None:
...

def compute(
self,
predictions: Dict[str, TensorMap],
targets: Dict[str, TensorMap],
extra_data: Dict[str, TensorMap]
) -> torch.Tensor:
...


Examples of loss functions already implemented in ``metatrain`` are
:py:class:`metatrain.utils.loss.TensorMapMSELoss` and
:py:class:`metatrain.utils.loss.TensorMapMAELoss`. They both inherit from the
:py:class:`metatrain.utils.loss.BaseTensorMapLoss` class, which implements pointwise
losses for :py:class:`metatensor.torch.TensorMap` objects.


Loss weight scheduling
----------------------

Currently, only one loss weight scheduler is implemented in ``metatrain``, which is
:py:class:`metatrain.utils.loss.EMAScheduler`. This class is used to schedule the weight
of a loss function based on the Exponential Moving Average (EMA) of the loss value.
The EMA scheduler is useful to adapt the loss weight during training, allowing for a
more dynamic adjustment of the loss contribution based on the training progress.
New schedulers can be implemented by inheriting from the
:py:class:`metatrain.utils.loss.WeightScheduler` abstract class, which defines the
:py:meth:`initialize` and :py:meth:`update` methods that need to be implemented.
12 changes: 10 additions & 2 deletions src/metatrain/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@
)
from ..utils.jsonschema import validate
from ..utils.logging import ROOT_LOGGER, WandbHandler, human_readable
from ..utils.omegaconf import BASE_OPTIONS, check_units, expand_dataset_config
from ..utils.omegaconf import (
BASE_OPTIONS,
check_units,
expand_dataset_config,
expand_loss_config,
)
from .eval import _eval_targets
from .export import _has_extensions
from .formatter import CustomHelpFormatter
Expand Down Expand Up @@ -205,7 +210,6 @@ def train_model(
{"architecture": get_default_hypers(architecture_name)},
options,
)
hypers = OmegaConf.to_container(options["architecture"])

###########################
# PROCESS BASE PARAMETERS #
Expand Down Expand Up @@ -386,6 +390,10 @@ def train_model(
test_datasets.append(dataset)
test_indices.append(None)

# Expand loss options and finalize the hypers
options = expand_loss_config(options)
hypers = OmegaConf.to_container(options["architecture"])

############################################
# SAVE TRAIN, VALIDATION, TEST INDICES #####
############################################
Expand Down
20 changes: 20 additions & 0 deletions src/metatrain/experimental/nanopet/checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# ===== Model checkpoint updates =====

# ...

# ===== Trainer checkpoint updates =====


def trainer_update_v1_v2(checkpoint):
old_loss_hypers = checkpoint["train_hypers"]["loss"].copy()
dataset_info = checkpoint["model_data"]["dataset_info"]
new_loss_hypers = {}

for target_name in dataset_info.targets.keys():
new_loss_hypers[target_name] = {
"type": old_loss_hypers["type"],
"weight": old_loss_hypers["weights"].get(target_name, 1.0),
"reduction": old_loss_hypers["reduction"],
"sliding_factor": old_loss_hypers.get("sliding_factor", None),
}
checkpoint["train_hypers"]["loss"] = new_loss_hypers
5 changes: 1 addition & 4 deletions src/metatrain/experimental/nanopet/default-hypers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,4 @@ architecture:
log_mae: false
log_separate_blocks: false
best_model_metric: rmse_prod
loss:
type: mse
weights: {}
reduction: mean
Comment on lines -34 to -37
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to remove the defaults here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And what is the new default? Is it documented?

loss: mse
Loading