Skip to content

Add positions and momenta outputs #63

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions docs/src/outputs/displacements-and-momenta.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
.. _positions-output:

positions
^^^^^^^^^^^^^

positions are differences between atomic positions at two different times.
They can be used to predict the next configuration in molecular dynamics
(see, e.g., https://arxiv.org/pdf/2505.19350).

In metatomic models, they are associated with the ``"positions"``
key in the model outputs, and must adhere to the following metadata schema:

.. list-table:: Metadata for positions
:widths: 2 3 7
:header-rows: 1

* - Metadata
- Names
- Description

* - keys
- ``"_"``
- the keys must have a single dimension named ``"_"``, with a single
entry set to ``0``. positions are always a
:py:class:`metatensor.torch.TensorMap` with a single block.

* - samples
- ``["system", "atom"]``
- the samples must be named ``["system", "atom"]``, since
positions are always per-atom.

``"system"`` must range from 0 to the number of systems given as an input
to the model. ``"atom"`` must range between 0 and the number of
atoms/particles in the corresponding system. If ``selected_atoms`` is
provided, then only the selected atoms for each system should be part of
the samples.

* - components
- ``"xyz"``
- positions must have a single component dimension named
``"xyz"``, with three entries set to ``0``, ``1``, and ``2``. The
positions are always 3D vectors, and the order of the
components is x, y, z.

* - properties
- ``"positions"``
- positions must have a single property dimension named
``"positions"``, with a single entry set to ``0``.

At the moment, positions are not integrated into any simulation engines.

.. _momenta-output:

Momenta
^^^^^^^

The momentum of a particle is a vector defined as its mass times its velocity.
Predictions of momenta can be used, for example, to predict a future step in molecular
dynamics (see, e.g., https://arxiv.org/pdf/2505.19350).

In metatomic models, they are associated with the ``"momenta"``
key in the model outputs, and must adhere to the following metadata schema:

.. list-table:: Metadata for momenta
:widths: 2 3 7
:header-rows: 1

* - Metadata
- Names
- Description

* - keys
- ``"_"``
- the keys must have a single dimension named ``"_"``, with a single
entry set to ``0``. Momenta are always a
:py:class:`metatensor.torch.TensorMap` with a single block.

* - samples
- ``["system", "atom"]``
- the samples must be named ``["system", "atom"]``, since
momenta are always per-atom.

``"system"`` must range from 0 to the number of systems given as an input
to the model. ``"atom"`` must range between 0 and the number of
atoms/particles in the corresponding system. If ``selected_atoms`` is
provided, then only the selected atoms for each system should be part of
the samples.

* - components
- ``"xyz"``
- momenta must have a single component dimension named
``"xyz"``, with three entries set to ``0``, ``1``, and ``2``. The
momenta are always 3D vectors, and the order of the
components is x, y, z.

* - properties
- ``"momenta"``
- momenta must have a single property dimension named
``"momenta"``, with a single entry set to ``0``.

At the moment, momenta are not integrated into any simulation engines.
17 changes: 17 additions & 0 deletions docs/src/outputs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ section to these pages.
energy
features
non_conservative
positions-and-momenta


Physical quantities
Expand Down Expand Up @@ -76,6 +77,22 @@ quantities, i.e. quantities with a well-defined physical meaning.
Stress directly predicted by the model, not derived from the potential
energy.

.. grid-item-card:: positions
:link: positions-output
:link-type: ref

.. image:: /../static/images/positions-output.png

Atomic positions predicted by the model, to be used in ML-driven simulations.

.. grid-item-card:: Momenta
:link: momenta-output
:link-type: ref

.. image:: /../static/images/momenta-output.png

Atomic momenta predicted by the model, to be used in ML-driven simulations.


Machine learning outputs
^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
4 changes: 3 additions & 1 deletion docs/src/torch/reference/models/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ In the mean time, you can create :py:class:`metatomic.torch.ModelOutput` with
quantities that are not in this table. A warning will be issued and no unit
conversion will be performed.

When working with one of the quantity in this table, the unit you use must be
When working with one of the quantities in this table, the unit you use must be
one of the registered unit.

+----------------+---------------------------------------------------------------------------------------------------+
Expand All @@ -48,3 +48,5 @@ one of the registered unit.
+----------------+---------------------------------------------------------------------------------------------------+
| **pressure** | eV/Angstrom^3 (eV/A^3, eV/Angstrom^3) |
+----------------+---------------------------------------------------------------------------------------------------+
| **momentum** | sqrt(eV*u) |
+----------------+---------------------------------------------------------------------------------------------------+
Binary file added docs/static/images/momenta-output.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/static/images/positions-output.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 8 additions & 1 deletion metatomic-torch/src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ std::unordered_set<std::string> KNOWN_OUTPUTS = {
"energy_uncertainty",
"features",
"non_conservative_forces",
"non_conservative_stress"
"non_conservative_stress",
"positions",
"momenta"
};

void ModelCapabilitiesHolder::set_outputs(torch::Dict<std::string, ModelOutput> outputs) {
Expand Down Expand Up @@ -1082,6 +1084,11 @@ static std::map<std::string, Quantity> KNOWN_QUANTITIES = {
// alternative names
{"eV/A^3", "eV/Angstrom^3"},
}}},
{"momentum", Quantity{/* name */ "momentum", /* baseline */ "sqrt(eV*u)", {
{"sqrt(eV*u)", 1.0},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's u?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

u is the atomic mass unit

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}, {
// alternative names
}}},
};

bool metatomic_torch::valid_quantity(const std::string& quantity) {
Expand Down
2 changes: 1 addition & 1 deletion metatomic-torch/tests/models.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ TEST_CASE("Models metadata") {
struct WarningHandler: public torch::WarningHandler {
virtual ~WarningHandler() override = default;
void process(const torch::Warning& warning) override {
CHECK(warning.msg() == "unknown quantity 'unknown', only [energy force length pressure] are supported");
CHECK(warning.msg() == "unknown quantity 'unknown', only [energy force length momentum pressure] are supported");
}
};

Expand Down
100 changes: 100 additions & 0 deletions python/metatomic_torch/metatomic/torch/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def _check_outputs(
_check_non_conservative_forces(value, systems, request, selected_atoms)
elif name == "non_conservative_stress":
_check_non_conservative_stress(value, systems, request)
elif name == "positions":
_check_positions(value, systems, request)
elif name == "momenta":
_check_momenta(value, systems, request)
else:
# this is a non-standard output, there is nothing to check
continue
Expand Down Expand Up @@ -263,6 +267,102 @@ def _check_non_conservative_stress(
)


def _check_positions(
value: TensorMap,
systems: List[System],
request: ModelOutput,
):
"""
Check output metadata for positions.
"""
# Ensure the output contains a single block with the expected key
_validate_single_block("positions", value)

# Check samples values from systems
_validate_atomic_samples("positions", value, systems, request, selected_atoms=None)

positions_block = value.block_by_id(0)

# Check that the block has correct "Cartesian-form" components
if len(positions_block.components) != 1:
raise ValueError(
"invalid components for 'positions' output: "
f"expected one component, got {len(positions_block.components)}"
)
expected_component = Labels(
"xyz", torch.tensor([[0], [1], [2]], device=value.device)
)
if positions_block.components[0] != expected_component:
raise ValueError(
f"invalid components for 'positions' output: "
f"expected {expected_component}, got {positions_block.components[0]}"
)

expected_properties = Labels("positions", torch.tensor([[0]], device=value.device))
message = "`Labels('positions', [[0]])`"

if positions_block.properties != expected_properties:
raise ValueError(
f"invalid properties for 'positions' output: expected {message}, "
f"got {positions_block.properties}"
)

# Should not have any gradients
if len(positions_block.gradients_list()) > 0:
raise ValueError(
"invalid gradients for 'positions' output: "
f"expected no gradients, found {positions_block.gradients_list()}"
)


def _check_momenta(
value: TensorMap,
systems: List[System],
request: ModelOutput,
):
"""
Check output metadata for momenta.
"""
# Ensure the output contains a single block with the expected key
_validate_single_block("momenta", value)

# Check samples values from systems
_validate_atomic_samples("momenta", value, systems, request, selected_atoms=None)

momenta_block = value.block_by_id(0)

# Check that the block has correct "Cartesian-form" components
if len(momenta_block.components) != 1:
raise ValueError(
"invalid components for 'momenta' output: "
f"expected one component, got {len(momenta_block.components)}"
)
expected_component = Labels(
"xyz", torch.tensor([[0], [1], [2]], device=value.device)
)
if momenta_block.components[0] != expected_component:
raise ValueError(
f"invalid components for 'momenta' output: "
f"expected {expected_component}, got {momenta_block.components[0]}"
)

expected_properties = Labels("momenta", torch.tensor([[0]], device=value.device))
message = "`Labels('momenta', [[0]])`"

if momenta_block.properties != expected_properties:
raise ValueError(
f"invalid properties for 'momenta' output: expected {message}, "
f"got {momenta_block.properties}"
)

# Should not have any gradients
if len(momenta_block.gradients_list()) > 0:
raise ValueError(
"invalid gradients for 'momenta' output: "
f"expected no gradients, found {momenta_block.gradients_list()}"
)


def _validate_atomic_samples(
name: str,
value: TensorMap,
Expand Down
Loading
Loading