diff --git a/torchrec/metrics/tensor_weighted_avg.py b/torchrec/metrics/tensor_weighted_avg.py index 2c582f4c0..580432351 100644 --- a/torchrec/metrics/tensor_weighted_avg.py +++ b/torchrec/metrics/tensor_weighted_avg.py @@ -25,6 +25,13 @@ def get_mean(value_sum: torch.Tensor, num_samples: torch.Tensor) -> torch.Tensor class TensorWeightedAvgMetricComputation(RecMetricComputation): + """ + This class implements the RecMetricComputation for tensor weighted average. + + It is a sibling to WeightedAvgMetricComputation, but it computes the weighted average of a tensor + passed in as a required input instead of the predictions tensor. + """ + def __init__( self, *args: Any, @@ -116,15 +123,6 @@ class TensorWeightedAvgMetric(RecMetric): _namespace: MetricNamespace = MetricNamespace.WEIGHTED_AVG _computation_class: Type[RecMetricComputation] = TensorWeightedAvgMetricComputation - def __init__( - self, - # pyre-ignore Missing parameter annotation [2] - *args, - **kwargs: Dict[str, Any], - ) -> None: - - super().__init__(*args, **kwargs) - def _get_task_kwargs( self, task_config: Union[RecTaskInfo, List[RecTaskInfo]] ) -> Dict[str, Any]: diff --git a/torchrec/metrics/test_utils/__init__.py b/torchrec/metrics/test_utils/__init__.py index 0a1085195..4fbf1b8e2 100644 --- a/torchrec/metrics/test_utils/__init__.py +++ b/torchrec/metrics/test_utils/__init__.py @@ -87,6 +87,7 @@ def gen_test_tasks( label_name=f"{task_name}-label", prediction_name=f"{task_name}-prediction", weight_name=f"{task_name}-weight", + tensor_name=f"{task_name}-tensor", ) for task_name in task_names ] @@ -131,7 +132,10 @@ def _aggregate( @staticmethod @abc.abstractmethod def _get_states( - labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: pass @@ -161,6 +165,7 @@ def compute( aggregated_model_out[task_info.label_name], aggregated_model_out[task_info.prediction_name], aggregated_model_out[task_info.weight_name], + aggregated_model_out[task_info.tensor_name or "tensor"], ) if self._compute_lifetime_metric: self._aggregate(lifetime_states[task_info.name], states) @@ -170,6 +175,7 @@ def compute( model_outs[i][task_info.label_name], model_outs[i][task_info.prediction_name], model_outs[i][task_info.weight_name], + model_outs[i][task_info.tensor_name or "tensor"], ) if self._local_compute_lifetime_metric: self._aggregate(local_lifetime_states[task_info.name], local_states) @@ -252,6 +258,7 @@ def rec_metric_value_test_helper( label_name=task.label_name, prediction_name=task.prediction_name, weight_name=task.weight_name, + tensor_name=task.tensor_name or "tensor", batch_size=batch_size, n_classes=n_classes, weight_value=weight_value, @@ -288,8 +295,11 @@ def get_target_rec_metric_value( **kwargs, ) for i in range(nsteps): - labels, predictions, weights, _ = parse_task_model_outputs( - tasks, model_outs[i] + # Get required_inputs_list from the target metric + required_inputs_list = list(target_metric_obj.get_required_inputs()) + + labels, predictions, weights, required_inputs = parse_task_model_outputs( + tasks, model_outs[i], required_inputs_list ) if target_compute_mode in [ RecComputeMode.FUSED_TASKS_COMPUTATION, @@ -302,7 +312,10 @@ def get_target_rec_metric_value( if timestamps is not None: time_mock.return_value = timestamps[i] target_metric_obj.update( - predictions=predictions, labels=labels, weights=weights + predictions=predictions, + labels=labels, + weights=weights, + required_inputs=required_inputs, ) result_metrics = target_metric_obj.compute() result_metrics.update(target_metric_obj.local_compute()) @@ -422,7 +435,7 @@ def sync_test_helper( # pyre-ignore[6]: Incompatible parameter type kwargs["number_of_classes"] = n_classes - auc = target_clazz( + target_metric_obj = target_clazz( world_size=world_size, batch_size=batch_size, my_rank=rank, @@ -440,6 +453,7 @@ def sync_test_helper( label_name=task.label_name, prediction_name=task.prediction_name, weight_name=task.weight_name, + tensor_name=task.tensor_name or "tensor", batch_size=batch_size, n_classes=n_classes, weight_value=weight_value, @@ -450,19 +464,32 @@ def sync_test_helper( model_outs = [] model_outs.append({k: v for d in _model_outs for k, v in d.items()}) + # Get required_inputs from the target metric + required_inputs_list = list(target_metric_obj.get_required_inputs()) + # we send an uneven number of tensors to each rank to test that GPU sync works if rank == 0: for _ in range(3): - labels, predictions, weights, _ = parse_task_model_outputs( - tasks, model_outs[0] + labels, predictions, weights, required_inputs = parse_task_model_outputs( + tasks, model_outs[0], required_inputs_list + ) + target_metric_obj.update( + predictions=predictions, + labels=labels, + weights=weights, + required_inputs=required_inputs, ) - auc.update(predictions=predictions, labels=labels, weights=weights) elif rank == 1: for _ in range(1): - labels, predictions, weights, _ = parse_task_model_outputs( - tasks, model_outs[0] + labels, predictions, weights, required_inputs = parse_task_model_outputs( + tasks, model_outs[0], required_inputs_list + ) + target_metric_obj.update( + predictions=predictions, + labels=labels, + weights=weights, + required_inputs=required_inputs, ) - auc.update(predictions=predictions, labels=labels, weights=weights) # check against test metric test_metrics: TestRecMetricOutput = ({}, {}, {}, {}) @@ -474,7 +501,7 @@ def sync_test_helper( model_outs = model_outs * 2 test_metrics = test_metric_obj.compute(model_outs, 2, batch_window_size, None) - res = auc.compute() + res = target_metric_obj.compute() if rank == 0: # Serving Calibration uses Calibration naming inconsistently @@ -490,21 +517,31 @@ def sync_test_helper( ) # we also test the case where other rank has more tensors than rank 0 - auc.reset() + target_metric_obj.reset() if rank == 0: for _ in range(1): - labels, predictions, weights, _ = parse_task_model_outputs( - tasks, model_outs[0] + labels, predictions, weights, required_inputs = parse_task_model_outputs( + tasks, model_outs[0], required_inputs_list + ) + target_metric_obj.update( + predictions=predictions, + labels=labels, + weights=weights, + required_inputs=required_inputs, ) - auc.update(predictions=predictions, labels=labels, weights=weights) elif rank == 1: for _ in range(3): - labels, predictions, weights, _ = parse_task_model_outputs( - tasks, model_outs[0] + labels, predictions, weights, required_inputs = parse_task_model_outputs( + tasks, model_outs[0], required_inputs_list + ) + target_metric_obj.update( + predictions=predictions, + labels=labels, + weights=weights, + required_inputs=required_inputs, ) - auc.update(predictions=predictions, labels=labels, weights=weights) - res = auc.compute() + res = target_metric_obj.compute() if rank == 0: # Serving Calibration uses Calibration naming inconsistently diff --git a/torchrec/metrics/tests/test_accuracy.py b/torchrec/metrics/tests/test_accuracy.py index fa46b3e87..c0d8f976a 100644 --- a/torchrec/metrics/tests/test_accuracy.py +++ b/torchrec/metrics/tests/test_accuracy.py @@ -8,7 +8,7 @@ # pyre-strict import unittest -from typing import Dict, Iterable, Type, Union +from typing import Dict, Iterable, Optional, Type, Union import torch from torch import no_grad @@ -31,7 +31,10 @@ class TestAccuracyMetric(TestMetric): @staticmethod def _get_states( - labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: predictions = predictions.double() accuracy_sum = torch.sum(weights * ((predictions >= 0.5) == labels)) diff --git a/torchrec/metrics/tests/test_auc.py b/torchrec/metrics/tests/test_auc.py index 36f389c86..98ae2b497 100644 --- a/torchrec/metrics/tests/test_auc.py +++ b/torchrec/metrics/tests/test_auc.py @@ -70,7 +70,10 @@ def _aggregate( @staticmethod def _get_states( - labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: return { "predictions": predictions, diff --git a/torchrec/metrics/tests/test_auprc.py b/torchrec/metrics/tests/test_auprc.py index e7172639a..57308004a 100644 --- a/torchrec/metrics/tests/test_auprc.py +++ b/torchrec/metrics/tests/test_auprc.py @@ -61,7 +61,10 @@ def _aggregate( @staticmethod def _get_states( - labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: return { "predictions": predictions, diff --git a/torchrec/metrics/tests/test_cali_free_ne.py b/torchrec/metrics/tests/test_cali_free_ne.py index 328dd7931..2deeeea95 100644 --- a/torchrec/metrics/tests/test_cali_free_ne.py +++ b/torchrec/metrics/tests/test_cali_free_ne.py @@ -8,7 +8,7 @@ # pyre-strict import unittest -from typing import Dict, Type +from typing import Dict, Optional, Type import torch from torchrec.metrics.cali_free_ne import ( @@ -34,7 +34,10 @@ class TestCaliFreeNEMetric(TestMetric): @staticmethod def _get_states( - labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: cross_entropy = compute_cross_entropy( labels, predictions, weights, TestCaliFreeNEMetric.eta diff --git a/torchrec/metrics/tests/test_calibration.py b/torchrec/metrics/tests/test_calibration.py index 6a2304485..64422cb3c 100644 --- a/torchrec/metrics/tests/test_calibration.py +++ b/torchrec/metrics/tests/test_calibration.py @@ -8,7 +8,7 @@ # pyre-strict import unittest -from typing import Dict, Type +from typing import Dict, Optional, Type import torch from torchrec.metrics.calibration import CalibrationMetric @@ -25,7 +25,10 @@ class TestCalibrationMetric(TestMetric): @staticmethod def _get_states( - labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: calibration_num = torch.sum(predictions * weights) calibration_denom = torch.sum(labels * weights) diff --git a/torchrec/metrics/tests/test_ctr.py b/torchrec/metrics/tests/test_ctr.py index efd45752c..fe5a96ab5 100644 --- a/torchrec/metrics/tests/test_ctr.py +++ b/torchrec/metrics/tests/test_ctr.py @@ -8,7 +8,7 @@ # pyre-strict import unittest -from typing import Dict, Type +from typing import Dict, Optional, Type import torch from torchrec.metrics.ctr import CTRMetric @@ -25,7 +25,10 @@ class TestCTRMetric(TestMetric): @staticmethod def _get_states( - labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: ctr_num = torch.sum(labels * weights) ctr_denom = torch.sum(weights) diff --git a/torchrec/metrics/tests/test_gauc.py b/torchrec/metrics/tests/test_gauc.py index 513988cff..0c05c73c9 100644 --- a/torchrec/metrics/tests/test_gauc.py +++ b/torchrec/metrics/tests/test_gauc.py @@ -9,7 +9,7 @@ import unittest -from typing import Dict +from typing import Dict, Optional import torch from torchrec.metrics.gauc import compute_gauc_3d, compute_window_auc, GAUCMetric @@ -24,6 +24,7 @@ def _get_states( labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: gauc_res = compute_gauc_3d(predictions, labels, weights) return { diff --git a/torchrec/metrics/tests/test_hindsight_target_pr.py b/torchrec/metrics/tests/test_hindsight_target_pr.py index 5cc9e406d..6130068be 100644 --- a/torchrec/metrics/tests/test_hindsight_target_pr.py +++ b/torchrec/metrics/tests/test_hindsight_target_pr.py @@ -8,7 +8,7 @@ # pyre-strict import unittest -from typing import Dict, Type +from typing import Dict, Optional, Type import torch from torchrec.metrics.hindsight_target_pr import ( @@ -32,7 +32,10 @@ class TestHindsightTargetPRMetric(TestMetric): @staticmethod def _get_states( - labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: predictions = predictions.double() tp_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double) @@ -59,7 +62,10 @@ def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: class TestHindsightTargetPrecisionMetric(TestMetric): @staticmethod def _get_states( - labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: predictions = predictions.double() tp_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double) @@ -89,7 +95,10 @@ def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: class TestHindsightTargetRecallMetric(TestMetric): @staticmethod def _get_states( - labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: predictions = predictions.double() tp_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double) diff --git a/torchrec/metrics/tests/test_mae.py b/torchrec/metrics/tests/test_mae.py index 7f7737e45..77c891d79 100644 --- a/torchrec/metrics/tests/test_mae.py +++ b/torchrec/metrics/tests/test_mae.py @@ -8,7 +8,7 @@ # pyre-strict import unittest -from typing import Dict, Type +from typing import Dict, Optional, Type import torch from torchrec.metrics.mae import compute_mae, MAEMetric @@ -25,7 +25,10 @@ class TestMAEMetric(TestMetric): @staticmethod def _get_states( - labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: predictions = predictions.double() error_sum = torch.sum(weights * torch.abs(labels - predictions)) diff --git a/torchrec/metrics/tests/test_mse.py b/torchrec/metrics/tests/test_mse.py index d1b1b7615..c778039dc 100644 --- a/torchrec/metrics/tests/test_mse.py +++ b/torchrec/metrics/tests/test_mse.py @@ -9,7 +9,7 @@ import unittest from functools import partial, update_wrapper -from typing import Callable, Dict, Type +from typing import Callable, Dict, Optional, Type import torch from torchrec.metrics.mse import compute_mse, compute_r_squared, compute_rmse, MSEMetric @@ -26,7 +26,10 @@ class TestMSEMetric(TestMetric): @staticmethod def _get_states( - labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: predictions = predictions.double() error_sum = torch.sum(weights * torch.square(labels - predictions)) @@ -49,7 +52,10 @@ def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: class TestRMSEMetric(TestMetric): @staticmethod def _get_states( - labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: predictions = predictions.double() error_sum = torch.sum(weights * torch.square(labels - predictions)) @@ -69,7 +75,10 @@ def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: class TestRSquaredMetric(TestMetric): @staticmethod def _get_states( - labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: predictions = predictions.double() error_sum = torch.sum(weights * torch.square(labels - predictions)) diff --git a/torchrec/metrics/tests/test_multiclass_recall.py b/torchrec/metrics/tests/test_multiclass_recall.py index d0c736b69..2f83311fc 100644 --- a/torchrec/metrics/tests/test_multiclass_recall.py +++ b/torchrec/metrics/tests/test_multiclass_recall.py @@ -8,7 +8,7 @@ # pyre-strict import unittest -from typing import Dict, Type +from typing import Dict, Optional, Type import torch from torchrec.metrics.multiclass_recall import ( @@ -37,6 +37,7 @@ def _get_states( labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: states = get_multiclass_recall_states( predictions, labels, weights, TestMulticlassRecallMetric.n_classes diff --git a/torchrec/metrics/tests/test_ne.py b/torchrec/metrics/tests/test_ne.py index 4a5a5359d..e2dcfc254 100644 --- a/torchrec/metrics/tests/test_ne.py +++ b/torchrec/metrics/tests/test_ne.py @@ -9,7 +9,7 @@ import unittest from functools import partial, update_wrapper -from typing import Callable, Dict, Type +from typing import Callable, Dict, Optional, Type import torch from torchrec.metrics.ne import ( @@ -36,7 +36,10 @@ class TestNEMetric(TestMetric): @staticmethod def _get_states( - labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: cross_entropy = compute_cross_entropy( labels, predictions, weights, TestNEMetric.eta @@ -74,7 +77,10 @@ class TestLoglossMetric(TestMetric): @staticmethod def _get_states( - labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: cross_entropy = compute_cross_entropy( labels, predictions, weights, TestNEMetric.eta diff --git a/torchrec/metrics/tests/test_precision.py b/torchrec/metrics/tests/test_precision.py index 8a58485f6..4b313e0f9 100644 --- a/torchrec/metrics/tests/test_precision.py +++ b/torchrec/metrics/tests/test_precision.py @@ -8,7 +8,7 @@ # pyre-strict import unittest -from typing import Dict, Iterable, Type, Union +from typing import Dict, Iterable, Optional, Type, Union import torch from torch import no_grad @@ -31,7 +31,10 @@ class TestPrecisionMetric(TestMetric): @staticmethod def _get_states( - labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: predictions = predictions.double() true_pos_sum = torch.sum(weights * ((predictions >= 0.5) * labels)) diff --git a/torchrec/metrics/tests/test_rauc.py b/torchrec/metrics/tests/test_rauc.py index be9a7dd4b..cad60ccb8 100644 --- a/torchrec/metrics/tests/test_rauc.py +++ b/torchrec/metrics/tests/test_rauc.py @@ -68,7 +68,10 @@ def _aggregate( @staticmethod def _get_states( - labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: return { "predictions": predictions, diff --git a/torchrec/metrics/tests/test_recall.py b/torchrec/metrics/tests/test_recall.py index d09faf464..0f7fad82e 100644 --- a/torchrec/metrics/tests/test_recall.py +++ b/torchrec/metrics/tests/test_recall.py @@ -8,7 +8,7 @@ # pyre-strict import unittest -from typing import Dict, Iterable, Type, Union +from typing import Dict, Iterable, Optional, Type, Union import torch from torch import no_grad @@ -31,7 +31,10 @@ class TestRecallMetric(TestMetric): @staticmethod def _get_states( - labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: predictions = predictions.double() true_pos_sum = torch.sum(weights * ((predictions >= 0.5) * labels)) diff --git a/torchrec/metrics/tests/test_serving_calibration.py b/torchrec/metrics/tests/test_serving_calibration.py index 810a69bfb..ea954c736 100644 --- a/torchrec/metrics/tests/test_serving_calibration.py +++ b/torchrec/metrics/tests/test_serving_calibration.py @@ -9,7 +9,7 @@ # pyre-strict import unittest -from typing import Dict, Type +from typing import Dict, Optional, Type import torch from torchrec.metrics.rec_metric import RecComputeMode, RecMetric @@ -28,7 +28,10 @@ class TestServingCalibrationMetric(TestMetric): @staticmethod def _get_states( - labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: calibration_num = torch.sum(predictions * weights) calibration_denom = torch.sum(labels * weights) diff --git a/torchrec/metrics/tests/test_serving_ne.py b/torchrec/metrics/tests/test_serving_ne.py index 888bc1278..1a993d4b2 100644 --- a/torchrec/metrics/tests/test_serving_ne.py +++ b/torchrec/metrics/tests/test_serving_ne.py @@ -8,8 +8,7 @@ # pyre-strict import unittest - -from typing import Dict, Type +from typing import Dict, Optional, Type import torch @@ -31,7 +30,10 @@ class TestNEMetric(TestMetric): @staticmethod def _get_states( - labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: cross_entropy = compute_cross_entropy( labels, predictions, weights, TestNEMetric.eta diff --git a/torchrec/metrics/tests/test_tensor_weighted_avg.py b/torchrec/metrics/tests/test_tensor_weighted_avg.py new file mode 100644 index 000000000..79041693d --- /dev/null +++ b/torchrec/metrics/tests/test_tensor_weighted_avg.py @@ -0,0 +1,506 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +import unittest +from typing import Dict, Iterable, Optional, Type, Union + +import torch +from torchrec.metrics.metrics_config import RecTaskInfo +from torchrec.metrics.rec_metric import RecComputeMode, RecMetric, RecMetricException +from torchrec.metrics.tensor_weighted_avg import get_mean, TensorWeightedAvgMetric +from torchrec.metrics.test_utils import ( + metric_test_helper, + rec_metric_gpu_sync_test_launcher, + rec_metric_value_test_launcher, + sync_test_helper, + TestMetric, +) + + +WORLD_SIZE = 4 +METRIC_NAMESPACE: str = TensorWeightedAvgMetric._namespace.value + + +class TestTensorWeightedAvgMetric(TestMetric): + @staticmethod + def _get_states( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, + ) -> Dict[str, torch.Tensor]: + """ + Compute states for tensor weighted average. + + For TensorWeightedAvgMetric, we use the 'required_inputs_tensor' parameter + which contains the actual tensor to compute the weighted average on. + """ + + if required_inputs_tensor is None: + raise ValueError("required_inputs_tensor cannot be None") + + # Compute weighted sum and weighted num samples using the target tensor + weighted_sum = (required_inputs_tensor * weights).sum(dim=-1) + weighted_num_samples = weights.sum(dim=-1) + + return { + "weighted_sum": weighted_sum, + "weighted_num_samples": weighted_num_samples, + } + + @staticmethod + def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: + return get_mean(states["weighted_sum"], states["weighted_num_samples"]) + + +class TensorWeightedAvgMetricTest(unittest.TestCase): + target_clazz: Type[RecMetric] = TensorWeightedAvgMetric + target_compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION + + def test_tensor_weighted_avg_unfused(self) -> None: + """Test TensorWeightedAvgMetric with UNFUSED_TASKS_COMPUTATION.""" + rec_metric_value_test_launcher( + target_clazz=TensorWeightedAvgMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestTensorWeightedAvgMetric, + metric_name=METRIC_NAMESPACE, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_tensor_weighted_avg_fused_fails(self) -> None: + """Test that TensorWeightedAvgMetric fails with FUSED_TASKS_COMPUTATION as expected.""" + # This test verifies the current limitation - FUSED mode should fail + with self.assertRaisesRegex( + RecMetricException, "expects task_config to be RecTaskInfo not" + ): + rec_metric_value_test_launcher( + target_clazz=TensorWeightedAvgMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + test_clazz=TestTensorWeightedAvgMetric, + metric_name=METRIC_NAMESPACE, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_tensor_weighted_avg_single_task(self) -> None: + rec_metric_value_test_launcher( + target_clazz=TensorWeightedAvgMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestTensorWeightedAvgMetric, + metric_name=METRIC_NAMESPACE, + task_names=["single_task"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + +class TensorWeightedAvgGPUSyncTest(unittest.TestCase): + """GPU synchronization tests for TensorWeightedAvgMetric.""" + + def test_sync_tensor_weighted_avg(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=TensorWeightedAvgMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestTensorWeightedAvgMetric, + metric_name=METRIC_NAMESPACE, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, + ) + + +class TensorWeightedAvgFunctionalityTest(unittest.TestCase): + """Test basic functionality of TensorWeightedAvgMetric.""" + + def test_tensor_weighted_avg_basic_functionality(self) -> None: + + tasks = [ + RecTaskInfo( + name="test_task", + label_name="test_label", + prediction_name="test_pred", + weight_name="test_weight", + tensor_name="test_tensor", + weighted=True, + ) + ] + metric = TensorWeightedAvgMetric( + world_size=1, + my_rank=0, + batch_size=4, + tasks=tasks, + compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + window_size=100, + ) + + self.assertIsNotNone(metric) + self.assertEqual(len(metric._metrics_computations), 1) + + computation = metric._metrics_computations[0] + self.assertEqual(computation.tensor_name, "test_tensor") + self.assertTrue(computation.weighted) + + def test_tensor_weighted_avg_unweighted_task(self) -> None: + + # Create an unweighted task + tasks = [ + RecTaskInfo( + name="unweighted_task", + label_name="test_label", + prediction_name="test_pred", + weight_name="test_weight", + tensor_name="test_tensor", + weighted=False, + ) + ] + + metric = TensorWeightedAvgMetric( + world_size=1, + my_rank=0, + batch_size=4, + tasks=tasks, + compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + window_size=100, + ) + + computation = metric._metrics_computations[0] + self.assertEqual(computation.tensor_name, "test_tensor") + self.assertFalse(computation.weighted) + + def test_tensor_weighted_avg_missing_tensor_name_throws_exception(self) -> None: + + # Create task with None tensor_name + tasks = [ + RecTaskInfo( + name="test_task", + label_name="test_label", + prediction_name="test_pred", + weight_name="test_weight", + tensor_name=None, + weighted=True, + ) + ] + + with self.assertRaisesRegex(RecMetricException, "tensor_name"): + TensorWeightedAvgMetric( + world_size=1, + my_rank=0, + batch_size=4, + tasks=tasks, + compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + window_size=100, + ) + + def test_tensor_weighted_avg_required_inputs_validation(self) -> None: + tasks = [ + RecTaskInfo( + name="test_task", + label_name="test_label", + prediction_name="test_pred", + weight_name="test_weight", + tensor_name="test_tensor", + weighted=True, + ) + ] + + metric = TensorWeightedAvgMetric( + world_size=1, + my_rank=0, + batch_size=2, + tasks=tasks, + compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + window_size=100, + ) + + # Test that required inputs are correctly identified + required_inputs = metric.get_required_inputs() + self.assertIn("test_tensor", required_inputs) + + # Test update with missing required inputs should fail + with self.assertRaisesRegex(RecMetricException, "required_inputs"): + metric.update( + predictions={"test_task": torch.tensor([0.1, 0.2])}, + labels={"test_task": torch.tensor([1.0, 0.0])}, + weights={"test_task": torch.tensor([1.0, 2.0])}, + ) + + +def generate_tensor_model_outputs_cases() -> Iterable[Dict[str, torch.Tensor]]: + """Generate test cases with known inputs and expected tensor weighted average outputs.""" + return [ + # Basic weighted case + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "tensors": torch.tensor([[2.0, 4.0, 6.0, 8.0, 10.0]]), + "weights": torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5]]), + # Expected: (2.0*0.1 + 4.0*0.2 + 6.0*0.3 + 8.0*0.4 + 10.0*0.5) / (0.1+0.2+0.3+0.4+0.5) = 11/1.5 = 7.3333 + "expected_tensor_weighted_avg": torch.tensor([7.3333]), + }, + # Uniform weights (should equal simple average) + { + "labels": torch.tensor([[1, 0, 1, 0]]), + "predictions": torch.tensor([[0.5, 0.5, 0.5, 0.5]]), + "tensors": torch.tensor([[1.0, 3.0, 5.0, 7.0]]), + "weights": torch.tensor([[1.0, 1.0, 1.0, 1.0]]), + # Expected: (1.0 + 3.0 + 5.0 + 7.0) / 4 = 16/4 = 4.0 + "expected_tensor_weighted_avg": torch.tensor([4.0]), + }, + # No weights (should default to uniform weights) + { + "labels": torch.tensor([[1, 0, 1]]), + "predictions": torch.tensor([[0.3, 0.7, 0.5]]), + "tensors": torch.tensor([[2.0, 8.0, 5.0]]), + # Expected: (2.0 + 8.0 + 5.0) / 3 = 15/3 = 5.0 + "expected_tensor_weighted_avg": torch.tensor([5.0]), + }, + # Single non-zero weight + { + "labels": torch.tensor([[1, 0, 1, 0]]), + "predictions": torch.tensor([[0.1, 0.2, 0.3, 0.4]]), + "tensors": torch.tensor([[10.0, 20.0, 30.0, 40.0]]), + "weights": torch.tensor([[0.0, 0.0, 1.0, 0.0]]), + # Expected: only third element contributes: 30.0/1.0 = 30.0 + "expected_tensor_weighted_avg": torch.tensor([30.0]), + }, + # All weights zero (should result in NaN) + { + "labels": torch.tensor([[1, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8]]), + "tensors": torch.tensor([[1.0, 2.0, 3.0]]), + "weights": torch.tensor([[0.0, 0.0, 0.0]]), + "expected_tensor_weighted_avg": torch.tensor([float("nan")]), + }, + # Negative tensor values + { + "labels": torch.tensor([[1, 0, 1]]), + "predictions": torch.tensor([[0.1, 0.5, 0.9]]), + "tensors": torch.tensor([[-2.0, 4.0, -6.0]]), + "weights": torch.tensor([[0.5, 0.3, 0.2]]), + # Expected: (-2.0*0.5 + 4.0*0.3 + -6.0*0.2) / (0.5+0.3+0.2) = (-1.0 + 1.2 - 1.2) / 1.0 = -1.0 + "expected_tensor_weighted_avg": torch.tensor([-1.0]), + }, + ] + + +class TensorWeightedAvgValueTest(unittest.TestCase): + """This set of tests verify the computation logic of tensor weighted avg in several + corner cases that we know the computation results. The goal is to + provide some confidence of the correctness of the math formula. + """ + + @torch.no_grad() + def _test_tensor_weighted_avg_helper( + self, + labels: torch.Tensor, + predictions: torch.Tensor, + tensors: torch.Tensor, + weights: Optional[torch.Tensor], + expected_tensor_weighted_avg: torch.Tensor, + ) -> None: + num_task = labels.shape[0] + batch_size = labels.shape[1] + task_list = [] + + predictions_dict: Dict[str, torch.Tensor] = {} + labels_dict: Dict[str, torch.Tensor] = {} + weights_dict: Optional[Dict[str, torch.Tensor]] = ( + {} if weights is not None else None + ) + required_inputs_dict: Dict[str, torch.Tensor] = {} + + for i in range(num_task): + task_info = RecTaskInfo( + name=f"Task:{i}", + label_name="label", + prediction_name="prediction", + weight_name="weight", + tensor_name="test_tensor", + weighted=True, + ) + task_list.append(task_info) + predictions_dict[task_info.name] = predictions[i] + labels_dict[task_info.name] = labels[i] + + # Ensure tensor_name is not None before using as dict key + tensor_name = task_info.tensor_name + if tensor_name is not None: + required_inputs_dict[tensor_name] = tensors[i] + + if weights is not None and weights_dict is not None: + weights_dict[task_info.name] = weights[i] + + inputs: Dict[str, Union[Dict[str, torch.Tensor], None]] = { + "predictions": predictions_dict, + "labels": labels_dict, + "weights": weights_dict, + "required_inputs": required_inputs_dict, + } + tensor_weighted_avg = TensorWeightedAvgMetric( + world_size=1, + my_rank=0, + batch_size=batch_size, + tasks=task_list, + ) + tensor_weighted_avg.update(**inputs) + actual_tensor_weighted_avg = tensor_weighted_avg.compute() + + for task_id, task in enumerate(task_list): + cur_actual_tensor_weighted_avg = actual_tensor_weighted_avg[ + f"weighted_avg-{task.name}|window_weighted_avg" + ] + cur_expected_tensor_weighted_avg = expected_tensor_weighted_avg[ + task_id + ].unsqueeze(dim=0) + + if cur_expected_tensor_weighted_avg.isnan().any(): + self.assertTrue(cur_actual_tensor_weighted_avg.isnan().any()) + else: + torch.testing.assert_close( + cur_actual_tensor_weighted_avg, + cur_expected_tensor_weighted_avg, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + msg=f"Actual: {cur_actual_tensor_weighted_avg}, Expected: {cur_expected_tensor_weighted_avg}", + ) + + def test_tensor_weighted_avg_computation_correctness(self) -> None: + """Test tensor weighted average computation correctness with known values.""" + test_data = generate_tensor_model_outputs_cases() + for inputs in test_data: + try: + # Extract and validate inputs + labels = inputs["labels"] + predictions = inputs["predictions"] + tensors = inputs["tensors"] + weights = inputs["weights"] if "weights" in inputs else None + expected = inputs["expected_tensor_weighted_avg"] + + # Call helper with properly typed arguments + self._test_tensor_weighted_avg_helper( + labels=labels, + predictions=predictions, + tensors=tensors, + weights=weights, + expected_tensor_weighted_avg=expected, + ) + except AssertionError: + print("Assertion error caught with data set ", inputs) + raise + + def test_tensor_weighted_vs_unweighted_computation(self) -> None: + """Test that weighted and unweighted computations produce different results when weights vary.""" + # Test data with non-uniform weights + labels = torch.tensor([[1, 0, 1, 0]]) + predictions = torch.tensor([[0.5, 0.5, 0.5, 0.5]]) + required_inputs_tensor = torch.tensor([[1.0, 2.0, 3.0, 4.0]]) + varying_weights = torch.tensor([[0.1, 0.2, 0.3, 0.4]]) + + # Weighted: (1.0*0.1 + 2.0*0.2 + 3.0*0.3 + 4.0*0.4) / (0.1+0.2+0.3+0.4) = 3.0/1.0 = 3.0 + expected_weighted_avg = torch.tensor([3.0]) + # Unweighted: (1.0 + 2.0 + 3.0 + 4.0) / 4 = 10.0/4 = 2.5 + expected_unweighted_avg = torch.tensor([2.5]) + + # Create weighted task + weighted_task = RecTaskInfo( + name="weighted_task", + label_name="label", + prediction_name="prediction", + weight_name="weight", + tensor_name="test_tensor", + weighted=True, + ) + + # Create unweighted task + unweighted_task = RecTaskInfo( + name="unweighted_task", + label_name="label", + prediction_name="prediction", + weight_name="weight", + tensor_name="test_tensor", + weighted=False, + ) + # Test weighted computation + weighted_metric = TensorWeightedAvgMetric( + world_size=1, + my_rank=0, + batch_size=4, + tasks=[weighted_task], + ) + + weighted_metric.update( + predictions={"weighted_task": predictions[0]}, + labels={"weighted_task": labels[0]}, + weights={"weighted_task": varying_weights[0]}, + required_inputs={"test_tensor": required_inputs_tensor[0]}, + ) + + weighted_result = weighted_metric.compute() + + # Test unweighted computation + unweighted_metric = TensorWeightedAvgMetric( + world_size=1, + my_rank=0, + batch_size=4, + tasks=[unweighted_task], + ) + + unweighted_metric.update( + predictions={"unweighted_task": predictions[0]}, + labels={"unweighted_task": labels[0]}, + weights={"unweighted_task": varying_weights[0]}, # ignored + required_inputs={"test_tensor": required_inputs_tensor[0]}, + ) + + unweighted_result = unweighted_metric.compute() + + # Results should be different + weighted_value = weighted_result[ + "weighted_avg-weighted_task|window_weighted_avg" + ] + unweighted_value = unweighted_result[ + "weighted_avg-unweighted_task|window_weighted_avg" + ] + + torch.testing.assert_close( + weighted_value, + expected_weighted_avg, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + msg=f"Actual: {weighted_value}, Expected: {expected_weighted_avg}", + ) + + torch.testing.assert_close( + unweighted_value, + expected_unweighted_avg, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + msg=f"Actual: {unweighted_value}, Expected: {expected_unweighted_avg}", + ) diff --git a/torchrec/metrics/tests/test_tower_qps.py b/torchrec/metrics/tests/test_tower_qps.py index 7dd91010d..5b0393b45 100644 --- a/torchrec/metrics/tests/test_tower_qps.py +++ b/torchrec/metrics/tests/test_tower_qps.py @@ -53,7 +53,10 @@ def __init__( # or weights @staticmethod def _get_states( - labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: return {} diff --git a/torchrec/metrics/tests/test_unweighted_ne.py b/torchrec/metrics/tests/test_unweighted_ne.py index 5a18178d0..97f9e8df6 100644 --- a/torchrec/metrics/tests/test_unweighted_ne.py +++ b/torchrec/metrics/tests/test_unweighted_ne.py @@ -8,7 +8,7 @@ # pyre-strict import unittest -from typing import Dict, Type +from typing import Dict, Optional, Type import torch from torchrec.metrics.rec_metric import RecComputeMode, RecMetric @@ -34,7 +34,10 @@ class TestUnweightedNEMetric(TestMetric): @staticmethod def _get_states( - labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: # Override the weights to be all ones weights = torch.ones_like(labels) diff --git a/torchrec/metrics/tests/test_weighted_avg.py b/torchrec/metrics/tests/test_weighted_avg.py index 226c06748..5ecd75cd5 100644 --- a/torchrec/metrics/tests/test_weighted_avg.py +++ b/torchrec/metrics/tests/test_weighted_avg.py @@ -26,7 +26,10 @@ class TestWeightedAvgMetric(TestMetric): @staticmethod def _get_states( - labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: return {