From ee97a4ee5e9326e4f9e5b536f7a63af9e853a868 Mon Sep 17 00:00:00 2001 From: Guillaume Fraux Date: Mon, 14 Jul 2025 12:25:23 +0200 Subject: [PATCH 1/2] Import metatensor.torch as mts --- src/metatrain/deprecated/pet/model.py | 6 +-- src/metatrain/experimental/nanopet/model.py | 6 +-- .../nanopet/tests/test_functionality.py | 12 ++---- src/metatrain/gap/model.py | 34 +++++++---------- src/metatrain/gap/trainer.py | 7 ++-- src/metatrain/pet/model.py | 4 +- src/metatrain/pet/modules/compatibility.py | 4 +- src/metatrain/pet/tests/test_functionality.py | 12 ++---- .../pet/tests/test_pet_compatibility.py | 8 ++-- src/metatrain/soap_bpnn/model.py | 16 +++----- src/metatrain/soap_bpnn/spherical.py | 12 +++--- .../soap_bpnn/tests/test_functionality.py | 12 ++---- .../utils/additive/old_composition.py | 22 +++++------ src/metatrain/utils/additive/remove.py | 6 +-- src/metatrain/utils/additive/zbl.py | 4 +- .../utils/data/readers/metatensor.py | 8 ++-- .../utils/data/writers/metatensor.py | 6 +-- src/metatrain/utils/data/writers/writers.py | 4 +- src/metatrain/utils/llpr.py | 4 +- src/metatrain/utils/scaler.py | 8 ++-- tests/utils/data/test_readers_metatensor.py | 38 ++++++++----------- tests/utils/data/test_writers.py | 4 +- tests/utils/test_additive.py | 30 ++++++--------- tests/utils/test_sum_over_atoms.py | 6 +-- tests/utils/test_transfer.py | 6 +-- 25 files changed, 119 insertions(+), 160 deletions(-) diff --git a/src/metatrain/deprecated/pet/model.py b/src/metatrain/deprecated/pet/model.py index 33b0dba86..a88b0a831 100644 --- a/src/metatrain/deprecated/pet/model.py +++ b/src/metatrain/deprecated/pet/model.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Literal, Optional -import metatensor.torch +import metatensor.torch as mts import torch from metatensor.torch import Labels, TensorBlock, TensorMap from metatomic.torch import ( @@ -219,7 +219,7 @@ def forward( values=predictions, ) if selected_atoms is not None: - block = metatensor.torch.slice_block(block, "samples", selected_atoms) + block = mts.slice_block(block, "samples", selected_atoms) output_tmap = TensorMap(keys=empty_labels, blocks=[block]) if not outputs[output_name].per_atom: output_tmap = sum_over_atoms(output_tmap) @@ -238,7 +238,7 @@ def forward( selected_atoms, ) for output_name in additive_contributions: - output_quantities[output_name] = metatensor.torch.add( + output_quantities[output_name] = mts.add( output_quantities[output_name], additive_contributions[output_name], ) diff --git a/src/metatrain/experimental/nanopet/model.py b/src/metatrain/experimental/nanopet/model.py index 35dfc4f94..b55bc61b1 100644 --- a/src/metatrain/experimental/nanopet/model.py +++ b/src/metatrain/experimental/nanopet/model.py @@ -2,7 +2,7 @@ from math import prod from typing import Any, Dict, List, Literal, Optional -import metatensor.torch +import metatensor.torch as mts import torch from metatensor.torch import Labels, TensorBlock, TensorMap from metatomic.torch import ( @@ -517,7 +517,7 @@ def forward( if selected_atoms is not None: for output_name, tmap in atomic_properties_tmap_dict.items(): - atomic_properties_tmap_dict[output_name] = metatensor.torch.slice( + atomic_properties_tmap_dict[output_name] = mts.slice( tmap, axis="samples", selection=selected_atoms ) @@ -541,7 +541,7 @@ def forward( selected_atoms, ) for name in additive_contributions: - return_dict[name] = metatensor.torch.add( + return_dict[name] = mts.add( return_dict[name], additive_contributions[name], ) diff --git a/src/metatrain/experimental/nanopet/tests/test_functionality.py b/src/metatrain/experimental/nanopet/tests/test_functionality.py index af2737ef6..8bd7adb18 100644 --- a/src/metatrain/experimental/nanopet/tests/test_functionality.py +++ b/src/metatrain/experimental/nanopet/tests/test_functionality.py @@ -1,4 +1,4 @@ -import metatensor.torch +import metatensor.torch as mts import pytest import torch from jsonschema.exceptions import ValidationError @@ -152,7 +152,7 @@ def test_prediction_subset_atoms(): system_far_away_dimer, model.requested_neighbor_lists() ) - selection_labels = metatensor.torch.Labels( + selection_labels = mts.Labels( names=["system", "atom"], values=torch.tensor([[0, 0], [0, 2], [0, 3]]), ) @@ -168,13 +168,9 @@ def test_prediction_subset_atoms(): selected_atoms=selection_labels, ) - assert not metatensor.torch.allclose( - energy_monomer["energy"], energy_dimer["energy"] - ) + assert not mts.allclose(energy_monomer["energy"], energy_dimer["energy"]) - assert metatensor.torch.allclose( - energy_monomer["energy"], energy_monomer_in_dimer["energy"] - ) + assert mts.allclose(energy_monomer["energy"], energy_monomer_in_dimer["energy"]) torch.set_default_dtype(default_dtype_before) diff --git a/src/metatrain/gap/model.py b/src/metatrain/gap/model.py index 7fdc92a5c..7b7988230 100644 --- a/src/metatrain/gap/model.py +++ b/src/metatrain/gap/model.py @@ -2,7 +2,7 @@ import featomic import featomic.torch -import metatensor.torch +import metatensor.torch as mts import numpy as np import scipy import torch @@ -125,11 +125,11 @@ def __init__(self, hypers: Dict, dataset_info: DatasetInfo) -> None: dummy_weights = TorchTensorMap( TorchLabels(["_"], torch.tensor([[0]])), - [metatensor.torch.block_from_array(torch.empty(1, 1))], + [mts.block_from_array(torch.empty(1, 1))], ) dummy_X_pseudo = TorchTensorMap( TorchLabels(["_"], torch.tensor([[0]])), - [metatensor.torch.block_from_array(torch.empty(1, 1))], + [mts.block_from_array(torch.empty(1, 1))], ) self._subset_of_regressors_torch = TorchSubsetofRegressors( dummy_weights, @@ -268,7 +268,7 @@ def forward( selected_atoms, ) for name in additive_contributions: - return_dict[name] = metatensor.torch.add( + return_dict[name] = mts.add( return_dict[name], additive_contributions[name], ) @@ -421,7 +421,7 @@ def aggregate_kernel( self, kernel: TensorMap, are_pseudo_points: Tuple[bool, bool] = (False, False) ) -> TensorMap: if not are_pseudo_points[0]: - kernel = metatensor.sum_over_samples(kernel, self._aggregate_names) + kernel = mts.sum_over_samples(kernel, self._aggregate_names) if not are_pseudo_points[1]: raise NotImplementedError( "properties dimension cannot be aggregated for the moment" @@ -453,7 +453,7 @@ def __init__( self._degree = degree def compute_kernel(self, tensor1: TensorMap, tensor2: TensorMap): - return metatensor.pow(metatensor.dot(tensor1, tensor2), self._degree) + return mts.pow(mts.dot(tensor1, tensor2), self._degree) class TorchAggregateKernel(torch.nn.Module): @@ -479,7 +479,7 @@ def aggregate_kernel( are_pseudo_points: Tuple[bool, bool] = (False, False), ) -> TorchTensorMap: if not are_pseudo_points[0]: - kernel = metatensor.torch.sum_over_samples(kernel, self._aggregate_names) + kernel = mts.sum_over_samples(kernel, self._aggregate_names) if not are_pseudo_points[1]: raise NotImplementedError( "properties dimension cannot be aggregated for the moment" @@ -513,9 +513,7 @@ def __init__( self._degree = degree def compute_kernel(self, tensor1: TorchTensorMap, tensor2: TorchTensorMap): - return metatensor.torch.pow( - metatensor.torch.dot(tensor1, tensor2), self._degree - ) + return mts.pow(mts.dot(tensor1, tensor2), self._degree) class _FPS: @@ -609,13 +607,11 @@ def transform(self, X: TensorMap) -> TensorMap: block_support = self.support.block(key) if self._selection_type == "feature": - new_block = metatensor.slice_block( + new_block = mts.slice_block( block, "properties", block_support.properties ) elif self._selection_type == "sample": - new_block = metatensor.slice_block( - block, "samples", block_support.samples - ) + new_block = mts.slice_block(block, "samples", block_support.samples) blocks.append(new_block) X_reduced = TensorMap(X.keys, blocks) @@ -832,16 +828,14 @@ def fit( k_nm_block = k_nm.block(key) k_mm_block = k_mm.block(key) X_block = X.block(key) - structures = metatensor.operations._dispatch.unique( - k_nm_block.samples["system"] - ) + structures = torch.unique(k_nm_block.samples["system"]) n_atoms_per_structure = [] for structure in structures: n_atoms = np.sum(X_block.samples["system"] == structure) n_atoms_per_structure.append(float(n_atoms)) n_atoms_per_structure = np.array(n_atoms_per_structure) - normalization = metatensor.operations._dispatch.sqrt(n_atoms_per_structure) + normalization = np.sqrt(n_atoms_per_structure) if not (np.allclose(alpha_energy, 0.0)): normalization /= alpha_energy @@ -903,7 +897,7 @@ def predict(self, T: TensorMap) -> TensorMap: k_tm = T else: k_tm = self._kernel(T, self._X_pseudo, are_pseudo_points=(False, True)) - return metatensor.dot(k_tm, self._weights) + return mts.dot(k_tm, self._weights) def export_torch_script_model(self): return TorchSubsetofRegressors( @@ -940,4 +934,4 @@ def forward(self, T: TorchTensorMap) -> TorchTensorMap: self._X_pseudo = self._X_pseudo.to(T.device) k_tm = self._kernel(T, self._X_pseudo, are_pseudo_points=(False, True)) - return metatensor.torch.dot(k_tm, self._weights) + return mts.dot(k_tm, self._weights) diff --git a/src/metatrain/gap/trainer.py b/src/metatrain/gap/trainer.py index e95eec7a2..92eb0cd42 100644 --- a/src/metatrain/gap/trainer.py +++ b/src/metatrain/gap/trainer.py @@ -2,8 +2,7 @@ from pathlib import Path from typing import Any, Dict, List, Literal, Union -import metatensor -import metatensor.torch +import metatensor.torch as mts import torch from metatensor.torch import TensorMap @@ -65,7 +64,7 @@ def train( "equivariant learning which is not supported yet." ) train_dataset = train_datasets[0] - train_y = metatensor.torch.join( + train_y = mts.join( [sample[output_name] for sample in train_dataset], axis="samples", remove_tensor_name=True, @@ -121,7 +120,7 @@ def train( f"should be smaller than the number of environments ({lens})" ) sparse_points = model._sampler.fit_transform(train_tensor) - sparse_points = metatensor.operations.remove_gradients(sparse_points) + sparse_points = mts.remove_gradients(sparse_points) alpha_energy = self.hypers["regularizer"] if self.hypers["regularizer_forces"] is None: alpha_forces = alpha_energy diff --git a/src/metatrain/pet/model.py b/src/metatrain/pet/model.py index e021cb016..b4e2a4f63 100644 --- a/src/metatrain/pet/model.py +++ b/src/metatrain/pet/model.py @@ -2,7 +2,7 @@ from math import prod from typing import Any, Dict, List, Literal, Optional -import metatensor.torch +import metatensor.torch as mts import torch from metatensor.torch import Labels, TensorBlock, TensorMap from metatensor.torch.operations._add import _add_block_block @@ -621,7 +621,7 @@ def forward( if selected_atoms is not None: for output_name, tmap in atomic_predictions_tmap_dict.items(): - atomic_predictions_tmap_dict[output_name] = metatensor.torch.slice( + atomic_predictions_tmap_dict[output_name] = mts.slice( tmap, axis="samples", selection=selected_atoms ) diff --git a/src/metatrain/pet/modules/compatibility.py b/src/metatrain/pet/modules/compatibility.py index 72c17eb46..71f7fbbe7 100644 --- a/src/metatrain/pet/modules/compatibility.py +++ b/src/metatrain/pet/modules/compatibility.py @@ -1,6 +1,6 @@ from typing import Dict -import metatensor.torch +import metatensor.torch as mts import torch from metatensor.torch import Labels, TensorBlock, TensorMap @@ -68,7 +68,7 @@ def convert_model_state_dict_from_legacy_pet( ) new_model_state_dict["additive_models.0.energy_composition_buffer"] = ( - metatensor.torch.save_buffer(weights) + mts.save_buffer(weights) ) species_to_species_index = torch.full( diff --git a/src/metatrain/pet/tests/test_functionality.py b/src/metatrain/pet/tests/test_functionality.py index b1e3529cd..b30dda4d8 100644 --- a/src/metatrain/pet/tests/test_functionality.py +++ b/src/metatrain/pet/tests/test_functionality.py @@ -1,4 +1,4 @@ -import metatensor.torch +import metatensor.torch as mts import pytest import torch from jsonschema.exceptions import ValidationError @@ -176,7 +176,7 @@ def test_prediction_subset_atoms(): system_far_away_dimer, model.requested_neighbor_lists() ) - selection_labels = metatensor.torch.Labels( + selection_labels = mts.Labels( names=["system", "atom"], values=torch.tensor([[0, 0], [0, 2], [0, 3]]), ) @@ -192,13 +192,9 @@ def test_prediction_subset_atoms(): selected_atoms=selection_labels, ) - assert not metatensor.torch.allclose( - energy_monomer["energy"], energy_dimer["energy"] - ) + assert not mts.allclose(energy_monomer["energy"], energy_dimer["energy"]) - assert metatensor.torch.allclose( - energy_monomer["energy"], energy_monomer_in_dimer["energy"] - ) + assert mts.allclose(energy_monomer["energy"], energy_monomer_in_dimer["energy"]) torch.set_default_dtype(default_dtype_before) diff --git a/src/metatrain/pet/tests/test_pet_compatibility.py b/src/metatrain/pet/tests/test_pet_compatibility.py index eeba1be89..a8cbeb7a0 100644 --- a/src/metatrain/pet/tests/test_pet_compatibility.py +++ b/src/metatrain/pet/tests/test_pet_compatibility.py @@ -6,7 +6,7 @@ from urllib.parse import urlparse from urllib.request import urlretrieve -import metatensor.torch +import metatensor.torch as mts import torch from metatomic.torch import ModelOutput @@ -765,8 +765,10 @@ def test_last_layer_features_compatibility(): pet_last_layer_features = pet_predictions["mtt::aux::energy_last_layer_features"] - assert metatensor.torch.allclose( - nativepet_last_layer_features, pet_last_layer_features, atol=1e-6 + assert mts.allclose( + nativepet_last_layer_features, + pet_last_layer_features, + atol=1e-6, ) diff --git a/src/metatrain/soap_bpnn/model.py b/src/metatrain/soap_bpnn/model.py index a656fc753..d700616a2 100644 --- a/src/metatrain/soap_bpnn/model.py +++ b/src/metatrain/soap_bpnn/model.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Literal, Optional -import metatensor.torch +import metatensor.torch as mts import torch from metatensor.torch import Labels, TensorBlock, TensorMap from metatensor.torch.learn.nn import Linear as LinearMap @@ -441,9 +441,7 @@ def forward( sample_values[:, 1], ) if selected_atoms is not None: - soap_features = metatensor.torch.slice( - soap_features, "samples", selected_atoms - ) + soap_features = mts.slice(soap_features, "samples", selected_atoms) device = soap_features.block(0).values.device @@ -459,9 +457,7 @@ def forward( # first, send center_type to the samples dimension and make sure the # ordering is the same as in the systems merged_features = ( - metatensor.torch.sort( - features.keys_to_samples("center_type"), axes="samples" - ) + mts.sort(features.keys_to_samples("center_type"), axes="samples") .block() .values ) @@ -473,7 +469,7 @@ def forward( ) # also sort the original features to avoid problems - features = metatensor.torch.sort(features, axes="samples") + features = mts.sort(features, axes="samples") # split the long-range features back to center types center_types = torch.concatenate([system.types for system in systems]) @@ -497,7 +493,7 @@ def forward( ) # combine short- and long-range features - features = metatensor.torch.add(features, long_range_features) + features = mts.add(features, long_range_features) # output the hidden features, if requested: if "features" in outputs: @@ -632,7 +628,7 @@ def forward( selected_atoms, ) for name in additive_contributions: - return_dict[name] = metatensor.torch.add( + return_dict[name] = mts.add( return_dict[name], additive_contributions[name], ) diff --git a/src/metatrain/soap_bpnn/spherical.py b/src/metatrain/soap_bpnn/spherical.py index 5a849b26e..a7a54f7ec 100644 --- a/src/metatrain/soap_bpnn/spherical.py +++ b/src/metatrain/soap_bpnn/spherical.py @@ -3,7 +3,7 @@ import copy from typing import Dict, Optional -import metatensor.torch +import metatensor.torch as mts import numpy as np import sphericart.torch import torch @@ -92,7 +92,7 @@ def forward( atom_index_in_structure, ) if selected_atoms is not None: - spherical_expansion = metatensor.torch.slice( + spherical_expansion = mts.slice( spherical_expansion, "samples", selected_atoms ) @@ -104,7 +104,7 @@ def forward( ) # drop all L=0 blocks - spherical_expansion = metatensor.torch.drop_blocks( + spherical_expansion = mts.drop_blocks( spherical_expansion, keys=Labels( ["o3_lambda", "o3_sigma"], torch.tensor([[0, 1]], device=device) @@ -358,11 +358,9 @@ def forward( atom_index_in_structure, ) if selected_atoms is not None: - lambda_basis = metatensor.torch.slice( - lambda_basis, "samples", selected_atoms - ) + lambda_basis = mts.slice(lambda_basis, "samples", selected_atoms) lambda_basis = lambda_basis.keys_to_properties(self.neighbor_species_labels) - lambda_basis = metatensor.torch.drop_blocks( + lambda_basis = mts.drop_blocks( lambda_basis, keys=Labels( ["o3_lambda", "o3_sigma"], diff --git a/src/metatrain/soap_bpnn/tests/test_functionality.py b/src/metatrain/soap_bpnn/tests/test_functionality.py index cfe6138e1..8fe8c5737 100644 --- a/src/metatrain/soap_bpnn/tests/test_functionality.py +++ b/src/metatrain/soap_bpnn/tests/test_functionality.py @@ -1,6 +1,6 @@ import copy -import metatensor.torch +import metatensor.torch as mts import pytest import torch from jsonschema.exceptions import ValidationError @@ -100,7 +100,7 @@ def test_prediction_subset_atoms(): system_far_away_dimer, requested_neighbor_lists ) - selection_labels = metatensor.torch.Labels( + selection_labels = mts.Labels( names=["system", "atom"], values=torch.tensor([[0, 0], [0, 2], [0, 3]]), ) @@ -116,13 +116,9 @@ def test_prediction_subset_atoms(): selected_atoms=selection_labels, ) - assert not metatensor.torch.allclose( - energy_monomer["energy"], energy_dimer["energy"] - ) + assert not mts.allclose(energy_monomer["energy"], energy_dimer["energy"]) - assert metatensor.torch.allclose( - energy_monomer["energy"], energy_monomer_in_dimer["energy"] - ) + assert mts.allclose(energy_monomer["energy"], energy_monomer_in_dimer["energy"]) def test_output_last_layer_features(): diff --git a/src/metatrain/utils/additive/old_composition.py b/src/metatrain/utils/additive/old_composition.py index 11814a9d3..e3038b0f4 100644 --- a/src/metatrain/utils/additive/old_composition.py +++ b/src/metatrain/utils/additive/old_composition.py @@ -2,7 +2,7 @@ import warnings from typing import Dict, List, Optional, Union -import metatensor.torch +import metatensor.torch as mts import torch from metatensor.torch import Labels, LabelsEntry, TensorBlock, TensorMap from metatomic.torch import ModelOutput, System @@ -231,7 +231,7 @@ def train_model( # there is no center type, we need to add it # and we will rely on the fact that per-atom targets # should be in the same order as the atoms in the system - targets[target_key] = metatensor.torch.append_dimension( + targets[target_key] = mts.append_dimension( targets[target_key], "samples", "center_type", @@ -285,7 +285,7 @@ def train_model( if self.dataset_info.targets[target_key].per_atom: # hack: metatensor.join doesn't work on single blocks; # create TensorMaps, join, and then extract the joined block - joined_blocks = metatensor.torch.join( + joined_blocks = mts.join( [ TensorMap( keys=Labels.single(), @@ -296,8 +296,8 @@ def train_model( axis="samples", remove_tensor_name=True, ).block() - weights_tensor = metatensor.torch.sort_block( - metatensor.torch.mean_over_samples_block( + weights_tensor = mts.sort_block( + mts.mean_over_samples_block( joined_blocks, [ n @@ -338,9 +338,9 @@ def train_model( # make sure to update the weights buffer with the new weights self.register_buffer( target_key + "_composition_buffer", - metatensor.torch.save_buffer( - self.weights[target_key].to("cpu", torch.float64) - ).to(device), + mts.save_buffer(self.weights[target_key].to("cpu", torch.float64)).to( + device + ), ) def restart(self, dataset_info: DatasetInfo) -> "OldCompositionModel": @@ -458,7 +458,7 @@ def forward( # apply selected_atoms to the composition if needed if selected_atoms is not None: - composition_result_dict[output_name] = metatensor.torch.slice( + composition_result_dict[output_name] = mts.slice( composition_result_dict[output_name], "samples", selected_atoms ) @@ -523,7 +523,7 @@ def _add_output(self, target_name: str, target_info: TargetInfo) -> None: ) self.register_buffer( target_name + "_composition_buffer", - metatensor.torch.save_buffer(fake_weights), + mts.save_buffer(fake_weights), ) def weights_to(self, device: torch.device, dtype: torch.dtype): @@ -561,7 +561,7 @@ def sync_tensor_maps(self): # Reload the weights of the (old) targets, which are not stored in the model # state_dict, from the buffers for k in self.dataset_info.targets: - self.weights[k] = metatensor.torch.load_buffer( + self.weights[k] = mts.load_buffer( self.__getattr__(k + "_composition_buffer") ) diff --git a/src/metatrain/utils/additive/remove.py b/src/metatrain/utils/additive/remove.py index 039f6899e..b9598ba2f 100644 --- a/src/metatrain/utils/additive/remove.py +++ b/src/metatrain/utils/additive/remove.py @@ -1,7 +1,7 @@ import warnings from typing import Dict, List -import metatensor.torch +import metatensor.torch as mts import torch from metatensor.torch import TensorMap from metatensor.torch.operations._add import _add_block_block @@ -57,7 +57,7 @@ def remove_additive( blocks = [] for block_key, old_block in additive_contribution[target_key].items(): device = targets[target_key].block(block_key).values.device - block = metatensor.torch.TensorBlock( + block = mts.TensorBlock( values=old_block.values.detach().to(device=device), samples=targets[target_key].block(block_key).samples, components=[c.to(device=device) for c in old_block.components], @@ -71,7 +71,7 @@ def remove_additive( ) block.add_gradient( gradient_name, - metatensor.torch.TensorBlock( + mts.TensorBlock( values=gradient.values.detach(), samples=targets[target_key] .block(block_key) diff --git a/src/metatrain/utils/additive/zbl.py b/src/metatrain/utils/additive/zbl.py index f37ff7f80..efe9a71c4 100644 --- a/src/metatrain/utils/additive/zbl.py +++ b/src/metatrain/utils/additive/zbl.py @@ -1,7 +1,7 @@ import logging from typing import Dict, List, Optional -import metatensor.torch +import metatensor.torch as mts import torch from ase.data import covalent_radii from metatensor.torch import Labels, TensorBlock, TensorMap @@ -198,7 +198,7 @@ def forward( # apply selected_atoms to the composition if needed if selected_atoms is not None: - targets_out[target_key] = metatensor.torch.slice( + targets_out[target_key] = mts.slice( targets_out[target_key], "samples", selected_atoms ) diff --git a/src/metatrain/utils/data/readers/metatensor.py b/src/metatrain/utils/data/readers/metatensor.py index d5a70fff6..8768ae6f3 100644 --- a/src/metatrain/utils/data/readers/metatensor.py +++ b/src/metatrain/utils/data/readers/metatensor.py @@ -1,6 +1,6 @@ from typing import List, Tuple -import metatensor.torch +import metatensor.torch as mts import torch from metatensor.torch import Labels, TensorBlock, TensorMap from metatomic.torch import System @@ -22,7 +22,7 @@ def read_systems(filename: str) -> List[System]: def _wrapped_metatensor_read(filename) -> TensorMap: try: - return metatensor.torch.load(filename) + return mts.load(filename) except Exception as e: raise ValueError(f"Failed to read '{filename}' with torch: {e}") from e @@ -54,7 +54,7 @@ def read_energy(target: DictConfig) -> Tuple[TensorMap, TargetInfo]: ) ) ] - tensor_maps = metatensor.torch.split(tensor_map, "samples", selections) + tensor_maps = mts.split(tensor_map, "samples", selections) return tensor_maps, target_info @@ -79,7 +79,7 @@ def read_generic(target: DictConfig) -> Tuple[List[TensorMap], TargetInfo]: ) for i in torch.unique(tensor_map.block(0).samples.column("system")) ] - tensor_maps = metatensor.torch.split(tensor_map, "samples", selections) + tensor_maps = mts.split(tensor_map, "samples", selections) return tensor_maps, target_info diff --git a/src/metatrain/utils/data/writers/metatensor.py b/src/metatrain/utils/data/writers/metatensor.py index 49297c625..8846b0f40 100644 --- a/src/metatrain/utils/data/writers/metatensor.py +++ b/src/metatrain/utils/data/writers/metatensor.py @@ -1,7 +1,7 @@ from pathlib import Path from typing import Dict, List, Optional, Union -import metatensor.torch +import metatensor.torch as mts import torch from metatensor.torch import Labels, TensorBlock, TensorMap from metatomic.torch import ModelCapabilities, System @@ -41,7 +41,7 @@ def finish(self): # write out .mts files (writes one file per target) filename_base = Path(self.filename).stem for prediction_name, prediction_tmap in predictions.items(): - metatensor.torch.save( + mts.save( filename_base + "_" + prediction_name + ".mts", prediction_tmap.to("cpu").to(torch.float64), ) @@ -95,7 +95,7 @@ def _concatenate_tensormaps( system_counter += n_systems return { - target: metatensor.torch.join( + target: mts.join( [pred[target] for pred in tensormaps_shifted_systems], axis="samples", remove_tensor_name=True, diff --git a/src/metatrain/utils/data/writers/writers.py b/src/metatrain/utils/data/writers/writers.py index 43e368c9f..4eee7c36c 100644 --- a/src/metatrain/utils/data/writers/writers.py +++ b/src/metatrain/utils/data/writers/writers.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Dict, List, Optional, Union -import metatensor.torch +import metatensor.torch as mts import torch from metatensor.torch import Labels, TensorBlock, TensorMap from metatomic.torch import ModelCapabilities, System @@ -46,7 +46,7 @@ def _split_tensormaps( for i in range(len(systems)) ] batch_predictions_split = { - key: metatensor.torch.split(tensormap, "samples", split_selection) + key: mts.split(tensormap, "samples", split_selection) for key, tensormap in batch_predictions.items() } diff --git a/src/metatrain/utils/llpr.py b/src/metatrain/utils/llpr.py index 4a00bf8ea..c7f14ad5e 100644 --- a/src/metatrain/utils/llpr.py +++ b/src/metatrain/utils/llpr.py @@ -1,7 +1,7 @@ from pathlib import Path from typing import Dict, List, Optional, Union -import metatensor.torch +import metatensor.torch as mts import numpy as np import torch from metatensor.torch import Labels, TensorBlock, TensorMap @@ -241,7 +241,7 @@ def forward( ], ) - return_dict[uncertainty_name] = metatensor.torch.multiply( + return_dict[uncertainty_name] = mts.multiply( uncertainty, float(self._get_multiplier(uncertainty_name).item()) ) diff --git a/src/metatrain/utils/scaler.py b/src/metatrain/utils/scaler.py index c50a7534b..058d3b4c7 100644 --- a/src/metatrain/utils/scaler.py +++ b/src/metatrain/utils/scaler.py @@ -1,6 +1,6 @@ from typing import Dict, List, Union -import metatensor.torch +import metatensor.torch as mts import numpy as np import torch from metatensor.torch import TensorMap @@ -188,7 +188,7 @@ def forward( scale = float( self.scales[self.output_name_to_output_index[target_key]].item() ) - scaled_target = metatensor.torch.multiply(target, scale) + scaled_target = mts.multiply(target, scale) scaled_outputs[target_key] = scaled_target else: scaled_outputs[target_key] = target @@ -246,8 +246,6 @@ def remove_scale( scale = float( scaler.scales[scaler.output_name_to_output_index[target_key]].item() ) - scaled_targets[target_key] = metatensor.torch.multiply( - targets[target_key], 1.0 / scale - ) + scaled_targets[target_key] = mts.multiply(targets[target_key], 1.0 / scale) return scaled_targets diff --git a/tests/utils/data/test_readers_metatensor.py b/tests/utils/data/test_readers_metatensor.py index 3cdb7aded..76422800b 100644 --- a/tests/utils/data/test_readers_metatensor.py +++ b/tests/utils/data/test_readers_metatensor.py @@ -1,4 +1,4 @@ -import metatensor.torch +import metatensor.torch as mts import numpy as np import pytest import torch @@ -136,13 +136,11 @@ def test_read_energy(tmpdir, energy_tensor_map): } with tmpdir.as_cwd(): - metatensor.torch.save("energy.mts", energy_tensor_map) + mts.save("energy.mts", energy_tensor_map) tensor_maps, _ = read_energy(OmegaConf.create(conf)) - tensor_map = metatensor.torch.join( - tensor_maps, axis="samples", remove_tensor_name=True - ) - assert metatensor.torch.equal(tensor_map, energy_tensor_map) + tensor_map = mts.join(tensor_maps, axis="samples", remove_tensor_name=True) + assert mts.equal(tensor_map, energy_tensor_map) def test_read_generic_scalar(tmpdir, scalar_tensor_map): @@ -158,13 +156,11 @@ def test_read_generic_scalar(tmpdir, scalar_tensor_map): } with tmpdir.as_cwd(): - metatensor.torch.save("generic.mts", scalar_tensor_map) + mts.save("generic.mts", scalar_tensor_map) tensor_maps, _ = read_generic(OmegaConf.create(conf)) - tensor_map = metatensor.torch.join( - tensor_maps, axis="samples", remove_tensor_name=True - ) - assert metatensor.torch.equal(tensor_map, scalar_tensor_map) + tensor_map = mts.join(tensor_maps, axis="samples", remove_tensor_name=True) + assert mts.equal(tensor_map, scalar_tensor_map) def test_read_generic_spherical(tmpdir, spherical_tensor_map): @@ -187,13 +183,11 @@ def test_read_generic_spherical(tmpdir, spherical_tensor_map): } with tmpdir.as_cwd(): - metatensor.torch.save("generic.mts", spherical_tensor_map) + mts.save("generic.mts", spherical_tensor_map) tensor_maps, _ = read_generic(OmegaConf.create(conf)) - tensor_map = metatensor.torch.join( - tensor_maps, axis="samples", remove_tensor_name=True - ) - assert metatensor.torch.equal(tensor_map, spherical_tensor_map) + tensor_map = mts.join(tensor_maps, axis="samples", remove_tensor_name=True) + assert mts.equal(tensor_map, spherical_tensor_map) def test_read_generic_cartesian(tmpdir, cartesian_tensor_map): @@ -213,19 +207,17 @@ def test_read_generic_cartesian(tmpdir, cartesian_tensor_map): } with tmpdir.as_cwd(): - metatensor.torch.save("generic.mts", cartesian_tensor_map) + mts.save("generic.mts", cartesian_tensor_map) tensor_maps, _ = read_generic(OmegaConf.create(conf)) - tensor_map = metatensor.torch.join( - tensor_maps, axis="samples", remove_tensor_name=True - ) + tensor_map = mts.join(tensor_maps, axis="samples", remove_tensor_name=True) - assert metatensor.torch.equal(tensor_map, cartesian_tensor_map) + assert mts.equal(tensor_map, cartesian_tensor_map) def test_read_errors(tmpdir, energy_tensor_map, scalar_tensor_map): with tmpdir.as_cwd(): - metatensor.torch.save("energy.mts", energy_tensor_map) + mts.save("energy.mts", energy_tensor_map) conf = { "quantity": "energy", @@ -255,7 +247,7 @@ def test_read_errors(tmpdir, energy_tensor_map, scalar_tensor_map): read_energy(OmegaConf.create(conf)) conf["forces"] = False - metatensor.torch.save("scalar.mts", scalar_tensor_map) + mts.save("scalar.mts", scalar_tensor_map) conf["read_from"] = "scalar.mts" with pytest.raises(ValueError, match="Unexpected samples"): diff --git a/tests/utils/data/test_writers.py b/tests/utils/data/test_writers.py index 5cf703cc5..4ba7d524e 100644 --- a/tests/utils/data/test_writers.py +++ b/tests/utils/data/test_writers.py @@ -1,6 +1,6 @@ from typing import Dict, List, Tuple -import metatensor.torch +import metatensor.torch as mts import pytest import torch from metatensor.torch import Labels, TensorBlock, TensorMap @@ -239,7 +239,7 @@ def test_write_predictions(filename, fileformat, cell, monkeypatch, tmp_path): assert frame.info["stress"].shape == (3, 3) elif filename.endswith(".mts"): - tensormap = metatensor.torch.load(filename.split(".")[0] + "_energy.mts") + tensormap = mts.load(filename.split(".")[0] + "_energy.mts") assert tensormap.block().values.shape == (2, 1) assert tensormap.block().gradient("positions").values.shape == (4, 3, 1) if cell is not None: diff --git a/tests/utils/test_additive.py b/tests/utils/test_additive.py index c7f518209..fb1b27241 100644 --- a/tests/utils/test_additive.py +++ b/tests/utils/test_additive.py @@ -1,7 +1,7 @@ import logging from pathlib import Path -import metatensor.torch +import metatensor.torch as mts import pytest import torch from metatensor.torch import Labels, TensorBlock, TensorMap @@ -308,7 +308,7 @@ def test_old_composition_model_predict(): assert output["mtt::U0"].block().values.shape != (5, 1) # with selected_atoms - selected_atoms = metatensor.torch.Labels( + selected_atoms = mts.Labels( names=["system", "atom"], values=torch.tensor([[0, 0]]), ) @@ -397,7 +397,7 @@ def test_composition_model_predict(device): assert output["mtt::U0"].block().values.device.type == device # with selected_atoms - selected_atoms = metatensor.torch.Labels( + selected_atoms = mts.Labels( names=["system", "atom"], values=torch.tensor([[0, 0]]), ).to(device=device) @@ -413,7 +413,7 @@ def test_composition_model_predict(device): assert output["mtt::U0"].block().values.device.type == device # with selected_atoms - selected_atoms = metatensor.torch.Labels( + selected_atoms = mts.Labels( names=["system"], values=torch.tensor([[0]]), ).to(device=device) @@ -527,7 +527,7 @@ def test_old_remove_additive(): composition_model.train_model(dataset, []) # concatenate all targets - targets["mtt::U0"] = metatensor.torch.join(targets["mtt::U0"], axis="samples") + targets["mtt::U0"] = mts.join(targets["mtt::U0"], axis="samples") std_before = targets["mtt::U0"].block().values.std().item() remove_additive(systems, targets, composition_model, target_info) @@ -576,7 +576,7 @@ def test_remove_additive(): composition_model.train_model(dataloader, additive_models=[]) # concatenate all targets - targets["mtt::U0"] = metatensor.torch.join(targets["mtt::U0"], axis="samples") + targets["mtt::U0"] = mts.join(targets["mtt::U0"], axis="samples") std_before = targets["mtt::U0"].block().values.std().item() remove_additive(systems, targets, composition_model, target_info) @@ -851,7 +851,7 @@ def test_zbl(): assert output["mtt::U0"].block().values.shape != (5, 1) # with selected_atoms - selected_atoms = metatensor.torch.Labels( + selected_atoms = mts.Labels( names=["system", "atom"], values=torch.tensor([[0, 0]]), ) @@ -947,12 +947,8 @@ def test_old_composition_model_train_per_atom(where_is_center_type): tensor_map_1 = tensor_map_1.keys_to_samples("center_type") tensor_map_2 = tensor_map_2.keys_to_samples("center_type") if where_is_center_type == "nowhere": - tensor_map_1 = metatensor.torch.remove_dimension( - tensor_map_1, "samples", "center_type" - ) - tensor_map_2 = metatensor.torch.remove_dimension( - tensor_map_2, "samples", "center_type" - ) + tensor_map_1 = mts.remove_dimension(tensor_map_1, "samples", "center_type") + tensor_map_2 = mts.remove_dimension(tensor_map_2, "samples", "center_type") energies = [tensor_map_1, tensor_map_2] dataset = Dataset.from_dict({"system": systems, "energy": energies}) @@ -1068,12 +1064,8 @@ def test_composition_model_train_per_atom(where_is_center_type): tensor_map_1 = tensor_map_1.keys_to_samples("center_type") tensor_map_2 = tensor_map_2.keys_to_samples("center_type") if where_is_center_type == "nowhere": - tensor_map_1 = metatensor.torch.remove_dimension( - tensor_map_1, "samples", "center_type" - ) - tensor_map_2 = metatensor.torch.remove_dimension( - tensor_map_2, "samples", "center_type" - ) + tensor_map_1 = mts.remove_dimension(tensor_map_1, "samples", "center_type") + tensor_map_2 = mts.remove_dimension(tensor_map_2, "samples", "center_type") energies = [tensor_map_1, tensor_map_2] dataset = Dataset.from_dict({"system": systems, "energy": energies}) diff --git a/tests/utils/test_sum_over_atoms.py b/tests/utils/test_sum_over_atoms.py index b3163ca36..640148375 100644 --- a/tests/utils/test_sum_over_atoms.py +++ b/tests/utils/test_sum_over_atoms.py @@ -1,4 +1,4 @@ -import metatensor.torch +import metatensor.torch as mts import torch from metatensor.torch import Labels, TensorBlock, TensorMap @@ -35,9 +35,9 @@ def test_sum_over_atoms(): # Call the sum_over_atoms function summed_tensor_map = sum_over_atoms(tensor_map) - summed_tensor_map_ref = metatensor.torch.sum_over_samples( + summed_tensor_map_ref = mts.sum_over_samples( tensor_map, sample_names=["atom"], ) - assert metatensor.torch.allclose(summed_tensor_map, summed_tensor_map_ref) + assert mts.allclose(summed_tensor_map, summed_tensor_map_ref) diff --git a/tests/utils/test_transfer.py b/tests/utils/test_transfer.py index 5e8f5e39f..a1021a0c1 100644 --- a/tests/utils/test_transfer.py +++ b/tests/utils/test_transfer.py @@ -1,4 +1,4 @@ -import metatensor.torch +import metatensor.torch as mts import torch from metatensor.torch import Labels, TensorMap from metatomic.torch import System @@ -15,7 +15,7 @@ def test_batch_to_dtype(): ) targets = TensorMap( keys=Labels.single(), - blocks=[metatensor.torch.block_from_array(torch.tensor([[1.0]]))], + blocks=[mts.block_from_array(torch.tensor([[1.0]]))], ) systems = [system] @@ -41,7 +41,7 @@ def test_batch_to_device(): ) targets = TensorMap( keys=Labels.single(), - blocks=[metatensor.torch.block_from_array(torch.tensor([[1.0]]))], + blocks=[mts.block_from_array(torch.tensor([[1.0]]))], ) systems = [system] From a49e6132c44102fc680a8e3b0ab993b78320e7c7 Mon Sep 17 00:00:00 2001 From: Guillaume Fraux Date: Mon, 14 Jul 2025 13:59:43 +0200 Subject: [PATCH 2/2] Do not mix and match metatensor-core and metatensor-torch in gap --- src/metatrain/gap/model.py | 198 +++++++---------------------------- src/metatrain/gap/trainer.py | 6 -- 2 files changed, 36 insertions(+), 168 deletions(-) diff --git a/src/metatrain/gap/model.py b/src/metatrain/gap/model.py index 7b7988230..bf38372b7 100644 --- a/src/metatrain/gap/model.py +++ b/src/metatrain/gap/model.py @@ -6,10 +6,7 @@ import numpy as np import scipy import torch -from metatensor import Labels, TensorBlock, TensorMap -from metatensor.torch import Labels as TorchLabels -from metatensor.torch import TensorBlock as TorchTensorBlock -from metatensor.torch import TensorMap as TorchTensorMap +from metatensor.torch import Labels, TensorBlock, TensorMap from metatomic.torch import ( AtomisticModel, ModelCapabilities, @@ -121,14 +118,14 @@ def __init__(self, hypers: Dict, dataset_info: DatasetInfo) -> None: self._sampler = _FPS(n_to_select=self.hypers["krr"]["num_sparse_points"]) # set it do dummy keys, these are properly set during training - self._keys = TorchLabels.empty("_") + self._keys = Labels.empty("_") - dummy_weights = TorchTensorMap( - TorchLabels(["_"], torch.tensor([[0]])), + dummy_weights = TensorMap( + Labels(["_"], torch.tensor([[0]])), [mts.block_from_array(torch.empty(1, 1))], ) - dummy_X_pseudo = TorchTensorMap( - TorchLabels(["_"], torch.tensor([[0]])), + dummy_X_pseudo = TensorMap( + Labels(["_"], torch.tensor([[0]])), [mts.block_from_array(torch.empty(1, 1))], ) self._subset_of_regressors_torch = TorchSubsetofRegressors( @@ -138,7 +135,7 @@ def __init__(self, hypers: Dict, dataset_info: DatasetInfo) -> None: "aggregate_names": ["atom", "center_type"], }, ) - self._species_labels: TorchLabels = TorchLabels.empty("_") + self._species_labels: Labels = Labels.empty("_") # additive models: these are handled by the trainer at training # time, and they are added to the output at evaluation time @@ -186,8 +183,8 @@ def forward( self, systems: List[System], outputs: Dict[str, ModelOutput], - selected_atoms: Optional[TorchLabels] = None, - ) -> Dict[str, TorchTensorMap]: + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: soap_features = self._soap_torch_calculator( systems, selected_samples=selected_atoms ) @@ -195,18 +192,18 @@ def forward( self._keys = self._keys.to(systems[0].device) self._species_labels = self._species_labels.to(systems[0].device) - new_blocks: List[TorchTensorBlock] = [] + new_blocks: List[TensorBlock] = [] # HACK: to add a block of zeros if there are missing species # which were present at training time # (with samples "system", "atom" = 0, 0) # given the values are all zeros, it does not introduce an error - dummyblock: TorchTensorBlock = TorchTensorBlock( + dummyblock = TensorBlock( values=torch.zeros( (1, len(soap_features[0].properties)), dtype=systems[0].positions.dtype, device=systems[0].device, ), - samples=TorchLabels( + samples=Labels( ["system", "atom"], torch.tensor([[0, 0]], dtype=torch.int, device=systems[0].device), ), @@ -215,7 +212,7 @@ def forward( ) if len(soap_features[0].gradients_list()) > 0: for idx, grad in enumerate(soap_features[0].gradients_list()): - dummyblock_grad: TorchTensorBlock = TorchTensorBlock( + dummyblock_grad = TensorBlock( values=torch.zeros( ( 1, @@ -225,7 +222,7 @@ def forward( dtype=systems[0].positions.dtype, device=systems[0].device, ), - samples=TorchLabels( + samples=Labels( ["sample", "system", "atom"], torch.tensor( [[0, 0, 0]], dtype=torch.int, device=systems[0].device @@ -242,7 +239,7 @@ def forward( new_blocks.append(soap_features.block(key)) else: new_blocks.append(dummyblock) - soap_features = TorchTensorMap(keys=self._species_labels, blocks=new_blocks) + soap_features = TensorMap(keys=self._species_labels, blocks=new_blocks) soap_features = soap_features.keys_to_samples("center_type") # here, we move to properties to use metatensor operations to aggregate # later on. Perhaps we could retain the sparsity all the way to the kernels @@ -250,7 +247,7 @@ def forward( soap_features = soap_features.keys_to_properties( ["neighbor_1_type", "neighbor_2_type"] ) - soap_features = TorchTensorMap(self._keys, soap_features.blocks()) + soap_features = TensorMap(self._keys, soap_features.blocks()) output_key = list(outputs.keys())[0] energies = self._subset_of_regressors_torch(soap_features) return_dict = {output_key: energies} @@ -475,9 +472,9 @@ def __init__( def aggregate_kernel( self, - kernel: TorchTensorMap, + kernel: TensorMap, are_pseudo_points: Tuple[bool, bool] = (False, False), - ) -> TorchTensorMap: + ) -> TensorMap: if not are_pseudo_points[0]: kernel = mts.sum_over_samples(kernel, self._aggregate_names) if not are_pseudo_points[1]: @@ -488,17 +485,15 @@ def aggregate_kernel( def forward( self, - tensor1: TorchTensorMap, - tensor2: TorchTensorMap, + tensor1: TensorMap, + tensor2: TensorMap, are_pseudo_points: Tuple[bool, bool] = (False, False), - ) -> TorchTensorMap: + ) -> TensorMap: return self.aggregate_kernel( self.compute_kernel(tensor1, tensor2), are_pseudo_points ) - def compute_kernel( - self, tensor1: TorchTensorMap, tensor2: TorchTensorMap - ) -> TorchTensorMap: + def compute_kernel(self, tensor1: TensorMap, tensor2: TensorMap) -> TensorMap: raise NotImplementedError("compute_kernel needs to be implemented.") @@ -512,7 +507,7 @@ def __init__( super().__init__(aggregate_names, structurewise_aggregate) self._degree = degree - def compute_kernel(self, tensor1: TorchTensorMap, tensor2: TorchTensorMap): + def compute_kernel(self, tensor1: TensorMap, tensor2: TensorMap): return mts.pow(mts.dot(tensor1, tensor2), self._degree) @@ -546,10 +541,6 @@ def fit(self, X: TensorMap): # -> GreedySelector: :param X: Training vectors. """ - if isinstance(X, torch.ScriptObject): - X = torch_tensor_map_to_core(X) - assert isinstance(X[0].values, np.ndarray) - if len(X.component_names) != 0: raise ValueError("Only blocks with no components are supported.") @@ -578,7 +569,9 @@ def fit(self, X: TensorMap): # -> GreedySelector: blocks.append( TensorBlock( - values=np.zeros([len(samples), len(properties)], dtype=np.int32), + values=torch.zeros( + [len(samples), len(properties)], dtype=torch.int32 + ), samples=samples, components=[], properties=properties, @@ -596,12 +589,6 @@ def transform(self, X: TensorMap) -> TensorMap: :returns: The selected subset of the input. """ - if isinstance(X, torch.ScriptObject): - use_mts_torch = True - X = torch_tensor_map_to_core(X) - else: - use_mts_torch = False - blocks = [] for key, block in X.items(): block_support = self.support.block(key) @@ -614,10 +601,7 @@ def transform(self, X: TensorMap) -> TensorMap: new_block = mts.slice_block(block, "samples", block_support.samples) blocks.append(new_block) - X_reduced = TensorMap(X.keys, blocks) - if use_mts_torch: - X_reduced = core_tensor_map_to_torch(X_reduced) - return X_reduced + return TensorMap(X.keys, blocks) def fit_transform(self, X: TensorMap) -> TensorMap: """Fit to data, then transform it. @@ -628,112 +612,6 @@ def fit_transform(self, X: TensorMap) -> TensorMap: return self.fit(X).transform(X) -def torch_tensor_map_to_core(torch_tensor: TorchTensorMap): - torch_blocks = [] - for _, torch_block in torch_tensor.items(): - torch_blocks.append(torch_tensor_block_to_core(torch_block)) - torch_keys = torch_labels_to_core(torch_tensor.keys) - return TensorMap(torch_keys, torch_blocks) - - -def torch_tensor_block_to_core(torch_block: TorchTensorBlock): - """Transforms a tensor block from metatensor-torch to metatensor-torch - :param torch_block: - tensor block from metatensor-torch - :returns torch_block: - tensor block from metatensor-torch - """ - block = TensorBlock( - values=torch_block.values.detach().cpu().numpy(), - samples=torch_labels_to_core(torch_block.samples), - components=[ - torch_labels_to_core(component) for component in torch_block.components - ], - properties=torch_labels_to_core(torch_block.properties), - ) - for parameter, gradient in torch_block.gradients(): - block.add_gradient( - parameter=parameter, - gradient=TensorBlock( - values=gradient.values.detach().cpu().numpy(), - samples=torch_labels_to_core(gradient.samples), - components=[ - torch_labels_to_core(component) for component in gradient.components - ], - properties=torch_labels_to_core(gradient.properties), - ), - ) - return block - - -def torch_labels_to_core(torch_labels: TorchLabels): - """Transforms labels from metatensor-torch to metatensor-torch - :param torch_block: - tensor block from metatensor-torch - :returns torch_block: - labels from metatensor-torch - """ - return Labels(torch_labels.names, torch_labels.values.detach().cpu().numpy()) - - -### - - -def core_tensor_map_to_torch(core_tensor: TensorMap): - """Transforms a tensor map from metatensor-core to metatensor-torch - :param core_tensor: - tensor map from metatensor-core - :returns torch_tensor: - tensor map from metatensor-torch - """ - - torch_blocks = [] - for _, core_block in core_tensor.items(): - torch_blocks.append(core_tensor_block_to_torch(core_block)) - torch_keys = core_labels_to_torch(core_tensor.keys) - return TorchTensorMap(torch_keys, torch_blocks) - - -def core_tensor_block_to_torch(core_block: TensorBlock): - """Transforms a tensor block from metatensor-core to metatensor-torch - :param core_block: - tensor block from metatensor-core - :returns torch_block: - tensor block from metatensor-torch - """ - block = TorchTensorBlock( - values=torch.tensor(core_block.values), - samples=core_labels_to_torch(core_block.samples), - components=[ - core_labels_to_torch(component) for component in core_block.components - ], - properties=core_labels_to_torch(core_block.properties), - ) - for parameter, gradient in core_block.gradients(): - block.add_gradient( - parameter=parameter, - gradient=TorchTensorBlock( - values=torch.tensor(gradient.values), - samples=core_labels_to_torch(gradient.samples), - components=[ - core_labels_to_torch(component) for component in gradient.components - ], - properties=core_labels_to_torch(gradient.properties), - ), - ) - return block - - -def core_labels_to_torch(core_labels: Labels): - """Transforms labels from metatensor-core to metatensor-torch - :param core_block: - tensor block from metatensor-core - :returns torch_block: - labels from metatensor-torch - """ - return TorchLabels(core_labels.names, torch.tensor(core_labels.values)) - - class SubsetOfRegressors: def __init__( self, @@ -809,10 +687,6 @@ def fit( if not isinstance(alpha_forces, float): raise ValueError("alpha must either be a float") - X = X.to(arrays="numpy") - X_pseudo = X_pseudo.to(arrays="numpy") - y = y.to(arrays="numpy") - if self._kernel is None: # _set_kernel only returns None if kernel type is precomputed k_nm = X @@ -831,11 +705,11 @@ def fit( structures = torch.unique(k_nm_block.samples["system"]) n_atoms_per_structure = [] for structure in structures: - n_atoms = np.sum(X_block.samples["system"] == structure) + n_atoms = torch.sum(X_block.samples["system"] == structure) n_atoms_per_structure.append(float(n_atoms)) - n_atoms_per_structure = np.array(n_atoms_per_structure) - normalization = np.sqrt(n_atoms_per_structure) + n_atoms_per_structure = torch.tensor(n_atoms_per_structure) + normalization = torch.sqrt(n_atoms_per_structure) if not (np.allclose(alpha_energy, 0.0)): normalization /= alpha_energy @@ -871,7 +745,7 @@ def fit( self._solver.fit(k_nm_reg, y_reg) weight_block = TensorBlock( - values=self._solver.weights.T, + values=torch.as_tensor(self._solver.weights.T), samples=y_block.properties, components=k_nm_block.components, properties=k_nm_block.properties, @@ -901,8 +775,8 @@ def predict(self, T: TensorMap) -> TensorMap: def export_torch_script_model(self): return TorchSubsetofRegressors( - core_tensor_map_to_torch(self._weights), - core_tensor_map_to_torch(self._X_pseudo), + self._weights, + self._X_pseudo, self._kernel_kwargs, ) @@ -910,8 +784,8 @@ def export_torch_script_model(self): class TorchSubsetofRegressors(torch.nn.Module): def __init__( self, - weights: TorchTensorMap, - X_pseudo: TorchTensorMap, + weights: TensorMap, + X_pseudo: TensorMap, kernel_kwargs: Optional[dict] = None, ): super().__init__() @@ -923,7 +797,7 @@ def __init__( # Set the kernel self._kernel = TorchAggregatePolynomial(**kernel_kwargs) - def forward(self, T: TorchTensorMap) -> TorchTensorMap: + def forward(self, T: TensorMap) -> TensorMap: """ :param T: features diff --git a/src/metatrain/gap/trainer.py b/src/metatrain/gap/trainer.py index 92eb0cd42..843bac8ff 100644 --- a/src/metatrain/gap/trainer.py +++ b/src/metatrain/gap/trainer.py @@ -4,7 +4,6 @@ import metatensor.torch as mts import torch -from metatensor.torch import TensorMap from metatrain.utils.abc import TrainerInterface from metatrain.utils.additive import remove_additive @@ -15,7 +14,6 @@ ) from . import GAP -from .model import torch_tensor_map_to_core class Trainer(TrainerInterface): @@ -107,10 +105,6 @@ def train( train_tensor = train_tensor.keys_to_properties( ["neighbor_1_type", "neighbor_2_type"] ) - # change backend - train_tensor = TensorMap(train_y.keys, train_tensor.blocks()) - train_tensor = torch_tensor_map_to_core(train_tensor) - train_y = torch_tensor_map_to_core(train_y) logging.info("Selecting sparse points") lens = len(train_tensor[0].values)