Skip to content

Unit test for TensorWeightedAvg RecMetric #3139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions torchrec/metrics/tensor_weighted_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ def get_mean(value_sum: torch.Tensor, num_samples: torch.Tensor) -> torch.Tensor


class TensorWeightedAvgMetricComputation(RecMetricComputation):
"""
This class implements the RecMetricComputation for tensor weighted average.

It is a sibling to WeightedAvgMetricComputation, but it computes the weighted average of a tensor
passed in as a required input instead of the predictions tensor.
"""

def __init__(
self,
*args: Any,
Expand Down Expand Up @@ -116,15 +123,6 @@ class TensorWeightedAvgMetric(RecMetric):
_namespace: MetricNamespace = MetricNamespace.WEIGHTED_AVG
_computation_class: Type[RecMetricComputation] = TensorWeightedAvgMetricComputation

def __init__(
self,
# pyre-ignore Missing parameter annotation [2]
*args,
**kwargs: Dict[str, Any],
) -> None:

super().__init__(*args, **kwargs)

def _get_task_kwargs(
self, task_config: Union[RecTaskInfo, List[RecTaskInfo]]
) -> Dict[str, Any]:
Expand Down
77 changes: 57 additions & 20 deletions torchrec/metrics/test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def gen_test_tasks(
label_name=f"{task_name}-label",
prediction_name=f"{task_name}-prediction",
weight_name=f"{task_name}-weight",
tensor_name=f"{task_name}-tensor",
)
for task_name in task_names
]
Expand Down Expand Up @@ -131,7 +132,10 @@ def _aggregate(
@staticmethod
@abc.abstractmethod
def _get_states(
labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor
labels: torch.Tensor,
predictions: torch.Tensor,
weights: torch.Tensor,
required_inputs_tensor: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
pass

Expand Down Expand Up @@ -161,6 +165,7 @@ def compute(
aggregated_model_out[task_info.label_name],
aggregated_model_out[task_info.prediction_name],
aggregated_model_out[task_info.weight_name],
aggregated_model_out[task_info.tensor_name or "tensor"],
)
if self._compute_lifetime_metric:
self._aggregate(lifetime_states[task_info.name], states)
Expand All @@ -170,6 +175,7 @@ def compute(
model_outs[i][task_info.label_name],
model_outs[i][task_info.prediction_name],
model_outs[i][task_info.weight_name],
model_outs[i][task_info.tensor_name or "tensor"],
)
if self._local_compute_lifetime_metric:
self._aggregate(local_lifetime_states[task_info.name], local_states)
Expand Down Expand Up @@ -252,6 +258,7 @@ def rec_metric_value_test_helper(
label_name=task.label_name,
prediction_name=task.prediction_name,
weight_name=task.weight_name,
tensor_name=task.tensor_name or "tensor",
batch_size=batch_size,
n_classes=n_classes,
weight_value=weight_value,
Expand Down Expand Up @@ -288,8 +295,11 @@ def get_target_rec_metric_value(
**kwargs,
)
for i in range(nsteps):
labels, predictions, weights, _ = parse_task_model_outputs(
tasks, model_outs[i]
# Get required_inputs_list from the target metric
required_inputs_list = list(target_metric_obj.get_required_inputs())

labels, predictions, weights, required_inputs = parse_task_model_outputs(
tasks, model_outs[i], required_inputs_list
)
if target_compute_mode in [
RecComputeMode.FUSED_TASKS_COMPUTATION,
Expand All @@ -302,7 +312,10 @@ def get_target_rec_metric_value(
if timestamps is not None:
time_mock.return_value = timestamps[i]
target_metric_obj.update(
predictions=predictions, labels=labels, weights=weights
predictions=predictions,
labels=labels,
weights=weights,
required_inputs=required_inputs,
)
result_metrics = target_metric_obj.compute()
result_metrics.update(target_metric_obj.local_compute())
Expand Down Expand Up @@ -422,7 +435,7 @@ def sync_test_helper(
# pyre-ignore[6]: Incompatible parameter type
kwargs["number_of_classes"] = n_classes

auc = target_clazz(
target_metric_obj = target_clazz(
world_size=world_size,
batch_size=batch_size,
my_rank=rank,
Expand All @@ -440,6 +453,7 @@ def sync_test_helper(
label_name=task.label_name,
prediction_name=task.prediction_name,
weight_name=task.weight_name,
tensor_name=task.tensor_name or "tensor",
batch_size=batch_size,
n_classes=n_classes,
weight_value=weight_value,
Expand All @@ -450,19 +464,32 @@ def sync_test_helper(
model_outs = []
model_outs.append({k: v for d in _model_outs for k, v in d.items()})

# Get required_inputs from the target metric
required_inputs_list = list(target_metric_obj.get_required_inputs())

# we send an uneven number of tensors to each rank to test that GPU sync works
if rank == 0:
for _ in range(3):
labels, predictions, weights, _ = parse_task_model_outputs(
tasks, model_outs[0]
labels, predictions, weights, required_inputs = parse_task_model_outputs(
tasks, model_outs[0], required_inputs_list
)
target_metric_obj.update(
predictions=predictions,
labels=labels,
weights=weights,
required_inputs=required_inputs,
)
auc.update(predictions=predictions, labels=labels, weights=weights)
elif rank == 1:
for _ in range(1):
labels, predictions, weights, _ = parse_task_model_outputs(
tasks, model_outs[0]
labels, predictions, weights, required_inputs = parse_task_model_outputs(
tasks, model_outs[0], required_inputs_list
)
target_metric_obj.update(
predictions=predictions,
labels=labels,
weights=weights,
required_inputs=required_inputs,
)
auc.update(predictions=predictions, labels=labels, weights=weights)

# check against test metric
test_metrics: TestRecMetricOutput = ({}, {}, {}, {})
Expand All @@ -474,7 +501,7 @@ def sync_test_helper(
model_outs = model_outs * 2
test_metrics = test_metric_obj.compute(model_outs, 2, batch_window_size, None)

res = auc.compute()
res = target_metric_obj.compute()

if rank == 0:
# Serving Calibration uses Calibration naming inconsistently
Expand All @@ -490,21 +517,31 @@ def sync_test_helper(
)

# we also test the case where other rank has more tensors than rank 0
auc.reset()
target_metric_obj.reset()
if rank == 0:
for _ in range(1):
labels, predictions, weights, _ = parse_task_model_outputs(
tasks, model_outs[0]
labels, predictions, weights, required_inputs = parse_task_model_outputs(
tasks, model_outs[0], required_inputs_list
)
target_metric_obj.update(
predictions=predictions,
labels=labels,
weights=weights,
required_inputs=required_inputs,
)
auc.update(predictions=predictions, labels=labels, weights=weights)
elif rank == 1:
for _ in range(3):
labels, predictions, weights, _ = parse_task_model_outputs(
tasks, model_outs[0]
labels, predictions, weights, required_inputs = parse_task_model_outputs(
tasks, model_outs[0], required_inputs_list
)
target_metric_obj.update(
predictions=predictions,
labels=labels,
weights=weights,
required_inputs=required_inputs,
)
auc.update(predictions=predictions, labels=labels, weights=weights)

res = auc.compute()
res = target_metric_obj.compute()

if rank == 0:
# Serving Calibration uses Calibration naming inconsistently
Expand Down
7 changes: 5 additions & 2 deletions torchrec/metrics/tests/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# pyre-strict

import unittest
from typing import Dict, Iterable, Type, Union
from typing import Dict, Iterable, Optional, Type, Union

import torch
from torch import no_grad
Expand All @@ -31,7 +31,10 @@
class TestAccuracyMetric(TestMetric):
@staticmethod
def _get_states(
labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor
labels: torch.Tensor,
predictions: torch.Tensor,
weights: torch.Tensor,
required_inputs_tensor: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
predictions = predictions.double()
accuracy_sum = torch.sum(weights * ((predictions >= 0.5) == labels))
Expand Down
5 changes: 4 additions & 1 deletion torchrec/metrics/tests/test_auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ def _aggregate(

@staticmethod
def _get_states(
labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor
labels: torch.Tensor,
predictions: torch.Tensor,
weights: torch.Tensor,
required_inputs_tensor: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
return {
"predictions": predictions,
Expand Down
5 changes: 4 additions & 1 deletion torchrec/metrics/tests/test_auprc.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ def _aggregate(

@staticmethod
def _get_states(
labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor
labels: torch.Tensor,
predictions: torch.Tensor,
weights: torch.Tensor,
required_inputs_tensor: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
return {
"predictions": predictions,
Expand Down
7 changes: 5 additions & 2 deletions torchrec/metrics/tests/test_cali_free_ne.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# pyre-strict

import unittest
from typing import Dict, Type
from typing import Dict, Optional, Type

import torch
from torchrec.metrics.cali_free_ne import (
Expand All @@ -34,7 +34,10 @@ class TestCaliFreeNEMetric(TestMetric):

@staticmethod
def _get_states(
labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor
labels: torch.Tensor,
predictions: torch.Tensor,
weights: torch.Tensor,
required_inputs_tensor: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
cross_entropy = compute_cross_entropy(
labels, predictions, weights, TestCaliFreeNEMetric.eta
Expand Down
7 changes: 5 additions & 2 deletions torchrec/metrics/tests/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# pyre-strict

import unittest
from typing import Dict, Type
from typing import Dict, Optional, Type

import torch
from torchrec.metrics.calibration import CalibrationMetric
Expand All @@ -25,7 +25,10 @@
class TestCalibrationMetric(TestMetric):
@staticmethod
def _get_states(
labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor
labels: torch.Tensor,
predictions: torch.Tensor,
weights: torch.Tensor,
required_inputs_tensor: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
calibration_num = torch.sum(predictions * weights)
calibration_denom = torch.sum(labels * weights)
Expand Down
7 changes: 5 additions & 2 deletions torchrec/metrics/tests/test_ctr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# pyre-strict

import unittest
from typing import Dict, Type
from typing import Dict, Optional, Type

import torch
from torchrec.metrics.ctr import CTRMetric
Expand All @@ -25,7 +25,10 @@
class TestCTRMetric(TestMetric):
@staticmethod
def _get_states(
labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor
labels: torch.Tensor,
predictions: torch.Tensor,
weights: torch.Tensor,
required_inputs_tensor: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
ctr_num = torch.sum(labels * weights)
ctr_denom = torch.sum(weights)
Expand Down
3 changes: 2 additions & 1 deletion torchrec/metrics/tests/test_gauc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import unittest

from typing import Dict
from typing import Dict, Optional

import torch
from torchrec.metrics.gauc import compute_gauc_3d, compute_window_auc, GAUCMetric
Expand All @@ -24,6 +24,7 @@ def _get_states(
labels: torch.Tensor,
predictions: torch.Tensor,
weights: torch.Tensor,
required_inputs_tensor: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
gauc_res = compute_gauc_3d(predictions, labels, weights)
return {
Expand Down
17 changes: 13 additions & 4 deletions torchrec/metrics/tests/test_hindsight_target_pr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# pyre-strict

import unittest
from typing import Dict, Type
from typing import Dict, Optional, Type

import torch
from torchrec.metrics.hindsight_target_pr import (
Expand All @@ -32,7 +32,10 @@
class TestHindsightTargetPRMetric(TestMetric):
@staticmethod
def _get_states(
labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor
labels: torch.Tensor,
predictions: torch.Tensor,
weights: torch.Tensor,
required_inputs_tensor: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
predictions = predictions.double()
tp_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double)
Expand All @@ -59,7 +62,10 @@ def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor:
class TestHindsightTargetPrecisionMetric(TestMetric):
@staticmethod
def _get_states(
labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor
labels: torch.Tensor,
predictions: torch.Tensor,
weights: torch.Tensor,
required_inputs_tensor: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
predictions = predictions.double()
tp_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double)
Expand Down Expand Up @@ -89,7 +95,10 @@ def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor:
class TestHindsightTargetRecallMetric(TestMetric):
@staticmethod
def _get_states(
labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor
labels: torch.Tensor,
predictions: torch.Tensor,
weights: torch.Tensor,
required_inputs_tensor: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
predictions = predictions.double()
tp_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double)
Expand Down
7 changes: 5 additions & 2 deletions torchrec/metrics/tests/test_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# pyre-strict

import unittest
from typing import Dict, Type
from typing import Dict, Optional, Type

import torch
from torchrec.metrics.mae import compute_mae, MAEMetric
Expand All @@ -25,7 +25,10 @@
class TestMAEMetric(TestMetric):
@staticmethod
def _get_states(
labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor
labels: torch.Tensor,
predictions: torch.Tensor,
weights: torch.Tensor,
required_inputs_tensor: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
predictions = predictions.double()
error_sum = torch.sum(weights * torch.abs(labels - predictions))
Expand Down
Loading
Loading