Skip to content

Commit d40c589

Browse files
committed
version 0.1 updates
1 parent 4e74e34 commit d40c589

File tree

7 files changed

+171
-39
lines changed

7 files changed

+171
-39
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ See more details at [User Guide](https://nvidia-merlin.github.io/distributed-emb
1818

1919
## Installation
2020
### Requirements
21-
Python 3, CUDA 11 or newer, TensorFlow 2.6.0 or newer
21+
Python 3, CUDA 11 or newer, TensorFlow 2
2222
### Containers ###
2323
You can build inside 22.03 or later NGC TF2 [image](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tensorflow):
2424
```bash

build_pip_pkg.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ echo "=== Copy TensorFlow Custom op files"
1111
cp setup.py "${TMPDIR}"
1212
cp MANIFEST.in "${TMPDIR}"
1313
cp requirements.txt "${TMPDIR}"
14+
cp version.txt "${TMPDIR}"
1415
rsync -avm -L --exclude='*_test.py' distributed_embeddings "${TMPDIR}"
1516

1617
pushd ${TMPDIR}

distributed_embeddings/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@
1515
"""Distributed embedding API."""
1616

1717
from distributed_embeddings.python.ops.embedding_lookup_ops import embedding_lookup
18+
from .version import __version__

distributed_embeddings/python/layers/dist_model_parallel.py

Lines changed: 125 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,12 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
"""Distributed Embedding layers and utils"""
16+
import math
17+
import numpy as np
1618
import tensorflow as tf
1719
from tensorflow.python.keras.utils import tf_utils
1820
import horovod.tensorflow as hvd
21+
from distributed_embeddings.python.ops.embedding_lookup_ops import read_var_no_copy
1922
from .embedding import Embedding
2023

2124

@@ -48,19 +51,21 @@ def __init__(self,
4851
self.input_ids_list = [list(range(len(input_table_map)))]
4952
self.table_ids_list = [list(range(len(embeddings)))]
5053
return
54+
5155
# Create (maybe) sliced configs
5256
sliced_configs, self.sliced_out_ranges = self.create_sliced_configs(
5357
world_size, column_slice_threshold, input_table_map)
5458
# Apply strategy and save nested list containing table indices by rank
5559
self.table_ids_list = self.apply_stragety(strategy, world_size, sliced_configs)
56-
# Nested list to split embedding output from each rank into tables
57-
self.widths_list = []
60+
5861
# Nested list containing input indices by rank
5962
self.input_ids_list = []
6063
# Nested list containing local input to local table map by rank
6164
self.local_map_list = []
6265
# Nested list containing local configs by rank
6366
self.local_configs_list = []
67+
# All of local widths ordered by rank flat into single list
68+
self.widths_list_flat = []
6469
# Each worker loop over all rank to get global view of strategy
6570
for rank_table_ids in self.table_ids_list:
6671
# calculate stats needed for each rank
@@ -73,20 +78,22 @@ def __init__(self,
7378
rank_input_ids.append(k)
7479
rank_input_map.append(m)
7580
self.local_configs_list.append(rank_configs)
76-
self.widths_list.append(rank_widths)
81+
self.widths_list_flat += rank_widths
7782
self.input_ids_list.append(rank_input_ids)
7883
self.local_map_list.append(rank_input_map)
79-
# List of total embedding widths to split embedding output by rank after alltoall
80-
self.total_local_widths = [sum(widths) for widths in self.widths_list]
84+
8185
# List that maps local inputs to local table
8286
self.local_input_table_map = self.local_map_list[rank]
87+
8388
# flatten self.input_ids_list
8489
worker_order_input_ids = [item for sublist in self.input_ids_list for item in sublist]
90+
8591
# List of indices to shuffle worker ordered embedding outputs back to original order
8692
self.rev_global_input_ids = [
8793
index
8894
for _, index in sorted(zip(worker_order_input_ids, range(len(worker_order_input_ids))))
8995
]
96+
9097
# List of configs to create local embedding layers
9198
self.local_configs = self.local_configs_list[rank]
9299

@@ -286,18 +293,17 @@ def _call_base(self, inputs): # pylint: disable=missing-param-doc,missing-type-
286293
for m, inp in zip(self.strategy.local_input_table_map, inputs)
287294
]
288295

289-
# concat last axis to make all2all slice correct, and reshape to make later split easier
290296
# TODO(Deyu): current assume 2D with same batch for all output, ideally should support general case
291-
local_bs = inputs[0].shape[0] // self.world_size
292-
mp_outs = tf.reshape(tf.concat(mp_outs, axis=-1), [-1, local_bs])
297+
mp_outs = [tf.reshape(mp_out, [self.world_size, -1]) for mp_out in mp_outs]
298+
mp_outs = tf.reshape(tf.concat(mp_outs, axis=1), [-1])
299+
# cast before alltoall according to dtype policy
300+
mp_outs = tf.cast(mp_outs, self.compute_dtype)
293301
dp_outs = hvd.alltoall(mp_outs, name='out_mp_to_dp')
294-
dp_outs = [
295-
tf.reshape(t, [local_bs, -1]) for t in tf.split(dp_outs, self.strategy.total_local_widths)
296-
]
297-
# split each worker result and re-order using id
298-
worker_order_res = []
299-
for dp_out, widths in zip(dp_outs, self.strategy.widths_list):
300-
worker_order_res += tf.split(dp_out, widths, 1)
302+
local_bs = inputs[0].shape[0] // self.world_size
303+
num_elements = [local_bs * item for item in self.strategy.widths_list_flat]
304+
split_outs = tf.split(dp_outs, num_elements)
305+
worker_order_res = [tf.reshape(split_out, [local_bs, -1]) for split_out in split_outs]
306+
301307
# reorder outputs to be same as inputs order
302308
result = [worker_order_res[index] for index in self.strategy.rev_global_input_ids]
303309
return result
@@ -309,70 +315,149 @@ def _concat_column_slice_outputs(self, outs):
309315
outs[start:end] = [tf.concat(outs[start:end], axis=-1)]
310316
return outs
311317

312-
def set_weights(self, weights): # pylint: disable=missing-param-doc,missing-type-doc
318+
def set_weights(self, weights, chunk=134217728, use_lock=False):
313319
"""Sets the weights of the layer, from NumPy arrays.
314320
315-
This override expects global weights for all tables as input.
321+
Args:
322+
weights (list): list containing global weights for all table.
323+
item in the list can be either numpy array or file path to load from.
324+
chunk (int): max number of elements per chunk when set weight on GPU by chunks.
325+
this will be round to number of rows base on weight shape.
326+
use_lock (bool): If true, set weights rank by rank in lock step to avoid OOM. Default False.
316327
"""
317-
if self.world_size == 1:
318-
sliced_local_weights = weights
319-
else:
328+
if use_lock:
329+
for _ in range(self.rank):
330+
hvd.broadcast_object(0)
331+
332+
if self.world_size > 1:
320333
slice_info = [[rank_tids.count(tid)
321334
for rank_tids in self.strategy.table_ids_list]
322335
for tid in range(len(weights))]
323-
local_weights = [weights[index] for index in self.strategy.table_ids_list[self.rank]]
336+
weights = [weights[index] for index in self.strategy.table_ids_list[self.rank]]
337+
if isinstance(weights[0], str):
338+
weights = [np.load(file=path, mmap_mode='r') for path in weights]
324339
local_info = [slice_info[index] for index in self.strategy.table_ids_list[self.rank]]
340+
# array to handle multiple slice into same table case
341+
# TODO(Deyu): avoid this by merge those table again after find strategy
342+
rank_ids = self.strategy.table_ids_list[self.rank]
343+
index_offset = [rank_ids[:i].count(rank_id) for i, rank_id in enumerate(rank_ids)]
325344

326-
def _slice_weight_for_rank(weight, info, global_rank):
345+
def _slice_weight_for_rank(weight, info, global_rank, offset):
327346
num_columns = weight.shape[1]
328347
num_slices = sum(info)
329348
column_per_slice = num_columns // num_slices
330349
remainder = num_columns % num_slices
331-
rank = info[:global_rank].count(1)
350+
rank = sum(info[:global_rank]) + offset
332351

333352
start = column_per_slice * rank + min(rank, remainder)
334353
rank += 1
335354
end = column_per_slice * rank + min(rank, remainder)
336355
return weight[:, start:end]
337356

338-
sliced_local_weights = [
339-
_slice_weight_for_rank(weight, info, self.rank)
340-
for weight, info in zip(local_weights, local_info)
357+
weights = [
358+
_slice_weight_for_rank(weight, info, self.rank, offset)
359+
for weight, info, offset in zip(weights, local_info, index_offset)
341360
]
342-
super().set_weights(sliced_local_weights)
361+
# variable.assign and copy-on-write creates extra copy of weight that causes OOM
362+
# so here we scatter update by ~128M elements chunks instead of just do
363+
# super().set_weights(weights)
364+
for weight, arr in zip(self.weights, weights):
365+
if arr.size <= chunk:
366+
weight.assign(arr)
367+
else:
368+
chunk_size_dim0 = chunk // weight.shape[1]
369+
num_chunks = math.ceil(weight.shape[0] / chunk_size_dim0)
370+
last_size = weight.shape[0] - chunk_size_dim0 * (num_chunks - 1)
371+
chunk_sizes = [chunk_size_dim0] * (num_chunks - 1) + [last_size]
372+
for i in range(num_chunks):
373+
start = i * chunk_size_dim0
374+
end = start + chunk_sizes[i]
375+
indices = tf.range(start=start, limit=end, dtype=tf.int64)
376+
update = tf.IndexedSlices(values=arr[start:end],
377+
indices=indices,
378+
dense_shape=weight.shape)
379+
weight.scatter_update(sparse_delta=update)
380+
del weights
381+
382+
if use_lock:
383+
for _ in range(self.world_size - self.rank):
384+
hvd.broadcast_object(0)
385+
386+
# 1d split that works beyond 32bit indexing limit TF support
387+
def _split_1d(self, tensor, lengths):
388+
# choose a number close to int32 limit as maximum chunk size
389+
# This will handle tensor with size up to square of int32_max
390+
chunking_threshold = 2147483646
391+
if tensor.shape[0] <= chunking_threshold:
392+
return tf.split(tensor, lengths)
393+
num_chunks = math.ceil(tensor.shape[0] / chunking_threshold)
394+
padding_len = math.ceil(tensor.shape[0] / num_chunks) * num_chunks - tensor.shape[0]
395+
padded_tensor = tf.concat([tensor, tf.zeros(padding_len, tensor.dtype)], axis=0)
396+
tensor_list = tf.unstack(tf.reshape(padded_tensor, [num_chunks, -1]))
397+
result = []
398+
for length in lengths:
399+
this_slice = []
400+
while length > 0:
401+
if length > tensor_list[0].shape[0]:
402+
this_slice.append(tensor_list.pop(0))
403+
else:
404+
this_slice.append(tensor_list[0][:length])
405+
tensor_list[0] = tensor_list[0][length:]
406+
length -= this_slice[-1].shape[0]
407+
result.append(tf.concat(this_slice, axis=0))
408+
return result
343409

344-
def get_weights(self):
410+
def get_weights(self, all_ranks=False):
345411
"""Returns the current weights of the layer, as NumPy arrays.
346412
347413
This override outputs global weights for all tables.
414+
Args:
415+
all_ranks (bool): If true, return weights in all ranks, otherwise only in rank 0.
416+
Default False.
348417
"""
418+
# avoid copy-on-read on dense access
419+
local_weights = [read_var_no_copy(w) for w in self.weights]
349420
if self.world_size == 1:
350-
return [weight.numpy() for weight in self.weights]
421+
return [w.numpy() for w in local_weights]
422+
423+
# mpi segfault on over 32bit range index, so we gather weights chunk by chunk here
424+
# choose a number not very close to int32 limit as maximum chunk size just to be safe
425+
chunking_threshold = 2000000000
426+
num_chunks = 1
427+
for local_configs in self.strategy.local_configs_list:
428+
total_elements = sum([c['input_dim'] * c['output_dim'] for c in local_configs])
429+
num_chunks = max(num_chunks, math.ceil(self.world_size * total_elements / chunking_threshold))
351430

352-
# mpi segfault on large sizes so we gather weights chunk by chunk here
353-
num_chunks = 8
354431
with tf.device('CPU:0'):
355-
local_weights = tf.concat([tf.reshape(w, [-1]) for w in self.weights], axis=0)
432+
local_weights = tf.concat([tf.reshape(w, [-1]) for w in local_weights], axis=0)
356433
chunk_size = local_weights.shape[0] // num_chunks
357434
last_size = local_weights.shape[0] - chunk_size * (num_chunks - 1)
358435
chunk_sizes = [chunk_size] * (num_chunks - 1) + [last_size]
359-
local_weights = tf.split(local_weights, chunk_sizes)
436+
local_weights = self._split_1d(local_weights, chunk_sizes)
437+
# communicate chunk sizes
360438
all_sizes = hvd.allgather(chunk_sizes)
361439

362440
# collect all chunks and split to reverse allgather concat
363441
chunks = []
364442
for i, w in enumerate(local_weights):
365-
chunks += tf.split(hvd.allgather(w), all_sizes[i::num_chunks])
443+
w = hvd.allgather(w)
444+
if all_ranks or self.rank == 0:
445+
chunks += self._split_1d(w, all_sizes[i::num_chunks])
446+
if not chunks:
447+
return []
448+
366449
# re-construct all local weights from chunks
367450
local_weights = []
368451
for i in range(self.world_size):
369452
local_weights.append(tf.concat(chunks[i::self.world_size], axis=0))
453+
del chunks
454+
370455
# split flat local weights into correct sizes
371456
weights = []
372457
for local_weight, local_configs in zip(local_weights, self.strategy.local_configs_list):
373458
local_shapes = [[c['input_dim'], c['output_dim']] for c in local_configs]
374459
local_sizes = [shape[0] * shape[1] for shape in local_shapes]
375-
flat_weights = tf.split(local_weight, local_sizes)
460+
flat_weights = self._split_1d(local_weight, local_sizes)
376461
weights += [tf.reshape(weight, shape) for weight, shape in zip(flat_weights, local_shapes)]
377462
# restore original table order
378463
# flatten self.strategy.table_ids_list
@@ -408,6 +493,7 @@ def call(self, inputs): # pylint: disable=missing-function-docstring
408493
self.local_embedding_layers[m](inp)
409494
for m, inp in zip(self.strategy.local_input_table_map, inputs)
410495
]
496+
outputs = [tf.cast(output, self.compute_dtype) for output in outputs]
411497
return outputs
412498

413499
# TODO(skyw): Revisit logics of selecting call functions for different strategy
@@ -460,7 +546,10 @@ def gradient(self, target, sources, output_gradients=None):
460546
dp_vars.append(var)
461547
dp_grads.append(grad)
462548
split_infos.append((False, len(dp_grads) - 1))
463-
dp_grads = self._allreduce_grads(dp_grads, dp_vars) # pylint: disable=protected-access
549+
# TODO(Deyu): make sure not reusing _allreduce_grads doesn't lead to any issue
550+
dp_grads = [
551+
hvd.allreduce(g, name=f'dp_gradient_{i}', op=hvd.Average) for i, g in enumerate(dp_grads)
552+
]
464553
# put gradients back in original order
465554
grads = []
466555
for info in split_infos:

distributed_embeddings/python/layers/dist_model_parallel_test.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def run_and_test(self, ref_model, ref_inputs, test_model, test_inputs):
158158
optimizer.apply_gradients(zip(ref_grads, ref_model.variables))
159159
optimizer.apply_gradients(zip(test_grads, test_model.variables))
160160
ref_weights = ref_model.get_weights()
161-
test_weights = test_model.dist_embeddings.get_weights() + test_model.dense.get_weights()
161+
test_weights = test_model.dist_embeddings.get_weights(True) + test_model.dense.get_weights()
162162

163163
for ref_w, test_w in zip(ref_weights, test_weights):
164164
# assert close here since order of accumulations(inputs and batch dim) might have changed
@@ -269,6 +269,18 @@ def test_column_slice_threshold(self):
269269
dp_inputs, _ = self.gen_inputs(table_sizes)
270270
self.run_and_test(ref_model, dp_inputs, test_model, dp_inputs)
271271

272+
def test_column_slice_dup_worker(self):
273+
table_sizes = [[10, 4], [11, 2], [4, 2], [4, 2]]
274+
ref_model = EmbeddingListModel(table_sizes, distribute=False)
275+
test_model = EmbeddingListModel(table_sizes,
276+
distribute=True,
277+
strategy='memory_balanced',
278+
dp_input=False,
279+
column_slice_threshold=10)
280+
mp_input_ids = test_model.dist_embeddings.strategy.input_ids_list[self.hvd_rank]
281+
dp_inputs, mp_inputs = self.gen_inputs(table_sizes, mp_input_ids=mp_input_ids)
282+
self.run_and_test(ref_model, dp_inputs, test_model, mp_inputs)
283+
272284

273285
if __name__ == "__main__":
274286
test.main()

setup.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,44 @@
1414
# limitations under the License.
1515
"""Simple setup script"""
1616

17+
import os
1718
from setuptools import setup, find_packages
1819

20+
abspath = os.path.dirname(os.path.realpath(__file__))
21+
1922
with open("requirements.txt", encoding='utf-8') as f:
2023
requirements = f.read().splitlines() # pylint: disable=invalid-name
2124

2225
print(find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]))
2326

27+
license_header = """#
28+
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
29+
# SPDX-License-Identifier: Apache-2.0
30+
#
31+
# Licensed under the Apache License, Version 2.0 (the "License");
32+
# you may not use this file except in compliance with the License.
33+
# You may obtain a copy of the License at
34+
#
35+
# http://www.apache.org/licenses/LICENSE-2.0
36+
#
37+
# Unless required by applicable law or agreed to in writing, software
38+
# distributed under the License is distributed on an "AS IS" BASIS,
39+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
40+
# See the License for the specific language governing permissions and
41+
# limitations under the License.
42+
#
43+
"""
44+
45+
# Generate version file
46+
with open(os.path.join(abspath, "version.txt"), encoding="utf-8") as f:
47+
version = f.read().strip()
48+
with open(os.path.join(abspath, "distributed_embeddings/version.py"), "w", encoding="utf-8") as f:
49+
f.write(license_header)
50+
f.write(F"__version__ = \"{version}\"")
51+
2452
setup(
2553
name="distributed-embeddings",
26-
version="1.0.0",
54+
version=version,
2755
description="Distributed Embedding",
2856
packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
2957
install_requires=requirements,

version.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
0.1.0

0 commit comments

Comments
 (0)