Skip to content

Commit a581613

Browse files
committed
Add automatically slicing when there is more workers than tables and no column_slice_threshold is set. These case now run without NotImplemented error.
1 parent 70b7d20 commit a581613

File tree

2 files changed

+31
-7
lines changed

2 files changed

+31
-7
lines changed

distributed_embeddings/python/layers/dist_model_parallel.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,16 @@ def create_sliced_configs(self, world_size, column_slice_threshold, input_table_
178178
sliced_out_ranges (list): each element is list of 2 integers, representing output ranges need
179179
to be concatenated to re-form output due to above slice.
180180
"""
181+
# TODO(Deyu): in auto slice and when there are equal sized tables, allow slice some of them
182+
# less table than worker, we try our best to slice into worker count slices(may go over)
183+
if column_slice_threshold is None:
184+
table_sizes = [config['input_dim'] * config['output_dim'] for config in self.global_configs]
185+
while world_size > len(table_sizes):
186+
table_sizes.sort()
187+
column_slice_threshold = table_sizes[-1] - 1
188+
cur_max_size = table_sizes.pop(-1)
189+
table_sizes += [cur_max_size // 2, cur_max_size // 2]
190+
181191
sliced_configs = []
182192
for global_config in self.global_configs:
183193
maybe_sliced_config = self.maybe_slice_table_column(global_config, column_slice_threshold,
@@ -300,8 +310,11 @@ class DistributedEmbedding(tf.keras.layers.Layer):
300310
embeddings (list of keras Embedding layers): embedding tables to be distributed
301311
strategy (str): A string indicates how embedding tables are distributed.
302312
Choices are [“basic”, “memory_balanced”]. Default "basic"
303-
column_slice_threshold (int or None): If not None, embedding tables with more elements than
304-
column_slice_threshold will be divide into N even pieces alone embedded width dimension.
313+
column_slice_threshold (int or None): If None, column slice only happen when there are more
314+
workers than tables. In that case, column_slice_threshold will be choose automatically
315+
so each worker receive at least one slice.
316+
If not None, embedding tables with more elements than column_slice_threshold will be divide
317+
into N even pieces alone embedded width dimension.
305318
N is smallest power of 2 makes each slice smaller than column_slice_threshold. Default None.
306319
row_slice (TBD): Describe how which embedding needs to be row sliced
307320
dp_input (bool): If True, takes data parallel input, i.e. in shape
@@ -342,8 +355,10 @@ def __init__(self,
342355
strategy,
343356
input_table_map=input_table_map,
344357
column_slice_threshold=column_slice_threshold)
345-
if len(self.strategy.global_configs) < self.world_size:
346-
raise NotImplementedError
358+
# Handle explicit threshold or corner cases, in which worker may receive no configs
359+
if not all(rank_configs for rank_configs in self.strategy.local_configs):
360+
raise ValueError("Not enough table after slicing to run on all worker."
361+
"Try decrease column_slice_threshold or decrease worker count")
347362

348363
# create local embeddings
349364
self.local_embedding_layers = []

distributed_embeddings/python/layers/dist_model_parallel_test.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,11 @@ def __init__(self, *args, **kwargs):
9696
def gen_table_sizes(self, num_tables=None):
9797
random.seed(self.seed)
9898
if num_tables is None:
99-
num_tables = random.randint(self.hvd_size, 2 * self.hvd_size)
99+
num_tables = random.randint(1, 2 * self.hvd_size)
100100
table_sizes = []
101101
for _ in range(num_tables):
102102
table_height = random.randint(3, 20)
103-
table_width = random.randint(3, 15)
103+
table_width = random.randint(4, 15)
104104
table_sizes.append([table_height, table_width])
105105
return table_sizes
106106

@@ -278,7 +278,7 @@ def test_column_slice_merge(self):
278278
self.assertEqual(len(tables), len(set(tables)))
279279

280280
def test_column_slice_threshold(self):
281-
table_sizes = self.gen_table_sizes()
281+
table_sizes = self.gen_table_sizes(self.hvd_size + 1)
282282
ref_model = EmbeddingListModel(table_sizes, distribute=False)
283283
test_model = EmbeddingListModel(table_sizes,
284284
distribute=True,
@@ -377,6 +377,15 @@ def test_indivisible_batch(self):
377377
with self.assertRaisesRegex(ValueError, "not divisible"):
378378
self.run_and_test(ref_model, dp_inputs, test_model, mp_inputs)
379379

380+
def test_fewer_tables_than_workers(self):
381+
table_sizes = self.gen_table_sizes(1)
382+
383+
ref_model = EmbeddingListModel(table_sizes, distribute=False)
384+
test_model = EmbeddingListModel(table_sizes, distribute=True, strategy='memory_balanced')
385+
386+
dp_inputs, _ = self.gen_inputs(table_sizes)
387+
self.run_and_test(ref_model, dp_inputs, test_model, dp_inputs)
388+
380389

381390
if __name__ == "__main__":
382391
test.main()

0 commit comments

Comments
 (0)