@@ -179,7 +179,11 @@ template <typename scalar_t, typename accscalar_t, typename index_t>
179179__global__ void renorm_kernel (
180180 scalar_t * weights, index_t * indices, accscalar_t max_norm,
181181 accscalar_t norm_type, int64_t dim,
182- int64_t weights_stride0, int64_t weights_stride1) {
182+ int64_t weights_stride0, int64_t weights_stride1,
183+ int64_t *num_unique_indices) {
184+ if (blockIdx .x >= *num_unique_indices) {
185+ return ;
186+ }
183187
184188 // Some casting hacks since dynamic shared memory and templates don't work together:
185189 extern __shared__ unsigned char smem[];
@@ -315,7 +319,8 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices,
315319 static_assert (num_threads % C10_WARP_SIZE == 0 &&
316320 num_threads <= cuda_utils::kCUDABlockReduceMaxThreads ,
317321 " BlockReduceSum requires all warps be active" );
318- dim3 grid = num_unique_indices.item <int64_t >();
322+ int64_t *num_unique_indices_ptr = num_unique_indices.data_ptr <int64_t >();
323+ dim3 grid = unique_indices.numel ();
319324 dim3 block = num_threads;
320325 int dim = self.stride (0 );
321326
@@ -326,7 +331,8 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices,
326331 unique_indices.data_ptr <index_t >(),
327332 static_cast <accscalar_t >(max_norm),
328333 static_cast <accscalar_t >(norm_type),
329- dim, self.stride (0 ), self.stride (1 ));
334+ dim, self.stride (0 ), self.stride (1 ),
335+ num_unique_indices_ptr);
330336 C10_CUDA_KERNEL_LAUNCH_CHECK ();
331337 });
332338 });
0 commit comments