From 953c8752cf06946793c2d3edbbf03a4321895c56 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Wed, 16 Jul 2025 01:46:57 -0700 Subject: [PATCH 1/2] Add displacements and momenta outputs --- .../src/outputs/displacements-and-momenta.rst | 101 ++++++++++++++++ docs/src/outputs/index.rst | 17 +++ docs/src/torch/reference/models/index.rst | 4 +- docs/static/images/displacements-output.png | Bin 0 -> 7577 bytes docs/static/images/momenta-output.png | Bin 0 -> 5842 bytes metatomic-torch/src/model.cpp | 9 +- metatomic-torch/tests/models.cpp | 2 +- .../metatomic/torch/outputs.py | 104 +++++++++++++++++ python/metatomic_torch/tests/outputs.py | 109 ++++++++++++++++++ 9 files changed, 343 insertions(+), 3 deletions(-) create mode 100644 docs/src/outputs/displacements-and-momenta.rst create mode 100644 docs/static/images/displacements-output.png create mode 100644 docs/static/images/momenta-output.png diff --git a/docs/src/outputs/displacements-and-momenta.rst b/docs/src/outputs/displacements-and-momenta.rst new file mode 100644 index 00000000..1535c465 --- /dev/null +++ b/docs/src/outputs/displacements-and-momenta.rst @@ -0,0 +1,101 @@ +.. _displacements-output: + +Displacements +^^^^^^^^^^^^^ + +Displacements are differences between atomic positions at two different times. +They can be used to predict the next configuration in molecular dynamics +(see, e.g., https://arxiv.org/pdf/2505.19350). + +In metatomic models, they are associated with the ``"displacements"`` +key in the model outputs, and must adhere to the following metadata schema: + +.. list-table:: Metadata for displacements + :widths: 2 3 7 + :header-rows: 1 + + * - Metadata + - Names + - Description + + * - keys + - ``"_"`` + - the keys must have a single dimension named ``"_"``, with a single + entry set to ``0``. Displacements are always a + :py:class:`metatensor.torch.TensorMap` with a single block. + + * - samples + - ``["system", "atom"]`` + - the samples must be named ``["system", "atom"]``, since + displacements are always per-atom. + + ``"system"`` must range from 0 to the number of systems given as an input + to the model. ``"atom"`` must range between 0 and the number of + atoms/particles in the corresponding system. If ``selected_atoms`` is + provided, then only the selected atoms for each system should be part of + the samples. + + * - components + - ``"xyz"`` + - displacements must have a single component dimension named + ``"xyz"``, with three entries set to ``0``, ``1``, and ``2``. The + displacements are always 3D vectors, and the order of the + components is x, y, z. + + * - properties + - ``"displacements"`` + - displacements must have a single property dimension named + ``"displacements"``, with a single entry set to ``0``. + +At the moment, displacements are not integrated into any simulation engines. + +.. _momenta-output: + +Momenta +^^^^^^^ + +The momentum of a particle is a vector defined as its mass times its velocity. +Predictions of momenta can be used, for example, to predict a future step in molecular +dynamics (see, e.g., https://arxiv.org/pdf/2505.19350). + +In metatomic models, they are associated with the ``"momenta"`` +key in the model outputs, and must adhere to the following metadata schema: + +.. list-table:: Metadata for momenta + :widths: 2 3 7 + :header-rows: 1 + + * - Metadata + - Names + - Description + + * - keys + - ``"_"`` + - the keys must have a single dimension named ``"_"``, with a single + entry set to ``0``. Momenta are always a + :py:class:`metatensor.torch.TensorMap` with a single block. + + * - samples + - ``["system", "atom"]`` + - the samples must be named ``["system", "atom"]``, since + momenta are always per-atom. + + ``"system"`` must range from 0 to the number of systems given as an input + to the model. ``"atom"`` must range between 0 and the number of + atoms/particles in the corresponding system. If ``selected_atoms`` is + provided, then only the selected atoms for each system should be part of + the samples. + + * - components + - ``"xyz"`` + - momenta must have a single component dimension named + ``"xyz"``, with three entries set to ``0``, ``1``, and ``2``. The + momenta are always 3D vectors, and the order of the + components is x, y, z. + + * - properties + - ``"momenta"`` + - momenta must have a single property dimension named + ``"momenta"``, with a single entry set to ``0``. + +At the moment, momenta are not integrated into any simulation engines. diff --git a/docs/src/outputs/index.rst b/docs/src/outputs/index.rst index 58309fc6..b0219f7d 100644 --- a/docs/src/outputs/index.rst +++ b/docs/src/outputs/index.rst @@ -20,6 +20,7 @@ section to these pages. energy features non_conservative + displacements-and-momenta Physical quantities @@ -76,6 +77,22 @@ quantities, i.e. quantities with a well-defined physical meaning. Stress directly predicted by the model, not derived from the potential energy. + .. grid-item-card:: Displacements + :link: displacements-output + :link-type: ref + + .. image:: /../static/images/displacements-output.png + + Atomic displacements predicted by the model, to be used in ML-driven simulations. + + .. grid-item-card:: Momenta + :link: momenta-output + :link-type: ref + + .. image:: /../static/images/momenta-output.png + + Atomic momenta predicted by the model, to be used in ML-driven simulations. + Machine learning outputs ^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/src/torch/reference/models/index.rst b/docs/src/torch/reference/models/index.rst index f7178e89..d4aa6e55 100644 --- a/docs/src/torch/reference/models/index.rst +++ b/docs/src/torch/reference/models/index.rst @@ -34,7 +34,7 @@ In the mean time, you can create :py:class:`metatomic.torch.ModelOutput` with quantities that are not in this table. A warning will be issued and no unit conversion will be performed. -When working with one of the quantity in this table, the unit you use must be +When working with one of the quantities in this table, the unit you use must be one of the registered unit. +----------------+---------------------------------------------------------------------------------------------------+ @@ -48,3 +48,5 @@ one of the registered unit. +----------------+---------------------------------------------------------------------------------------------------+ | **pressure** | eV/Angstrom^3 (eV/A^3, eV/Angstrom^3) | +----------------+---------------------------------------------------------------------------------------------------+ +| **momentum** | sqrt(eV*u) | ++----------------+---------------------------------------------------------------------------------------------------+ diff --git a/docs/static/images/displacements-output.png b/docs/static/images/displacements-output.png new file mode 100644 index 0000000000000000000000000000000000000000..5b3feefb4cf30dfe1429c4b0ebf5aab17023859a GIT binary patch literal 7577 zcmeHsXH=6}w05|lh(Z_#8(_qd02a#7gajcnjABM1I;e;UNE46}HKCU%I$}Z5AiV_( zQE7@G3W2CJ0|6ok2na|E9mD`3lyLX!`qp>Xy7%wB|HdT?eb0N&*=O(PJp0-Ey?NBy zQf{r%S`-Q;M?CcB-zb#C4Dz#PH9VIEM%)yK*)CY{};5$mppWSvzBa!MPx18Hr=tu^`c>L7sze|h+mcMV2B6dC|{}?D~=$&CX zm9vJ#{U)Mmw~6CUPEIN)hJ!tGR08!bS%He$d6^-(UqT9lTDyH1brfYTefdAu;87d) z^CTFkNHl8IFaN&suV4Om!Eh=cz=vt*tKxJ8Y6ylj_MF6~kkiM@LknhGPv|_tC{QmK z4a(CFX_px&n&N|d|H&eceZ?v6GsaE6amxsQ-fmgi`m7xOjIKu=uz#(I6VDGk&!+Rk zK6ZKM4`ij#h2_Pw#m8jLr9G_gLQ5V8dXUYUXR-oiwWfZ zsM%J>!X~IE(cP|M*@p{!DO0Z+@SI06>L%knyveTX3i$o&>|guXooh%>dsN)9d+KnL zDW?)!N7p>0=XJubl^K)RtZ6-r(()89iJm=s#*viY_dT9G^YSQG z>FaNQ{55zD&5PU?G2dH!Czw8SNZ&ObJCv%*H%I~F8D(=Fdpy#E&#ji-a_`RGj^U~9 zh_e*Gre`)5T#`v{mGGE$#wBS_XJJ>MQ^?$SlC5K~8nNjuWk}PZ#2~HBPNOle*m18@ z$xLp~OB?2O1<^C6dRa(rRn>#DWzVHi4+gqJr>q*&O<$dTABpAUl};uvPO%BZCU5p~ zfoS*>2mXG3##s`{v9|@>)}ogQt?!!d6%AWvr;<$gVD}%BZ->%nMq9FV`3|K4^%rX} z37HmA>O{Xcz1O3+=sv8yxq&N#VXXQZF6+ES&7pcdva=yp(0^sa0NwIV-m`?m_Urq&M~H zk!6;3+HQW1epFf-Duv2L^W5jvQ&$0q9t=en}rscAj@4s+>Kasqz6_pgHP zvd|%Z3V9P{x9Tt|`)OXDcKG64y?1_DJa(j+Nm3`Jr^@a9!sv6Sy1}Yz`P`&tEFWIk zNa>IQU}v?z@?bi#BkT%2kI~)O%X4*BMqLlI9ZHU~WC6u?=jxMNu>9pkVdGBurB>YaUe`ADkqF|L=56y}>sqOR+~*hKlQ#=k45A2~GCkhaHh zwUrG1UFVzryC%rHnsxKNys=)5gUuX}{55=lm)xL19P;!m<{$ARHo2Y9S*yRtt$hPEN zL3{L7-c&RT3v%Vko6hi}x0#;idbLSmhL5TrNxZjS^`KkTt&K!_1=>W@*1s6MViK0U z&?e(AKDV`L5Uq`SrF5A=j$QkY9gNwQRv}Md>GCzwgIpqJ9;rOBVOnq0x!S5Wyk9r3 zd|@)pZFYk7FUC?j zHU}-o<%g*ar3I!ZHva~G)OLU)K?3Mk$|spj6sCqwzJ;y9s8Y&w)#jo3{unY#JZj5M zrZnIzxg$I&jHhkHQoqVjr2@%^B^J_FgsI_k?J!3J=3M-AVDsQ%r-i@!XN}6I@1**- z+PR?DQyFP*zJFLpX(DPEI?@CwJ>0-@0Y*2f4ff|dTzJ0ZOT2zPMoj&H| zl#x$c9>Wc#Qj}B{r(41T3F`i6}Inj5BTRGfNu@4U^B z%zX|fIW3Rc3Nlti8FylS!e#D@>z*^K1|hI50Q8~HHl%v_i0@acYqZN|nzmB7x*PK% z^&q&rN5{qlKk=49i|BX#iBA0J)3vtdZy{C}XJ=<|gAZ&FWbZW#sM=^d@KH4;$8W<1 zYu=SjC&1=dk{OP$)GHs0O~d=w#@RY1sophy*ynz9FPMD<%(ljlJYE{MB>dI%iy%9% zWOtNGL{L>7BMd-t198U4VW--v%}%M0f8#&0jA_V>+ID(cZ+MFo=9DY4`E;d%AU0zO zOSItDL@UNxpmC4;+<6z*;9nd(nRg~hfwWII;mjo>-7U?(l|!N}G;gKy%;D1BJ;1xt z@)~$#D{X1|?RY^Q-Jo(Q**lK|^s?J=H3p)b?LJr=hXY4Z;9OX*8=|5%#}rTqFJJ>* z8T^Y{WCPc`ICV(p?)2-4F&LA*st2Q=1}Rd{M51%(aeHj}PaoZ_ZLxJq!y;sh13^Iw8%efHSqaS;-OEU5v@&$c;bI9FUh!IFimV|Oy=)P2>MeR{6d(L`{~#r|sd zkN19SK}*ts)@Apc%=rkUdEYzltva!wJ7mlqcvEp-(1ohH%j{pN-&9V&bUA<~EPX!{ zI#EPaaxsfonn&`)Mjc5p$Stxr@G)+NXMvc6mDA1m-T9Bza!SW1e*ZR=V>63l_?|8g z3r6Dx?o}t|v&9QZ#upzZS}21CT1p$NOBe#4MRFINwO$}0xf&8{5Tp)+myj|L_bL)1 z!oVM-WP@FhX7hs-zsG>ho}wRviiYQp`?030u|&wD`d&?sXU5v{``q7|tfy`#7WlLm zL@OEYM$#;tm=mkPB=?kq#mSHuG(pWmuHZt>rehsSsGI>#EacdXTL}{NYQwzuTFCgp z=#NOQ2dEhJgi%6-lMd5*XHTrbQ_mhw(5}C+UR5vn#{ibITG{8`E}N{!9*ya+`pGzk zg>(_999iYPnxIG!Z$GnPW?9B;Lq?-uv=AA^;0=K9tzgvN@2wiBBskcM63{%uXX7>P z=s$C5ScRMsH@CQZH=oyl%e0U#aO2Qrx9u}7?h3G{8@Y5BdAEVV-nB9q{Ee#Hn};q- zqCKNU2bFyvK<2mNX8vM5%#&C$_*^30r+}mWc6qMzm9RS$^oyreBzl$In+whh!l?{Uw(7*aUDRestr!t#LHlVv zMv)ADgG=v+e@5E!okmAT;l71|neD)|*_hQi8)5n>m_FA@W^+@FbchA*=TeY!w7bTb z!iCJyibPL*&CV}RhCOuxyMsLdVU4Lc{R|P;0~v??C%#0kTGv!e?)mE9=0K!_!r7Az_9D@hrW+uB#gkLEn`asuIL?GzNIETQlqCWTiaZ;F$ z`#_PnvQS||nXo$rkod}@v4Lqe`|Ny_L>yS4;TRzrM6@qb3((!_w2Hao?XLtJQbYkP zVE@9pMe=gs9HjaufX+L*J?FD?^*kE3>PjmCo|J{a>{&1%8q}dB4ctkj^(kZk-|*ts z;0v8~%I94|W`^%0R8OP>3h@vqj&|wz-{Ae+>0N=Ij~s(~|DjHGm+A7Sf*V8i0w>St zN_W#0@P+2IJECP_PX-pe?h6~HhK;+bWW&lQW$=Z%f}utdh!aNt@j(DJcmXP*r&bXLFRn=eUyy%hdDU-U6P5yc*FAWxwXyg)I5oWX1+Qyy5_b4m)+N@u2r>@F%Q3Ldy`36N&NWYCTlMHQEK%s(*6 zSw%7vxt%|MUdgo_#P1Cp@1YdtbUP%)q>he`C9#lAqK6GzsQgM{RDq)(g0kFUXlQsx zcZEq?V190CabKJ$2A(;poDIE)_>9sNdTO*=6`fD$m&vR0Qo~`Ql{f7C{Vcsow>C{4bvK z#$|9~BRoyf$OzLT7J%-*fjf#bk2>)kl*r?(EOnCU>~Z>lLdN``zC@B<~#Oo!)C1ZR2DrgMj>${)}*F~c97ZJ5bo#-&6iRzr_*wFEp zmdGSvMOKG13iR>4fi||-zr=qmUIeWs;v8HosOnTxd|%sD)uECtc_ZH^*(+jo-5#jD z!PY>jH1q^}Ws%*SYhTQQYJdoxhHjJMo`_I{>mj;1O6JUpg_it?xXG^4Ay92c=wXwT z3sKoNv6J{mf|ZL{)JQW!GLL$S4=~eZkhfH0_I#C26R;+0rs*(|4%LVy_eubX zH*z!rV9kV794s14h{Zb1f7pyv&_19cAXKOHBdwz@LWVTF8Ej>BCsDHM7b^Z*=v0}J zX~@jSJ^3s;7I8RGVRB&Em;=o4RE)Hq{@D!~Z`WV9N%IgLJamG4rEd-XK+)+()<{`pjdt-7X3(QNx`=8$n z%`tF7A=;d2j3lUqnUUCk)@@cNhRlw}f!+XI8kjS8{-b6a& zCup8=+Rt8{c@oL0_(Od2#Mf_N9u_GX}W0N_vK3UKGuj{gkOGK(d;rANZ7b6*OCOj#;V|72P4F?F+tiT~0Y3M63q! z;2otnUDxUxxObhyNc-PucFkI0(_s#72n&*Z-^x-afjpEJt6=@ns?#U7SJiTF^%CPM zx)wS@1(Se@v@+DhJw|hhc!v3y+%u3Fl{`B`Bp#%WJ{ZhrYs#C=fBV@8k?6j{Xbwpa zR&V#tzy8#y;u=;pHYaF#foFl{oUnD$;-l)nrSy54Rzw+@@Qt#(xbmb+Ez*RY@O2U&x>A}kjPq7JntCa*ux12KJ>*~KQ zKQzv4{sAP1aZNX1%?FU0c5!Y7-%ixvEDMWg@RxUpv@nIN1jqApHca)T47cp!&<{;d zINtec%D#pCQq)@hHS4;urYG&Bi3hY_1lW#Y@P z73B}6zfS@2Qg?9niddPHU>68dd)nsx=2B@z{Z*S|bzPPE#2-~@D|5-o18@52%R*AC z6U{eY(o+a;WS5(XM2)e`^z-v_7gBuV!)6}Itw_A+Kt2mW)KLF^3i|gG+5hc#L9q&H Y<2u(@^rI&DjMNmLM^xz3n zqEsX0XecTrNbeAe5eU5m5OVkD9pk>-$NTC-M#7hw)%IF*%}v5J3$w!q#16nhZi+bkTKZg9m(KlhVaEcL z5=tI*;DA_4LV`^Gl^etHOZj~)vWFla{}M zhYkf!<7*^f`dPjCy}U4(OHLlQ{=R59jGqq;J1)(HrQQGi;CDTKx5Mvg_&;PorvS$I z%StI@a=BTGw6*%(%-Gn;$*H@i$E3zFHFR~RQ8H{EG2}<}O({LDV8H7k@7k1{u)V$A zG2GYF(`c?-E2ML8%|6Di)0zEPOaH=Ct>6XXaK_hZ`)dD*Gw)+JnzRV_-OjtUX1*>e zO8WfhWOKG=fG$OkJ^r8|cyUySS6T{n|8}X5 zMEMSuGrRxkI#ZXr@v(52tL`P>y3Nrr?AlzPGF)T)J~|xe;ox(193H?U1tDU1!+t-Y}_mn8w8Fxe_X5 zU2kuhL>?)jE{LV1q>P4k&GMleI1Uty$(3g+r!T*{-kPoX&a1D`GWV3w0o-^L?*Yx? z|KxGAMpvf4x-w`9o!#9+@7&tRgX^pz^#d`6^YxrPLyIDLrarYId~2n8GF1!mGC%OH zC}?+kJ*lawsUKUv`!HM0fF2d(zO^#FoHh1S{@Jr(NASQU0EYHCfRVT(>#lX)?G=(MIdsYua!cB3+;H50YW zNK&?C0Skh_(8i#5shAV*S#x@n8urJxSGccdt7!cc*U48I&g8*}ouRQ*!+gQ9pEK&- z1mqjBXn6Te^9-ewENq=-Qz}wix&9}<$8WqL%IcI5rbB1*O=n=47eAUlqv-Tr#qa%( zzeTlzd_JsvQLv##Y>xb@Hll8AEZTxq>-6G=*{w0`ei&qCPP zjx&VH)s28`SpbySFH{Lt2|@*YZ!aD^EUB#_CMJfd8VM@u45Yv3M_&tAtmi~J29E^d zI!#0(WT$wMglmwU#nD>G+u>+{xJ}HM)m~N4!(ez$?=}gQ4ThhBK=R)Ws&%KF6aO zj`&M9EpxN8jIvMgBtSvTfp_kQl^<4uBNX0j*IK0Y*tsAxobP!m+Lz^>F9m$_NkC6~ zc5cq~drol4L=1fc91vzSa)(ttODPEv=vZ_Y!DU?(*64}cn zq6D^X%kA!HzEPaU?CSzd_iHQ4C%9Jkl3@#ovDByYg$ZtJI-~g^|4|J1oE{nnM>3cN zq*SW{)5eV!Sns7RJ}cR^?39_4De1KiPDeBv}Q zJ?(U1_7Tu&&OTuLK@+jbJjf$i7u+pQOabO7Zlw}_94X8Ra>x!_sG)q8&jM|cSHp)s zz6R>k6eHp+3XQr=OP^yv+vOJZ(U3_oXFe5mgE(X7(CPHcEqXtfK79!N=9m_LDWM|3uXueOSw|0_PtNK1xE|D%@8NJ0h8_p`S;h|b4WD*BwY+R<}w*qx+E zA#tBF1K`qD)gVr73d9_ba&4lWtr|Y^HAA_yweO~Tj>gXXyUFZ;_5{C?_dOs(TTu1; z-*bt&rX)!um=$l3p;Pe4<(57d;x5Rr`&pim!p3kwhmBz!San~*Yu9^|0@^iCwu5DB z51QXASeov|hb}kl(B1J#mefHu`h6MtxpV@WGDa8B4L4xVzMT|=%H1P<8< ziiWGIc=x{X@bXF=!q(TGnSv6A4h*;bqp%sQ45v>Aqbd#Q`G2^zXa5P-w!B5Q{jrt9 z>ceG6R!W?xy&&Ua1i?P$dF1ww%PpLdzPgPHjl~bE3?Ltl;m6Ivcl*J@LV%wxfFUGr znTq=gf$~IVHD68 z=025v8Flx|WAc`Kmd}w-v4@i$S)ebqM7f<{(vu~LKbKDlC*iaM3NkLPDXL%_&?WiM;rmIt+a5wmjn#{F zJLH)x-^!bg{{QaBedIsMr&)Fx0Qd#;e)>ZI&fZSZ4A_jJ$)YG(hzgis{8CD|2jvXn zl*h5eom&9K{v@?`O?3>6Eg&Yu@&l=#`Yln<1PS*10y=Bg03ZY?7V->Whj@gXDGk?* z6e%jLy9dk&64<1_uVx8@Q1{jf^=#Uz02_*G;{SqZ&JbRHAR3;_P$q7fLta*MIgYAh zfUgu}Ai4oRXCF6+BVGuc?FRKgfj9D&+K|sLB6UsGG;N8py+rri{4$Ndl7m(SGGwxp zxa-~H2l9nvuEsXApup9P=y<>PtJ$f+VxT<*blJCor%W|phaNvFt1<~c3+M2V&`^HPFI!I}1rbSa%g#mJYH&>jB}y zpc@IKfW6kuG`v4^lKB$422|DW`aZt=DrBD0Ob*fs+enU$jn$Y+GkPd-5!7iF)M*$9 z8mlwujD(ZVUk^rrg+YQK%)AExa+U7+4l3o-^CuD;G1}|h);3hI@=bsO{PzLxK~1_6 zAG)ZWSIYQGwkU3{7*Sd)ObNl!&~wJ~Wk$4wlctA6%dRU!nm1EB6c6IaM7f&wlz^K; zdK$QIRL~ap<=K!G!2QBO$^f4vY?3T4f}2c1tbQkMOa=|37;9Z}7CNd7>rtChpsk@1 z15>iR)@5nO0=uvZ2xF87U0Q~B@Pt9(-Q5n3v4p=qZGJ$3lWhdwVmgIqC?S@uT25Yl zCIo5n_UxPMblr$;V^>#K3I7HzP;l>Pi?2_$)L#%eq!)W8hsA`xp1hiv5Ik}W&<+Yv z2RRVPE}+q9;3A_u1(RWxzpULzWcPENE6+bMqp=X3 zG(h0&j0q|6`JhHcicVt7dpij4k^${1uHz?|D)sF-cH8;4Oli5CF7DacI{F+}xn137 z)i&&8O4ee9TEFy1jvaidh6lvv!Z}7+ZEst3Fq7!iF6BG~x%k3qr!ZeynhYXIgt+HQ4JKu(oW^X9+}cyRZ_JZpGufst z=U2B?P|evS67lt$qlm)iVz{M$5jp7M**Vytwt;)knG89}2o=V9pWZgyf38HFvb!O* z3mZ2-N`PH#a7pC8&%=X+9T5b7>C?HSxeR!JKlojb|3^DSaRf+5KBvDnEH(rGUxr;W Lu`n(%ydC}Dl}pPo literal 0 HcmV?d00001 diff --git a/metatomic-torch/src/model.cpp b/metatomic-torch/src/model.cpp index 6c2f7737..d9bb2d39 100644 --- a/metatomic-torch/src/model.cpp +++ b/metatomic-torch/src/model.cpp @@ -141,7 +141,9 @@ std::unordered_set KNOWN_OUTPUTS = { "energy_uncertainty", "features", "non_conservative_forces", - "non_conservative_stress" + "non_conservative_stress", + "displacements", + "momenta" }; void ModelCapabilitiesHolder::set_outputs(torch::Dict outputs) { @@ -1082,6 +1084,11 @@ static std::map KNOWN_QUANTITIES = { // alternative names {"eV/A^3", "eV/Angstrom^3"}, }}}, + {"momentum", Quantity{/* name */ "momentum", /* baseline */ "sqrt(eV*u)", { + {"sqrt(eV*u)", 1.0}, + }, { + // alternative names + }}}, }; bool metatomic_torch::valid_quantity(const std::string& quantity) { diff --git a/metatomic-torch/tests/models.cpp b/metatomic-torch/tests/models.cpp index fc84d7e3..8f33e648 100644 --- a/metatomic-torch/tests/models.cpp +++ b/metatomic-torch/tests/models.cpp @@ -109,7 +109,7 @@ TEST_CASE("Models metadata") { struct WarningHandler: public torch::WarningHandler { virtual ~WarningHandler() override = default; void process(const torch::Warning& warning) override { - CHECK(warning.msg() == "unknown quantity 'unknown', only [energy force length pressure] are supported"); + CHECK(warning.msg() == "unknown quantity 'unknown', only [energy force length momentum pressure] are supported"); } }; diff --git a/python/metatomic_torch/metatomic/torch/outputs.py b/python/metatomic_torch/metatomic/torch/outputs.py index 5d45a648..937b2f9a 100644 --- a/python/metatomic_torch/metatomic/torch/outputs.py +++ b/python/metatomic_torch/metatomic/torch/outputs.py @@ -51,6 +51,10 @@ def _check_outputs( _check_non_conservative_forces(value, systems, request, selected_atoms) elif name == "non_conservative_stress": _check_non_conservative_stress(value, systems, request) + elif name == "displacements": + _check_displacements(value, systems, request) + elif name == "momenta": + _check_momenta(value, systems, request) else: # this is a non-standard output, there is nothing to check continue @@ -263,6 +267,106 @@ def _check_non_conservative_stress( ) +def _check_displacements( + value: TensorMap, + systems: List[System], + request: ModelOutput, +): + """ + Check output metadata for displacements. + """ + # Ensure the output contains a single block with the expected key + _validate_single_block("displacements", value) + + # Check samples values from systems + _validate_atomic_samples( + "displacements", value, systems, request, selected_atoms=None + ) + + displacements_block = value.block_by_id(0) + + # Check that the block has correct "Cartesian-form" components + if len(displacements_block.components) != 1: + raise ValueError( + "invalid components for 'displacements' output: " + f"expected one component, got {len(displacements_block.components)}" + ) + expected_component = Labels( + "xyz", torch.tensor([[0], [1], [2]], device=value.device) + ) + if displacements_block.components[0] != expected_component: + raise ValueError( + f"invalid components for 'displacements' output: " + f"expected {expected_component}, got {displacements_block.components[0]}" + ) + + expected_properties = Labels( + "displacements", torch.tensor([[0]], device=value.device) + ) + message = "`Labels('displacements', [[0]])`" + + if displacements_block.properties != expected_properties: + raise ValueError( + f"invalid properties for 'displacements' output: expected {message}, " + f"got {displacements_block.properties}" + ) + + # Should not have any gradients + if len(displacements_block.gradients_list()) > 0: + raise ValueError( + "invalid gradients for 'displacements' output: " + f"expected no gradients, found {displacements_block.gradients_list()}" + ) + + +def _check_momenta( + value: TensorMap, + systems: List[System], + request: ModelOutput, +): + """ + Check output metadata for momenta. + """ + # Ensure the output contains a single block with the expected key + _validate_single_block("momenta", value) + + # Check samples values from systems + _validate_atomic_samples("momenta", value, systems, request, selected_atoms=None) + + momenta_block = value.block_by_id(0) + + # Check that the block has correct "Cartesian-form" components + if len(momenta_block.components) != 1: + raise ValueError( + "invalid components for 'momenta' output: " + f"expected one component, got {len(momenta_block.components)}" + ) + expected_component = Labels( + "xyz", torch.tensor([[0], [1], [2]], device=value.device) + ) + if momenta_block.components[0] != expected_component: + raise ValueError( + f"invalid components for 'momenta' output: " + f"expected {expected_component}, got {momenta_block.components[0]}" + ) + + expected_properties = Labels("momenta", torch.tensor([[0]], device=value.device)) + message = "`Labels('momenta', [[0]])`" + + if momenta_block.properties != expected_properties: + raise ValueError( + f"invalid properties for 'momenta' output: expected {message}, " + f"got {momenta_block.properties}" + ) + + # Should not have any gradients + if len(momenta_block.gradients_list()) > 0: + raise ValueError( + "invalid gradients for 'momenta' output: " + f"expected no gradients, found {momenta_block.gradients_list()}" + ) + + def _validate_atomic_samples( name: str, value: TensorMap, diff --git a/python/metatomic_torch/tests/outputs.py b/python/metatomic_torch/tests/outputs.py index b085ae13..73b43ad6 100644 --- a/python/metatomic_torch/tests/outputs.py +++ b/python/metatomic_torch/tests/outputs.py @@ -105,6 +105,74 @@ def __init__(self): super().__init__("features") +class DisplacementsMomentaModel(torch.nn.Module): + """A model predicting displacements and momenta""" + + def __init__(self): + super().__init__() + self.output_names = ["displacements", "momenta"] + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + assert "displacements" in outputs + assert "momenta" in outputs + assert outputs["displacements"].per_atom + assert outputs["momenta"].per_atom + assert selected_atoms is None + + sample_values = torch.stack( + [ + torch.concatenate( + [ + torch.full( + (len(system),), + i_system, + ) + for i_system, system in enumerate(systems) + ], + ), + torch.concatenate( + [ + torch.arange( + len(system), + ) + for system in systems + ], + ), + ], + dim=1, + ) + samples = Labels( + names=["system", "atom"], + values=sample_values, + ) + + blocks = [] + for output_name in self.output_names: + block = TensorBlock( + values=torch.tensor( + [[[0.0], [1.0], [2.0]]] * sum(len(system) for system in systems), + dtype=torch.float64, + ), + samples=samples, + components=[Labels("xyz", torch.tensor([[0], [1], [2]]))], + properties=Labels( + output_name, + torch.tensor([[0]]), + ), + ) + blocks.append(block) + + return { + output_name: TensorMap(Labels("_", torch.tensor([[0]])), [block]) + for output_name, block in zip(self.output_names, blocks) + } + + def test_energy_ensemble_model(system, get_capabilities): model = EnergyEnsembleModel() capabilities = get_capabilities("energy_ensemble") @@ -161,3 +229,44 @@ def test_features_model(system, get_capabilities): assert features.block().properties.names == ["energy"] assert features.block().components == [] assert len(result["features"].blocks()) == 1 + + +def test_displacements_momenta_model(system): + model = DisplacementsMomentaModel() + outputs = { + "displacements": ModelOutput(per_atom=True), + "momenta": ModelOutput(per_atom=True), + } + capabilities = ModelCapabilities( + length_unit="angstrom", + atomic_types=[1, 2, 3], + interaction_range=4.3, + outputs=outputs, + supported_devices=["cpu"], + dtype="float64", + ) + atomistic = AtomisticModel(model.eval(), ModelMetadata(), capabilities) + + options = ModelEvaluationOptions(outputs=outputs) + + result = atomistic([system, system], options, check_consistency=True) + assert "displacements" in result + assert "momenta" in result + + displacements = result["displacements"] + assert displacements.keys == Labels("_", torch.tensor([[0]])) + assert list(displacements.block().values.shape) == [6, 3, 1] + assert displacements.block().samples.names == ["system", "atom"] + assert displacements.block().properties.names == ["displacements"] + assert displacements.block().components == [ + Labels("xyz", torch.tensor([[0], [1], [2]])) + ] + assert len(result["displacements"].blocks()) == 1 + + momenta = result["momenta"] + assert momenta.keys == Labels("_", torch.tensor([[0]])) + assert list(momenta.block().values.shape) == [6, 3, 1] + assert momenta.block().samples.names == ["system", "atom"] + assert momenta.block().properties.names == ["momenta"] + assert momenta.block().components == [Labels("xyz", torch.tensor([[0], [1], [2]]))] + assert len(result["momenta"].blocks()) == 1 From 1c1cb29ac423a95c313e067224a91bcacb0f5354 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Wed, 23 Jul 2025 11:40:07 +0200 Subject: [PATCH 2/2] Change displacements to positions --- .../src/outputs/displacements-and-momenta.rst | 26 +++++----- docs/src/outputs/index.rst | 10 ++-- ...ements-output.png => positions-output.png} | Bin metatomic-torch/src/model.cpp | 2 +- .../metatomic/torch/outputs.py | 46 ++++++++---------- python/metatomic_torch/tests/outputs.py | 32 ++++++------ 6 files changed, 56 insertions(+), 60 deletions(-) rename docs/static/images/{displacements-output.png => positions-output.png} (100%) diff --git a/docs/src/outputs/displacements-and-momenta.rst b/docs/src/outputs/displacements-and-momenta.rst index 1535c465..737b576e 100644 --- a/docs/src/outputs/displacements-and-momenta.rst +++ b/docs/src/outputs/displacements-and-momenta.rst @@ -1,16 +1,16 @@ -.. _displacements-output: +.. _positions-output: -Displacements +positions ^^^^^^^^^^^^^ -Displacements are differences between atomic positions at two different times. +positions are differences between atomic positions at two different times. They can be used to predict the next configuration in molecular dynamics (see, e.g., https://arxiv.org/pdf/2505.19350). -In metatomic models, they are associated with the ``"displacements"`` +In metatomic models, they are associated with the ``"positions"`` key in the model outputs, and must adhere to the following metadata schema: -.. list-table:: Metadata for displacements +.. list-table:: Metadata for positions :widths: 2 3 7 :header-rows: 1 @@ -21,13 +21,13 @@ key in the model outputs, and must adhere to the following metadata schema: * - keys - ``"_"`` - the keys must have a single dimension named ``"_"``, with a single - entry set to ``0``. Displacements are always a + entry set to ``0``. positions are always a :py:class:`metatensor.torch.TensorMap` with a single block. * - samples - ``["system", "atom"]`` - the samples must be named ``["system", "atom"]``, since - displacements are always per-atom. + positions are always per-atom. ``"system"`` must range from 0 to the number of systems given as an input to the model. ``"atom"`` must range between 0 and the number of @@ -37,17 +37,17 @@ key in the model outputs, and must adhere to the following metadata schema: * - components - ``"xyz"`` - - displacements must have a single component dimension named + - positions must have a single component dimension named ``"xyz"``, with three entries set to ``0``, ``1``, and ``2``. The - displacements are always 3D vectors, and the order of the + positions are always 3D vectors, and the order of the components is x, y, z. * - properties - - ``"displacements"`` - - displacements must have a single property dimension named - ``"displacements"``, with a single entry set to ``0``. + - ``"positions"`` + - positions must have a single property dimension named + ``"positions"``, with a single entry set to ``0``. -At the moment, displacements are not integrated into any simulation engines. +At the moment, positions are not integrated into any simulation engines. .. _momenta-output: diff --git a/docs/src/outputs/index.rst b/docs/src/outputs/index.rst index b0219f7d..f43a2e59 100644 --- a/docs/src/outputs/index.rst +++ b/docs/src/outputs/index.rst @@ -20,7 +20,7 @@ section to these pages. energy features non_conservative - displacements-and-momenta + positions-and-momenta Physical quantities @@ -77,13 +77,13 @@ quantities, i.e. quantities with a well-defined physical meaning. Stress directly predicted by the model, not derived from the potential energy. - .. grid-item-card:: Displacements - :link: displacements-output + .. grid-item-card:: positions + :link: positions-output :link-type: ref - .. image:: /../static/images/displacements-output.png + .. image:: /../static/images/positions-output.png - Atomic displacements predicted by the model, to be used in ML-driven simulations. + Atomic positions predicted by the model, to be used in ML-driven simulations. .. grid-item-card:: Momenta :link: momenta-output diff --git a/docs/static/images/displacements-output.png b/docs/static/images/positions-output.png similarity index 100% rename from docs/static/images/displacements-output.png rename to docs/static/images/positions-output.png diff --git a/metatomic-torch/src/model.cpp b/metatomic-torch/src/model.cpp index d9bb2d39..2698139b 100644 --- a/metatomic-torch/src/model.cpp +++ b/metatomic-torch/src/model.cpp @@ -142,7 +142,7 @@ std::unordered_set KNOWN_OUTPUTS = { "features", "non_conservative_forces", "non_conservative_stress", - "displacements", + "positions", "momenta" }; diff --git a/python/metatomic_torch/metatomic/torch/outputs.py b/python/metatomic_torch/metatomic/torch/outputs.py index 937b2f9a..f7a4e781 100644 --- a/python/metatomic_torch/metatomic/torch/outputs.py +++ b/python/metatomic_torch/metatomic/torch/outputs.py @@ -51,8 +51,8 @@ def _check_outputs( _check_non_conservative_forces(value, systems, request, selected_atoms) elif name == "non_conservative_stress": _check_non_conservative_stress(value, systems, request) - elif name == "displacements": - _check_displacements(value, systems, request) + elif name == "positions": + _check_positions(value, systems, request) elif name == "momenta": _check_momenta(value, systems, request) else: @@ -267,55 +267,51 @@ def _check_non_conservative_stress( ) -def _check_displacements( +def _check_positions( value: TensorMap, systems: List[System], request: ModelOutput, ): """ - Check output metadata for displacements. + Check output metadata for positions. """ # Ensure the output contains a single block with the expected key - _validate_single_block("displacements", value) + _validate_single_block("positions", value) # Check samples values from systems - _validate_atomic_samples( - "displacements", value, systems, request, selected_atoms=None - ) + _validate_atomic_samples("positions", value, systems, request, selected_atoms=None) - displacements_block = value.block_by_id(0) + positions_block = value.block_by_id(0) # Check that the block has correct "Cartesian-form" components - if len(displacements_block.components) != 1: + if len(positions_block.components) != 1: raise ValueError( - "invalid components for 'displacements' output: " - f"expected one component, got {len(displacements_block.components)}" + "invalid components for 'positions' output: " + f"expected one component, got {len(positions_block.components)}" ) expected_component = Labels( "xyz", torch.tensor([[0], [1], [2]], device=value.device) ) - if displacements_block.components[0] != expected_component: + if positions_block.components[0] != expected_component: raise ValueError( - f"invalid components for 'displacements' output: " - f"expected {expected_component}, got {displacements_block.components[0]}" + f"invalid components for 'positions' output: " + f"expected {expected_component}, got {positions_block.components[0]}" ) - expected_properties = Labels( - "displacements", torch.tensor([[0]], device=value.device) - ) - message = "`Labels('displacements', [[0]])`" + expected_properties = Labels("positions", torch.tensor([[0]], device=value.device)) + message = "`Labels('positions', [[0]])`" - if displacements_block.properties != expected_properties: + if positions_block.properties != expected_properties: raise ValueError( - f"invalid properties for 'displacements' output: expected {message}, " - f"got {displacements_block.properties}" + f"invalid properties for 'positions' output: expected {message}, " + f"got {positions_block.properties}" ) # Should not have any gradients - if len(displacements_block.gradients_list()) > 0: + if len(positions_block.gradients_list()) > 0: raise ValueError( - "invalid gradients for 'displacements' output: " - f"expected no gradients, found {displacements_block.gradients_list()}" + "invalid gradients for 'positions' output: " + f"expected no gradients, found {positions_block.gradients_list()}" ) diff --git a/python/metatomic_torch/tests/outputs.py b/python/metatomic_torch/tests/outputs.py index 73b43ad6..59e2a59f 100644 --- a/python/metatomic_torch/tests/outputs.py +++ b/python/metatomic_torch/tests/outputs.py @@ -105,12 +105,12 @@ def __init__(self): super().__init__("features") -class DisplacementsMomentaModel(torch.nn.Module): - """A model predicting displacements and momenta""" +class positionsMomentaModel(torch.nn.Module): + """A model predicting positions and momenta""" def __init__(self): super().__init__() - self.output_names = ["displacements", "momenta"] + self.output_names = ["positions", "momenta"] def forward( self, @@ -118,9 +118,9 @@ def forward( outputs: Dict[str, ModelOutput], selected_atoms: Optional[Labels] = None, ) -> Dict[str, TensorMap]: - assert "displacements" in outputs + assert "positions" in outputs assert "momenta" in outputs - assert outputs["displacements"].per_atom + assert outputs["positions"].per_atom assert outputs["momenta"].per_atom assert selected_atoms is None @@ -231,10 +231,10 @@ def test_features_model(system, get_capabilities): assert len(result["features"].blocks()) == 1 -def test_displacements_momenta_model(system): - model = DisplacementsMomentaModel() +def test_positions_momenta_model(system): + model = positionsMomentaModel() outputs = { - "displacements": ModelOutput(per_atom=True), + "positions": ModelOutput(per_atom=True), "momenta": ModelOutput(per_atom=True), } capabilities = ModelCapabilities( @@ -250,18 +250,18 @@ def test_displacements_momenta_model(system): options = ModelEvaluationOptions(outputs=outputs) result = atomistic([system, system], options, check_consistency=True) - assert "displacements" in result + assert "positions" in result assert "momenta" in result - displacements = result["displacements"] - assert displacements.keys == Labels("_", torch.tensor([[0]])) - assert list(displacements.block().values.shape) == [6, 3, 1] - assert displacements.block().samples.names == ["system", "atom"] - assert displacements.block().properties.names == ["displacements"] - assert displacements.block().components == [ + positions = result["positions"] + assert positions.keys == Labels("_", torch.tensor([[0]])) + assert list(positions.block().values.shape) == [6, 3, 1] + assert positions.block().samples.names == ["system", "atom"] + assert positions.block().properties.names == ["positions"] + assert positions.block().components == [ Labels("xyz", torch.tensor([[0], [1], [2]])) ] - assert len(result["displacements"].blocks()) == 1 + assert len(result["positions"].blocks()) == 1 momenta = result["momenta"] assert momenta.keys == Labels("_", torch.tensor([[0]]))