Skip to content

Commit 5dafd70

Browse files
skywFDecaYed
authored andcommitted
Support custom keras layer in DistributedEmbedding wrapper.
Switched submodule to thrust to avoid build issue.
1 parent a581613 commit 5dafd70

File tree

7 files changed

+62
-16
lines changed

7 files changed

+62
-16
lines changed

.gitmodules

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
[submodule "third_party/cub"]
2-
path = third_party/cub
3-
url = https://github.com/NVIDIA/cub.git
1+
[submodule "third_party/thrust"]
2+
path = third_party/thrust
3+
url = https://github.com/NVIDIA/thrust.git

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ TARGET_LIB = distributed_embeddings/python/ops/_embedding_lookup_ops.so
4040
all: $(TARGET_LIB)
4141

4242
%_kernels.cu.o: distributed_embeddings/cc/kernels/%_kernels.cu distributed_embeddings/cc/kernels/%.h
43-
$(NVCC) -c -o $@ $< -Ithird_party/cub $(CFLAGS) -I. -DGOOGLE_CUDA=1 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -x cu -Xcompiler -fPIC --expt-relaxed-constexpr
43+
$(NVCC) -c -o $@ $< -Ithird_party/thrust/dependencies/cub $(CFLAGS) -I. -DGOOGLE_CUDA=1 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -x cu -Xcompiler -fPIC --expt-relaxed-constexpr
4444

4545
%_kernels.cc.o: distributed_embeddings/cc/kernels/%_kernels.cc distributed_embeddings/cc/kernels/%.h
4646
$(CXX) -c -o $@ $< $(CFLAGS) -Wall -fPIC -I/usr/local/cuda/include

distributed_embeddings/python/layers/dist_model_parallel.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ def __init__(self,
7474
# column_slice can be used to enable more table concat, so keep it in single process
7575
self.column_slice_threshold = column_slice_threshold
7676
self.global_configs = [e.get_config() for e in embeddings]
77+
# Insert layer type information to config dicts
78+
for config, embedding in zip(self.global_configs, embeddings):
79+
config['layer_type'] = type(embedding)
7780
if input_table_map is None:
7881
input_table_map = list(range(len(embeddings)))
7982

@@ -274,8 +277,10 @@ def _create_concat(self, table_configs, input_maps):
274277
for concat_config in concat_configs:
275278
input_dims = concat_config.pop('input_dims')
276279
if len(input_dims) > 1:
277-
orig_initializer = initializers.deserialize(concat_config['embeddings_initializer'])
278-
concat_config['embeddings_initializer'] = ConcatInitializer(orig_initializer, input_dims)
280+
# TODO(deyuf): custom layer without initializer will be concat but init is not wrapped
281+
if 'embeddings_initializer' in concat_config:
282+
orig_initializer = initializers.deserialize(concat_config['embeddings_initializer'])
283+
concat_config['embeddings_initializer'] = ConcatInitializer(orig_initializer, input_dims)
279284

280285
# record weight offsets for get/set.
281286
weight_offsets = [concat_config.pop('offsets', None) for concat_config in concat_configs]
@@ -363,8 +368,12 @@ def __init__(self,
363368
# create local embeddings
364369
self.local_embedding_layers = []
365370
for config in self.strategy.local_configs[self.rank]:
366-
config['synchronization'] = tf.VariableSynchronization.NONE
367-
self.local_embedding_layers.append(Embedding.from_config(config))
371+
layer_type = config.pop('layer_type')
372+
# For stock keras Embedding, we switch underlying layer for better performance
373+
# If inputs are custom layers, original layer will be used
374+
# TODO(deyuf): Check functionality coverage, add fallback or type picking api
375+
layer_type = Embedding if layer_type == tf.keras.layers.Embedding else layer_type
376+
self.local_embedding_layers.append(layer_type.from_config(config))
368377
self.offsets = [
369378
None if offset == 0 else tf.constant([offset], dtype=tf.int64)
370379
for offset in self.strategy.local_input_offsets[self.rank]
@@ -651,6 +660,11 @@ def build(self, input_shape):
651660
F"Global batchsize {batch_sizes[0]} not divisible workers count {self.world_size}.")
652661
for layer in self.local_embedding_layers:
653662
layer.build(input_shape[0] if input_shape else None)
663+
for var in layer.trainable_weights:
664+
# Mark local(model parallel) variable. use prefix de(distributed embeddings) to avoid conflicts.
665+
var.de_local = True
666+
# set built flag to prevent above build trigger again and above flag fall off
667+
layer.built = True
654668
self.built = True
655669

656670
def call(self, inputs): # pylint: disable=missing-function-docstring
@@ -671,7 +685,7 @@ def broadcast_variables(model_vars, root_rank=0): # pylint: disable=missing-any
671685
dp_vars = []
672686
mp_vars = []
673687
for var in model_vars:
674-
if var.synchronization == tf.VariableSynchronization.NONE:
688+
if hasattr(var, 'de_local'):
675689
mp_vars.append(var)
676690
else:
677691
dp_vars.append(var)
@@ -693,7 +707,7 @@ def gradient(self, target, sources, output_gradients=None):
693707
mp_grads = []
694708
split_infos = []
695709
for grad, var in zip(gradients, sources):
696-
if var.synchronization == tf.VariableSynchronization.NONE:
710+
if hasattr(var, 'de_local'):
697711
if isinstance(grad, tf.IndexedSlices):
698712
mp_grads.append(tf.IndexedSlices(grad.values / hvd.size(), grad.indices,
699713
grad.dense_shape))

distributed_embeddings/python/layers/dist_model_parallel_test.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,28 @@
2121
import horovod.tensorflow as hvd
2222
from distributed_embeddings.python.layers import dist_model_parallel as dmp
2323

24+
2425
# There are some functions in TF that pylint can't inspect correctly which leads to incorrect
2526
# report of unexpected-keyword-arg, no-value-for-parameter. Disable them globally here
2627
# pylint: disable=no-self-use,unexpected-keyword-arg,no-value-for-parameter,missing-docstring
28+
class CustomEmbedding(tf.keras.layers.Layer):
29+
30+
def __init__(self, input_dim, output_dim, **kwargs):
31+
super().__init__(**kwargs)
32+
self.input_dim = input_dim
33+
self.output_dim = output_dim
34+
35+
def build(self, _):
36+
self.params = self.add_weight("params",
37+
shape=[self.input_dim, self.output_dim],
38+
dtype=tf.float32)
39+
40+
def call(self, inputs):
41+
return tf.gather(params=self.params, indices=inputs, axis=None)
42+
43+
def get_config(self):
44+
config = {'input_dim': self.input_dim, 'output_dim': self.output_dim}
45+
return config
2746

2847

2948
class EmbeddingListModel(tf.keras.Model):
@@ -35,11 +54,15 @@ def __init__(self,
3554
strategy='basic',
3655
dp_input=True,
3756
input_table_map=None,
38-
column_slice_threshold=None):
57+
column_slice_threshold=None,
58+
test_custom_layer=False):
3959
super().__init__()
4060
self.embeddings = []
4161
for size in table_sizes:
42-
self.embeddings.append(tf.keras.layers.Embedding(*size))
62+
if test_custom_layer:
63+
self.embeddings.append(CustomEmbedding(*size))
64+
else:
65+
self.embeddings.append(tf.keras.layers.Embedding(*size))
4366
if distribute:
4467
self.dist_embeddings = dmp.DistributedEmbedding(self.embeddings,
4568
strategy=strategy,
@@ -386,6 +409,18 @@ def test_fewer_tables_than_workers(self):
386409
dp_inputs, _ = self.gen_inputs(table_sizes)
387410
self.run_and_test(ref_model, dp_inputs, test_model, dp_inputs)
388411

412+
def test_custom_embedding_layer(self):
413+
table_sizes = self.gen_table_sizes()
414+
415+
ref_model = EmbeddingListModel(table_sizes, distribute=False, test_custom_layer=True)
416+
test_model = EmbeddingListModel(table_sizes,
417+
distribute=True,
418+
strategy='basic',
419+
test_custom_layer=True)
420+
421+
dp_inputs, _ = self.gen_inputs(table_sizes)
422+
self.run_and_test(ref_model, dp_inputs, test_model, dp_inputs)
423+
389424

390425
if __name__ == "__main__":
391426
test.main()

distributed_embeddings/python/layers/embedding.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def __init__(self,
6767
activity_regularizer=None,
6868
embeddings_constraint=None,
6969
combiner=None,
70-
synchronization=tf.VariableSynchronization.AUTO,
7170
**kwargs):
7271
if 'input_shape' not in kwargs:
7372
kwargs['input_shape'] = (None,)
@@ -89,7 +88,6 @@ def __init__(self,
8988
self.activity_regularizer = regularizers.get(activity_regularizer)
9089
self.embeddings_constraint = constraints.get(embeddings_constraint)
9190
self.combiner = combiner
92-
self.synchronization = synchronization
9391

9492
@tf_utils.shape_type_conversion
9593
def build(self, input_shape): # pylint: disable=unused-argument
@@ -98,7 +96,6 @@ def build(self, input_shape): # pylint: disable=unused-argument
9896
name='embeddings',
9997
regularizer=self.embeddings_regularizer,
10098
constraint=self.embeddings_constraint,
101-
synchronization=self.synchronization,
10299
experimental_autocast=False)
103100
self.built = True
104101

third_party/cub

Lines changed: 0 additions & 1 deletion
This file was deleted.

third_party/thrust

Submodule thrust added at 65fbe23

0 commit comments

Comments
 (0)