Skip to content

Commit aa80f05

Browse files
zasdfgbnmfacebook-github-bot
authored andcommitted
Remove sync in Embedding caused by unique (pytorch#66091)
Summary: Pull Request resolved: pytorch#66091 Reviewed By: albanD Differential Revision: D31385576 Pulled By: ngimel fbshipit-source-id: e656d4d9c38b705c71853ca295f977d1cddc61a1
1 parent 1932bc6 commit aa80f05

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

aten/src/ATen/native/cuda/Embedding.cu

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,11 @@ template <typename scalar_t, typename accscalar_t, typename index_t>
179179
__global__ void renorm_kernel(
180180
scalar_t* weights, index_t* indices, accscalar_t max_norm,
181181
accscalar_t norm_type, int64_t dim,
182-
int64_t weights_stride0, int64_t weights_stride1) {
182+
int64_t weights_stride0, int64_t weights_stride1,
183+
int64_t *num_unique_indices) {
184+
if (blockIdx.x >= *num_unique_indices) {
185+
return;
186+
}
183187

184188
// Some casting hacks since dynamic shared memory and templates don't work together:
185189
extern __shared__ unsigned char smem[];
@@ -315,7 +319,8 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices,
315319
static_assert(num_threads % C10_WARP_SIZE == 0 &&
316320
num_threads <= cuda_utils::kCUDABlockReduceMaxThreads,
317321
"BlockReduceSum requires all warps be active");
318-
dim3 grid = num_unique_indices.item<int64_t>();
322+
int64_t *num_unique_indices_ptr = num_unique_indices.data_ptr<int64_t>();
323+
dim3 grid = unique_indices.numel();
319324
dim3 block = num_threads;
320325
int dim = self.stride(0);
321326

@@ -326,7 +331,8 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices,
326331
unique_indices.data_ptr<index_t>(),
327332
static_cast<accscalar_t>(max_norm),
328333
static_cast<accscalar_t>(norm_type),
329-
dim, self.stride(0), self.stride(1));
334+
dim, self.stride(0), self.stride(1),
335+
num_unique_indices_ptr);
330336
C10_CUDA_KERNEL_LAUNCH_CHECK();
331337
});
332338
});

0 commit comments

Comments
 (0)