6
6
import numpy as np
7
7
import scipy
8
8
import torch
9
- from metatensor import Labels , TensorBlock , TensorMap
10
- from metatensor .torch import Labels as TorchLabels
11
- from metatensor .torch import TensorBlock as TorchTensorBlock
12
- from metatensor .torch import TensorMap as TorchTensorMap
9
+ from metatensor .torch import Labels , TensorBlock , TensorMap
13
10
from metatomic .torch import (
14
11
AtomisticModel ,
15
12
ModelCapabilities ,
@@ -121,14 +118,14 @@ def __init__(self, hypers: Dict, dataset_info: DatasetInfo) -> None:
121
118
self ._sampler = _FPS (n_to_select = self .hypers ["krr" ]["num_sparse_points" ])
122
119
123
120
# set it do dummy keys, these are properly set during training
124
- self ._keys = TorchLabels .empty ("_" )
121
+ self ._keys = Labels .empty ("_" )
125
122
126
- dummy_weights = TorchTensorMap (
127
- TorchLabels (["_" ], torch .tensor ([[0 ]])),
123
+ dummy_weights = TensorMap (
124
+ Labels (["_" ], torch .tensor ([[0 ]])),
128
125
[mts .block_from_array (torch .empty (1 , 1 ))],
129
126
)
130
- dummy_X_pseudo = TorchTensorMap (
131
- TorchLabels (["_" ], torch .tensor ([[0 ]])),
127
+ dummy_X_pseudo = TensorMap (
128
+ Labels (["_" ], torch .tensor ([[0 ]])),
132
129
[mts .block_from_array (torch .empty (1 , 1 ))],
133
130
)
134
131
self ._subset_of_regressors_torch = TorchSubsetofRegressors (
@@ -138,7 +135,7 @@ def __init__(self, hypers: Dict, dataset_info: DatasetInfo) -> None:
138
135
"aggregate_names" : ["atom" , "center_type" ],
139
136
},
140
137
)
141
- self ._species_labels : TorchLabels = TorchLabels .empty ("_" )
138
+ self ._species_labels : Labels = Labels .empty ("_" )
142
139
143
140
# additive models: these are handled by the trainer at training
144
141
# time, and they are added to the output at evaluation time
@@ -186,27 +183,27 @@ def forward(
186
183
self ,
187
184
systems : List [System ],
188
185
outputs : Dict [str , ModelOutput ],
189
- selected_atoms : Optional [TorchLabels ] = None ,
190
- ) -> Dict [str , TorchTensorMap ]:
186
+ selected_atoms : Optional [Labels ] = None ,
187
+ ) -> Dict [str , TensorMap ]:
191
188
soap_features = self ._soap_torch_calculator (
192
189
systems , selected_samples = selected_atoms
193
190
)
194
191
# move keys and species labels to device
195
192
self ._keys = self ._keys .to (systems [0 ].device )
196
193
self ._species_labels = self ._species_labels .to (systems [0 ].device )
197
194
198
- new_blocks : List [TorchTensorBlock ] = []
195
+ new_blocks : List [TensorBlock ] = []
199
196
# HACK: to add a block of zeros if there are missing species
200
197
# which were present at training time
201
198
# (with samples "system", "atom" = 0, 0)
202
199
# given the values are all zeros, it does not introduce an error
203
- dummyblock : TorchTensorBlock = TorchTensorBlock (
200
+ dummyblock = TensorBlock (
204
201
values = torch .zeros (
205
202
(1 , len (soap_features [0 ].properties )),
206
203
dtype = systems [0 ].positions .dtype ,
207
204
device = systems [0 ].device ,
208
205
),
209
- samples = TorchLabels (
206
+ samples = Labels (
210
207
["system" , "atom" ],
211
208
torch .tensor ([[0 , 0 ]], dtype = torch .int , device = systems [0 ].device ),
212
209
),
@@ -215,7 +212,7 @@ def forward(
215
212
)
216
213
if len (soap_features [0 ].gradients_list ()) > 0 :
217
214
for idx , grad in enumerate (soap_features [0 ].gradients_list ()):
218
- dummyblock_grad : TorchTensorBlock = TorchTensorBlock (
215
+ dummyblock_grad = TensorBlock (
219
216
values = torch .zeros (
220
217
(
221
218
1 ,
@@ -225,7 +222,7 @@ def forward(
225
222
dtype = systems [0 ].positions .dtype ,
226
223
device = systems [0 ].device ,
227
224
),
228
- samples = TorchLabels (
225
+ samples = Labels (
229
226
["sample" , "system" , "atom" ],
230
227
torch .tensor (
231
228
[[0 , 0 , 0 ]], dtype = torch .int , device = systems [0 ].device
@@ -242,15 +239,15 @@ def forward(
242
239
new_blocks .append (soap_features .block (key ))
243
240
else :
244
241
new_blocks .append (dummyblock )
245
- soap_features = TorchTensorMap (keys = self ._species_labels , blocks = new_blocks )
242
+ soap_features = TensorMap (keys = self ._species_labels , blocks = new_blocks )
246
243
soap_features = soap_features .keys_to_samples ("center_type" )
247
244
# here, we move to properties to use metatensor operations to aggregate
248
245
# later on. Perhaps we could retain the sparsity all the way to the kernels
249
246
# of the soap features with a lot more implementation effort
250
247
soap_features = soap_features .keys_to_properties (
251
248
["neighbor_1_type" , "neighbor_2_type" ]
252
249
)
253
- soap_features = TorchTensorMap (self ._keys , soap_features .blocks ())
250
+ soap_features = TensorMap (self ._keys , soap_features .blocks ())
254
251
output_key = list (outputs .keys ())[0 ]
255
252
energies = self ._subset_of_regressors_torch (soap_features )
256
253
return_dict = {output_key : energies }
@@ -475,9 +472,9 @@ def __init__(
475
472
476
473
def aggregate_kernel (
477
474
self ,
478
- kernel : TorchTensorMap ,
475
+ kernel : TensorMap ,
479
476
are_pseudo_points : Tuple [bool , bool ] = (False , False ),
480
- ) -> TorchTensorMap :
477
+ ) -> TensorMap :
481
478
if not are_pseudo_points [0 ]:
482
479
kernel = mts .sum_over_samples (kernel , self ._aggregate_names )
483
480
if not are_pseudo_points [1 ]:
@@ -488,17 +485,15 @@ def aggregate_kernel(
488
485
489
486
def forward (
490
487
self ,
491
- tensor1 : TorchTensorMap ,
492
- tensor2 : TorchTensorMap ,
488
+ tensor1 : TensorMap ,
489
+ tensor2 : TensorMap ,
493
490
are_pseudo_points : Tuple [bool , bool ] = (False , False ),
494
- ) -> TorchTensorMap :
491
+ ) -> TensorMap :
495
492
return self .aggregate_kernel (
496
493
self .compute_kernel (tensor1 , tensor2 ), are_pseudo_points
497
494
)
498
495
499
- def compute_kernel (
500
- self , tensor1 : TorchTensorMap , tensor2 : TorchTensorMap
501
- ) -> TorchTensorMap :
496
+ def compute_kernel (self , tensor1 : TensorMap , tensor2 : TensorMap ) -> TensorMap :
502
497
raise NotImplementedError ("compute_kernel needs to be implemented." )
503
498
504
499
@@ -512,7 +507,7 @@ def __init__(
512
507
super ().__init__ (aggregate_names , structurewise_aggregate )
513
508
self ._degree = degree
514
509
515
- def compute_kernel (self , tensor1 : TorchTensorMap , tensor2 : TorchTensorMap ):
510
+ def compute_kernel (self , tensor1 : TensorMap , tensor2 : TensorMap ):
516
511
return mts .pow (mts .dot (tensor1 , tensor2 ), self ._degree )
517
512
518
513
@@ -546,10 +541,6 @@ def fit(self, X: TensorMap): # -> GreedySelector:
546
541
:param X:
547
542
Training vectors.
548
543
"""
549
- if isinstance (X , torch .ScriptObject ):
550
- X = torch_tensor_map_to_core (X )
551
- assert isinstance (X [0 ].values , np .ndarray )
552
-
553
544
if len (X .component_names ) != 0 :
554
545
raise ValueError ("Only blocks with no components are supported." )
555
546
@@ -578,7 +569,9 @@ def fit(self, X: TensorMap): # -> GreedySelector:
578
569
579
570
blocks .append (
580
571
TensorBlock (
581
- values = np .zeros ([len (samples ), len (properties )], dtype = np .int32 ),
572
+ values = torch .zeros (
573
+ [len (samples ), len (properties )], dtype = torch .int32
574
+ ),
582
575
samples = samples ,
583
576
components = [],
584
577
properties = properties ,
@@ -596,12 +589,6 @@ def transform(self, X: TensorMap) -> TensorMap:
596
589
:returns:
597
590
The selected subset of the input.
598
591
"""
599
- if isinstance (X , torch .ScriptObject ):
600
- use_mts_torch = True
601
- X = torch_tensor_map_to_core (X )
602
- else :
603
- use_mts_torch = False
604
-
605
592
blocks = []
606
593
for key , block in X .items ():
607
594
block_support = self .support .block (key )
@@ -614,10 +601,7 @@ def transform(self, X: TensorMap) -> TensorMap:
614
601
new_block = mts .slice_block (block , "samples" , block_support .samples )
615
602
blocks .append (new_block )
616
603
617
- X_reduced = TensorMap (X .keys , blocks )
618
- if use_mts_torch :
619
- X_reduced = core_tensor_map_to_torch (X_reduced )
620
- return X_reduced
604
+ return TensorMap (X .keys , blocks )
621
605
622
606
def fit_transform (self , X : TensorMap ) -> TensorMap :
623
607
"""Fit to data, then transform it.
@@ -628,112 +612,6 @@ def fit_transform(self, X: TensorMap) -> TensorMap:
628
612
return self .fit (X ).transform (X )
629
613
630
614
631
- def torch_tensor_map_to_core (torch_tensor : TorchTensorMap ):
632
- torch_blocks = []
633
- for _ , torch_block in torch_tensor .items ():
634
- torch_blocks .append (torch_tensor_block_to_core (torch_block ))
635
- torch_keys = torch_labels_to_core (torch_tensor .keys )
636
- return TensorMap (torch_keys , torch_blocks )
637
-
638
-
639
- def torch_tensor_block_to_core (torch_block : TorchTensorBlock ):
640
- """Transforms a tensor block from metatensor-torch to metatensor-torch
641
- :param torch_block:
642
- tensor block from metatensor-torch
643
- :returns torch_block:
644
- tensor block from metatensor-torch
645
- """
646
- block = TensorBlock (
647
- values = torch_block .values .detach ().cpu ().numpy (),
648
- samples = torch_labels_to_core (torch_block .samples ),
649
- components = [
650
- torch_labels_to_core (component ) for component in torch_block .components
651
- ],
652
- properties = torch_labels_to_core (torch_block .properties ),
653
- )
654
- for parameter , gradient in torch_block .gradients ():
655
- block .add_gradient (
656
- parameter = parameter ,
657
- gradient = TensorBlock (
658
- values = gradient .values .detach ().cpu ().numpy (),
659
- samples = torch_labels_to_core (gradient .samples ),
660
- components = [
661
- torch_labels_to_core (component ) for component in gradient .components
662
- ],
663
- properties = torch_labels_to_core (gradient .properties ),
664
- ),
665
- )
666
- return block
667
-
668
-
669
- def torch_labels_to_core (torch_labels : TorchLabels ):
670
- """Transforms labels from metatensor-torch to metatensor-torch
671
- :param torch_block:
672
- tensor block from metatensor-torch
673
- :returns torch_block:
674
- labels from metatensor-torch
675
- """
676
- return Labels (torch_labels .names , torch_labels .values .detach ().cpu ().numpy ())
677
-
678
-
679
- ###
680
-
681
-
682
- def core_tensor_map_to_torch (core_tensor : TensorMap ):
683
- """Transforms a tensor map from metatensor-core to metatensor-torch
684
- :param core_tensor:
685
- tensor map from metatensor-core
686
- :returns torch_tensor:
687
- tensor map from metatensor-torch
688
- """
689
-
690
- torch_blocks = []
691
- for _ , core_block in core_tensor .items ():
692
- torch_blocks .append (core_tensor_block_to_torch (core_block ))
693
- torch_keys = core_labels_to_torch (core_tensor .keys )
694
- return TorchTensorMap (torch_keys , torch_blocks )
695
-
696
-
697
- def core_tensor_block_to_torch (core_block : TensorBlock ):
698
- """Transforms a tensor block from metatensor-core to metatensor-torch
699
- :param core_block:
700
- tensor block from metatensor-core
701
- :returns torch_block:
702
- tensor block from metatensor-torch
703
- """
704
- block = TorchTensorBlock (
705
- values = torch .tensor (core_block .values ),
706
- samples = core_labels_to_torch (core_block .samples ),
707
- components = [
708
- core_labels_to_torch (component ) for component in core_block .components
709
- ],
710
- properties = core_labels_to_torch (core_block .properties ),
711
- )
712
- for parameter , gradient in core_block .gradients ():
713
- block .add_gradient (
714
- parameter = parameter ,
715
- gradient = TorchTensorBlock (
716
- values = torch .tensor (gradient .values ),
717
- samples = core_labels_to_torch (gradient .samples ),
718
- components = [
719
- core_labels_to_torch (component ) for component in gradient .components
720
- ],
721
- properties = core_labels_to_torch (gradient .properties ),
722
- ),
723
- )
724
- return block
725
-
726
-
727
- def core_labels_to_torch (core_labels : Labels ):
728
- """Transforms labels from metatensor-core to metatensor-torch
729
- :param core_block:
730
- tensor block from metatensor-core
731
- :returns torch_block:
732
- labels from metatensor-torch
733
- """
734
- return TorchLabels (core_labels .names , torch .tensor (core_labels .values ))
735
-
736
-
737
615
class SubsetOfRegressors :
738
616
def __init__ (
739
617
self ,
@@ -809,10 +687,6 @@ def fit(
809
687
if not isinstance (alpha_forces , float ):
810
688
raise ValueError ("alpha must either be a float" )
811
689
812
- X = X .to (arrays = "numpy" )
813
- X_pseudo = X_pseudo .to (arrays = "numpy" )
814
- y = y .to (arrays = "numpy" )
815
-
816
690
if self ._kernel is None :
817
691
# _set_kernel only returns None if kernel type is precomputed
818
692
k_nm = X
@@ -831,11 +705,11 @@ def fit(
831
705
structures = torch .unique (k_nm_block .samples ["system" ])
832
706
n_atoms_per_structure = []
833
707
for structure in structures :
834
- n_atoms = np .sum (X_block .samples ["system" ] == structure )
708
+ n_atoms = torch .sum (X_block .samples ["system" ] == structure )
835
709
n_atoms_per_structure .append (float (n_atoms ))
836
710
837
- n_atoms_per_structure = np . array (n_atoms_per_structure )
838
- normalization = np .sqrt (n_atoms_per_structure )
711
+ n_atoms_per_structure = torch . tensor (n_atoms_per_structure )
712
+ normalization = torch .sqrt (n_atoms_per_structure )
839
713
840
714
if not (np .allclose (alpha_energy , 0.0 )):
841
715
normalization /= alpha_energy
@@ -871,7 +745,7 @@ def fit(
871
745
self ._solver .fit (k_nm_reg , y_reg )
872
746
873
747
weight_block = TensorBlock (
874
- values = self ._solver .weights .T ,
748
+ values = torch . as_tensor ( self ._solver .weights .T ) ,
875
749
samples = y_block .properties ,
876
750
components = k_nm_block .components ,
877
751
properties = k_nm_block .properties ,
@@ -901,17 +775,17 @@ def predict(self, T: TensorMap) -> TensorMap:
901
775
902
776
def export_torch_script_model (self ):
903
777
return TorchSubsetofRegressors (
904
- core_tensor_map_to_torch ( self ._weights ) ,
905
- core_tensor_map_to_torch ( self ._X_pseudo ) ,
778
+ self ._weights ,
779
+ self ._X_pseudo ,
906
780
self ._kernel_kwargs ,
907
781
)
908
782
909
783
910
784
class TorchSubsetofRegressors (torch .nn .Module ):
911
785
def __init__ (
912
786
self ,
913
- weights : TorchTensorMap ,
914
- X_pseudo : TorchTensorMap ,
787
+ weights : TensorMap ,
788
+ X_pseudo : TensorMap ,
915
789
kernel_kwargs : Optional [dict ] = None ,
916
790
):
917
791
super ().__init__ ()
@@ -923,7 +797,7 @@ def __init__(
923
797
# Set the kernel
924
798
self ._kernel = TorchAggregatePolynomial (** kernel_kwargs )
925
799
926
- def forward (self , T : TorchTensorMap ) -> TorchTensorMap :
800
+ def forward (self , T : TensorMap ) -> TensorMap :
927
801
"""
928
802
:param T:
929
803
features
0 commit comments