Skip to content

Commit 70b7d20

Browse files
committed
Raise a descriptive exception when global batch size is not divisible by the number of workers
1 parent c9e7d10 commit 70b7d20

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

distributed_embeddings/python/layers/dist_model_parallel.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -622,8 +622,20 @@ def get_weights(self, all_ranks=False):
622622

623623
@tf_utils.shape_type_conversion
624624
def build(self, input_shape):
625+
if input_shape is not None:
626+
# Do some checks to detect cases that are not supported
627+
if not isinstance(input_shape, list):
628+
input_shape = [input_shape]
629+
batch_sizes = [shape[0] for shape in input_shape]
630+
batch_sizes = hvd.allgather(batch_sizes).numpy().tolist()
631+
if len(set(batch_sizes)) > 1:
632+
raise ValueError(F"All input need to have same batchsize. got {set(batch_sizes)}.")
633+
if not self.dp_input:
634+
if batch_sizes[0] % self.world_size > 0:
635+
raise ValueError(
636+
F"Global batchsize {batch_sizes[0]} not divisible workers count {self.world_size}.")
625637
for layer in self.local_embedding_layers:
626-
layer.build(input_shape)
638+
layer.build(input_shape[0] if input_shape else None)
627639
self.built = True
628640

629641
def call(self, inputs): # pylint: disable=missing-function-docstring

distributed_embeddings/python/layers/dist_model_parallel_test.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def gen_inputs(self, table_sizes, input_to_table_map=None, mp_input_ids=None):
129129
dp_inputs = [
130130
t[self.hvd_rank * local_batch:(self.hvd_rank + 1) * local_batch] for t in global_inputs
131131
]
132-
mp_inputs = [global_inputs[i] for i in mp_input_ids] if mp_input_ids else None
132+
mp_inputs = [global_inputs[i] for i in mp_input_ids] if mp_input_ids else []
133133

134134
return dp_inputs, mp_inputs
135135

@@ -362,6 +362,21 @@ def test_set_weight_uninitialized(self):
362362
test_model.dist_embeddings.set_weights(ref_weights[:num_tables])
363363
test_model.dense.set_weights(ref_weights[num_tables:])
364364

365+
def test_indivisible_batch(self):
366+
table_sizes = self.gen_table_sizes()
367+
368+
ref_model = EmbeddingListModel(table_sizes, distribute=False)
369+
test_model = EmbeddingListModel(table_sizes, distribute=True, strategy='basic', dp_input=False)
370+
371+
# First generate model parallel batches that's divisible by world_size. We then use (batch_size - 1)
372+
# which will be indivisible by world_size greater than 1 due to consecutive numbers coprimes
373+
mp_input_ids = test_model.dist_embeddings.strategy.input_ids_list[self.hvd_rank]
374+
dp_inputs, mp_inputs = self.gen_inputs(table_sizes, mp_input_ids=mp_input_ids)
375+
mp_inputs = [inp[1:] for inp in mp_inputs]
376+
if self.hvd_size > 1:
377+
with self.assertRaisesRegex(ValueError, "not divisible"):
378+
self.run_and_test(ref_model, dp_inputs, test_model, mp_inputs)
379+
365380

366381
if __name__ == "__main__":
367382
test.main()

0 commit comments

Comments
 (0)