Skip to content

Commit 8efb9dc

Browse files
committed
Do not mix and match metatensor-core and metatensor-torch in gap
1 parent a764a6f commit 8efb9dc

File tree

2 files changed

+36
-168
lines changed

2 files changed

+36
-168
lines changed

src/metatrain/gap/model.py

Lines changed: 36 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,7 @@
66
import numpy as np
77
import scipy
88
import torch
9-
from metatensor import Labels, TensorBlock, TensorMap
10-
from metatensor.torch import Labels as TorchLabels
11-
from metatensor.torch import TensorBlock as TorchTensorBlock
12-
from metatensor.torch import TensorMap as TorchTensorMap
9+
from metatensor.torch import Labels, TensorBlock, TensorMap
1310
from metatomic.torch import (
1411
AtomisticModel,
1512
ModelCapabilities,
@@ -121,14 +118,14 @@ def __init__(self, hypers: Dict, dataset_info: DatasetInfo) -> None:
121118
self._sampler = _FPS(n_to_select=self.hypers["krr"]["num_sparse_points"])
122119

123120
# set it do dummy keys, these are properly set during training
124-
self._keys = TorchLabels.empty("_")
121+
self._keys = Labels.empty("_")
125122

126-
dummy_weights = TorchTensorMap(
127-
TorchLabels(["_"], torch.tensor([[0]])),
123+
dummy_weights = TensorMap(
124+
Labels(["_"], torch.tensor([[0]])),
128125
[mts.block_from_array(torch.empty(1, 1))],
129126
)
130-
dummy_X_pseudo = TorchTensorMap(
131-
TorchLabels(["_"], torch.tensor([[0]])),
127+
dummy_X_pseudo = TensorMap(
128+
Labels(["_"], torch.tensor([[0]])),
132129
[mts.block_from_array(torch.empty(1, 1))],
133130
)
134131
self._subset_of_regressors_torch = TorchSubsetofRegressors(
@@ -138,7 +135,7 @@ def __init__(self, hypers: Dict, dataset_info: DatasetInfo) -> None:
138135
"aggregate_names": ["atom", "center_type"],
139136
},
140137
)
141-
self._species_labels: TorchLabels = TorchLabels.empty("_")
138+
self._species_labels: Labels = Labels.empty("_")
142139

143140
# additive models: these are handled by the trainer at training
144141
# time, and they are added to the output at evaluation time
@@ -186,27 +183,27 @@ def forward(
186183
self,
187184
systems: List[System],
188185
outputs: Dict[str, ModelOutput],
189-
selected_atoms: Optional[TorchLabels] = None,
190-
) -> Dict[str, TorchTensorMap]:
186+
selected_atoms: Optional[Labels] = None,
187+
) -> Dict[str, TensorMap]:
191188
soap_features = self._soap_torch_calculator(
192189
systems, selected_samples=selected_atoms
193190
)
194191
# move keys and species labels to device
195192
self._keys = self._keys.to(systems[0].device)
196193
self._species_labels = self._species_labels.to(systems[0].device)
197194

198-
new_blocks: List[TorchTensorBlock] = []
195+
new_blocks: List[TensorBlock] = []
199196
# HACK: to add a block of zeros if there are missing species
200197
# which were present at training time
201198
# (with samples "system", "atom" = 0, 0)
202199
# given the values are all zeros, it does not introduce an error
203-
dummyblock: TorchTensorBlock = TorchTensorBlock(
200+
dummyblock = TensorBlock(
204201
values=torch.zeros(
205202
(1, len(soap_features[0].properties)),
206203
dtype=systems[0].positions.dtype,
207204
device=systems[0].device,
208205
),
209-
samples=TorchLabels(
206+
samples=Labels(
210207
["system", "atom"],
211208
torch.tensor([[0, 0]], dtype=torch.int, device=systems[0].device),
212209
),
@@ -215,7 +212,7 @@ def forward(
215212
)
216213
if len(soap_features[0].gradients_list()) > 0:
217214
for idx, grad in enumerate(soap_features[0].gradients_list()):
218-
dummyblock_grad: TorchTensorBlock = TorchTensorBlock(
215+
dummyblock_grad = TensorBlock(
219216
values=torch.zeros(
220217
(
221218
1,
@@ -225,7 +222,7 @@ def forward(
225222
dtype=systems[0].positions.dtype,
226223
device=systems[0].device,
227224
),
228-
samples=TorchLabels(
225+
samples=Labels(
229226
["sample", "system", "atom"],
230227
torch.tensor(
231228
[[0, 0, 0]], dtype=torch.int, device=systems[0].device
@@ -242,15 +239,15 @@ def forward(
242239
new_blocks.append(soap_features.block(key))
243240
else:
244241
new_blocks.append(dummyblock)
245-
soap_features = TorchTensorMap(keys=self._species_labels, blocks=new_blocks)
242+
soap_features = TensorMap(keys=self._species_labels, blocks=new_blocks)
246243
soap_features = soap_features.keys_to_samples("center_type")
247244
# here, we move to properties to use metatensor operations to aggregate
248245
# later on. Perhaps we could retain the sparsity all the way to the kernels
249246
# of the soap features with a lot more implementation effort
250247
soap_features = soap_features.keys_to_properties(
251248
["neighbor_1_type", "neighbor_2_type"]
252249
)
253-
soap_features = TorchTensorMap(self._keys, soap_features.blocks())
250+
soap_features = TensorMap(self._keys, soap_features.blocks())
254251
output_key = list(outputs.keys())[0]
255252
energies = self._subset_of_regressors_torch(soap_features)
256253
return_dict = {output_key: energies}
@@ -475,9 +472,9 @@ def __init__(
475472

476473
def aggregate_kernel(
477474
self,
478-
kernel: TorchTensorMap,
475+
kernel: TensorMap,
479476
are_pseudo_points: Tuple[bool, bool] = (False, False),
480-
) -> TorchTensorMap:
477+
) -> TensorMap:
481478
if not are_pseudo_points[0]:
482479
kernel = mts.sum_over_samples(kernel, self._aggregate_names)
483480
if not are_pseudo_points[1]:
@@ -488,17 +485,15 @@ def aggregate_kernel(
488485

489486
def forward(
490487
self,
491-
tensor1: TorchTensorMap,
492-
tensor2: TorchTensorMap,
488+
tensor1: TensorMap,
489+
tensor2: TensorMap,
493490
are_pseudo_points: Tuple[bool, bool] = (False, False),
494-
) -> TorchTensorMap:
491+
) -> TensorMap:
495492
return self.aggregate_kernel(
496493
self.compute_kernel(tensor1, tensor2), are_pseudo_points
497494
)
498495

499-
def compute_kernel(
500-
self, tensor1: TorchTensorMap, tensor2: TorchTensorMap
501-
) -> TorchTensorMap:
496+
def compute_kernel(self, tensor1: TensorMap, tensor2: TensorMap) -> TensorMap:
502497
raise NotImplementedError("compute_kernel needs to be implemented.")
503498

504499

@@ -512,7 +507,7 @@ def __init__(
512507
super().__init__(aggregate_names, structurewise_aggregate)
513508
self._degree = degree
514509

515-
def compute_kernel(self, tensor1: TorchTensorMap, tensor2: TorchTensorMap):
510+
def compute_kernel(self, tensor1: TensorMap, tensor2: TensorMap):
516511
return mts.pow(mts.dot(tensor1, tensor2), self._degree)
517512

518513

@@ -546,10 +541,6 @@ def fit(self, X: TensorMap): # -> GreedySelector:
546541
:param X:
547542
Training vectors.
548543
"""
549-
if isinstance(X, torch.ScriptObject):
550-
X = torch_tensor_map_to_core(X)
551-
assert isinstance(X[0].values, np.ndarray)
552-
553544
if len(X.component_names) != 0:
554545
raise ValueError("Only blocks with no components are supported.")
555546

@@ -578,7 +569,9 @@ def fit(self, X: TensorMap): # -> GreedySelector:
578569

579570
blocks.append(
580571
TensorBlock(
581-
values=np.zeros([len(samples), len(properties)], dtype=np.int32),
572+
values=torch.zeros(
573+
[len(samples), len(properties)], dtype=torch.int32
574+
),
582575
samples=samples,
583576
components=[],
584577
properties=properties,
@@ -596,12 +589,6 @@ def transform(self, X: TensorMap) -> TensorMap:
596589
:returns:
597590
The selected subset of the input.
598591
"""
599-
if isinstance(X, torch.ScriptObject):
600-
use_mts_torch = True
601-
X = torch_tensor_map_to_core(X)
602-
else:
603-
use_mts_torch = False
604-
605592
blocks = []
606593
for key, block in X.items():
607594
block_support = self.support.block(key)
@@ -614,10 +601,7 @@ def transform(self, X: TensorMap) -> TensorMap:
614601
new_block = mts.slice_block(block, "samples", block_support.samples)
615602
blocks.append(new_block)
616603

617-
X_reduced = TensorMap(X.keys, blocks)
618-
if use_mts_torch:
619-
X_reduced = core_tensor_map_to_torch(X_reduced)
620-
return X_reduced
604+
return TensorMap(X.keys, blocks)
621605

622606
def fit_transform(self, X: TensorMap) -> TensorMap:
623607
"""Fit to data, then transform it.
@@ -628,112 +612,6 @@ def fit_transform(self, X: TensorMap) -> TensorMap:
628612
return self.fit(X).transform(X)
629613

630614

631-
def torch_tensor_map_to_core(torch_tensor: TorchTensorMap):
632-
torch_blocks = []
633-
for _, torch_block in torch_tensor.items():
634-
torch_blocks.append(torch_tensor_block_to_core(torch_block))
635-
torch_keys = torch_labels_to_core(torch_tensor.keys)
636-
return TensorMap(torch_keys, torch_blocks)
637-
638-
639-
def torch_tensor_block_to_core(torch_block: TorchTensorBlock):
640-
"""Transforms a tensor block from metatensor-torch to metatensor-torch
641-
:param torch_block:
642-
tensor block from metatensor-torch
643-
:returns torch_block:
644-
tensor block from metatensor-torch
645-
"""
646-
block = TensorBlock(
647-
values=torch_block.values.detach().cpu().numpy(),
648-
samples=torch_labels_to_core(torch_block.samples),
649-
components=[
650-
torch_labels_to_core(component) for component in torch_block.components
651-
],
652-
properties=torch_labels_to_core(torch_block.properties),
653-
)
654-
for parameter, gradient in torch_block.gradients():
655-
block.add_gradient(
656-
parameter=parameter,
657-
gradient=TensorBlock(
658-
values=gradient.values.detach().cpu().numpy(),
659-
samples=torch_labels_to_core(gradient.samples),
660-
components=[
661-
torch_labels_to_core(component) for component in gradient.components
662-
],
663-
properties=torch_labels_to_core(gradient.properties),
664-
),
665-
)
666-
return block
667-
668-
669-
def torch_labels_to_core(torch_labels: TorchLabels):
670-
"""Transforms labels from metatensor-torch to metatensor-torch
671-
:param torch_block:
672-
tensor block from metatensor-torch
673-
:returns torch_block:
674-
labels from metatensor-torch
675-
"""
676-
return Labels(torch_labels.names, torch_labels.values.detach().cpu().numpy())
677-
678-
679-
###
680-
681-
682-
def core_tensor_map_to_torch(core_tensor: TensorMap):
683-
"""Transforms a tensor map from metatensor-core to metatensor-torch
684-
:param core_tensor:
685-
tensor map from metatensor-core
686-
:returns torch_tensor:
687-
tensor map from metatensor-torch
688-
"""
689-
690-
torch_blocks = []
691-
for _, core_block in core_tensor.items():
692-
torch_blocks.append(core_tensor_block_to_torch(core_block))
693-
torch_keys = core_labels_to_torch(core_tensor.keys)
694-
return TorchTensorMap(torch_keys, torch_blocks)
695-
696-
697-
def core_tensor_block_to_torch(core_block: TensorBlock):
698-
"""Transforms a tensor block from metatensor-core to metatensor-torch
699-
:param core_block:
700-
tensor block from metatensor-core
701-
:returns torch_block:
702-
tensor block from metatensor-torch
703-
"""
704-
block = TorchTensorBlock(
705-
values=torch.tensor(core_block.values),
706-
samples=core_labels_to_torch(core_block.samples),
707-
components=[
708-
core_labels_to_torch(component) for component in core_block.components
709-
],
710-
properties=core_labels_to_torch(core_block.properties),
711-
)
712-
for parameter, gradient in core_block.gradients():
713-
block.add_gradient(
714-
parameter=parameter,
715-
gradient=TorchTensorBlock(
716-
values=torch.tensor(gradient.values),
717-
samples=core_labels_to_torch(gradient.samples),
718-
components=[
719-
core_labels_to_torch(component) for component in gradient.components
720-
],
721-
properties=core_labels_to_torch(gradient.properties),
722-
),
723-
)
724-
return block
725-
726-
727-
def core_labels_to_torch(core_labels: Labels):
728-
"""Transforms labels from metatensor-core to metatensor-torch
729-
:param core_block:
730-
tensor block from metatensor-core
731-
:returns torch_block:
732-
labels from metatensor-torch
733-
"""
734-
return TorchLabels(core_labels.names, torch.tensor(core_labels.values))
735-
736-
737615
class SubsetOfRegressors:
738616
def __init__(
739617
self,
@@ -809,10 +687,6 @@ def fit(
809687
if not isinstance(alpha_forces, float):
810688
raise ValueError("alpha must either be a float")
811689

812-
X = X.to(arrays="numpy")
813-
X_pseudo = X_pseudo.to(arrays="numpy")
814-
y = y.to(arrays="numpy")
815-
816690
if self._kernel is None:
817691
# _set_kernel only returns None if kernel type is precomputed
818692
k_nm = X
@@ -831,11 +705,11 @@ def fit(
831705
structures = torch.unique(k_nm_block.samples["system"])
832706
n_atoms_per_structure = []
833707
for structure in structures:
834-
n_atoms = np.sum(X_block.samples["system"] == structure)
708+
n_atoms = torch.sum(X_block.samples["system"] == structure)
835709
n_atoms_per_structure.append(float(n_atoms))
836710

837-
n_atoms_per_structure = np.array(n_atoms_per_structure)
838-
normalization = np.sqrt(n_atoms_per_structure)
711+
n_atoms_per_structure = torch.tensor(n_atoms_per_structure)
712+
normalization = torch.sqrt(n_atoms_per_structure)
839713

840714
if not (np.allclose(alpha_energy, 0.0)):
841715
normalization /= alpha_energy
@@ -871,7 +745,7 @@ def fit(
871745
self._solver.fit(k_nm_reg, y_reg)
872746

873747
weight_block = TensorBlock(
874-
values=self._solver.weights.T,
748+
values=torch.as_tensor(self._solver.weights.T),
875749
samples=y_block.properties,
876750
components=k_nm_block.components,
877751
properties=k_nm_block.properties,
@@ -901,17 +775,17 @@ def predict(self, T: TensorMap) -> TensorMap:
901775

902776
def export_torch_script_model(self):
903777
return TorchSubsetofRegressors(
904-
core_tensor_map_to_torch(self._weights),
905-
core_tensor_map_to_torch(self._X_pseudo),
778+
self._weights,
779+
self._X_pseudo,
906780
self._kernel_kwargs,
907781
)
908782

909783

910784
class TorchSubsetofRegressors(torch.nn.Module):
911785
def __init__(
912786
self,
913-
weights: TorchTensorMap,
914-
X_pseudo: TorchTensorMap,
787+
weights: TensorMap,
788+
X_pseudo: TensorMap,
915789
kernel_kwargs: Optional[dict] = None,
916790
):
917791
super().__init__()
@@ -923,7 +797,7 @@ def __init__(
923797
# Set the kernel
924798
self._kernel = TorchAggregatePolynomial(**kernel_kwargs)
925799

926-
def forward(self, T: TorchTensorMap) -> TorchTensorMap:
800+
def forward(self, T: TensorMap) -> TensorMap:
927801
"""
928802
:param T:
929803
features

0 commit comments

Comments
 (0)