1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515"""Distributed Embedding layers and utils"""
16+ import math
17+ import numpy as np
1618import tensorflow as tf
1719from tensorflow .python .keras .utils import tf_utils
1820import horovod .tensorflow as hvd
21+ from distributed_embeddings .python .ops .embedding_lookup_ops import read_var_no_copy
1922from .embedding import Embedding
2023
2124
@@ -48,19 +51,21 @@ def __init__(self,
4851 self .input_ids_list = [list (range (len (input_table_map )))]
4952 self .table_ids_list = [list (range (len (embeddings )))]
5053 return
54+
5155 # Create (maybe) sliced configs
5256 sliced_configs , self .sliced_out_ranges = self .create_sliced_configs (
5357 world_size , column_slice_threshold , input_table_map )
5458 # Apply strategy and save nested list containing table indices by rank
5559 self .table_ids_list = self .apply_stragety (strategy , world_size , sliced_configs )
56- # Nested list to split embedding output from each rank into tables
57- self .widths_list = []
60+
5861 # Nested list containing input indices by rank
5962 self .input_ids_list = []
6063 # Nested list containing local input to local table map by rank
6164 self .local_map_list = []
6265 # Nested list containing local configs by rank
6366 self .local_configs_list = []
67+ # All of local widths ordered by rank flat into single list
68+ self .widths_list_flat = []
6469 # Each worker loop over all rank to get global view of strategy
6570 for rank_table_ids in self .table_ids_list :
6671 # calculate stats needed for each rank
@@ -73,20 +78,22 @@ def __init__(self,
7378 rank_input_ids .append (k )
7479 rank_input_map .append (m )
7580 self .local_configs_list .append (rank_configs )
76- self .widths_list . append ( rank_widths )
81+ self .widths_list_flat += rank_widths
7782 self .input_ids_list .append (rank_input_ids )
7883 self .local_map_list .append (rank_input_map )
79- # List of total embedding widths to split embedding output by rank after alltoall
80- self .total_local_widths = [sum (widths ) for widths in self .widths_list ]
84+
8185 # List that maps local inputs to local table
8286 self .local_input_table_map = self .local_map_list [rank ]
87+
8388 # flatten self.input_ids_list
8489 worker_order_input_ids = [item for sublist in self .input_ids_list for item in sublist ]
90+
8591 # List of indices to shuffle worker ordered embedding outputs back to original order
8692 self .rev_global_input_ids = [
8793 index
8894 for _ , index in sorted (zip (worker_order_input_ids , range (len (worker_order_input_ids ))))
8995 ]
96+
9097 # List of configs to create local embedding layers
9198 self .local_configs = self .local_configs_list [rank ]
9299
@@ -286,18 +293,17 @@ def _call_base(self, inputs): # pylint: disable=missing-param-doc,missing-type-
286293 for m , inp in zip (self .strategy .local_input_table_map , inputs )
287294 ]
288295
289- # concat last axis to make all2all slice correct, and reshape to make later split easier
290296 # TODO(Deyu): current assume 2D with same batch for all output, ideally should support general case
291- local_bs = inputs [0 ].shape [0 ] // self .world_size
292- mp_outs = tf .reshape (tf .concat (mp_outs , axis = - 1 ), [- 1 , local_bs ])
297+ mp_outs = [tf .reshape (mp_out , [self .world_size , - 1 ]) for mp_out in mp_outs ]
298+ mp_outs = tf .reshape (tf .concat (mp_outs , axis = 1 ), [- 1 ])
299+ # cast before alltoall according to dtype policy
300+ mp_outs = tf .cast (mp_outs , self .compute_dtype )
293301 dp_outs = hvd .alltoall (mp_outs , name = 'out_mp_to_dp' )
294- dp_outs = [
295- tf .reshape (t , [local_bs , - 1 ]) for t in tf .split (dp_outs , self .strategy .total_local_widths )
296- ]
297- # split each worker result and re-order using id
298- worker_order_res = []
299- for dp_out , widths in zip (dp_outs , self .strategy .widths_list ):
300- worker_order_res += tf .split (dp_out , widths , 1 )
302+ local_bs = inputs [0 ].shape [0 ] // self .world_size
303+ num_elements = [local_bs * item for item in self .strategy .widths_list_flat ]
304+ split_outs = tf .split (dp_outs , num_elements )
305+ worker_order_res = [tf .reshape (split_out , [local_bs , - 1 ]) for split_out in split_outs ]
306+
301307 # reorder outputs to be same as inputs order
302308 result = [worker_order_res [index ] for index in self .strategy .rev_global_input_ids ]
303309 return result
@@ -309,70 +315,149 @@ def _concat_column_slice_outputs(self, outs):
309315 outs [start :end ] = [tf .concat (outs [start :end ], axis = - 1 )]
310316 return outs
311317
312- def set_weights (self , weights ): # pylint: disable=missing-param-doc,missing-type-doc
318+ def set_weights (self , weights , chunk = 134217728 , use_lock = False ):
313319 """Sets the weights of the layer, from NumPy arrays.
314320
315- This override expects global weights for all tables as input.
321+ Args:
322+ weights (list): list containing global weights for all table.
323+ item in the list can be either numpy array or file path to load from.
324+ chunk (int): max number of elements per chunk when set weight on GPU by chunks.
325+ this will be round to number of rows base on weight shape.
326+ use_lock (bool): If true, set weights rank by rank in lock step to avoid OOM. Default False.
316327 """
317- if self .world_size == 1 :
318- sliced_local_weights = weights
319- else :
328+ if use_lock :
329+ for _ in range (self .rank ):
330+ hvd .broadcast_object (0 )
331+
332+ if self .world_size > 1 :
320333 slice_info = [[rank_tids .count (tid )
321334 for rank_tids in self .strategy .table_ids_list ]
322335 for tid in range (len (weights ))]
323- local_weights = [weights [index ] for index in self .strategy .table_ids_list [self .rank ]]
336+ weights = [weights [index ] for index in self .strategy .table_ids_list [self .rank ]]
337+ if isinstance (weights [0 ], str ):
338+ weights = [np .load (file = path , mmap_mode = 'r' ) for path in weights ]
324339 local_info = [slice_info [index ] for index in self .strategy .table_ids_list [self .rank ]]
340+ # array to handle multiple slice into same table case
341+ # TODO(Deyu): avoid this by merge those table again after find strategy
342+ rank_ids = self .strategy .table_ids_list [self .rank ]
343+ index_offset = [rank_ids [:i ].count (rank_id ) for i , rank_id in enumerate (rank_ids )]
325344
326- def _slice_weight_for_rank (weight , info , global_rank ):
345+ def _slice_weight_for_rank (weight , info , global_rank , offset ):
327346 num_columns = weight .shape [1 ]
328347 num_slices = sum (info )
329348 column_per_slice = num_columns // num_slices
330349 remainder = num_columns % num_slices
331- rank = info [:global_rank ]. count ( 1 )
350+ rank = sum ( info [:global_rank ]) + offset
332351
333352 start = column_per_slice * rank + min (rank , remainder )
334353 rank += 1
335354 end = column_per_slice * rank + min (rank , remainder )
336355 return weight [:, start :end ]
337356
338- sliced_local_weights = [
339- _slice_weight_for_rank (weight , info , self .rank )
340- for weight , info in zip (local_weights , local_info )
357+ weights = [
358+ _slice_weight_for_rank (weight , info , self .rank , offset )
359+ for weight , info , offset in zip (weights , local_info , index_offset )
341360 ]
342- super ().set_weights (sliced_local_weights )
361+ # variable.assign and copy-on-write creates extra copy of weight that causes OOM
362+ # so here we scatter update by ~128M elements chunks instead of just do
363+ # super().set_weights(weights)
364+ for weight , arr in zip (self .weights , weights ):
365+ if arr .size <= chunk :
366+ weight .assign (arr )
367+ else :
368+ chunk_size_dim0 = chunk // weight .shape [1 ]
369+ num_chunks = math .ceil (weight .shape [0 ] / chunk_size_dim0 )
370+ last_size = weight .shape [0 ] - chunk_size_dim0 * (num_chunks - 1 )
371+ chunk_sizes = [chunk_size_dim0 ] * (num_chunks - 1 ) + [last_size ]
372+ for i in range (num_chunks ):
373+ start = i * chunk_size_dim0
374+ end = start + chunk_sizes [i ]
375+ indices = tf .range (start = start , limit = end , dtype = tf .int64 )
376+ update = tf .IndexedSlices (values = arr [start :end ],
377+ indices = indices ,
378+ dense_shape = weight .shape )
379+ weight .scatter_update (sparse_delta = update )
380+ del weights
381+
382+ if use_lock :
383+ for _ in range (self .world_size - self .rank ):
384+ hvd .broadcast_object (0 )
385+
386+ # 1d split that works beyond 32bit indexing limit TF support
387+ def _split_1d (self , tensor , lengths ):
388+ # choose a number close to int32 limit as maximum chunk size
389+ # This will handle tensor with size up to square of int32_max
390+ chunking_threshold = 2147483646
391+ if tensor .shape [0 ] <= chunking_threshold :
392+ return tf .split (tensor , lengths )
393+ num_chunks = math .ceil (tensor .shape [0 ] / chunking_threshold )
394+ padding_len = math .ceil (tensor .shape [0 ] / num_chunks ) * num_chunks - tensor .shape [0 ]
395+ padded_tensor = tf .concat ([tensor , tf .zeros (padding_len , tensor .dtype )], axis = 0 )
396+ tensor_list = tf .unstack (tf .reshape (padded_tensor , [num_chunks , - 1 ]))
397+ result = []
398+ for length in lengths :
399+ this_slice = []
400+ while length > 0 :
401+ if length > tensor_list [0 ].shape [0 ]:
402+ this_slice .append (tensor_list .pop (0 ))
403+ else :
404+ this_slice .append (tensor_list [0 ][:length ])
405+ tensor_list [0 ] = tensor_list [0 ][length :]
406+ length -= this_slice [- 1 ].shape [0 ]
407+ result .append (tf .concat (this_slice , axis = 0 ))
408+ return result
343409
344- def get_weights (self ):
410+ def get_weights (self , all_ranks = False ):
345411 """Returns the current weights of the layer, as NumPy arrays.
346412
347413 This override outputs global weights for all tables.
414+ Args:
415+ all_ranks (bool): If true, return weights in all ranks, otherwise only in rank 0.
416+ Default False.
348417 """
418+ # avoid copy-on-read on dense access
419+ local_weights = [read_var_no_copy (w ) for w in self .weights ]
349420 if self .world_size == 1 :
350- return [weight .numpy () for weight in self .weights ]
421+ return [w .numpy () for w in local_weights ]
422+
423+ # mpi segfault on over 32bit range index, so we gather weights chunk by chunk here
424+ # choose a number not very close to int32 limit as maximum chunk size just to be safe
425+ chunking_threshold = 2000000000
426+ num_chunks = 1
427+ for local_configs in self .strategy .local_configs_list :
428+ total_elements = sum ([c ['input_dim' ] * c ['output_dim' ] for c in local_configs ])
429+ num_chunks = max (num_chunks , math .ceil (self .world_size * total_elements / chunking_threshold ))
351430
352- # mpi segfault on large sizes so we gather weights chunk by chunk here
353- num_chunks = 8
354431 with tf .device ('CPU:0' ):
355- local_weights = tf .concat ([tf .reshape (w , [- 1 ]) for w in self . weights ], axis = 0 )
432+ local_weights = tf .concat ([tf .reshape (w , [- 1 ]) for w in local_weights ], axis = 0 )
356433 chunk_size = local_weights .shape [0 ] // num_chunks
357434 last_size = local_weights .shape [0 ] - chunk_size * (num_chunks - 1 )
358435 chunk_sizes = [chunk_size ] * (num_chunks - 1 ) + [last_size ]
359- local_weights = tf .split (local_weights , chunk_sizes )
436+ local_weights = self ._split_1d (local_weights , chunk_sizes )
437+ # communicate chunk sizes
360438 all_sizes = hvd .allgather (chunk_sizes )
361439
362440 # collect all chunks and split to reverse allgather concat
363441 chunks = []
364442 for i , w in enumerate (local_weights ):
365- chunks += tf .split (hvd .allgather (w ), all_sizes [i ::num_chunks ])
443+ w = hvd .allgather (w )
444+ if all_ranks or self .rank == 0 :
445+ chunks += self ._split_1d (w , all_sizes [i ::num_chunks ])
446+ if not chunks :
447+ return []
448+
366449 # re-construct all local weights from chunks
367450 local_weights = []
368451 for i in range (self .world_size ):
369452 local_weights .append (tf .concat (chunks [i ::self .world_size ], axis = 0 ))
453+ del chunks
454+
370455 # split flat local weights into correct sizes
371456 weights = []
372457 for local_weight , local_configs in zip (local_weights , self .strategy .local_configs_list ):
373458 local_shapes = [[c ['input_dim' ], c ['output_dim' ]] for c in local_configs ]
374459 local_sizes = [shape [0 ] * shape [1 ] for shape in local_shapes ]
375- flat_weights = tf . split (local_weight , local_sizes )
460+ flat_weights = self . _split_1d (local_weight , local_sizes )
376461 weights += [tf .reshape (weight , shape ) for weight , shape in zip (flat_weights , local_shapes )]
377462 # restore original table order
378463 # flatten self.strategy.table_ids_list
@@ -408,6 +493,7 @@ def call(self, inputs): # pylint: disable=missing-function-docstring
408493 self .local_embedding_layers [m ](inp )
409494 for m , inp in zip (self .strategy .local_input_table_map , inputs )
410495 ]
496+ outputs = [tf .cast (output , self .compute_dtype ) for output in outputs ]
411497 return outputs
412498
413499 # TODO(skyw): Revisit logics of selecting call functions for different strategy
@@ -460,7 +546,10 @@ def gradient(self, target, sources, output_gradients=None):
460546 dp_vars .append (var )
461547 dp_grads .append (grad )
462548 split_infos .append ((False , len (dp_grads ) - 1 ))
463- dp_grads = self ._allreduce_grads (dp_grads , dp_vars ) # pylint: disable=protected-access
549+ # TODO(Deyu): make sure not reusing _allreduce_grads doesn't lead to any issue
550+ dp_grads = [
551+ hvd .allreduce (g , name = f'dp_gradient_{ i } ' , op = hvd .Average ) for i , g in enumerate (dp_grads )
552+ ]
464553 # put gradients back in original order
465554 grads = []
466555 for info in split_infos :
0 commit comments