Skip to content

Added model wrappers for DeepFM and DLRM #3115

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions torchrec/models/deepfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import torch
from torch import nn
from torchrec.distributed.test_utils.test_input import ModelInput
from torchrec.modules.deepfm import DeepFM, FactorizationMachine
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
Expand Down Expand Up @@ -350,3 +351,21 @@ def forward(
)
logits = self.over_arch(concatenated_dense)
return logits


class SimpleDeepFMNNWrapper(SimpleDeepFMNN):
# pyre-ignore[14]
def forward(self, model_input: ModelInput) -> torch.Tensor:
"""
Forward pass for the SimpleDeepFMNNWrapper.

Args:
model_input (ModelInput): Contains dense and sparse features.

Returns:
torch.Tensor: Output tensor from the SimpleDeepFMNN model.
"""
return super().forward(
dense_features=model_input.float_features,
sparse_features=model_input.idlist_features, # pyre-ignore[6]
)
19 changes: 19 additions & 0 deletions torchrec/models/dlrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
from torch import nn
from torchrec.datasets.utils import Batch
from torchrec.distributed.test_utils.test_input import ModelInput
from torchrec.modules.crossnet import LowRankCrossNet
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.mlp import MLP
Expand Down Expand Up @@ -899,3 +900,21 @@ def forward(
loss = self.loss_fn(logits, batch.labels.float())

return loss, (loss.detach(), logits.detach(), batch.labels.detach())


class DLRMWrapper(DLRM):
# pyre-ignore[14]
def forward(self, model_input: ModelInput) -> torch.Tensor:
"""
Forward pass for the DLRMWrapper.

Args:
model_input (ModelInput): Contains dense and sparse features.

Returns:
torch.Tensor: Output tensor from the DLRM model.
"""
return super().forward(
dense_features=model_input.float_features,
sparse_features=model_input.idlist_features, # pyre-ignore[6]
)
121 changes: 120 additions & 1 deletion torchrec/models/tests/test_deepfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,14 @@

import torch
from torch.testing import FileCheck # @manual
from torchrec.distributed.test_utils.test_input import ModelInput
from torchrec.fx import symbolic_trace, Tracer
from torchrec.models.deepfm import DenseArch, FMInteractionArch, SimpleDeepFMNN
from torchrec.models.deepfm import (
DenseArch,
FMInteractionArch,
SimpleDeepFMNN,
SimpleDeepFMNNWrapper,
)
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
Expand Down Expand Up @@ -210,5 +216,118 @@ def test_fx_script(self) -> None:
self.assertEqual(logits.size(), (B, 1))


class SimpleDeepFMNNWrapperTest(unittest.TestCase):
def test_basic(self) -> None:
B = 2
D = 8
num_dense_features = 100
eb1_config = EmbeddingBagConfig(
name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1", "f3"]
)
eb2_config = EmbeddingBagConfig(
name="t2",
embedding_dim=D,
num_embeddings=100,
feature_names=["f2"],
)

ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])

deepfm_wrapper = SimpleDeepFMNNWrapper(
num_dense_features=num_dense_features,
embedding_bag_collection=ebc,
hidden_layer_size=20,
deep_fm_dimension=5,
)

# Create ModelInput with both dense and sparse features
dense_features = torch.rand((B, num_dense_features))
sparse_features = KeyedJaggedTensor.from_offsets_sync(
keys=["f1", "f3", "f2"],
values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3]),
offsets=torch.tensor([0, 2, 4, 6, 8, 10, 11]),
)

model_input = ModelInput(
float_features=dense_features,
idlist_features=sparse_features,
idscore_features=None,
label=torch.rand((B,)),
)

logits = deepfm_wrapper(model_input)
self.assertEqual(logits.size(), (B, 1))

def test_no_sparse_features(self) -> None:
B = 2
D = 8
num_dense_features = 100
eb1_config = EmbeddingBagConfig(
name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1"]
)

ebc = EmbeddingBagCollection(tables=[eb1_config])

deepfm_wrapper = SimpleDeepFMNNWrapper(
num_dense_features=num_dense_features,
embedding_bag_collection=ebc,
hidden_layer_size=20,
deep_fm_dimension=5,
)

# Create ModelInput with empty sparse features that match expected feature names
dense_features = torch.rand((B, num_dense_features))
empty_sparse_features = KeyedJaggedTensor.from_offsets_sync(
keys=["f1"],
values=torch.tensor([], dtype=torch.long),
offsets=torch.tensor([0] * (B + 1), dtype=torch.long),
)

model_input = ModelInput(
float_features=dense_features,
idlist_features=empty_sparse_features,
idscore_features=None,
label=torch.rand((B,)),
)

logits = deepfm_wrapper(model_input)
self.assertEqual(logits.size(), (B, 1))

def test_empty_sparse_features(self) -> None:
B = 2
D = 8
num_dense_features = 100
eb1_config = EmbeddingBagConfig(
name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1"]
)

ebc = EmbeddingBagCollection(tables=[eb1_config])

deepfm_wrapper = SimpleDeepFMNNWrapper(
num_dense_features=num_dense_features,
embedding_bag_collection=ebc,
hidden_layer_size=20,
deep_fm_dimension=5,
)

# Create ModelInput with empty sparse features that match expected feature names
dense_features = torch.rand((B, num_dense_features))
empty_sparse_features = KeyedJaggedTensor.from_offsets_sync(
keys=["f1"],
values=torch.tensor([], dtype=torch.long),
offsets=torch.tensor([0] * (B + 1), dtype=torch.long),
)

model_input = ModelInput(
float_features=dense_features,
idlist_features=empty_sparse_features,
idscore_features=None,
label=torch.rand((B,)),
)

logits = deepfm_wrapper(model_input)
self.assertEqual(logits.size(), (B, 1))


if __name__ == "__main__":
unittest.main()
115 changes: 115 additions & 0 deletions torchrec/models/tests/test_dlrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch import nn
from torch.testing import FileCheck # @manual
from torchrec.datasets.utils import Batch
from torchrec.distributed.test_utils.test_input import ModelInput
from torchrec.fx import symbolic_trace
from torchrec.ir.serializer import JsonSerializer
from torchrec.ir.utils import decapsulate_ir_modules, encapsulate_ir_modules
Expand All @@ -23,6 +24,7 @@
DLRM_DCN,
DLRM_Projection,
DLRMTrain,
DLRMWrapper,
InteractionArch,
InteractionDCNArch,
InteractionProjectionArch,
Expand Down Expand Up @@ -1283,3 +1285,116 @@ def test_export_serialization(self) -> None:
deserialized_logits = deserialized_model(features, sparse_features)

self.assertEqual(deserialized_logits.size(), (B, 1))


class DLRMWrapperTest(unittest.TestCase):
def test_basic(self) -> None:
B = 2
D = 8
dense_in_features = 100
eb1_config = EmbeddingBagConfig(
name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1", "f3"]
)
eb2_config = EmbeddingBagConfig(
name="t2",
embedding_dim=D,
num_embeddings=100,
feature_names=["f2"],
)

ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])

dlrm_wrapper = DLRMWrapper(
embedding_bag_collection=ebc,
dense_in_features=dense_in_features,
dense_arch_layer_sizes=[20, D],
over_arch_layer_sizes=[5, 1],
)

# Create ModelInput with both dense and sparse features
dense_features = torch.rand((B, dense_in_features))
sparse_features = KeyedJaggedTensor.from_offsets_sync(
keys=["f1", "f3", "f2"],
values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3]),
offsets=torch.tensor([0, 2, 4, 6, 8, 10, 11]),
)

model_input = ModelInput(
float_features=dense_features,
idlist_features=sparse_features,
idscore_features=None,
label=torch.rand((B,)),
)

logits = dlrm_wrapper(model_input)
self.assertEqual(logits.size(), (B, 1))

def test_no_sparse_features(self) -> None:
B = 2
D = 8
dense_in_features = 100
eb1_config = EmbeddingBagConfig(
name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1"]
)

ebc = EmbeddingBagCollection(tables=[eb1_config])

dlrm_wrapper = DLRMWrapper(
embedding_bag_collection=ebc,
dense_in_features=dense_in_features,
dense_arch_layer_sizes=[20, D],
over_arch_layer_sizes=[5, 1],
)

# Create ModelInput with empty sparse features that match expected feature names
dense_features = torch.rand((B, dense_in_features))
empty_sparse_features = KeyedJaggedTensor.from_offsets_sync(
keys=["f1"],
values=torch.tensor([], dtype=torch.long),
offsets=torch.tensor([0] * (B + 1), dtype=torch.long),
)

model_input = ModelInput(
float_features=dense_features,
idlist_features=empty_sparse_features,
idscore_features=None,
label=torch.rand((B,)),
)

logits = dlrm_wrapper(model_input)
self.assertEqual(logits.size(), (B, 1))

def test_empty_sparse_features(self) -> None:
B = 2
D = 8
dense_in_features = 100
eb1_config = EmbeddingBagConfig(
name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1"]
)

ebc = EmbeddingBagCollection(tables=[eb1_config])

dlrm_wrapper = DLRMWrapper(
embedding_bag_collection=ebc,
dense_in_features=dense_in_features,
dense_arch_layer_sizes=[20, D],
over_arch_layer_sizes=[5, 1],
)

# Create ModelInput with empty sparse features that match expected feature names
dense_features = torch.rand((B, dense_in_features))
empty_sparse_features = KeyedJaggedTensor.from_offsets_sync(
keys=["f1"],
values=torch.tensor([], dtype=torch.long),
offsets=torch.tensor([0] * (B + 1), dtype=torch.long),
)

model_input = ModelInput(
float_features=dense_features,
idlist_features=empty_sparse_features,
idscore_features=None,
label=torch.rand((B,)),
)

logits = dlrm_wrapper(model_input)
self.assertEqual(logits.size(), (B, 1))
Loading