diff --git a/mmdetection3d/mmdet3d/ops/bev_pool/src/bev_pool_cuda.cu b/mmdetection3d/mmdet3d/ops/bev_pool/src/bev_pool_cuda.cu index 9ae3b28..9ae23e4 100644 --- a/mmdetection3d/mmdet3d/ops/bev_pool/src/bev_pool_cuda.cu +++ b/mmdetection3d/mmdet3d/ops/bev_pool/src/bev_pool_cuda.cu @@ -29,7 +29,7 @@ __global__ void bev_pool_kernel(int b, int d, int h, int w, int n, int c, int n_ if (index >= n_intervals) return; int interval_start = interval_starts[index]; int interval_length = interval_lengths[index]; - const int* cur_geom_feats = geom_feats + interval_start * 4; + const int* cur_geom_feats = geom_feats + (interval_start + interval_length - 1) * 4; const float* cur_x = x + interval_start * c + cur_c; float* cur_out = out + cur_geom_feats[3] * d * h * w * c + cur_geom_feats[2] * h * w * c + cur_geom_feats[0] * w * c + @@ -71,7 +71,7 @@ __global__ void bev_pool_grad_kernel(int b, int d, int h, int w, int n, int c, i int interval_start = interval_starts[index]; int interval_length = interval_lengths[index]; - const int* cur_geom_feats = geom_feats + interval_start * 4; + const int* cur_geom_feats = geom_feats + (interval_start + interval_length - 1) * 4; float* cur_x_grad = x_grad + interval_start * c + cur_c; const float* cur_out_grad = out_grad + cur_geom_feats[3] * d * h * w * c +