|
8 | 8 | # pyre-strict
|
9 | 9 |
|
10 | 10 | import unittest
|
| 11 | +from dataclasses import dataclass |
| 12 | +from typing import List |
11 | 13 |
|
12 | 14 | import torch
|
| 15 | +from parameterized import parameterized |
13 | 16 | from torch import nn
|
14 | 17 | from torch.testing import FileCheck # @manual
|
15 | 18 | from torchrec.datasets.utils import Batch
|
| 19 | +from torchrec.distributed.test_utils.test_input import ModelInput |
16 | 20 | from torchrec.fx import symbolic_trace
|
17 | 21 | from torchrec.ir.serializer import JsonSerializer
|
18 | 22 | from torchrec.ir.utils import decapsulate_ir_modules, encapsulate_ir_modules
|
|
23 | 27 | DLRM_DCN,
|
24 | 28 | DLRM_Projection,
|
25 | 29 | DLRMTrain,
|
| 30 | + DLRMWrapper, |
26 | 31 | InteractionArch,
|
27 | 32 | InteractionDCNArch,
|
28 | 33 | InteractionProjectionArch,
|
@@ -1283,3 +1288,93 @@ def test_export_serialization(self) -> None:
|
1283 | 1288 | deserialized_logits = deserialized_model(features, sparse_features)
|
1284 | 1289 |
|
1285 | 1290 | self.assertEqual(deserialized_logits.size(), (B, 1))
|
| 1291 | + |
| 1292 | + |
| 1293 | +class DLRMWrapperTest(unittest.TestCase): |
| 1294 | + @dataclass |
| 1295 | + class WrapperTestParams: |
| 1296 | + # input parameters |
| 1297 | + embedding_configs: List[EmbeddingBagConfig] |
| 1298 | + sparse_feature_keys: List[str] |
| 1299 | + sparse_feature_values: List[int] |
| 1300 | + sparse_feature_offsets: List[int] |
| 1301 | + # expected output parameters |
| 1302 | + expected_output_size: tuple[int, ...] |
| 1303 | + |
| 1304 | + @parameterized.expand( |
| 1305 | + [ |
| 1306 | + ( |
| 1307 | + "basic_with_multiple_features", |
| 1308 | + WrapperTestParams( |
| 1309 | + embedding_configs=[ |
| 1310 | + EmbeddingBagConfig( |
| 1311 | + name="t1", |
| 1312 | + embedding_dim=8, |
| 1313 | + num_embeddings=100, |
| 1314 | + feature_names=["f1", "f3"], |
| 1315 | + ), |
| 1316 | + EmbeddingBagConfig( |
| 1317 | + name="t2", |
| 1318 | + embedding_dim=8, |
| 1319 | + num_embeddings=100, |
| 1320 | + feature_names=["f2"], |
| 1321 | + ), |
| 1322 | + ], |
| 1323 | + sparse_feature_keys=["f1", "f3", "f2"], |
| 1324 | + sparse_feature_values=[1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3], |
| 1325 | + sparse_feature_offsets=[0, 2, 4, 6, 8, 10, 11], |
| 1326 | + expected_output_size=(2, 1), |
| 1327 | + ), |
| 1328 | + ), |
| 1329 | + ( |
| 1330 | + "empty_sparse_features", |
| 1331 | + WrapperTestParams( |
| 1332 | + embedding_configs=[ |
| 1333 | + EmbeddingBagConfig( |
| 1334 | + name="t1", |
| 1335 | + embedding_dim=8, |
| 1336 | + num_embeddings=100, |
| 1337 | + feature_names=["f1"], |
| 1338 | + ), |
| 1339 | + ], |
| 1340 | + sparse_feature_keys=["f1"], |
| 1341 | + sparse_feature_values=[], |
| 1342 | + sparse_feature_offsets=[0, 0, 0], |
| 1343 | + expected_output_size=(2, 1), |
| 1344 | + ), |
| 1345 | + ), |
| 1346 | + ] |
| 1347 | + ) |
| 1348 | + def test_wrapper_functionality( |
| 1349 | + self, _test_name: str, test_params: WrapperTestParams |
| 1350 | + ) -> None: |
| 1351 | + B = 2 |
| 1352 | + D = 8 |
| 1353 | + dense_in_features = 100 |
| 1354 | + |
| 1355 | + ebc = EmbeddingBagCollection(tables=test_params.embedding_configs) |
| 1356 | + |
| 1357 | + dlrm_wrapper = DLRMWrapper( |
| 1358 | + embedding_bag_collection=ebc, |
| 1359 | + dense_in_features=dense_in_features, |
| 1360 | + dense_arch_layer_sizes=[20, D], |
| 1361 | + over_arch_layer_sizes=[5, 1], |
| 1362 | + ) |
| 1363 | + |
| 1364 | + # Create ModelInput |
| 1365 | + dense_features = torch.rand((B, dense_in_features)) |
| 1366 | + sparse_features = KeyedJaggedTensor.from_offsets_sync( |
| 1367 | + keys=test_params.sparse_feature_keys, |
| 1368 | + values=torch.tensor(test_params.sparse_feature_values, dtype=torch.long), |
| 1369 | + offsets=torch.tensor(test_params.sparse_feature_offsets, dtype=torch.long), |
| 1370 | + ) |
| 1371 | + |
| 1372 | + model_input = ModelInput( |
| 1373 | + float_features=dense_features, |
| 1374 | + idlist_features=sparse_features, |
| 1375 | + idscore_features=None, |
| 1376 | + label=torch.rand((B,)), |
| 1377 | + ) |
| 1378 | + |
| 1379 | + logits = dlrm_wrapper(model_input) |
| 1380 | + self.assertEqual(logits.size(), test_params.expected_output_size) |
0 commit comments