1414# limitations under the License.
1515#
1616
17+ import numpy as np
1718import pytest
1819
1920import merlin .models .tf as mm
21+ from merlin .dataloader .ops .embeddings import EmbeddingOperator
2022from merlin .io import Dataset
2123from merlin .schema import Tags
2224
@@ -29,11 +31,11 @@ def test_dlrm_block(testing_data: Dataset):
2931 bottom_block = mm .MLPBlock ([64 ]),
3032 top_block = mm .DenseResidualBlock (),
3133 )
32- features = mm .sample_batch (testing_data , batch_size = 100 , include_targets = False )
34+ features = mm .sample_batch (testing_data , batch_size = 10 , include_targets = False )
3335 outputs = dlrm (features )
3436 num_features = len (schema .select_by_tag (Tags .CATEGORICAL )) + 1
3537 dot_product_dim = (num_features - 1 ) * num_features // 2
36- assert list (outputs .shape ) == [100 , dot_product_dim + 64 ]
38+ assert list (outputs .shape ) == [10 , dot_product_dim + 64 ]
3739
3840
3941def test_dlrm_block_no_top_block (testing_data : Dataset ):
@@ -43,19 +45,19 @@ def test_dlrm_block_no_top_block(testing_data: Dataset):
4345 embedding_dim = 64 ,
4446 bottom_block = mm .MLPBlock ([64 ]),
4547 )
46- outputs = dlrm (mm .sample_batch (testing_data , batch_size = 100 , include_targets = False ))
48+ outputs = dlrm (mm .sample_batch (testing_data , batch_size = 10 , include_targets = False ))
4749 num_features = len (schema .select_by_tag (Tags .CATEGORICAL )) + 1
4850 dot_product_dim = (num_features - 1 ) * num_features // 2
4951
50- assert list (outputs .shape ) == [100 , dot_product_dim ]
52+ assert list (outputs .shape ) == [10 , dot_product_dim ]
5153
5254
5355def test_dlrm_block_no_continuous_features (testing_data : Dataset ):
5456 schema = testing_data .schema .remove_by_tag (Tags .CONTINUOUS )
5557 dlrm = mm .DLRMBlock (schema , embedding_dim = 64 , top_block = mm .MLPBlock ([32 ]))
56- outputs = dlrm (mm .sample_batch (testing_data , batch_size = 100 , include_targets = False ))
58+ outputs = dlrm (mm .sample_batch (testing_data , batch_size = 10 , include_targets = False ))
5759
58- assert list (outputs .shape ) == [100 , 32 ]
60+ assert list (outputs .shape ) == [10 , 32 ]
5961
6062
6163def test_dlrm_block_no_categ_features (testing_data : Dataset ):
@@ -70,9 +72,9 @@ def test_dlrm_block_no_categ_features(testing_data: Dataset):
7072def test_dlrm_block_single_categ_feature (testing_data : Dataset ):
7173 schema = testing_data .schema .select_by_tag ([Tags .ITEM_ID ])
7274 dlrm = mm .DLRMBlock (schema , embedding_dim = 64 , top_block = mm .MLPBlock ([32 ]))
73- outputs = dlrm (mm .sample_batch (testing_data , batch_size = 100 , include_targets = False ))
75+ outputs = dlrm (mm .sample_batch (testing_data , batch_size = 10 , include_targets = False ))
7476
75- assert list (outputs .shape ) == [100 , 32 ]
77+ assert list (outputs .shape ) == [10 , 32 ]
7678
7779
7880def test_dlrm_block_no_schema ():
@@ -120,6 +122,43 @@ def test_dlrm_with_embeddings(testing_data: Dataset):
120122 bottom_block = mm .MLPBlock ([embedding_dim ]),
121123 top_block = mm .MLPBlock ([top_dim ]),
122124 )
123- outputs = dlrm (mm .sample_batch (testing_data , batch_size = 100 , include_targets = False ))
125+ outputs = dlrm (mm .sample_batch (testing_data , batch_size = 10 , include_targets = False ))
124126
125- assert list (outputs .shape ) == [100 , 4 ]
127+ assert list (outputs .shape ) == [10 , 4 ]
128+
129+
130+ def test_dlrm_with_pretrained_embeddings (testing_data : Dataset ):
131+ embedding_dim = 12
132+ top_dim = 4
133+
134+ item_cardinality = testing_data .schema ["item_id" ].int_domain .max + 1
135+ pretrained_embedding = np .random .rand (item_cardinality , 12 )
136+
137+ loader = mm .Loader (
138+ testing_data ,
139+ batch_size = 10 ,
140+ transforms = [
141+ EmbeddingOperator (
142+ pretrained_embedding ,
143+ lookup_key = "item_id" ,
144+ embedding_name = "pretrained_item_embeddings" ,
145+ ),
146+ ],
147+ )
148+ schema = loader .output_schema
149+
150+ embeddings = mm .Embeddings (schema .select_by_tag (Tags .CATEGORICAL ), dim = embedding_dim )
151+ pretrained_embeddings = mm .PretrainedEmbeddings (
152+ schema .select_by_tag (Tags .EMBEDDING ),
153+ output_dims = embedding_dim ,
154+ )
155+
156+ dlrm = mm .DLRMBlock (
157+ schema ,
158+ embeddings = mm .ParallelBlock (embeddings , pretrained_embeddings ),
159+ bottom_block = mm .MLPBlock ([embedding_dim ]),
160+ top_block = mm .MLPBlock ([top_dim ]),
161+ )
162+ outputs = dlrm (mm .sample_batch (loader , include_targets = False ))
163+
164+ assert list (outputs .shape ) == [10 , 4 ]
0 commit comments