Skip to content

Commit 8efbd36

Browse files
Add dataloader pre-trained embeddings support to Merlin Models (#1083)
* Created PretrainedEmbeddings and changed other blocks to support pre-trained embeddings * Using SequenceAggregator with pre-trained embeddings * Added test for aggregating sequences of pre-trained embeddings * Fixed bugs in graph mode when using EmbeddingOperator, as last dim was undefined * Reducing cardinality of testing and sequence_testing fixture from 51996 to 100, in order to speed-up tests * Changing test of Transformers with pre-trained embeddings to use sequence_testing_data * Fixed tests * Fixed tests and added test of pre-trained embeddings with masked language modeling * Added tests with pretrained embeddings for DLRM and DCN * Fixed PopularityBasedSamplerV2 that was raising error when the sampled item is equal the item id cardinality * Linting fix * Fixed failing test * Implementing Sara's suggestions on pretrained embeddings * Fixed test --------- Co-authored-by: edknv <[email protected]>
1 parent 0931633 commit 8efbd36

File tree

22 files changed

+963
-78
lines changed

22 files changed

+963
-78
lines changed

merlin/datasets/entertainment/music_streaming/schema.json

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,16 @@
4242
"tag": [
4343
"categorical",
4444
"item"
45+
],
46+
"extraMetadata": [
47+
{
48+
"_dims": [
49+
[
50+
0.0,
51+
null
52+
]
53+
]
54+
}
4555
]
4656
}
4757
},

merlin/datasets/testing/schema.json

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"intDomain": {
77
"name": "user_id",
88
"min": "1",
9-
"max": "1797",
9+
"max": "90",
1010
"isCategorical": true
1111
},
1212
"annotation": {
@@ -90,7 +90,6 @@
9090
"tag": [
9191
"continuous",
9292
"item"
93-
9493
]
9594
}
9695
},
@@ -100,14 +99,24 @@
10099
"intDomain": {
101100
"name": "item_id",
102101
"min": "1",
103-
"max": "51996",
102+
"max": "100",
104103
"isCategorical": true
105104
},
106105
"annotation": {
107106
"tag": [
108107
"item_id",
109108
"categorical",
110109
"item"
110+
],
111+
"extraMetadata": [
112+
{
113+
"_dims": [
114+
[
115+
0.0,
116+
null
117+
]
118+
]
119+
}
111120
]
112121
}
113122
},
@@ -121,7 +130,7 @@
121130
"intDomain": {
122131
"name": "categories",
123132
"min": "1",
124-
"max": "331",
133+
"max": "70",
125134
"isCategorical": true
126135
},
127136
"annotation": {

merlin/datasets/testing/sequence_testing/schema.json

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,24 @@
66
"intDomain": {
77
"name": "test_user_id",
88
"min": "1",
9-
"max": "1797",
9+
"max": "90",
1010
"isCategorical": true
1111
},
1212
"annotation": {
1313
"tag": [
1414
"categorical",
1515
"user_id",
1616
"user"
17+
],
18+
"extraMetadata": [
19+
{
20+
"_dims": [
21+
[
22+
0.0,
23+
null
24+
]
25+
]
26+
}
1727
]
1828
}
1929
},
@@ -122,7 +132,7 @@
122132
"intDomain": {
123133
"name": "item_id_seq",
124134
"min": "1",
125-
"max": "51996",
135+
"max": "100",
126136
"isCategorical": true
127137
},
128138
"annotation": {

merlin/models/tf/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
ConcatFeatures,
5959
ElementwiseSum,
6060
ElementwiseSumItemMulti,
61+
SequenceAggregator,
6162
StackFeatures,
6263
)
6364
from merlin.models.tf.core.base import (
@@ -86,6 +87,7 @@
8687
Embeddings,
8788
EmbeddingTable,
8889
FeatureConfig,
90+
PretrainedEmbeddings,
8991
SequenceEmbeddingFeatures,
9092
TableConfig,
9193
)
@@ -215,6 +217,7 @@
215217
"EmbeddingTable",
216218
"AverageEmbeddingsByWeightFeature",
217219
"Embeddings",
220+
"PretrainedEmbeddings",
218221
"FeatureConfig",
219222
"TableConfig",
220223
"ParallelPredictionBlock",
@@ -236,6 +239,7 @@
236239
"Filter",
237240
"ParallelBlock",
238241
"StackFeatures",
242+
"SequenceAggregator",
239243
"DotProductInteraction",
240244
"FMPairwiseInteraction",
241245
"FMBlock",

merlin/models/tf/blocks/mlp.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,12 @@ def MLPBlock(
9696

9797
for idx, dim in enumerate(dimensions):
9898
dropout_layer = None
99-
activation_idx = activation if isinstance(activation, str) else activation[idx]
99+
activation = activation or "linear"
100+
if isinstance(activation, str):
101+
activation_idx = activation
102+
else:
103+
activation_idx = activation[idx]
104+
100105
if no_activation_last_layer and idx == len(dimensions) - 1:
101106
activation_idx = "linear"
102107
else:

merlin/models/tf/core/aggregation.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -414,8 +414,10 @@ def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
414414
outputs[k] = v
415415
return outputs
416416
else:
417-
assert len(inputs.shape) == 3, "Tensor inputs should be 3-D"
418-
return combiner(inputs, axis=self.axis, **kwargs)
417+
if inputs.get_shape().rank > self.axis + 1:
418+
return combiner(inputs, axis=self.axis, **kwargs)
419+
else:
420+
return inputs
419421

420422
def parse_combiner(self, combiner):
421423
if isinstance(combiner, str):
@@ -441,8 +443,12 @@ def compute_output_shape(self, input_shape):
441443
outputs[k] = v
442444
return outputs
443445
else:
444-
batch_size, _, last_dim = input_shape
445-
return batch_size, last_dim
446+
if len(input_shape) > self.axis + 1:
447+
return tf.TensorShape(
448+
list(input_shape)[: self.axis] + list(input_shape)[self.axis + 1 :]
449+
)
450+
else:
451+
return input_shape
446452

447453
def get_config(self):
448454
config = super().get_config()

merlin/models/tf/inputs/base.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
EmbeddingFeatures,
3030
EmbeddingOptions,
3131
Embeddings,
32+
PretrainedEmbeddings,
3233
SequenceEmbeddingFeatures,
3334
)
3435
from merlin.schema import Schema, Tags, TagsType
@@ -206,13 +207,15 @@ def InputBlock(
206207
INPUT_TAG_TO_BLOCK: Dict[Tags, Callable[[Schema], Layer]] = {
207208
Tags.CONTINUOUS: Continuous,
208209
Tags.CATEGORICAL: Embeddings,
210+
Tags.EMBEDDING: PretrainedEmbeddings,
209211
}
210212

211213

212214
def InputBlockV2(
213215
schema: Optional[Schema] = None,
214216
categorical: Union[Tags, Layer] = Tags.CATEGORICAL,
215217
continuous: Union[Tags, Layer] = Tags.CONTINUOUS,
218+
pretrained_embeddings: Union[Tags, Layer] = Tags.EMBEDDING,
216219
pre: Optional[BlockType] = None,
217220
post: Optional[BlockType] = None,
218221
aggregation: Optional[TabularAggregationType] = "concat",
@@ -262,11 +265,15 @@ def InputBlockV2(
262265
categorical : Union[Tags, Layer], defaults to `Tags.CATEGORICAL`
263266
A block or column-selector to use for categorical-features.
264267
If a column-selector is provided (either a schema or tags), the selector
265-
will be passed to `Embeddings` to infer the embedding tables from the column-selector.
268+
will be passed to `Embeddings()` to infer the embedding tables from the column-selector.
266269
continuous : Union[Tags, Layer], defaults to `Tags.CONTINUOUS`
267270
A block to use for continuous-features.
268271
If a column-selector is provided (either a schema or tags), the selector
269-
will be passed to `Continuous` to infer the features from the column-selector.
272+
will be passed to `Continuous()` to infer the features from the column-selector.
273+
pretrained_embeddings : Union[Tags, Layer], defaults to `Tags.EMBEDDING`
274+
A block to use for pre-trained embeddings
275+
If a column-selector is provided (either a schema or tags), the selector
276+
will be passed to `PretrainedEmbeddings()` to infer the features from the column-selector.
270277
pre : Optional[BlockType], optional
271278
Transformation block to apply before the embeddings lookup, by default None
272279
post : Optional[BlockType], optional
@@ -297,7 +304,12 @@ def InputBlockV2(
297304
)
298305
categorical = branches["embeddings"]
299306

300-
unparsed = {"categorical": categorical, "continuous": continuous, **branches}
307+
unparsed = {
308+
"categorical": categorical,
309+
"continuous": continuous,
310+
"pretrained_embeddings": pretrained_embeddings,
311+
**branches,
312+
}
301313
parsed = {}
302314
for name, branch in unparsed.items():
303315
if isinstance(branch, Layer):

merlin/models/tf/inputs/embedding.py

Lines changed: 109 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@
2727
import merlin.io
2828
from merlin.core.dispatch import DataFrameType
2929
from merlin.io import Dataset
30-
from merlin.models.tf.blocks.mlp import InitializerType, RegularizerType
31-
from merlin.models.tf.core.base import Block, BlockType
30+
from merlin.models.tf.blocks.mlp import InitializerType, MLPBlock, RegularizerType
31+
from merlin.models.tf.core.aggregation import SequenceAggregator
32+
from merlin.models.tf.core.base import Block, BlockType, NoOp, block_registry
3233
from merlin.models.tf.core.combinators import ParallelBlock, SequentialBlock
3334
from merlin.models.tf.core.tabular import (
3435
TABULAR_MODULE_PARAMS_DOCSTRING,
@@ -423,11 +424,7 @@ def _call_table(self, inputs, **kwargs):
423424
if inputs.shape.as_list()[-1] == 1:
424425
inputs = tf.squeeze(inputs, axis=-1)
425426
out = call_layer(self.table, inputs, **kwargs)
426-
if len(out.get_shape()) > 2 and self.sequence_combiner is not None:
427-
if isinstance(self.sequence_combiner, tf.keras.layers.Layer):
428-
out = call_layer(self.sequence_combiner, out, **kwargs)
429-
elif isinstance(self.sequence_combiner, str):
430-
out = process_str_sequence_combiner(out, self.sequence_combiner, **kwargs)
427+
out = process_sequence_combiner(out, self.sequence_combiner, **kwargs)
431428

432429
if self.l2_batch_regularization_factor > 0:
433430
self.add_loss(self.l2_batch_regularization_factor * tf.reduce_sum(tf.square(out)))
@@ -625,6 +622,95 @@ def _get_dim(col, embedding_dims, infer_dim_fn):
625622
return dim
626623

627624

625+
def PretrainedEmbeddings(
626+
schema: Schema,
627+
output_dims: Optional[Union[Dict[str, int], int]] = None,
628+
sequence_combiner: Optional[Union[CombinerType, Dict[str, CombinerType]]] = "mean",
629+
normalizer: Union[str, tf.keras.layers.Layer] = None,
630+
pre: Optional[BlockType] = None,
631+
post: Optional[BlockType] = None,
632+
aggregation: Optional[TabularAggregationType] = None,
633+
block_name: str = "pretrained_embeddings",
634+
**kwargs,
635+
) -> ParallelBlock:
636+
"""Creates a ParallelBlock with branch for each pre-trained embedding feature
637+
in the schema.
638+
639+
Parameters
640+
----------
641+
schema: Schema
642+
Schema of the input data, with the pre-trained embeddings.
643+
You typically will pass schema.select_by_tag(Tags.EMBEDDING), as that is the tag
644+
added to pre-trained embedding features when using the
645+
merlin.dataloader.ops.embeddings.EmbeddingOperator
646+
output_dims: Optional[Union[Dict[str, int], int]], optional
647+
If provided, it projects features to specified dim(s).
648+
If an int, all features are projected to that dim.
649+
If a dict, only features provided in the dict will be mapped to the specified dim,
650+
for example {"feature_name": projection_dim, ...}. By default None
651+
sequence_combiner: Optional[Union[str, tf.keras.layers.Layer]], optional
652+
A string ("mean", "sum", "max", "min") or Layer specifying
653+
how to combine the second dimension of
654+
the pre-trained embeddings if it is 3D.
655+
Default is None (no sequence combiner used)
656+
normalizer: Union[str, tf.keras.layers.Layer], optional
657+
A Layer (e.g. mm.L2Norm()) or string ("l2-norm") to be applied
658+
to pre-trained embeddings after projected and sequence combined
659+
Default is None (no normalization)
660+
pre: Optional[BlockType], optional
661+
Transformation block to apply before the embeddings lookup, by default None
662+
post: Optional[BlockType], optional
663+
Transformation block to apply after the embeddings lookup, by default None
664+
aggregation: Optional[TabularAggregationType], optional
665+
Transformation block to apply for aggregating the inputs, by default None
666+
block_name: str, optional
667+
Name of the block, by default "pretrained_embeddings"
668+
Returns
669+
-------
670+
ParallelBlock
671+
Returns a parallel block with a branch for each pre-trained embedding
672+
"""
673+
674+
tables = {}
675+
676+
for col in schema:
677+
table_name = col.name
678+
679+
tables[table_name] = NoOp()
680+
681+
if output_dims:
682+
new_dim = output_dims
683+
if isinstance(output_dims, dict):
684+
if table_name in output_dims:
685+
new_dim = (
686+
output_dims[table_name] if isinstance(output_dims, dict) else output_dims
687+
)
688+
else:
689+
new_dim = None
690+
if new_dim:
691+
tables[table_name] = MLPBlock([new_dim], activation=None)
692+
693+
if sequence_combiner:
694+
if isinstance(sequence_combiner, str):
695+
sequence_combiner = SequenceAggregator(sequence_combiner)
696+
697+
tables[table_name] = SequentialBlock([tables[table_name], sequence_combiner])
698+
699+
if normalizer:
700+
normalizer = block_registry.parse(normalizer)
701+
tables[table_name] = SequentialBlock([tables[table_name], normalizer])
702+
703+
return ParallelBlock(
704+
tables,
705+
pre=pre,
706+
post=post,
707+
aggregation=aggregation,
708+
name=block_name,
709+
schema=schema,
710+
**kwargs,
711+
)
712+
713+
628714
@tf.keras.utils.register_keras_serializable(package="merlin.models")
629715
class AverageEmbeddingsByWeightFeature(tf.keras.layers.Layer):
630716
def __init__(self, weight_feature_name: str, axis=1, **kwargs):
@@ -1215,17 +1301,28 @@ def serialize_feature_config(feature_config: FeatureConfig) -> Dict[str, Any]:
12151301
return outputs
12161302

12171303

1304+
def process_sequence_combiner(inputs, combiner, **kwargs):
1305+
result = inputs
1306+
if len(inputs.get_shape()) > 2 and combiner:
1307+
if isinstance(combiner, tf.keras.layers.Layer):
1308+
result = call_layer(combiner, inputs, **kwargs)
1309+
elif isinstance(combiner, str):
1310+
result = process_str_sequence_combiner(inputs, combiner, **kwargs)
1311+
1312+
return result
1313+
1314+
12181315
def process_str_sequence_combiner(
12191316
inputs: Union[tf.Tensor, tf.RaggedTensor], combiner: str, **kwargs
12201317
) -> tf.Tensor:
1221-
"""Process inputs with str sequence combiners ("mean" or "sum")
1318+
"""Process inputs with str sequence combiners ("mean", "sum" or "max")
12221319
12231320
Parameters
12241321
----------
12251322
inputs : Union[tf.Tensor, tf.RaggedTensor]
12261323
Input 3D tensor (batch size, seq length, embedding dim)
12271324
combiner : str
1228-
The combiner: "mean" or "sum"
1325+
The combiner: "mean", "sum" or "max"
12291326
12301327
Returns
12311328
-------
@@ -1238,9 +1335,11 @@ def process_str_sequence_combiner(
12381335
combiner = tf.keras.layers.Lambda(lambda x: tf.reduce_mean(x, axis=1))
12391336
elif combiner == "sum":
12401337
combiner = tf.keras.layers.Lambda(lambda x: tf.reduce_sum(x, axis=1))
1338+
elif combiner == "max":
1339+
combiner = tf.keras.layers.Lambda(lambda x: tf.reduce_max(x, axis=1))
12411340
else:
12421341
raise ValueError(
1243-
"Only 'mean' and 'sum' str combiners is implemented for dense"
1342+
"Only 'mean', 'sum', and 'max' str combiners is implemented for dense"
12441343
" list/multi-hot embedded features. You can also"
12451344
" provide a tf.keras.layers.Layer instance as a sequence combiner."
12461345
)

0 commit comments

Comments
 (0)