diff --git a/torchrec/models/dlrm.py b/torchrec/models/dlrm.py index ad1975eba..2bdf2e4fe 100644 --- a/torchrec/models/dlrm.py +++ b/torchrec/models/dlrm.py @@ -7,11 +7,12 @@ # pyre-strict -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import torch from torch import nn from torchrec.datasets.utils import Batch +from torchrec.distributed.test_utils.test_input import ModelInput from torchrec.modules.crossnet import LowRankCrossNet from torchrec.modules.embedding_modules import EmbeddingBagCollection from torchrec.modules.mlp import MLP @@ -899,3 +900,34 @@ def forward( loss = self.loss_fn(logits, batch.labels.float()) return loss, (loss.detach(), logits.detach(), batch.labels.detach()) + + +class DLRMWrapper(DLRM): + # pyre-ignore[14, 15] + def forward( + self, model_input: ModelInput + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Forward pass for the DLRMWrapper. + + Args: + model_input (ModelInput): Contains dense and sparse features. + + Returns: + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + If training, returns (loss, prediction). Otherwise, returns prediction. + """ + pred = super().forward( + dense_features=model_input.float_features, + sparse_features=model_input.idlist_features, # pyre-ignore[6] + ) + + if self.training: + # Calculate loss and return both loss and prediction + loss = torch.nn.functional.binary_cross_entropy_with_logits( + pred.squeeze(), model_input.label + ) + return (loss, pred) + else: + # Return just the prediction + return pred diff --git a/torchrec/models/tests/test_dlrm.py b/torchrec/models/tests/test_dlrm.py index e01976404..e8febb5e6 100644 --- a/torchrec/models/tests/test_dlrm.py +++ b/torchrec/models/tests/test_dlrm.py @@ -8,11 +8,15 @@ # pyre-strict import unittest +from dataclasses import dataclass +from typing import List import torch +from parameterized import parameterized from torch import nn from torch.testing import FileCheck # @manual from torchrec.datasets.utils import Batch +from torchrec.distributed.test_utils.test_input import ModelInput from torchrec.fx import symbolic_trace from torchrec.ir.serializer import JsonSerializer from torchrec.ir.utils import decapsulate_ir_modules, encapsulate_ir_modules @@ -23,6 +27,7 @@ DLRM_DCN, DLRM_Projection, DLRMTrain, + DLRMWrapper, InteractionArch, InteractionDCNArch, InteractionProjectionArch, @@ -1283,3 +1288,107 @@ def test_export_serialization(self) -> None: deserialized_logits = deserialized_model(features, sparse_features) self.assertEqual(deserialized_logits.size(), (B, 1)) + + +class DLRMWrapperTest(unittest.TestCase): + @dataclass + class WrapperTestParams: + # input parameters + embedding_configs: List[EmbeddingBagConfig] + sparse_feature_keys: List[str] + sparse_feature_values: List[int] + sparse_feature_offsets: List[int] + # expected output parameters + expected_output_size: tuple[int, ...] + + @parameterized.expand( + [ + ( + "basic_with_multiple_features", + WrapperTestParams( + embedding_configs=[ + EmbeddingBagConfig( + name="t1", + embedding_dim=8, + num_embeddings=100, + feature_names=["f1", "f3"], + ), + EmbeddingBagConfig( + name="t2", + embedding_dim=8, + num_embeddings=100, + feature_names=["f2"], + ), + ], + sparse_feature_keys=["f1", "f3", "f2"], + sparse_feature_values=[1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3], + sparse_feature_offsets=[0, 2, 4, 6, 8, 10, 11], + expected_output_size=(2, 1), + ), + ), + ( + "empty_sparse_features", + WrapperTestParams( + embedding_configs=[ + EmbeddingBagConfig( + name="t1", + embedding_dim=8, + num_embeddings=100, + feature_names=["f1"], + ), + ], + sparse_feature_keys=["f1"], + sparse_feature_values=[], + sparse_feature_offsets=[0, 0, 0], + expected_output_size=(2, 1), + ), + ), + ] + ) + def test_wrapper_functionality( + self, _test_name: str, test_params: WrapperTestParams + ) -> None: + B = 2 + D = 8 + dense_in_features = 100 + + ebc = EmbeddingBagCollection(tables=test_params.embedding_configs) + + dlrm_wrapper = DLRMWrapper( + embedding_bag_collection=ebc, + dense_in_features=dense_in_features, + dense_arch_layer_sizes=[20, D], + over_arch_layer_sizes=[5, 1], + ) + + # Create ModelInput + dense_features = torch.rand((B, dense_in_features)) + sparse_features = KeyedJaggedTensor.from_offsets_sync( + keys=test_params.sparse_feature_keys, + values=torch.tensor(test_params.sparse_feature_values, dtype=torch.long), + offsets=torch.tensor(test_params.sparse_feature_offsets, dtype=torch.long), + ) + + model_input = ModelInput( + float_features=dense_features, + idlist_features=sparse_features, + idscore_features=None, + label=torch.rand((B,)), + ) + + # Test eval mode - should return just logits + dlrm_wrapper.eval() + logits = dlrm_wrapper(model_input) + self.assertIsInstance(logits, torch.Tensor) + self.assertEqual(logits.size(), test_params.expected_output_size) + + # Test training mode - should return (loss, logits) tuple + dlrm_wrapper.train() + result = dlrm_wrapper(model_input) + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 2) + loss, pred = result + self.assertIsInstance(loss, torch.Tensor) + self.assertIsInstance(pred, torch.Tensor) + self.assertEqual(loss.size(), ()) # scalar loss + self.assertEqual(pred.size(), test_params.expected_output_size)