Skip to content

Commit bc7ac95

Browse files
Added tests with pretrained embeddings for DLRM and DCN
1 parent be9d5fb commit bc7ac95

File tree

4 files changed

+111
-10
lines changed

4 files changed

+111
-10
lines changed

merlin/datasets/entertainment/music_streaming/schema.json

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,16 @@
4242
"tag": [
4343
"categorical",
4444
"item"
45+
],
46+
"extraMetadata": [
47+
{
48+
"_dims": [
49+
[
50+
0.0,
51+
null
52+
]
53+
]
54+
}
4555
]
4656
}
4757
},

merlin/datasets/testing/schema.json

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,16 @@
107107
"item_id",
108108
"categorical",
109109
"item"
110+
],
111+
"extraMetadata": [
112+
{
113+
"_dims": [
114+
[
115+
0.0,
116+
null
117+
]
118+
]
119+
}
110120
]
111121
}
112122
},

tests/unit/tf/blocks/test_dlrm.py

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
# limitations under the License.
1515
#
1616

17+
import numpy as np
1718
import pytest
1819

1920
import merlin.models.tf as mm
21+
from merlin.dataloader.ops.embeddings import EmbeddingOperator
2022
from merlin.io import Dataset
2123
from merlin.schema import Tags
2224

@@ -29,11 +31,11 @@ def test_dlrm_block(testing_data: Dataset):
2931
bottom_block=mm.MLPBlock([64]),
3032
top_block=mm.DenseResidualBlock(),
3133
)
32-
features = mm.sample_batch(testing_data, batch_size=100, include_targets=False)
34+
features = mm.sample_batch(testing_data, batch_size=10, include_targets=False)
3335
outputs = dlrm(features)
3436
num_features = len(schema.select_by_tag(Tags.CATEGORICAL)) + 1
3537
dot_product_dim = (num_features - 1) * num_features // 2
36-
assert list(outputs.shape) == [100, dot_product_dim + 64]
38+
assert list(outputs.shape) == [10, dot_product_dim + 64]
3739

3840

3941
def test_dlrm_block_no_top_block(testing_data: Dataset):
@@ -43,19 +45,19 @@ def test_dlrm_block_no_top_block(testing_data: Dataset):
4345
embedding_dim=64,
4446
bottom_block=mm.MLPBlock([64]),
4547
)
46-
outputs = dlrm(mm.sample_batch(testing_data, batch_size=100, include_targets=False))
48+
outputs = dlrm(mm.sample_batch(testing_data, batch_size=10, include_targets=False))
4749
num_features = len(schema.select_by_tag(Tags.CATEGORICAL)) + 1
4850
dot_product_dim = (num_features - 1) * num_features // 2
4951

50-
assert list(outputs.shape) == [100, dot_product_dim]
52+
assert list(outputs.shape) == [10, dot_product_dim]
5153

5254

5355
def test_dlrm_block_no_continuous_features(testing_data: Dataset):
5456
schema = testing_data.schema.remove_by_tag(Tags.CONTINUOUS)
5557
dlrm = mm.DLRMBlock(schema, embedding_dim=64, top_block=mm.MLPBlock([32]))
56-
outputs = dlrm(mm.sample_batch(testing_data, batch_size=100, include_targets=False))
58+
outputs = dlrm(mm.sample_batch(testing_data, batch_size=10, include_targets=False))
5759

58-
assert list(outputs.shape) == [100, 32]
60+
assert list(outputs.shape) == [10, 32]
5961

6062

6163
def test_dlrm_block_no_categ_features(testing_data: Dataset):
@@ -70,9 +72,9 @@ def test_dlrm_block_no_categ_features(testing_data: Dataset):
7072
def test_dlrm_block_single_categ_feature(testing_data: Dataset):
7173
schema = testing_data.schema.select_by_tag([Tags.ITEM_ID])
7274
dlrm = mm.DLRMBlock(schema, embedding_dim=64, top_block=mm.MLPBlock([32]))
73-
outputs = dlrm(mm.sample_batch(testing_data, batch_size=100, include_targets=False))
75+
outputs = dlrm(mm.sample_batch(testing_data, batch_size=10, include_targets=False))
7476

75-
assert list(outputs.shape) == [100, 32]
77+
assert list(outputs.shape) == [10, 32]
7678

7779

7880
def test_dlrm_block_no_schema():
@@ -120,6 +122,43 @@ def test_dlrm_with_embeddings(testing_data: Dataset):
120122
bottom_block=mm.MLPBlock([embedding_dim]),
121123
top_block=mm.MLPBlock([top_dim]),
122124
)
123-
outputs = dlrm(mm.sample_batch(testing_data, batch_size=100, include_targets=False))
125+
outputs = dlrm(mm.sample_batch(testing_data, batch_size=10, include_targets=False))
124126

125-
assert list(outputs.shape) == [100, 4]
127+
assert list(outputs.shape) == [10, 4]
128+
129+
130+
def test_dlrm_with_pretrained_embeddings(testing_data: Dataset):
131+
embedding_dim = 12
132+
top_dim = 4
133+
134+
item_cardinality = testing_data.schema["item_id"].int_domain.max + 1
135+
pretrained_embedding = np.random.rand(item_cardinality, 12)
136+
137+
loader = mm.Loader(
138+
testing_data,
139+
batch_size=10,
140+
transforms=[
141+
EmbeddingOperator(
142+
pretrained_embedding,
143+
lookup_key="item_id",
144+
embedding_name="pretrained_item_embeddings",
145+
),
146+
],
147+
)
148+
schema = loader.output_schema
149+
150+
embeddings = mm.Embeddings(schema.select_by_tag(Tags.CATEGORICAL), dim=embedding_dim)
151+
pretrained_embeddings = mm.PretrainedEmbeddings(
152+
schema.select_by_tag(Tags.EMBEDDING),
153+
output_dims=embedding_dim,
154+
)
155+
156+
dlrm = mm.DLRMBlock(
157+
schema,
158+
embeddings=mm.ParallelBlock(embeddings, pretrained_embeddings),
159+
bottom_block=mm.MLPBlock([embedding_dim]),
160+
top_block=mm.MLPBlock([top_dim]),
161+
)
162+
outputs = dlrm(mm.sample_batch(loader, include_targets=False))
163+
164+
assert list(outputs.shape) == [10, 4]

tests/unit/tf/models/test_ranking.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from tensorflow.keras import regularizers
2121

2222
import merlin.models.tf as mm
23+
from merlin.dataloader.ops.embeddings import EmbeddingOperator
2324
from merlin.datasets.synthetic import generate_data
2425
from merlin.io import Dataset
2526
from merlin.models.tf.transforms.features import expected_input_cols_from_schema
@@ -164,6 +165,47 @@ def test_dcn_model(music_streaming_data, stacked, run_eagerly):
164165
testing_utils.model_test(model, music_streaming_data, run_eagerly=run_eagerly)
165166

166167

168+
@pytest.mark.parametrize("run_eagerly", [True, False])
169+
def test_dcn_model_with_pretrained_embeddings(music_streaming_data: Dataset, run_eagerly):
170+
music_streaming_data.schema = music_streaming_data.schema.select_by_name(
171+
["item_id", "item_category", "user_age", "click"]
172+
)
173+
174+
cardinality = music_streaming_data.schema["item_category"].int_domain.max + 1
175+
pretrained_embedding = np.random.rand(cardinality, 12)
176+
177+
loader = mm.Loader(
178+
music_streaming_data,
179+
batch_size=10,
180+
transforms=[
181+
EmbeddingOperator(
182+
pretrained_embedding,
183+
lookup_key="item_category",
184+
embedding_name="pretrained_category_embeddings",
185+
),
186+
],
187+
)
188+
schema = loader.output_schema
189+
190+
pretrained_embeddings = mm.PretrainedEmbeddings(
191+
schema.select_by_tag(Tags.EMBEDDING),
192+
output_dims=16,
193+
)
194+
195+
input_block = mm.InputBlockV2(schema, pretrained_embeddings=pretrained_embeddings)
196+
197+
model = mm.DCNModel(
198+
schema,
199+
input_block=input_block,
200+
depth=1,
201+
deep_block=mm.MLPBlock([2]),
202+
stacked=True,
203+
prediction_tasks=mm.BinaryOutput("click"),
204+
)
205+
206+
testing_utils.model_test(model, loader, run_eagerly=run_eagerly)
207+
208+
167209
@pytest.mark.parametrize("run_eagerly", [True, False])
168210
def test_deepfm_model_only_categ_feats(music_streaming_data, run_eagerly):
169211
music_streaming_data.schema = music_streaming_data.schema.select_by_name(

0 commit comments

Comments
 (0)