Skip to content

Commit 05ec6a0

Browse files
SSYernarfacebook-github-bot
authored andcommitted
Added model wrappers for DeepFM and DLRM
Summary: * Added model wrappers for DeepFM and DLRM. The wrapper will take ModelInput as an only parameter in the forward method. * Added the unit tests to cover the models' wrappers Differential Revision: D76916471
1 parent ab1cbe1 commit 05ec6a0

File tree

4 files changed

+268
-1
lines changed

4 files changed

+268
-1
lines changed

torchrec/models/deepfm.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import torch
1313
from torch import nn
14+
from torchrec.distributed.test_utils.test_input import ModelInput
1415
from torchrec.modules.deepfm import DeepFM, FactorizationMachine
1516
from torchrec.modules.embedding_modules import EmbeddingBagCollection
1617
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
@@ -350,3 +351,21 @@ def forward(
350351
)
351352
logits = self.over_arch(concatenated_dense)
352353
return logits
354+
355+
356+
class SimpleDeepFMNNWrapper(SimpleDeepFMNN):
357+
# pyre-ignore[14]
358+
def forward(self, model_input: ModelInput) -> torch.Tensor:
359+
"""
360+
Forward pass for the SimpleDeepFMNNWrapper.
361+
362+
Args:
363+
model_input (ModelInput): Contains dense and sparse features.
364+
365+
Returns:
366+
torch.Tensor: Output tensor from the SimpleDeepFMNN model.
367+
"""
368+
return super().forward(
369+
dense_features=model_input.float_features,
370+
sparse_features=model_input.idlist_features, # pyre-ignore[6]
371+
)

torchrec/models/dlrm.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch
1313
from torch import nn
1414
from torchrec.datasets.utils import Batch
15+
from torchrec.distributed.test_utils.test_input import ModelInput
1516
from torchrec.modules.crossnet import LowRankCrossNet
1617
from torchrec.modules.embedding_modules import EmbeddingBagCollection
1718
from torchrec.modules.mlp import MLP
@@ -899,3 +900,21 @@ def forward(
899900
loss = self.loss_fn(logits, batch.labels.float())
900901

901902
return loss, (loss.detach(), logits.detach(), batch.labels.detach())
903+
904+
905+
class DLRMWrapper(DLRM):
906+
# pyre-ignore[14]
907+
def forward(self, model_input: ModelInput) -> torch.Tensor:
908+
"""
909+
Forward pass for the DLRMWrapper.
910+
911+
Args:
912+
model_input (ModelInput): Contains dense and sparse features.
913+
914+
Returns:
915+
torch.Tensor: Output tensor from the DLRM model.
916+
"""
917+
return super().forward(
918+
dense_features=model_input.float_features,
919+
sparse_features=model_input.idlist_features, # pyre-ignore[6]
920+
)

torchrec/models/tests/test_deepfm.py

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
import torch
1313
from torch.testing import FileCheck # @manual
1414
from torchrec.fx import symbolic_trace, Tracer
15-
from torchrec.models.deepfm import DenseArch, FMInteractionArch, SimpleDeepFMNN
15+
from torchrec.distributed.test_utils.test_input import ModelInput
16+
from torchrec.models.deepfm import DenseArch, FMInteractionArch, SimpleDeepFMNN, SimpleDeepFMNNWrapper
1617
from torchrec.modules.embedding_configs import EmbeddingBagConfig
1718
from torchrec.modules.embedding_modules import EmbeddingBagCollection
1819
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
@@ -210,5 +211,118 @@ def test_fx_script(self) -> None:
210211
self.assertEqual(logits.size(), (B, 1))
211212

212213

214+
class SimpleDeepFMNNWrapperTest(unittest.TestCase):
215+
def test_basic(self) -> None:
216+
B = 2
217+
D = 8
218+
num_dense_features = 100
219+
eb1_config = EmbeddingBagConfig(
220+
name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1", "f3"]
221+
)
222+
eb2_config = EmbeddingBagConfig(
223+
name="t2",
224+
embedding_dim=D,
225+
num_embeddings=100,
226+
feature_names=["f2"],
227+
)
228+
229+
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
230+
231+
deepfm_wrapper = SimpleDeepFMNNWrapper(
232+
num_dense_features=num_dense_features,
233+
embedding_bag_collection=ebc,
234+
hidden_layer_size=20,
235+
deep_fm_dimension=5,
236+
)
237+
238+
# Create ModelInput with both dense and sparse features
239+
dense_features = torch.rand((B, num_dense_features))
240+
sparse_features = KeyedJaggedTensor.from_offsets_sync(
241+
keys=["f1", "f3", "f2"],
242+
values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3]),
243+
offsets=torch.tensor([0, 2, 4, 6, 8, 10, 11]),
244+
)
245+
246+
model_input = ModelInput(
247+
float_features=dense_features,
248+
idlist_features=sparse_features,
249+
idscore_features=None,
250+
label=torch.rand((B,)),
251+
)
252+
253+
logits = deepfm_wrapper(model_input)
254+
self.assertEqual(logits.size(), (B, 1))
255+
256+
def test_no_sparse_features(self) -> None:
257+
B = 2
258+
D = 8
259+
num_dense_features = 100
260+
eb1_config = EmbeddingBagConfig(
261+
name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1"]
262+
)
263+
264+
ebc = EmbeddingBagCollection(tables=[eb1_config])
265+
266+
deepfm_wrapper = SimpleDeepFMNNWrapper(
267+
num_dense_features=num_dense_features,
268+
embedding_bag_collection=ebc,
269+
hidden_layer_size=20,
270+
deep_fm_dimension=5,
271+
)
272+
273+
# Create ModelInput with empty sparse features that match expected feature names
274+
dense_features = torch.rand((B, num_dense_features))
275+
empty_sparse_features = KeyedJaggedTensor.from_offsets_sync(
276+
keys=["f1"],
277+
values=torch.tensor([], dtype=torch.long),
278+
offsets=torch.tensor([0] * (B + 1), dtype=torch.long),
279+
)
280+
281+
model_input = ModelInput(
282+
float_features=dense_features,
283+
idlist_features=empty_sparse_features,
284+
idscore_features=None,
285+
label=torch.rand((B,)),
286+
)
287+
288+
logits = deepfm_wrapper(model_input)
289+
self.assertEqual(logits.size(), (B, 1))
290+
291+
def test_empty_sparse_features(self) -> None:
292+
B = 2
293+
D = 8
294+
num_dense_features = 100
295+
eb1_config = EmbeddingBagConfig(
296+
name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1"]
297+
)
298+
299+
ebc = EmbeddingBagCollection(tables=[eb1_config])
300+
301+
deepfm_wrapper = SimpleDeepFMNNWrapper(
302+
num_dense_features=num_dense_features,
303+
embedding_bag_collection=ebc,
304+
hidden_layer_size=20,
305+
deep_fm_dimension=5,
306+
)
307+
308+
# Create ModelInput with empty sparse features that match expected feature names
309+
dense_features = torch.rand((B, num_dense_features))
310+
empty_sparse_features = KeyedJaggedTensor.from_offsets_sync(
311+
keys=["f1"],
312+
values=torch.tensor([], dtype=torch.long),
313+
offsets=torch.tensor([0] * (B + 1), dtype=torch.long),
314+
)
315+
316+
model_input = ModelInput(
317+
float_features=dense_features,
318+
idlist_features=empty_sparse_features,
319+
idscore_features=None,
320+
label=torch.rand((B,)),
321+
)
322+
323+
logits = deepfm_wrapper(model_input)
324+
self.assertEqual(logits.size(), (B, 1))
325+
326+
213327
if __name__ == "__main__":
214328
unittest.main()

torchrec/models/tests/test_dlrm.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torch import nn
1414
from torch.testing import FileCheck # @manual
1515
from torchrec.datasets.utils import Batch
16+
from torchrec.distributed.test_utils.test_input import ModelInput
1617
from torchrec.fx import symbolic_trace
1718
from torchrec.ir.serializer import JsonSerializer
1819
from torchrec.ir.utils import decapsulate_ir_modules, encapsulate_ir_modules
@@ -23,6 +24,7 @@
2324
DLRM_DCN,
2425
DLRM_Projection,
2526
DLRMTrain,
27+
DLRMWrapper,
2628
InteractionArch,
2729
InteractionDCNArch,
2830
InteractionProjectionArch,
@@ -1283,3 +1285,116 @@ def test_export_serialization(self) -> None:
12831285
deserialized_logits = deserialized_model(features, sparse_features)
12841286

12851287
self.assertEqual(deserialized_logits.size(), (B, 1))
1288+
1289+
1290+
class DLRMWrapperTest(unittest.TestCase):
1291+
def test_basic(self) -> None:
1292+
B = 2
1293+
D = 8
1294+
dense_in_features = 100
1295+
eb1_config = EmbeddingBagConfig(
1296+
name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1", "f3"]
1297+
)
1298+
eb2_config = EmbeddingBagConfig(
1299+
name="t2",
1300+
embedding_dim=D,
1301+
num_embeddings=100,
1302+
feature_names=["f2"],
1303+
)
1304+
1305+
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
1306+
1307+
dlrm_wrapper = DLRMWrapper(
1308+
embedding_bag_collection=ebc,
1309+
dense_in_features=dense_in_features,
1310+
dense_arch_layer_sizes=[20, D],
1311+
over_arch_layer_sizes=[5, 1],
1312+
)
1313+
1314+
# Create ModelInput with both dense and sparse features
1315+
dense_features = torch.rand((B, dense_in_features))
1316+
sparse_features = KeyedJaggedTensor.from_offsets_sync(
1317+
keys=["f1", "f3", "f2"],
1318+
values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3]),
1319+
offsets=torch.tensor([0, 2, 4, 6, 8, 10, 11]),
1320+
)
1321+
1322+
model_input = ModelInput(
1323+
float_features=dense_features,
1324+
idlist_features=sparse_features,
1325+
idscore_features=None,
1326+
label=torch.rand((B,)),
1327+
)
1328+
1329+
logits = dlrm_wrapper(model_input)
1330+
self.assertEqual(logits.size(), (B, 1))
1331+
1332+
def test_no_sparse_features(self) -> None:
1333+
B = 2
1334+
D = 8
1335+
dense_in_features = 100
1336+
eb1_config = EmbeddingBagConfig(
1337+
name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1"]
1338+
)
1339+
1340+
ebc = EmbeddingBagCollection(tables=[eb1_config])
1341+
1342+
dlrm_wrapper = DLRMWrapper(
1343+
embedding_bag_collection=ebc,
1344+
dense_in_features=dense_in_features,
1345+
dense_arch_layer_sizes=[20, D],
1346+
over_arch_layer_sizes=[5, 1],
1347+
)
1348+
1349+
# Create ModelInput with empty sparse features that match expected feature names
1350+
dense_features = torch.rand((B, dense_in_features))
1351+
empty_sparse_features = KeyedJaggedTensor.from_offsets_sync(
1352+
keys=["f1"],
1353+
values=torch.tensor([], dtype=torch.long),
1354+
offsets=torch.tensor([0] * (B + 1), dtype=torch.long),
1355+
)
1356+
1357+
model_input = ModelInput(
1358+
float_features=dense_features,
1359+
idlist_features=empty_sparse_features,
1360+
idscore_features=None,
1361+
label=torch.rand((B,)),
1362+
)
1363+
1364+
logits = dlrm_wrapper(model_input)
1365+
self.assertEqual(logits.size(), (B, 1))
1366+
1367+
def test_empty_sparse_features(self) -> None:
1368+
B = 2
1369+
D = 8
1370+
dense_in_features = 100
1371+
eb1_config = EmbeddingBagConfig(
1372+
name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1"]
1373+
)
1374+
1375+
ebc = EmbeddingBagCollection(tables=[eb1_config])
1376+
1377+
dlrm_wrapper = DLRMWrapper(
1378+
embedding_bag_collection=ebc,
1379+
dense_in_features=dense_in_features,
1380+
dense_arch_layer_sizes=[20, D],
1381+
over_arch_layer_sizes=[5, 1],
1382+
)
1383+
1384+
# Create ModelInput with empty sparse features that match expected feature names
1385+
dense_features = torch.rand((B, dense_in_features))
1386+
empty_sparse_features = KeyedJaggedTensor.from_offsets_sync(
1387+
keys=["f1"],
1388+
values=torch.tensor([], dtype=torch.long),
1389+
offsets=torch.tensor([0] * (B + 1), dtype=torch.long),
1390+
)
1391+
1392+
model_input = ModelInput(
1393+
float_features=dense_features,
1394+
idlist_features=empty_sparse_features,
1395+
idscore_features=None,
1396+
label=torch.rand((B,)),
1397+
)
1398+
1399+
logits = dlrm_wrapper(model_input)
1400+
self.assertEqual(logits.size(), (B, 1))

0 commit comments

Comments
 (0)