2727import merlin .io
2828from merlin .core .dispatch import DataFrameType
2929from merlin .io import Dataset
30- from merlin .models .tf .blocks .mlp import InitializerType , RegularizerType
31- from merlin .models .tf .core .base import Block , BlockType
30+ from merlin .models .tf .blocks .mlp import InitializerType , MLPBlock , RegularizerType
31+ from merlin .models .tf .core .aggregation import SequenceAggregator
32+ from merlin .models .tf .core .base import Block , BlockType , NoOp , block_registry
3233from merlin .models .tf .core .combinators import ParallelBlock , SequentialBlock
3334from merlin .models .tf .core .tabular import (
3435 TABULAR_MODULE_PARAMS_DOCSTRING ,
@@ -423,11 +424,7 @@ def _call_table(self, inputs, **kwargs):
423424 if inputs .shape .as_list ()[- 1 ] == 1 :
424425 inputs = tf .squeeze (inputs , axis = - 1 )
425426 out = call_layer (self .table , inputs , ** kwargs )
426- if len (out .get_shape ()) > 2 and self .sequence_combiner is not None :
427- if isinstance (self .sequence_combiner , tf .keras .layers .Layer ):
428- out = call_layer (self .sequence_combiner , out , ** kwargs )
429- elif isinstance (self .sequence_combiner , str ):
430- out = process_str_sequence_combiner (out , self .sequence_combiner , ** kwargs )
427+ out = process_sequence_combiner (out , self .sequence_combiner , ** kwargs )
431428
432429 if self .l2_batch_regularization_factor > 0 :
433430 self .add_loss (self .l2_batch_regularization_factor * tf .reduce_sum (tf .square (out )))
@@ -625,6 +622,95 @@ def _get_dim(col, embedding_dims, infer_dim_fn):
625622 return dim
626623
627624
625+ def PretrainedEmbeddings (
626+ schema : Schema ,
627+ output_dims : Optional [Union [Dict [str , int ], int ]] = None ,
628+ sequence_combiner : Optional [Union [CombinerType , Dict [str , CombinerType ]]] = "mean" ,
629+ normalizer : Union [str , tf .keras .layers .Layer ] = None ,
630+ pre : Optional [BlockType ] = None ,
631+ post : Optional [BlockType ] = None ,
632+ aggregation : Optional [TabularAggregationType ] = None ,
633+ block_name : str = "pretrained_embeddings" ,
634+ ** kwargs ,
635+ ) -> ParallelBlock :
636+ """Creates a ParallelBlock with branch for each pre-trained embedding feature
637+ in the schema.
638+
639+ Parameters
640+ ----------
641+ schema: Schema
642+ Schema of the input data, with the pre-trained embeddings.
643+ You typically will pass schema.select_by_tag(Tags.EMBEDDING), as that is the tag
644+ added to pre-trained embedding features when using the
645+ merlin.dataloader.ops.embeddings.EmbeddingOperator
646+ output_dims: Optional[Union[Dict[str, int], int]], optional
647+ If provided, it projects features to specified dim(s).
648+ If an int, all features are projected to that dim.
649+ If a dict, only features provided in the dict will be mapped to the specified dim,
650+ for example {"feature_name": projection_dim, ...}. By default None
651+ sequence_combiner: Optional[Union[str, tf.keras.layers.Layer]], optional
652+ A string ("mean", "sum", "max", "min") or Layer specifying
653+ how to combine the second dimension of
654+ the pre-trained embeddings if it is 3D.
655+ Default is None (no sequence combiner used)
656+ normalizer: Union[str, tf.keras.layers.Layer], optional
657+ A Layer (e.g. mm.L2Norm()) or string ("l2-norm") to be applied
658+ to pre-trained embeddings after projected and sequence combined
659+ Default is None (no normalization)
660+ pre: Optional[BlockType], optional
661+ Transformation block to apply before the embeddings lookup, by default None
662+ post: Optional[BlockType], optional
663+ Transformation block to apply after the embeddings lookup, by default None
664+ aggregation: Optional[TabularAggregationType], optional
665+ Transformation block to apply for aggregating the inputs, by default None
666+ block_name: str, optional
667+ Name of the block, by default "pretrained_embeddings"
668+ Returns
669+ -------
670+ ParallelBlock
671+ Returns a parallel block with a branch for each pre-trained embedding
672+ """
673+
674+ tables = {}
675+
676+ for col in schema :
677+ table_name = col .name
678+
679+ tables [table_name ] = NoOp ()
680+
681+ if output_dims :
682+ new_dim = output_dims
683+ if isinstance (output_dims , dict ):
684+ if table_name in output_dims :
685+ new_dim = (
686+ output_dims [table_name ] if isinstance (output_dims , dict ) else output_dims
687+ )
688+ else :
689+ new_dim = None
690+ if new_dim :
691+ tables [table_name ] = MLPBlock ([new_dim ], activation = None )
692+
693+ if sequence_combiner :
694+ if isinstance (sequence_combiner , str ):
695+ sequence_combiner = SequenceAggregator (sequence_combiner )
696+
697+ tables [table_name ] = SequentialBlock ([tables [table_name ], sequence_combiner ])
698+
699+ if normalizer :
700+ normalizer = block_registry .parse (normalizer )
701+ tables [table_name ] = SequentialBlock ([tables [table_name ], normalizer ])
702+
703+ return ParallelBlock (
704+ tables ,
705+ pre = pre ,
706+ post = post ,
707+ aggregation = aggregation ,
708+ name = block_name ,
709+ schema = schema ,
710+ ** kwargs ,
711+ )
712+
713+
628714@tf .keras .utils .register_keras_serializable (package = "merlin.models" )
629715class AverageEmbeddingsByWeightFeature (tf .keras .layers .Layer ):
630716 def __init__ (self , weight_feature_name : str , axis = 1 , ** kwargs ):
@@ -1215,17 +1301,28 @@ def serialize_feature_config(feature_config: FeatureConfig) -> Dict[str, Any]:
12151301 return outputs
12161302
12171303
1304+ def process_sequence_combiner (inputs , combiner , ** kwargs ):
1305+ result = inputs
1306+ if len (inputs .get_shape ()) > 2 and combiner :
1307+ if isinstance (combiner , tf .keras .layers .Layer ):
1308+ result = call_layer (combiner , inputs , ** kwargs )
1309+ elif isinstance (combiner , str ):
1310+ result = process_str_sequence_combiner (inputs , combiner , ** kwargs )
1311+
1312+ return result
1313+
1314+
12181315def process_str_sequence_combiner (
12191316 inputs : Union [tf .Tensor , tf .RaggedTensor ], combiner : str , ** kwargs
12201317) -> tf .Tensor :
1221- """Process inputs with str sequence combiners ("mean" or "sum ")
1318+ """Process inputs with str sequence combiners ("mean", "sum" or "max ")
12221319
12231320 Parameters
12241321 ----------
12251322 inputs : Union[tf.Tensor, tf.RaggedTensor]
12261323 Input 3D tensor (batch size, seq length, embedding dim)
12271324 combiner : str
1228- The combiner: "mean" or "sum "
1325+ The combiner: "mean", "sum" or "max "
12291326
12301327 Returns
12311328 -------
@@ -1238,9 +1335,11 @@ def process_str_sequence_combiner(
12381335 combiner = tf .keras .layers .Lambda (lambda x : tf .reduce_mean (x , axis = 1 ))
12391336 elif combiner == "sum" :
12401337 combiner = tf .keras .layers .Lambda (lambda x : tf .reduce_sum (x , axis = 1 ))
1338+ elif combiner == "max" :
1339+ combiner = tf .keras .layers .Lambda (lambda x : tf .reduce_max (x , axis = 1 ))
12411340 else :
12421341 raise ValueError (
1243- "Only 'mean' and 'sum ' str combiners is implemented for dense"
1342+ "Only 'mean', 'sum', and 'max ' str combiners is implemented for dense"
12441343 " list/multi-hot embedded features. You can also"
12451344 " provide a tf.keras.layers.Layer instance as a sequence combiner."
12461345 )
0 commit comments