@@ -74,6 +74,9 @@ def __init__(self,
7474 # column_slice can be used to enable more table concat, so keep it in single process
7575 self .column_slice_threshold = column_slice_threshold
7676 self .global_configs = [e .get_config () for e in embeddings ]
77+ # Insert layer type information to config dicts
78+ for config , embedding in zip (self .global_configs , embeddings ):
79+ config ['layer_type' ] = type (embedding )
7780 if input_table_map is None :
7881 input_table_map = list (range (len (embeddings )))
7982
@@ -274,8 +277,10 @@ def _create_concat(self, table_configs, input_maps):
274277 for concat_config in concat_configs :
275278 input_dims = concat_config .pop ('input_dims' )
276279 if len (input_dims ) > 1 :
277- orig_initializer = initializers .deserialize (concat_config ['embeddings_initializer' ])
278- concat_config ['embeddings_initializer' ] = ConcatInitializer (orig_initializer , input_dims )
280+ # TODO(deyuf): custom layer without initializer will be concat but init is not wrapped
281+ if 'embeddings_initializer' in concat_config :
282+ orig_initializer = initializers .deserialize (concat_config ['embeddings_initializer' ])
283+ concat_config ['embeddings_initializer' ] = ConcatInitializer (orig_initializer , input_dims )
279284
280285 # record weight offsets for get/set.
281286 weight_offsets = [concat_config .pop ('offsets' , None ) for concat_config in concat_configs ]
@@ -363,8 +368,12 @@ def __init__(self,
363368 # create local embeddings
364369 self .local_embedding_layers = []
365370 for config in self .strategy .local_configs [self .rank ]:
366- config ['synchronization' ] = tf .VariableSynchronization .NONE
367- self .local_embedding_layers .append (Embedding .from_config (config ))
371+ layer_type = config .pop ('layer_type' )
372+ # For stock keras Embedding, we switch underlying layer for better performance
373+ # If inputs are custom layers, original layer will be used
374+ # TODO(deyuf): Check functionality coverage, add fallback or type picking api
375+ layer_type = Embedding if layer_type == tf .keras .layers .Embedding else layer_type
376+ self .local_embedding_layers .append (layer_type .from_config (config ))
368377 self .offsets = [
369378 None if offset == 0 else tf .constant ([offset ], dtype = tf .int64 )
370379 for offset in self .strategy .local_input_offsets [self .rank ]
@@ -651,6 +660,11 @@ def build(self, input_shape):
651660 F"Global batchsize { batch_sizes [0 ]} not divisible workers count { self .world_size } ." )
652661 for layer in self .local_embedding_layers :
653662 layer .build (input_shape [0 ] if input_shape else None )
663+ for var in layer .trainable_weights :
664+ # Mark local(model parallel) variable. use prefix de(distributed embeddings) to avoid conflicts.
665+ var .de_local = True
666+ # set built flag to prevent above build trigger again and above flag fall off
667+ layer .built = True
654668 self .built = True
655669
656670 def call (self , inputs ): # pylint: disable=missing-function-docstring
@@ -671,7 +685,7 @@ def broadcast_variables(model_vars, root_rank=0): # pylint: disable=missing-any
671685 dp_vars = []
672686 mp_vars = []
673687 for var in model_vars :
674- if var . synchronization == tf . VariableSynchronization . NONE :
688+ if hasattr ( var , 'de_local' ) :
675689 mp_vars .append (var )
676690 else :
677691 dp_vars .append (var )
@@ -693,7 +707,7 @@ def gradient(self, target, sources, output_gradients=None):
693707 mp_grads = []
694708 split_infos = []
695709 for grad , var in zip (gradients , sources ):
696- if var . synchronization == tf . VariableSynchronization . NONE :
710+ if hasattr ( var , 'de_local' ) :
697711 if isinstance (grad , tf .IndexedSlices ):
698712 mp_grads .append (tf .IndexedSlices (grad .values / hvd .size (), grad .indices ,
699713 grad .dense_shape ))
0 commit comments