Skip to content

Commit 483828e

Browse files
authored
Don't throw exceptions inside OpenMP parallel blocks (pytorch#4857)
Fixes undefined behavior: exceptions are not allowed to be thrown across OpenMP constructs.
1 parent 0844b5b commit 483828e

File tree

1 file changed

+29
-16
lines changed

1 file changed

+29
-16
lines changed

aten/src/TH/generic/THTensorMath.c

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -374,8 +374,11 @@ static ptrdiff_t THTensor_(dataOffset)(THTensor* tensor, ptrdiff_t linearIndex)
374374
return dataOffset;
375375
}
376376

377-
static int64_t THTensor_(wrapLinearIndex)(int64_t linearIndex, int64_t numel) {
377+
static void THTensor_(checkLinearIndex)(int64_t linearIndex, int64_t numel) {
378378
THArgCheck(linearIndex < numel && linearIndex >= -numel, 2, "out of range: %d out of %d", (int)linearIndex, (int)numel);
379+
}
380+
381+
static int64_t THTensor_(wrapLinearIndex)(int64_t linearIndex, int64_t numel) {
379382
return linearIndex < 0 ? linearIndex + numel : linearIndex;
380383
}
381384

@@ -389,25 +392,34 @@ void THTensor_(take)(THTensor *r_, THTensor *src, THLongTensor *index)
389392
ptrdiff_t srcElements = THTensor_(nElement)(src);
390393
real* src_data = THTensor_(data)(src);
391394
real* dst_data = THTensor_(data)(dst);
392-
393395
ptrdiff_t nIndices = THLongTensor_nElement(index);
394-
if (THTensor_(isContiguous)(src)) {
395-
ptrdiff_t i;
396-
#pragma omp parallel for if(nIndices > TH_OMP_OVERHEAD_THRESHOLD) private(i)
397-
for (i = 0; i < nIndices; i++) {
398-
int64_t linearIndex = THTensor_(wrapLinearIndex)(index_data[i], srcElements);
399-
dst_data[i] = src_data[linearIndex];
400-
}
401-
} else {
402-
ptrdiff_t i;
403-
#pragma omp parallel for if(nIndices > TH_OMP_OVERHEAD_THRESHOLD) private(i)
404-
for (i = 0; i < nIndices; i++) {
405-
int64_t linearIndex = THTensor_(wrapLinearIndex)(index_data[i], srcElements);
406-
int64_t dataOffset = THTensor_(dataOffset)(src, linearIndex);
407-
dst_data[i] = src_data[dataOffset];
396+
int isContiguous = THTensor_(isContiguous)(src);
397+
398+
// Exceptions must not be thrown across OpenMP parallel sections, so we
399+
// record the value of the invalid index and throw the exception after the
400+
// loop.
401+
int64_t invalidIdx = -1;
402+
403+
ptrdiff_t i;
404+
#pragma omp parallel for if(nIndices > TH_OMP_OVERHEAD_THRESHOLD) private(i)
405+
for (i = 0; i < nIndices; i++) {
406+
int64_t idx = index_data[i];
407+
if (idx < srcElements && idx >= -srcElements) {
408+
idx = THTensor_(wrapLinearIndex)(idx, srcElements);
409+
if (isContiguous) {
410+
dst_data[i] = src_data[idx];
411+
} else {
412+
dst_data[i] = src_data[THTensor_(dataOffset)(src, idx)];
413+
}
414+
} else {
415+
THAtomicCompareAndSwapLong(&invalidIdx, -1, idx);
408416
}
409417
}
410418

419+
if (invalidIdx >= 0) {
420+
THTensor_(checkLinearIndex)(invalidIdx, srcElements);
421+
}
422+
411423
THLongTensor_free(index);
412424
THTensor_(freeCopyTo)(dst, r_);
413425
}
@@ -424,6 +436,7 @@ void THTensor_(put)(THTensor *tensor, THLongTensor *index, THTensor *src, int ac
424436
int is_contiguous = THTensor_(isContiguous)(tensor);
425437

426438
TH_TENSOR_APPLY2(int64_t, index, real, src,
439+
THTensor_(checkLinearIndex)(*index_data, numel);
427440
int64_t linearIndex = THTensor_(wrapLinearIndex)(*index_data, numel);
428441
int64_t dataOffset = is_contiguous ? linearIndex : THTensor_(dataOffset)(tensor, linearIndex);
429442
if (accumulate) {

0 commit comments

Comments
 (0)