@@ -129,7 +129,7 @@ def gen_inputs(self, table_sizes, input_to_table_map=None, mp_input_ids=None):
129129 dp_inputs = [
130130 t [self .hvd_rank * local_batch :(self .hvd_rank + 1 ) * local_batch ] for t in global_inputs
131131 ]
132- mp_inputs = [global_inputs [i ] for i in mp_input_ids ] if mp_input_ids else None
132+ mp_inputs = [global_inputs [i ] for i in mp_input_ids ] if mp_input_ids else []
133133
134134 return dp_inputs , mp_inputs
135135
@@ -362,6 +362,21 @@ def test_set_weight_uninitialized(self):
362362 test_model .dist_embeddings .set_weights (ref_weights [:num_tables ])
363363 test_model .dense .set_weights (ref_weights [num_tables :])
364364
365+ def test_indivisible_batch (self ):
366+ table_sizes = self .gen_table_sizes ()
367+
368+ ref_model = EmbeddingListModel (table_sizes , distribute = False )
369+ test_model = EmbeddingListModel (table_sizes , distribute = True , strategy = 'basic' , dp_input = False )
370+
371+ # First generate model parallel batches that's divisible by world_size. We then use (batch_size - 1)
372+ # which will be indivisible by world_size greater than 1 due to consecutive numbers coprimes
373+ mp_input_ids = test_model .dist_embeddings .strategy .input_ids_list [self .hvd_rank ]
374+ dp_inputs , mp_inputs = self .gen_inputs (table_sizes , mp_input_ids = mp_input_ids )
375+ mp_inputs = [inp [1 :] for inp in mp_inputs ]
376+ if self .hvd_size > 1 :
377+ with self .assertRaisesRegex (ValueError , "not divisible" ):
378+ self .run_and_test (ref_model , dp_inputs , test_model , mp_inputs )
379+
365380
366381if __name__ == "__main__" :
367382 test .main ()
0 commit comments