@@ -178,6 +178,16 @@ def create_sliced_configs(self, world_size, column_slice_threshold, input_table_
178178 sliced_out_ranges (list): each element is list of 2 integers, representing output ranges need
179179 to be concatenated to re-form output due to above slice.
180180 """
181+ # TODO(Deyu): in auto slice and when there are equal sized tables, allow slice some of them
182+ # less table than worker, we try our best to slice into worker count slices(may go over)
183+ if column_slice_threshold is None :
184+ table_sizes = [config ['input_dim' ] * config ['output_dim' ] for config in self .global_configs ]
185+ while world_size > len (table_sizes ):
186+ table_sizes .sort ()
187+ column_slice_threshold = table_sizes [- 1 ] - 1
188+ cur_max_size = table_sizes .pop (- 1 )
189+ table_sizes += [cur_max_size // 2 , cur_max_size // 2 ]
190+
181191 sliced_configs = []
182192 for global_config in self .global_configs :
183193 maybe_sliced_config = self .maybe_slice_table_column (global_config , column_slice_threshold ,
@@ -300,8 +310,11 @@ class DistributedEmbedding(tf.keras.layers.Layer):
300310 embeddings (list of keras Embedding layers): embedding tables to be distributed
301311 strategy (str): A string indicates how embedding tables are distributed.
302312 Choices are [“basic”, “memory_balanced”]. Default "basic"
303- column_slice_threshold (int or None): If not None, embedding tables with more elements than
304- column_slice_threshold will be divide into N even pieces alone embedded width dimension.
313+ column_slice_threshold (int or None): If None, column slice only happen when there are more
314+ workers than tables. In that case, column_slice_threshold will be choose automatically
315+ so each worker receive at least one slice.
316+ If not None, embedding tables with more elements than column_slice_threshold will be divide
317+ into N even pieces alone embedded width dimension.
305318 N is smallest power of 2 makes each slice smaller than column_slice_threshold. Default None.
306319 row_slice (TBD): Describe how which embedding needs to be row sliced
307320 dp_input (bool): If True, takes data parallel input, i.e. in shape
@@ -342,8 +355,10 @@ def __init__(self,
342355 strategy ,
343356 input_table_map = input_table_map ,
344357 column_slice_threshold = column_slice_threshold )
345- if len (self .strategy .global_configs ) < self .world_size :
346- raise NotImplementedError
358+ # Handle explicit threshold or corner cases, in which worker may receive no configs
359+ if not all (rank_configs for rank_configs in self .strategy .local_configs ):
360+ raise ValueError ("Not enough table after slicing to run on all worker."
361+ "Try decrease column_slice_threshold or decrease worker count" )
347362
348363 # create local embeddings
349364 self .local_embedding_layers = []
0 commit comments