Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 90 additions & 19 deletions src/metatrain/utils/llpr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Callable, Dict, List, Optional
from collections import defaultdict
from typing import Callable, Dict, DefaultDict, List, Optional

import metatensor.torch
import numpy as np
Expand Down Expand Up @@ -33,26 +34,34 @@ class LLPRUncertaintyModel(torch.nn.Module):
def __init__(
self,
model: torch.jit.RecursiveScriptModule,
num_subtargets: Optional[DefaultDict] = defaultdict(lambda: 1),
# TODO: read `num_targets` from capabilities instead of user input
) -> None:
super().__init__()

self.model = model
self.ll_feat_size = self.model.module.last_layer_feature_size

# update capabilities: now we have additional outputs for the uncertainty
old_capabilities = self.model.capabilities()
additional_capabilities = {}
self.uncertainty_multipliers = {}
self.num_subtargets = num_subtargets
for name, output in old_capabilities.outputs.items():

if is_auxiliary_output(name):
continue # auxiliary output
continue # skip auxiliary outputs

uncertainty_name = f"mtt::aux::{name.replace('mtt::', '')}_uncertainty"

additional_capabilities[uncertainty_name] = ModelOutput(
quantity="",
unit=f"({output.unit})^2",
per_atom=True,
)
self.uncertainty_multipliers[uncertainty_name] = 1.0

# TODO: read `num_targets` from capabilities instead of user input
self.uncertainty_multipliers[uncertainty_name] = torch.ones(num_subtargets[name])

self.capabilities = ModelCapabilities(
outputs={**old_capabilities.outputs, **additional_capabilities},
atomic_types=old_capabilities.atomic_types,
Expand Down Expand Up @@ -163,7 +172,7 @@ def forward(
if name.startswith("mtt::aux::") and name.endswith("_uncertainty"):
requested_uncertainties.append(name)

for name in requested_uncertainties:
for name, orig_name in zip(requested_uncertainties, outputs.keys()):
ll_features = return_dict[
name.replace("_uncertainty", "_last_layer_features")
]
Expand All @@ -176,6 +185,7 @@ def forward(
self.inv_covariances[name],
ll_features.block().values,
).unsqueeze(1)

one_over_pr = TensorMap(
keys=Labels(
names=["_"],
Expand All @@ -185,23 +195,56 @@ def forward(
),
blocks=[
TensorBlock(
values=one_over_pr_values,
# TODO: multiple samples case (e.g. per_atom=True)
values=one_over_pr_values.expand(-1, self.num_subtargets[orig_name]),
samples=ll_features.block().samples,
components=ll_features.block().components,
properties=Labels(
names=["_"],
values=torch.tensor(
[[0]], device=ll_features.block().values.device
),
),
properties=Labels.range("properties",
self.num_subtargets[orig_name],
).to(ll_features.block().values.device),
#(
# names=["_"],
# values=torch.tensor(
# [[0]], device=ll_features.block().values.device
# ),
#),
)
],
)

# TODO: save multipliers directly as tensormap if possible
tsm_uq_multipliers = TensorMap(
keys=Labels(
names=["_"],
values=torch.tensor(
[[0]], device=ll_features.block().values.device
),
),
blocks=[
TensorBlock(
# TODO: multiple samples case (e.g. per_atom=True)
values=self.uncertainty_multipliers[name].expand(len(systems), -1),
samples=ll_features.block().samples,
components=ll_features.block().components,
properties=Labels.range("properties",
self.num_subtargets[orig_name],
).to(ll_features.block().values.device),
#(
# names=["_"],
# values=torch.tensor(
# [[0]], device=ll_features.block().values.device
# ),
#),
)
],
)

return_dict[name] = metatensor.torch.multiply(
one_over_pr, self.uncertainty_multipliers[name]
one_over_pr,
tsm_uq_multipliers,
)


# now deal with potential ensembles (see generate_ensemble method)
requested_ensembles: List[str] = []
for name in outputs.keys():
Expand Down Expand Up @@ -491,36 +534,43 @@ def calibrate(self, valid_loader: DataLoader):
This data loader should be generated from a dataset from the
``Dataset`` class in ``metatrain.utils.data``.
"""

# calibrate the LLPR
# TODO: in the future, we might want to have one calibration factor per
# property for outputs with multiple properties

device = next(iter(self.covariances.values())).device
dtype = next(iter(self.covariances.values())).dtype

all_predictions = {} # type: ignore
all_targets = {} # type: ignore
all_uncertainties = {} # type: ignore

for batch in valid_loader:
systems, targets = batch
systems = [system.to(device=device, dtype=dtype) for system in systems]
targets = {
name: target.to(device=device, dtype=dtype)
for name, target in targets.items()
}

# evaluate the targets and their uncertainties, not per atom
requested_outputs = {}
for name in targets:

requested_outputs[name] = ModelOutput(
quantity="",
unit="",
per_atom=False,
)

uncertainty_name = f"mtt::aux::{name.replace('mtt::', '')}_uncertainty"
requested_outputs[uncertainty_name] = ModelOutput(
quantity="",
unit="",
per_atom=False,
)

outputs = self.forward(systems, requested_outputs)

for name, target in targets.items():
uncertainty_name = f"mtt::aux::{name.replace('mtt::', '')}_uncertainty"
if name not in all_predictions:
Expand All @@ -543,15 +593,36 @@ def calibrate(self, valid_loader: DataLoader):

for name in all_predictions:
# compute the uncertainty multiplier
residuals = all_predictions[name] - all_targets[name]
uncertainty_name = f"mtt::aux::{name.replace('mtt::', '')}_uncertainty"
uncertainties = all_uncertainties[uncertainty_name]
self.uncertainty_multipliers[uncertainty_name] = torch.mean(
residuals**2 / uncertainties
).item()

# exceptional case for DOS
# if name == "mtt::dos":
# residuals = agonistic_residual() ### TODO: with WB
# self.uncertainty_multipliers[uncertainty_name] = torch.mean(
# residuals**2 / uncertainties,
# axis=0,
# )

# generic case with num_subtargets > 1
if self.num_subtargets[name] > 1:
residuals = all_predictions[name] - all_targets[name]
self.uncertainty_multipliers[uncertainty_name] = torch.mean(
residuals**2 / uncertainties,
axis=0,
) # tensor

else:
residuals = all_predictions[name] - all_targets[name]
self.uncertainty_multipliers[uncertainty_name] = torch.mean(
residuals**2 / uncertainties
).item() # float


self.is_calibrated = True



def generate_ensemble(
self, weight_tensors: Dict[str, torch.Tensor], n_members: int
) -> None:
Expand Down
Loading