1616 */
1717
1818#include " embedding_lookup.h"
19- #include " tensorflow/core/framework/op_kernel.h"
2019#include " tensorflow/core/framework/resource_mgr.h"
2120#include " tensorflow/core/framework/resource_var.h"
2221
@@ -45,62 +44,25 @@ class ReadVariableNoCopyOp : public OpKernel {
4544 DataType dtype_;
4645};
4746
48- template <typename Device, typename T, typename Tindices>
49- class EmbeddingLookupConstantHotnessOp : public OpKernel {
50- public:
51- explicit EmbeddingLookupConstantHotnessOp (OpKernelConstruction* context) : OpKernel(context) {
52- OP_REQUIRES_OK (context, context->GetAttr (" combiner" , &_combiner));
53- }
54-
55- void Compute (OpKernelContext* context) override {
56- const Tensor& params = context->input (0 );
57- const Tensor& ids = context->input (1 );
58-
59- auto num_rows = ids.dim_size (0 );
60- auto nnz_per_row = ids.dim_size (1 );
61- auto embedding_width = params.dim_size (1 );
62-
63- TensorShape output_shape ({num_rows, embedding_width});
64- Tensor* output = nullptr ;
65- OP_REQUIRES_OK (context, context->allocate_output (0 , output_shape, &output));
66-
67- EmbeddingLookupConstantHotnessFunctor<Device, T, Tindices>()(
68- context->eigen_device <Device>(), output->flat <T>().data (), params.flat <T>().data (),
69- ids.flat <Tindices>().data (), nnz_per_row, num_rows, embedding_width,
70- StringToEnum (_combiner));
71- }
72-
73- private:
74- string _combiner;
75- };
76-
77- template <typename Device, typename T, typename Tindices>
78- class EmbeddingLookupConstantHotnessGradOp : public OpKernel {
47+ template <typename Device, typename Tindices>
48+ class RowToSplitOp : public OpKernel {
7949 public:
80- explicit EmbeddingLookupConstantHotnessGradOp (OpKernelConstruction* context) : OpKernel(context) {
81- OP_REQUIRES_OK (context, context->GetAttr (" combiner" , &_combiner));
82- }
50+ explicit RowToSplitOp (OpKernelConstruction* context) : OpKernel(context) {}
8351
8452 void Compute (OpKernelContext* context) override {
85- const Tensor& grad = context->input (0 );
86- const Tensor& ids = context->input (1 );
53+ // [n, 2]
54+ const Tensor& row = context->input (0 );
55+ auto num_ids = row.dim_size (0 );
56+ auto num_rows = context->input (1 ).scalar <int32>()();
8757
88- auto num_rows = ids.dim_size (0 );
89- auto nnz_per_row = ids.dim_size (1 );
90- auto nnz = num_rows * nnz_per_row;
91- auto embedding_width = grad.dim_size (1 );
92-
93- TensorShape output_shape ({nnz, embedding_width});
58+ TensorShape output_shape ({num_rows + 1 });
9459 Tensor* output = nullptr ;
9560 OP_REQUIRES_OK (context, context->allocate_output (0 , output_shape, &output));
9661
97- EmbeddingLookupConstantHotnessGradFunctor <Device, T, Tindices>()(
98- context-> eigen_device <Device>(), output->flat <T>(). data (), grad. flat <T >().data (),
99- nnz_per_row, num_rows, embedding_width, StringToEnum (_combiner) );
62+ RowToSplitFunctor <Device, Tindices>()(context-> eigen_device <Device>(),
63+ output->flat <Tindices >().data (),
64+ row. flat <Tindices>(). data (), num_ids, num_rows );
10065 }
101-
102- private:
103- string _combiner;
10466};
10567
10668template <typename Device, typename T, typename Tindices>
@@ -118,14 +80,17 @@ class EmbeddingLookupVariableHotnessOp : public OpKernel {
11880 auto num_rows = offsets.dim_size (0 ) - 1 ;
11981 auto embedding_width = params.dim_size (1 );
12082
83+ auto num_ids = ids.dim_size (0 );
84+ auto ave_red_len = num_ids / num_rows;
85+
12186 TensorShape output_shape ({num_rows, embedding_width});
12287 Tensor* output = nullptr ;
12388 OP_REQUIRES_OK (context, context->allocate_output (0 , output_shape, &output));
12489
12590 EmbeddingLookupVariableHotnessFunctor<Device, T, Tindices>()(
12691 context->eigen_device <Device>(), output->flat <T>().data (), params.flat <T>().data (),
12792 ids.flat <Tindices>().data (), offsets.flat <Tindices>().data (), num_rows, embedding_width,
128- StringToEnum (_combiner));
93+ StringToEnum (_combiner), ave_red_len );
12994 }
13095
13196 private:
@@ -140,21 +105,20 @@ class EmbeddingLookupVariableHotnessGradOp : public OpKernel {
140105 }
141106
142107 void Compute (OpKernelContext* context) override {
143- const Tensor& grad = context->input (0 );
144- const Tensor& ids = context->input (1 );
145- const Tensor& offsets = context->input (2 );
146-
147- auto num_rows = offsets.dim_size (0 ) - 1 ;
108+ const Tensor& ids = context->input (0 );
109+ const Tensor& offset_in = context->input (1 );
110+ const Tensor& grad = context->input (2 );
111+ const Tensor& param = context->input (3 );
112+ auto num_ids = ids.dim_size (0 );
113+ auto num_rows = offset_in.dim_size (0 ) - 1 ;
148114 auto embedding_width = grad.dim_size (1 );
149- auto nnz = ids.dim_size (0 );
150-
151- TensorShape output_shape ({nnz, embedding_width});
152- Tensor* output = nullptr ;
153- OP_REQUIRES_OK (context, context->allocate_output (0 , output_shape, &output));
115+ auto max_red_len = grad.dim_size (0 );
116+ auto dense_shape_dim0 = param.dim_size (0 );
154117
155118 EmbeddingLookupVariableHotnessGradFunctor<Device, T, Tindices>()(
156- context->eigen_device <Device>(), output->flat <T>().data (), grad.flat <T>().data (),
157- offsets.flat <Tindices>().data (), num_rows, embedding_width, StringToEnum (_combiner));
119+ context, ids.flat <Tindices>().data (), offset_in.flat <Tindices>().data (),
120+ grad.flat <T>().data (), num_ids, embedding_width, num_rows, dense_shape_dim0, max_red_len,
121+ StringToEnum (_combiner));
158122 }
159123
160124 private:
@@ -167,26 +131,21 @@ REGISTER_KERNEL_BUILDER(Name("ReadVariableNoCopy").Device(DEVICE_DEFAULT).HostMe
167131REGISTER_KERNEL_BUILDER (Name(" ReadVariableNoCopy" ).Device(DEVICE_GPU).HostMemory(" resource" ),
168132 ReadVariableNoCopyOp);
169133
170- #define REGISTER_GPU (T, Tindices ) \
171- REGISTER_KERNEL_BUILDER (Name(" EmbeddingLookupConstantHotness" ) \
172- .Device(DEVICE_GPU) \
173- .TypeConstraint<T>(" T" ) \
174- .TypeConstraint<Tindices>(" Tindices" ), \
175- EmbeddingLookupConstantHotnessOp<Eigen::GpuDevice, T, Tindices>); \
176- REGISTER_KERNEL_BUILDER (Name(" EmbeddingLookupConstantHotnessGrad" ) \
177- .Device(DEVICE_GPU) \
178- .TypeConstraint<T>(" T" ) \
179- .TypeConstraint<Tindices>(" Tindices" ), \
180- EmbeddingLookupConstantHotnessGradOp<Eigen::GpuDevice, T, Tindices>); \
181- REGISTER_KERNEL_BUILDER (Name(" EmbeddingLookupVariableHotness" ) \
182- .Device(DEVICE_GPU) \
183- .TypeConstraint<T>(" T" ) \
184- .TypeConstraint<Tindices>(" Tindices" ), \
185- EmbeddingLookupVariableHotnessOp<Eigen::GpuDevice, T, Tindices>); \
186- REGISTER_KERNEL_BUILDER (Name(" EmbeddingLookupVariableHotnessGrad" ) \
187- .Device(DEVICE_GPU) \
188- .TypeConstraint<T>(" T" ) \
189- .TypeConstraint<Tindices>(" Tindices" ), \
134+ #define REGISTER_GPU (T, Tindices ) \
135+ REGISTER_KERNEL_BUILDER (Name(" RowToSplit" ) \
136+ .Device(DEVICE_GPU) \
137+ .TypeConstraint<Tindices>(" Tindices" ) \
138+ .HostMemory(" shape" ), \
139+ RowToSplitOp<Eigen::GpuDevice, Tindices>); \
140+ REGISTER_KERNEL_BUILDER (Name(" EmbeddingLookupVariableHotness" ) \
141+ .Device(DEVICE_GPU) \
142+ .TypeConstraint<T>(" T" ) \
143+ .TypeConstraint<Tindices>(" Tindices" ), \
144+ EmbeddingLookupVariableHotnessOp<Eigen::GpuDevice, T, Tindices>); \
145+ REGISTER_KERNEL_BUILDER (Name(" EmbeddingLookupVariableHotnessGrad" ) \
146+ .Device(DEVICE_GPU) \
147+ .TypeConstraint<T>(" T" ) \
148+ .TypeConstraint<Tindices>(" Tindices" ), \
190149 EmbeddingLookupVariableHotnessGradOp<Eigen::GpuDevice, T, Tindices>);
191150
192151REGISTER_GPU (float , int64_t )
0 commit comments