Skip to content

Commit 953c875

Browse files
committed
Add displacements and momenta outputs
1 parent cc4cf6b commit 953c875

File tree

9 files changed

+343
-3
lines changed

9 files changed

+343
-3
lines changed
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
.. _displacements-output:
2+
3+
Displacements
4+
^^^^^^^^^^^^^
5+
6+
Displacements are differences between atomic positions at two different times.
7+
They can be used to predict the next configuration in molecular dynamics
8+
(see, e.g., https://arxiv.org/pdf/2505.19350).
9+
10+
In metatomic models, they are associated with the ``"displacements"``
11+
key in the model outputs, and must adhere to the following metadata schema:
12+
13+
.. list-table:: Metadata for displacements
14+
:widths: 2 3 7
15+
:header-rows: 1
16+
17+
* - Metadata
18+
- Names
19+
- Description
20+
21+
* - keys
22+
- ``"_"``
23+
- the keys must have a single dimension named ``"_"``, with a single
24+
entry set to ``0``. Displacements are always a
25+
:py:class:`metatensor.torch.TensorMap` with a single block.
26+
27+
* - samples
28+
- ``["system", "atom"]``
29+
- the samples must be named ``["system", "atom"]``, since
30+
displacements are always per-atom.
31+
32+
``"system"`` must range from 0 to the number of systems given as an input
33+
to the model. ``"atom"`` must range between 0 and the number of
34+
atoms/particles in the corresponding system. If ``selected_atoms`` is
35+
provided, then only the selected atoms for each system should be part of
36+
the samples.
37+
38+
* - components
39+
- ``"xyz"``
40+
- displacements must have a single component dimension named
41+
``"xyz"``, with three entries set to ``0``, ``1``, and ``2``. The
42+
displacements are always 3D vectors, and the order of the
43+
components is x, y, z.
44+
45+
* - properties
46+
- ``"displacements"``
47+
- displacements must have a single property dimension named
48+
``"displacements"``, with a single entry set to ``0``.
49+
50+
At the moment, displacements are not integrated into any simulation engines.
51+
52+
.. _momenta-output:
53+
54+
Momenta
55+
^^^^^^^
56+
57+
The momentum of a particle is a vector defined as its mass times its velocity.
58+
Predictions of momenta can be used, for example, to predict a future step in molecular
59+
dynamics (see, e.g., https://arxiv.org/pdf/2505.19350).
60+
61+
In metatomic models, they are associated with the ``"momenta"``
62+
key in the model outputs, and must adhere to the following metadata schema:
63+
64+
.. list-table:: Metadata for momenta
65+
:widths: 2 3 7
66+
:header-rows: 1
67+
68+
* - Metadata
69+
- Names
70+
- Description
71+
72+
* - keys
73+
- ``"_"``
74+
- the keys must have a single dimension named ``"_"``, with a single
75+
entry set to ``0``. Momenta are always a
76+
:py:class:`metatensor.torch.TensorMap` with a single block.
77+
78+
* - samples
79+
- ``["system", "atom"]``
80+
- the samples must be named ``["system", "atom"]``, since
81+
momenta are always per-atom.
82+
83+
``"system"`` must range from 0 to the number of systems given as an input
84+
to the model. ``"atom"`` must range between 0 and the number of
85+
atoms/particles in the corresponding system. If ``selected_atoms`` is
86+
provided, then only the selected atoms for each system should be part of
87+
the samples.
88+
89+
* - components
90+
- ``"xyz"``
91+
- momenta must have a single component dimension named
92+
``"xyz"``, with three entries set to ``0``, ``1``, and ``2``. The
93+
momenta are always 3D vectors, and the order of the
94+
components is x, y, z.
95+
96+
* - properties
97+
- ``"momenta"``
98+
- momenta must have a single property dimension named
99+
``"momenta"``, with a single entry set to ``0``.
100+
101+
At the moment, momenta are not integrated into any simulation engines.

docs/src/outputs/index.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ section to these pages.
2020
energy
2121
features
2222
non_conservative
23+
displacements-and-momenta
2324

2425

2526
Physical quantities
@@ -76,6 +77,22 @@ quantities, i.e. quantities with a well-defined physical meaning.
7677
Stress directly predicted by the model, not derived from the potential
7778
energy.
7879

80+
.. grid-item-card:: Displacements
81+
:link: displacements-output
82+
:link-type: ref
83+
84+
.. image:: /../static/images/displacements-output.png
85+
86+
Atomic displacements predicted by the model, to be used in ML-driven simulations.
87+
88+
.. grid-item-card:: Momenta
89+
:link: momenta-output
90+
:link-type: ref
91+
92+
.. image:: /../static/images/momenta-output.png
93+
94+
Atomic momenta predicted by the model, to be used in ML-driven simulations.
95+
7996

8097
Machine learning outputs
8198
^^^^^^^^^^^^^^^^^^^^^^^^

docs/src/torch/reference/models/index.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ In the mean time, you can create :py:class:`metatomic.torch.ModelOutput` with
3434
quantities that are not in this table. A warning will be issued and no unit
3535
conversion will be performed.
3636

37-
When working with one of the quantity in this table, the unit you use must be
37+
When working with one of the quantities in this table, the unit you use must be
3838
one of the registered unit.
3939

4040
+----------------+---------------------------------------------------------------------------------------------------+
@@ -48,3 +48,5 @@ one of the registered unit.
4848
+----------------+---------------------------------------------------------------------------------------------------+
4949
| **pressure** | eV/Angstrom^3 (eV/A^3, eV/Angstrom^3) |
5050
+----------------+---------------------------------------------------------------------------------------------------+
51+
| **momentum** | sqrt(eV*u) |
52+
+----------------+---------------------------------------------------------------------------------------------------+
7.4 KB
Loading

docs/static/images/momenta-output.png

5.71 KB
Loading

metatomic-torch/src/model.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,9 @@ std::unordered_set<std::string> KNOWN_OUTPUTS = {
141141
"energy_uncertainty",
142142
"features",
143143
"non_conservative_forces",
144-
"non_conservative_stress"
144+
"non_conservative_stress",
145+
"displacements",
146+
"momenta"
145147
};
146148

147149
void ModelCapabilitiesHolder::set_outputs(torch::Dict<std::string, ModelOutput> outputs) {
@@ -1082,6 +1084,11 @@ static std::map<std::string, Quantity> KNOWN_QUANTITIES = {
10821084
// alternative names
10831085
{"eV/A^3", "eV/Angstrom^3"},
10841086
}}},
1087+
{"momentum", Quantity{/* name */ "momentum", /* baseline */ "sqrt(eV*u)", {
1088+
{"sqrt(eV*u)", 1.0},
1089+
}, {
1090+
// alternative names
1091+
}}},
10851092
};
10861093

10871094
bool metatomic_torch::valid_quantity(const std::string& quantity) {

metatomic-torch/tests/models.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ TEST_CASE("Models metadata") {
109109
struct WarningHandler: public torch::WarningHandler {
110110
virtual ~WarningHandler() override = default;
111111
void process(const torch::Warning& warning) override {
112-
CHECK(warning.msg() == "unknown quantity 'unknown', only [energy force length pressure] are supported");
112+
CHECK(warning.msg() == "unknown quantity 'unknown', only [energy force length momentum pressure] are supported");
113113
}
114114
};
115115

python/metatomic_torch/metatomic/torch/outputs.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ def _check_outputs(
5151
_check_non_conservative_forces(value, systems, request, selected_atoms)
5252
elif name == "non_conservative_stress":
5353
_check_non_conservative_stress(value, systems, request)
54+
elif name == "displacements":
55+
_check_displacements(value, systems, request)
56+
elif name == "momenta":
57+
_check_momenta(value, systems, request)
5458
else:
5559
# this is a non-standard output, there is nothing to check
5660
continue
@@ -263,6 +267,106 @@ def _check_non_conservative_stress(
263267
)
264268

265269

270+
def _check_displacements(
271+
value: TensorMap,
272+
systems: List[System],
273+
request: ModelOutput,
274+
):
275+
"""
276+
Check output metadata for displacements.
277+
"""
278+
# Ensure the output contains a single block with the expected key
279+
_validate_single_block("displacements", value)
280+
281+
# Check samples values from systems
282+
_validate_atomic_samples(
283+
"displacements", value, systems, request, selected_atoms=None
284+
)
285+
286+
displacements_block = value.block_by_id(0)
287+
288+
# Check that the block has correct "Cartesian-form" components
289+
if len(displacements_block.components) != 1:
290+
raise ValueError(
291+
"invalid components for 'displacements' output: "
292+
f"expected one component, got {len(displacements_block.components)}"
293+
)
294+
expected_component = Labels(
295+
"xyz", torch.tensor([[0], [1], [2]], device=value.device)
296+
)
297+
if displacements_block.components[0] != expected_component:
298+
raise ValueError(
299+
f"invalid components for 'displacements' output: "
300+
f"expected {expected_component}, got {displacements_block.components[0]}"
301+
)
302+
303+
expected_properties = Labels(
304+
"displacements", torch.tensor([[0]], device=value.device)
305+
)
306+
message = "`Labels('displacements', [[0]])`"
307+
308+
if displacements_block.properties != expected_properties:
309+
raise ValueError(
310+
f"invalid properties for 'displacements' output: expected {message}, "
311+
f"got {displacements_block.properties}"
312+
)
313+
314+
# Should not have any gradients
315+
if len(displacements_block.gradients_list()) > 0:
316+
raise ValueError(
317+
"invalid gradients for 'displacements' output: "
318+
f"expected no gradients, found {displacements_block.gradients_list()}"
319+
)
320+
321+
322+
def _check_momenta(
323+
value: TensorMap,
324+
systems: List[System],
325+
request: ModelOutput,
326+
):
327+
"""
328+
Check output metadata for momenta.
329+
"""
330+
# Ensure the output contains a single block with the expected key
331+
_validate_single_block("momenta", value)
332+
333+
# Check samples values from systems
334+
_validate_atomic_samples("momenta", value, systems, request, selected_atoms=None)
335+
336+
momenta_block = value.block_by_id(0)
337+
338+
# Check that the block has correct "Cartesian-form" components
339+
if len(momenta_block.components) != 1:
340+
raise ValueError(
341+
"invalid components for 'momenta' output: "
342+
f"expected one component, got {len(momenta_block.components)}"
343+
)
344+
expected_component = Labels(
345+
"xyz", torch.tensor([[0], [1], [2]], device=value.device)
346+
)
347+
if momenta_block.components[0] != expected_component:
348+
raise ValueError(
349+
f"invalid components for 'momenta' output: "
350+
f"expected {expected_component}, got {momenta_block.components[0]}"
351+
)
352+
353+
expected_properties = Labels("momenta", torch.tensor([[0]], device=value.device))
354+
message = "`Labels('momenta', [[0]])`"
355+
356+
if momenta_block.properties != expected_properties:
357+
raise ValueError(
358+
f"invalid properties for 'momenta' output: expected {message}, "
359+
f"got {momenta_block.properties}"
360+
)
361+
362+
# Should not have any gradients
363+
if len(momenta_block.gradients_list()) > 0:
364+
raise ValueError(
365+
"invalid gradients for 'momenta' output: "
366+
f"expected no gradients, found {momenta_block.gradients_list()}"
367+
)
368+
369+
266370
def _validate_atomic_samples(
267371
name: str,
268372
value: TensorMap,

0 commit comments

Comments
 (0)