Skip to content

Commit f998f61

Browse files
SSYernarfacebook-github-bot
authored andcommitted
Added model wrapper for DLRM
Summary: * Added model wrapper for DLRM. The wrapper will take ModelInput as an only parameter in the forward method. * Added the parameterized unit tests to cover the model's wrapper Differential Revision: D77167717
1 parent 2f8d08c commit f998f61

File tree

2 files changed

+114
-0
lines changed

2 files changed

+114
-0
lines changed

torchrec/models/dlrm.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch
1313
from torch import nn
1414
from torchrec.datasets.utils import Batch
15+
from torchrec.distributed.test_utils.test_input import ModelInput
1516
from torchrec.modules.crossnet import LowRankCrossNet
1617
from torchrec.modules.embedding_modules import EmbeddingBagCollection
1718
from torchrec.modules.mlp import MLP
@@ -899,3 +900,21 @@ def forward(
899900
loss = self.loss_fn(logits, batch.labels.float())
900901

901902
return loss, (loss.detach(), logits.detach(), batch.labels.detach())
903+
904+
905+
class DLRMWrapper(DLRM):
906+
# pyre-ignore[14]
907+
def forward(self, model_input: ModelInput) -> torch.Tensor:
908+
"""
909+
Forward pass for the DLRMWrapper.
910+
911+
Args:
912+
model_input (ModelInput): Contains dense and sparse features.
913+
914+
Returns:
915+
torch.Tensor: Output tensor from the DLRM model.
916+
"""
917+
return super().forward(
918+
dense_features=model_input.float_features,
919+
sparse_features=model_input.idlist_features, # pyre-ignore[6]
920+
)

torchrec/models/tests/test_dlrm.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,15 @@
88
# pyre-strict
99

1010
import unittest
11+
from dataclasses import dataclass
12+
from typing import List
1113

1214
import torch
15+
from parameterized import parameterized
1316
from torch import nn
1417
from torch.testing import FileCheck # @manual
1518
from torchrec.datasets.utils import Batch
19+
from torchrec.distributed.test_utils.test_input import ModelInput
1620
from torchrec.fx import symbolic_trace
1721
from torchrec.ir.serializer import JsonSerializer
1822
from torchrec.ir.utils import decapsulate_ir_modules, encapsulate_ir_modules
@@ -23,6 +27,7 @@
2327
DLRM_DCN,
2428
DLRM_Projection,
2529
DLRMTrain,
30+
DLRMWrapper,
2631
InteractionArch,
2732
InteractionDCNArch,
2833
InteractionProjectionArch,
@@ -1283,3 +1288,93 @@ def test_export_serialization(self) -> None:
12831288
deserialized_logits = deserialized_model(features, sparse_features)
12841289

12851290
self.assertEqual(deserialized_logits.size(), (B, 1))
1291+
1292+
1293+
class DLRMWrapperTest(unittest.TestCase):
1294+
@dataclass
1295+
class WrapperTestParams:
1296+
# input parameters
1297+
embedding_configs: List[EmbeddingBagConfig]
1298+
sparse_feature_keys: List[str]
1299+
sparse_feature_values: List[int]
1300+
sparse_feature_offsets: List[int]
1301+
# expected output parameters
1302+
expected_output_size: tuple[int, ...]
1303+
1304+
@parameterized.expand(
1305+
[
1306+
(
1307+
"basic_with_multiple_features",
1308+
WrapperTestParams(
1309+
embedding_configs=[
1310+
EmbeddingBagConfig(
1311+
name="t1",
1312+
embedding_dim=8,
1313+
num_embeddings=100,
1314+
feature_names=["f1", "f3"],
1315+
),
1316+
EmbeddingBagConfig(
1317+
name="t2",
1318+
embedding_dim=8,
1319+
num_embeddings=100,
1320+
feature_names=["f2"],
1321+
),
1322+
],
1323+
sparse_feature_keys=["f1", "f3", "f2"],
1324+
sparse_feature_values=[1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3],
1325+
sparse_feature_offsets=[0, 2, 4, 6, 8, 10, 11],
1326+
expected_output_size=(2, 1),
1327+
),
1328+
),
1329+
(
1330+
"empty_sparse_features",
1331+
WrapperTestParams(
1332+
embedding_configs=[
1333+
EmbeddingBagConfig(
1334+
name="t1",
1335+
embedding_dim=8,
1336+
num_embeddings=100,
1337+
feature_names=["f1"],
1338+
),
1339+
],
1340+
sparse_feature_keys=["f1"],
1341+
sparse_feature_values=[],
1342+
sparse_feature_offsets=[0, 0, 0],
1343+
expected_output_size=(2, 1),
1344+
),
1345+
),
1346+
]
1347+
)
1348+
def test_wrapper_functionality(
1349+
self, _test_name: str, test_params: WrapperTestParams
1350+
) -> None:
1351+
B = 2
1352+
D = 8
1353+
dense_in_features = 100
1354+
1355+
ebc = EmbeddingBagCollection(tables=test_params.embedding_configs)
1356+
1357+
dlrm_wrapper = DLRMWrapper(
1358+
embedding_bag_collection=ebc,
1359+
dense_in_features=dense_in_features,
1360+
dense_arch_layer_sizes=[20, D],
1361+
over_arch_layer_sizes=[5, 1],
1362+
)
1363+
1364+
# Create ModelInput
1365+
dense_features = torch.rand((B, dense_in_features))
1366+
sparse_features = KeyedJaggedTensor.from_offsets_sync(
1367+
keys=test_params.sparse_feature_keys,
1368+
values=torch.tensor(test_params.sparse_feature_values, dtype=torch.long),
1369+
offsets=torch.tensor(test_params.sparse_feature_offsets, dtype=torch.long),
1370+
)
1371+
1372+
model_input = ModelInput(
1373+
float_features=dense_features,
1374+
idlist_features=sparse_features,
1375+
idscore_features=None,
1376+
label=torch.rand((B,)),
1377+
)
1378+
1379+
logits = dlrm_wrapper(model_input)
1380+
self.assertEqual(logits.size(), test_params.expected_output_size)

0 commit comments

Comments
 (0)