Skip to content

Commit 1d4d1e5

Browse files
jeffkbkimfacebook-github-bot
authored andcommitted
Unit test for TensorWeightedAvg RecMetric (#3139)
Summary: Pull Request resolved: #3139 Main test lies in `test_tensor_weighted_avg.py` (format similar to `test_weighted_avg.py`) This patch adds a unit test for TensorWeightedAvgMetric. TensorWeightedAvgMetric allows users to pass in a tensor that is not a prediction/weight/label with a corresponding tensor name, into a dict called "required_inputs". The purpose is to compute and emit the weighted average of a tensor of their choice. However, the existing framework for the rec metric testing module does not support processing tensors under "required_inputs". Main modifications to other tests include changes to the TestMetric's _get_states() method signature to support processing tensor passed into "required_inputs". Increases code coverage to 97.7%. [Previous external tests covered 35%](https://www.internalfb.com/quality/coverage?query=%7B%22key%22%3A%22AND%22%2C%22children%22%3A[%7B%22key%22%3A%22CONTAINS_ANY_OF_FBIDS%22%2C%22field%22%3A%22CODE_QUALITY_ONCALL%22%2C%22value%22%3A[%7B%22title%22%3A%22torchrec%20(torchrec%20%3A%20Zain%20Huda)%22%2C%22fbid%22%3A%22711323163133477%22%2C%22photo%22%3A%22%22%7D]%7D%2C%7B%22key%22%3A%22CODE_QUALITY_NOT_STARTS_WITH%22%2C%22field%22%3A%22CODE_QUALITY_FILE_PATH%22%2C%22value%22%3A%22fbcode%2Ftorchrec%2Ffb%22%7D%2C%7B%22key%22%3A%22CODE_QUALITY_NOT_CONTAINS_TEXT%22%2C%22field%22%3A%22CODE_QUALITY_FILE_PATH%22%2C%22value%22%3A%22tests%22%7D%2C%7B%22key%22%3A%22CODE_QUALITY_NOT_CONTAINS_TEXT%22%2C%22field%22%3A%22CODE_QUALITY_FILE_PATH%22%2C%22value%22%3A%22__init__.py%22%7D%2C%7B%22key%22%3A%22CODE_QUALITY_NOT_STARTS_WITH%22%2C%22field%22%3A%22CODE_QUALITY_FILE_PATH%22%2C%22value%22%3A%22fbcode%2Ftorchrec%2Fgithub%2F%22%7D%2C%7B%22key%22%3A%22CODE_QUALITY_NOT_STARTS_WITH%22%2C%22field%22%3A%22CODE_QUALITY_FILE_PATH%22%2C%22value%22%3A%22fbcode%2Ftorchrec%2Fcsrc%2F%22%7D%2C%7B%22key%22%3A%22CODE_QUALITY_NOT_STARTS_WITH%22%2C%22field%22%3A%22CODE_QUALITY_FILE_PATH%22%2C%22value%22%3A%22fbcode%2Ftorchrec%2Fdatasets%22%7D%2C%7B%22key%22%3A%22CODE_QUALITY_NOT_STARTS_WITH%22%2C%22field%22%3A%22CODE_QUALITY_FILE_PATH%22%2C%22value%22%3A%22fbcode%2Ftorchrec%2Finference%22%7D%2C%7B%22key%22%3A%22CODE_QUALITY_NOT_STARTS_WITH%22%2C%22field%22%3A%22CODE_QUALITY_FILE_PATH%22%2C%22value%22%3A%22fbcode%2Ftorchrec%2Fpt2%22%7D%2C%7B%22key%22%3A%22CODE_QUALITY_NOT_CONTAINS_TEXT%22%2C%22field%22%3A%22CODE_QUALITY_FILE_PATH%22%2C%22value%22%3A%22benchmark%22%7D%2C%7B%22key%22%3A%22CODE_QUALITY_STARTS_WITH%22%2C%22field%22%3A%22CODE_QUALITY_FILE_PATH%22%2C%22value%22%3A%22fbcode%2Ftorchrec%2Fmetrics%2F%22%7D]%7D&is_e2e=false&show_all_partitions=false&split_graph=true) Reviewed By: kausv Differential Revision: D77169334 fbshipit-source-id: db2a27d602622441ad3a8c0da8430a9891cf647f
1 parent 3ef5b37 commit 1d4d1e5

23 files changed

+679
-68
lines changed

torchrec/metrics/tensor_weighted_avg.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@ def get_mean(value_sum: torch.Tensor, num_samples: torch.Tensor) -> torch.Tensor
2525

2626

2727
class TensorWeightedAvgMetricComputation(RecMetricComputation):
28+
"""
29+
This class implements the RecMetricComputation for tensor weighted average.
30+
31+
It is a sibling to WeightedAvgMetricComputation, but it computes the weighted average of a tensor
32+
passed in as a required input instead of the predictions tensor.
33+
"""
34+
2835
def __init__(
2936
self,
3037
*args: Any,
@@ -116,15 +123,6 @@ class TensorWeightedAvgMetric(RecMetric):
116123
_namespace: MetricNamespace = MetricNamespace.WEIGHTED_AVG
117124
_computation_class: Type[RecMetricComputation] = TensorWeightedAvgMetricComputation
118125

119-
def __init__(
120-
self,
121-
# pyre-ignore Missing parameter annotation [2]
122-
*args,
123-
**kwargs: Dict[str, Any],
124-
) -> None:
125-
126-
super().__init__(*args, **kwargs)
127-
128126
def _get_task_kwargs(
129127
self, task_config: Union[RecTaskInfo, List[RecTaskInfo]]
130128
) -> Dict[str, Any]:

torchrec/metrics/test_utils/__init__.py

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def gen_test_tasks(
8787
label_name=f"{task_name}-label",
8888
prediction_name=f"{task_name}-prediction",
8989
weight_name=f"{task_name}-weight",
90+
tensor_name=f"{task_name}-tensor",
9091
)
9192
for task_name in task_names
9293
]
@@ -131,7 +132,10 @@ def _aggregate(
131132
@staticmethod
132133
@abc.abstractmethod
133134
def _get_states(
134-
labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor
135+
labels: torch.Tensor,
136+
predictions: torch.Tensor,
137+
weights: torch.Tensor,
138+
required_inputs_tensor: Optional[torch.Tensor] = None,
135139
) -> Dict[str, torch.Tensor]:
136140
pass
137141

@@ -161,6 +165,7 @@ def compute(
161165
aggregated_model_out[task_info.label_name],
162166
aggregated_model_out[task_info.prediction_name],
163167
aggregated_model_out[task_info.weight_name],
168+
aggregated_model_out[task_info.tensor_name or "tensor"],
164169
)
165170
if self._compute_lifetime_metric:
166171
self._aggregate(lifetime_states[task_info.name], states)
@@ -170,6 +175,7 @@ def compute(
170175
model_outs[i][task_info.label_name],
171176
model_outs[i][task_info.prediction_name],
172177
model_outs[i][task_info.weight_name],
178+
model_outs[i][task_info.tensor_name or "tensor"],
173179
)
174180
if self._local_compute_lifetime_metric:
175181
self._aggregate(local_lifetime_states[task_info.name], local_states)
@@ -252,6 +258,7 @@ def rec_metric_value_test_helper(
252258
label_name=task.label_name,
253259
prediction_name=task.prediction_name,
254260
weight_name=task.weight_name,
261+
tensor_name=task.tensor_name or "tensor",
255262
batch_size=batch_size,
256263
n_classes=n_classes,
257264
weight_value=weight_value,
@@ -288,8 +295,11 @@ def get_target_rec_metric_value(
288295
**kwargs,
289296
)
290297
for i in range(nsteps):
291-
labels, predictions, weights, _ = parse_task_model_outputs(
292-
tasks, model_outs[i]
298+
# Get required_inputs_list from the target metric
299+
required_inputs_list = list(target_metric_obj.get_required_inputs())
300+
301+
labels, predictions, weights, required_inputs = parse_task_model_outputs(
302+
tasks, model_outs[i], required_inputs_list
293303
)
294304
if target_compute_mode in [
295305
RecComputeMode.FUSED_TASKS_COMPUTATION,
@@ -302,7 +312,10 @@ def get_target_rec_metric_value(
302312
if timestamps is not None:
303313
time_mock.return_value = timestamps[i]
304314
target_metric_obj.update(
305-
predictions=predictions, labels=labels, weights=weights
315+
predictions=predictions,
316+
labels=labels,
317+
weights=weights,
318+
required_inputs=required_inputs,
306319
)
307320
result_metrics = target_metric_obj.compute()
308321
result_metrics.update(target_metric_obj.local_compute())
@@ -422,7 +435,7 @@ def sync_test_helper(
422435
# pyre-ignore[6]: Incompatible parameter type
423436
kwargs["number_of_classes"] = n_classes
424437

425-
auc = target_clazz(
438+
target_metric_obj = target_clazz(
426439
world_size=world_size,
427440
batch_size=batch_size,
428441
my_rank=rank,
@@ -440,6 +453,7 @@ def sync_test_helper(
440453
label_name=task.label_name,
441454
prediction_name=task.prediction_name,
442455
weight_name=task.weight_name,
456+
tensor_name=task.tensor_name or "tensor",
443457
batch_size=batch_size,
444458
n_classes=n_classes,
445459
weight_value=weight_value,
@@ -450,19 +464,32 @@ def sync_test_helper(
450464
model_outs = []
451465
model_outs.append({k: v for d in _model_outs for k, v in d.items()})
452466

467+
# Get required_inputs from the target metric
468+
required_inputs_list = list(target_metric_obj.get_required_inputs())
469+
453470
# we send an uneven number of tensors to each rank to test that GPU sync works
454471
if rank == 0:
455472
for _ in range(3):
456-
labels, predictions, weights, _ = parse_task_model_outputs(
457-
tasks, model_outs[0]
473+
labels, predictions, weights, required_inputs = parse_task_model_outputs(
474+
tasks, model_outs[0], required_inputs_list
475+
)
476+
target_metric_obj.update(
477+
predictions=predictions,
478+
labels=labels,
479+
weights=weights,
480+
required_inputs=required_inputs,
458481
)
459-
auc.update(predictions=predictions, labels=labels, weights=weights)
460482
elif rank == 1:
461483
for _ in range(1):
462-
labels, predictions, weights, _ = parse_task_model_outputs(
463-
tasks, model_outs[0]
484+
labels, predictions, weights, required_inputs = parse_task_model_outputs(
485+
tasks, model_outs[0], required_inputs_list
486+
)
487+
target_metric_obj.update(
488+
predictions=predictions,
489+
labels=labels,
490+
weights=weights,
491+
required_inputs=required_inputs,
464492
)
465-
auc.update(predictions=predictions, labels=labels, weights=weights)
466493

467494
# check against test metric
468495
test_metrics: TestRecMetricOutput = ({}, {}, {}, {})
@@ -474,7 +501,7 @@ def sync_test_helper(
474501
model_outs = model_outs * 2
475502
test_metrics = test_metric_obj.compute(model_outs, 2, batch_window_size, None)
476503

477-
res = auc.compute()
504+
res = target_metric_obj.compute()
478505

479506
if rank == 0:
480507
# Serving Calibration uses Calibration naming inconsistently
@@ -490,21 +517,31 @@ def sync_test_helper(
490517
)
491518

492519
# we also test the case where other rank has more tensors than rank 0
493-
auc.reset()
520+
target_metric_obj.reset()
494521
if rank == 0:
495522
for _ in range(1):
496-
labels, predictions, weights, _ = parse_task_model_outputs(
497-
tasks, model_outs[0]
523+
labels, predictions, weights, required_inputs = parse_task_model_outputs(
524+
tasks, model_outs[0], required_inputs_list
525+
)
526+
target_metric_obj.update(
527+
predictions=predictions,
528+
labels=labels,
529+
weights=weights,
530+
required_inputs=required_inputs,
498531
)
499-
auc.update(predictions=predictions, labels=labels, weights=weights)
500532
elif rank == 1:
501533
for _ in range(3):
502-
labels, predictions, weights, _ = parse_task_model_outputs(
503-
tasks, model_outs[0]
534+
labels, predictions, weights, required_inputs = parse_task_model_outputs(
535+
tasks, model_outs[0], required_inputs_list
536+
)
537+
target_metric_obj.update(
538+
predictions=predictions,
539+
labels=labels,
540+
weights=weights,
541+
required_inputs=required_inputs,
504542
)
505-
auc.update(predictions=predictions, labels=labels, weights=weights)
506543

507-
res = auc.compute()
544+
res = target_metric_obj.compute()
508545

509546
if rank == 0:
510547
# Serving Calibration uses Calibration naming inconsistently

torchrec/metrics/tests/test_accuracy.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# pyre-strict
99

1010
import unittest
11-
from typing import Dict, Iterable, Type, Union
11+
from typing import Dict, Iterable, Optional, Type, Union
1212

1313
import torch
1414
from torch import no_grad
@@ -31,7 +31,10 @@
3131
class TestAccuracyMetric(TestMetric):
3232
@staticmethod
3333
def _get_states(
34-
labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor
34+
labels: torch.Tensor,
35+
predictions: torch.Tensor,
36+
weights: torch.Tensor,
37+
required_inputs_tensor: Optional[torch.Tensor] = None,
3538
) -> Dict[str, torch.Tensor]:
3639
predictions = predictions.double()
3740
accuracy_sum = torch.sum(weights * ((predictions >= 0.5) == labels))

torchrec/metrics/tests/test_auc.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,10 @@ def _aggregate(
7070

7171
@staticmethod
7272
def _get_states(
73-
labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor
73+
labels: torch.Tensor,
74+
predictions: torch.Tensor,
75+
weights: torch.Tensor,
76+
required_inputs_tensor: Optional[torch.Tensor] = None,
7477
) -> Dict[str, torch.Tensor]:
7578
return {
7679
"predictions": predictions,

torchrec/metrics/tests/test_auprc.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,10 @@ def _aggregate(
6161

6262
@staticmethod
6363
def _get_states(
64-
labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor
64+
labels: torch.Tensor,
65+
predictions: torch.Tensor,
66+
weights: torch.Tensor,
67+
required_inputs_tensor: Optional[torch.Tensor] = None,
6568
) -> Dict[str, torch.Tensor]:
6669
return {
6770
"predictions": predictions,

torchrec/metrics/tests/test_cali_free_ne.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# pyre-strict
99

1010
import unittest
11-
from typing import Dict, Type
11+
from typing import Dict, Optional, Type
1212

1313
import torch
1414
from torchrec.metrics.cali_free_ne import (
@@ -34,7 +34,10 @@ class TestCaliFreeNEMetric(TestMetric):
3434

3535
@staticmethod
3636
def _get_states(
37-
labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor
37+
labels: torch.Tensor,
38+
predictions: torch.Tensor,
39+
weights: torch.Tensor,
40+
required_inputs_tensor: Optional[torch.Tensor] = None,
3841
) -> Dict[str, torch.Tensor]:
3942
cross_entropy = compute_cross_entropy(
4043
labels, predictions, weights, TestCaliFreeNEMetric.eta

torchrec/metrics/tests/test_calibration.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# pyre-strict
99

1010
import unittest
11-
from typing import Dict, Type
11+
from typing import Dict, Optional, Type
1212

1313
import torch
1414
from torchrec.metrics.calibration import CalibrationMetric
@@ -25,7 +25,10 @@
2525
class TestCalibrationMetric(TestMetric):
2626
@staticmethod
2727
def _get_states(
28-
labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor
28+
labels: torch.Tensor,
29+
predictions: torch.Tensor,
30+
weights: torch.Tensor,
31+
required_inputs_tensor: Optional[torch.Tensor] = None,
2932
) -> Dict[str, torch.Tensor]:
3033
calibration_num = torch.sum(predictions * weights)
3134
calibration_denom = torch.sum(labels * weights)

torchrec/metrics/tests/test_ctr.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# pyre-strict
99

1010
import unittest
11-
from typing import Dict, Type
11+
from typing import Dict, Optional, Type
1212

1313
import torch
1414
from torchrec.metrics.ctr import CTRMetric
@@ -25,7 +25,10 @@
2525
class TestCTRMetric(TestMetric):
2626
@staticmethod
2727
def _get_states(
28-
labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor
28+
labels: torch.Tensor,
29+
predictions: torch.Tensor,
30+
weights: torch.Tensor,
31+
required_inputs_tensor: Optional[torch.Tensor] = None,
2932
) -> Dict[str, torch.Tensor]:
3033
ctr_num = torch.sum(labels * weights)
3134
ctr_denom = torch.sum(weights)

torchrec/metrics/tests/test_gauc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import unittest
1111

12-
from typing import Dict
12+
from typing import Dict, Optional
1313

1414
import torch
1515
from torchrec.metrics.gauc import compute_gauc_3d, compute_window_auc, GAUCMetric
@@ -24,6 +24,7 @@ def _get_states(
2424
labels: torch.Tensor,
2525
predictions: torch.Tensor,
2626
weights: torch.Tensor,
27+
required_inputs_tensor: Optional[torch.Tensor] = None,
2728
) -> Dict[str, torch.Tensor]:
2829
gauc_res = compute_gauc_3d(predictions, labels, weights)
2930
return {

torchrec/metrics/tests/test_hindsight_target_pr.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# pyre-strict
99

1010
import unittest
11-
from typing import Dict, Type
11+
from typing import Dict, Optional, Type
1212

1313
import torch
1414
from torchrec.metrics.hindsight_target_pr import (
@@ -32,7 +32,10 @@
3232
class TestHindsightTargetPRMetric(TestMetric):
3333
@staticmethod
3434
def _get_states(
35-
labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor
35+
labels: torch.Tensor,
36+
predictions: torch.Tensor,
37+
weights: torch.Tensor,
38+
required_inputs_tensor: Optional[torch.Tensor] = None,
3639
) -> Dict[str, torch.Tensor]:
3740
predictions = predictions.double()
3841
tp_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double)
@@ -59,7 +62,10 @@ def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor:
5962
class TestHindsightTargetPrecisionMetric(TestMetric):
6063
@staticmethod
6164
def _get_states(
62-
labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor
65+
labels: torch.Tensor,
66+
predictions: torch.Tensor,
67+
weights: torch.Tensor,
68+
required_inputs_tensor: Optional[torch.Tensor] = None,
6369
) -> Dict[str, torch.Tensor]:
6470
predictions = predictions.double()
6571
tp_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double)
@@ -89,7 +95,10 @@ def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor:
8995
class TestHindsightTargetRecallMetric(TestMetric):
9096
@staticmethod
9197
def _get_states(
92-
labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor
98+
labels: torch.Tensor,
99+
predictions: torch.Tensor,
100+
weights: torch.Tensor,
101+
required_inputs_tensor: Optional[torch.Tensor] = None,
93102
) -> Dict[str, torch.Tensor]:
94103
predictions = predictions.double()
95104
tp_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double)

torchrec/metrics/tests/test_mae.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# pyre-strict
99

1010
import unittest
11-
from typing import Dict, Type
11+
from typing import Dict, Optional, Type
1212

1313
import torch
1414
from torchrec.metrics.mae import compute_mae, MAEMetric
@@ -25,7 +25,10 @@
2525
class TestMAEMetric(TestMetric):
2626
@staticmethod
2727
def _get_states(
28-
labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor
28+
labels: torch.Tensor,
29+
predictions: torch.Tensor,
30+
weights: torch.Tensor,
31+
required_inputs_tensor: Optional[torch.Tensor] = None,
2932
) -> Dict[str, torch.Tensor]:
3033
predictions = predictions.double()
3134
error_sum = torch.sum(weights * torch.abs(labels - predictions))

0 commit comments

Comments
 (0)