Skip to content

Added model wrapper for DLRM #3128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion torchrec/models/dlrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@

# pyre-strict

from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import torch
from torch import nn
from torchrec.datasets.utils import Batch
from torchrec.distributed.test_utils.test_input import ModelInput
from torchrec.modules.crossnet import LowRankCrossNet
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.mlp import MLP
Expand Down Expand Up @@ -899,3 +900,34 @@ def forward(
loss = self.loss_fn(logits, batch.labels.float())

return loss, (loss.detach(), logits.detach(), batch.labels.detach())


class DLRMWrapper(DLRM):
# pyre-ignore[14, 15]
def forward(
self, model_input: ModelInput
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Forward pass for the DLRMWrapper.

Args:
model_input (ModelInput): Contains dense and sparse features.

Returns:
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
If training, returns (loss, prediction). Otherwise, returns prediction.
"""
pred = super().forward(
dense_features=model_input.float_features,
sparse_features=model_input.idlist_features, # pyre-ignore[6]
)

if self.training:
# Calculate loss and return both loss and prediction
loss = torch.nn.functional.binary_cross_entropy_with_logits(
pred.squeeze(), model_input.label
)
return (loss, pred)
else:
# Return just the prediction
return pred
109 changes: 109 additions & 0 deletions torchrec/models/tests/test_dlrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@
# pyre-strict

import unittest
from dataclasses import dataclass
from typing import List

import torch
from parameterized import parameterized
from torch import nn
from torch.testing import FileCheck # @manual
from torchrec.datasets.utils import Batch
from torchrec.distributed.test_utils.test_input import ModelInput
from torchrec.fx import symbolic_trace
from torchrec.ir.serializer import JsonSerializer
from torchrec.ir.utils import decapsulate_ir_modules, encapsulate_ir_modules
Expand All @@ -23,6 +27,7 @@
DLRM_DCN,
DLRM_Projection,
DLRMTrain,
DLRMWrapper,
InteractionArch,
InteractionDCNArch,
InteractionProjectionArch,
Expand Down Expand Up @@ -1283,3 +1288,107 @@ def test_export_serialization(self) -> None:
deserialized_logits = deserialized_model(features, sparse_features)

self.assertEqual(deserialized_logits.size(), (B, 1))


class DLRMWrapperTest(unittest.TestCase):
@dataclass
class WrapperTestParams:
# input parameters
embedding_configs: List[EmbeddingBagConfig]
sparse_feature_keys: List[str]
sparse_feature_values: List[int]
sparse_feature_offsets: List[int]
# expected output parameters
expected_output_size: tuple[int, ...]

@parameterized.expand(
[
(
"basic_with_multiple_features",
WrapperTestParams(
embedding_configs=[
EmbeddingBagConfig(
name="t1",
embedding_dim=8,
num_embeddings=100,
feature_names=["f1", "f3"],
),
EmbeddingBagConfig(
name="t2",
embedding_dim=8,
num_embeddings=100,
feature_names=["f2"],
),
],
sparse_feature_keys=["f1", "f3", "f2"],
sparse_feature_values=[1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3],
sparse_feature_offsets=[0, 2, 4, 6, 8, 10, 11],
expected_output_size=(2, 1),
),
),
(
"empty_sparse_features",
WrapperTestParams(
embedding_configs=[
EmbeddingBagConfig(
name="t1",
embedding_dim=8,
num_embeddings=100,
feature_names=["f1"],
),
],
sparse_feature_keys=["f1"],
sparse_feature_values=[],
sparse_feature_offsets=[0, 0, 0],
expected_output_size=(2, 1),
),
),
]
)
def test_wrapper_functionality(
self, _test_name: str, test_params: WrapperTestParams
) -> None:
B = 2
D = 8
dense_in_features = 100

ebc = EmbeddingBagCollection(tables=test_params.embedding_configs)

dlrm_wrapper = DLRMWrapper(
embedding_bag_collection=ebc,
dense_in_features=dense_in_features,
dense_arch_layer_sizes=[20, D],
over_arch_layer_sizes=[5, 1],
)

# Create ModelInput
dense_features = torch.rand((B, dense_in_features))
sparse_features = KeyedJaggedTensor.from_offsets_sync(
keys=test_params.sparse_feature_keys,
values=torch.tensor(test_params.sparse_feature_values, dtype=torch.long),
offsets=torch.tensor(test_params.sparse_feature_offsets, dtype=torch.long),
)

model_input = ModelInput(
float_features=dense_features,
idlist_features=sparse_features,
idscore_features=None,
label=torch.rand((B,)),
)

# Test eval mode - should return just logits
dlrm_wrapper.eval()
logits = dlrm_wrapper(model_input)
self.assertIsInstance(logits, torch.Tensor)
self.assertEqual(logits.size(), test_params.expected_output_size)

# Test training mode - should return (loss, logits) tuple
dlrm_wrapper.train()
result = dlrm_wrapper(model_input)
self.assertIsInstance(result, tuple)
self.assertEqual(len(result), 2)
loss, pred = result
self.assertIsInstance(loss, torch.Tensor)
self.assertIsInstance(pred, torch.Tensor)
self.assertEqual(loss.size(), ()) # scalar loss
self.assertEqual(pred.size(), test_params.expected_output_size)
Loading