@@ -374,8 +374,11 @@ static ptrdiff_t THTensor_(dataOffset)(THTensor* tensor, ptrdiff_t linearIndex)
374
374
return dataOffset ;
375
375
}
376
376
377
- static int64_t THTensor_ (wrapLinearIndex )(int64_t linearIndex , int64_t numel ) {
377
+ static void THTensor_ (checkLinearIndex )(int64_t linearIndex , int64_t numel ) {
378
378
THArgCheck (linearIndex < numel && linearIndex >= - numel , 2 , "out of range: %d out of %d" , (int )linearIndex , (int )numel );
379
+ }
380
+
381
+ static int64_t THTensor_ (wrapLinearIndex )(int64_t linearIndex , int64_t numel ) {
379
382
return linearIndex < 0 ? linearIndex + numel : linearIndex ;
380
383
}
381
384
@@ -389,25 +392,34 @@ void THTensor_(take)(THTensor *r_, THTensor *src, THLongTensor *index)
389
392
ptrdiff_t srcElements = THTensor_ (nElement )(src );
390
393
real * src_data = THTensor_ (data )(src );
391
394
real * dst_data = THTensor_ (data )(dst );
392
-
393
395
ptrdiff_t nIndices = THLongTensor_nElement (index );
394
- if (THTensor_ (isContiguous )(src )) {
395
- ptrdiff_t i ;
396
- #pragma omp parallel for if(nIndices > TH_OMP_OVERHEAD_THRESHOLD) private(i)
397
- for (i = 0 ; i < nIndices ; i ++ ) {
398
- int64_t linearIndex = THTensor_ (wrapLinearIndex )(index_data [i ], srcElements );
399
- dst_data [i ] = src_data [linearIndex ];
400
- }
401
- } else {
402
- ptrdiff_t i ;
403
- #pragma omp parallel for if(nIndices > TH_OMP_OVERHEAD_THRESHOLD) private(i)
404
- for (i = 0 ; i < nIndices ; i ++ ) {
405
- int64_t linearIndex = THTensor_ (wrapLinearIndex )(index_data [i ], srcElements );
406
- int64_t dataOffset = THTensor_ (dataOffset )(src , linearIndex );
407
- dst_data [i ] = src_data [dataOffset ];
396
+ int isContiguous = THTensor_ (isContiguous )(src );
397
+
398
+ // Exceptions must not be thrown across OpenMP parallel sections, so we
399
+ // record the value of the invalid index and throw the exception after the
400
+ // loop.
401
+ int64_t invalidIdx = -1 ;
402
+
403
+ ptrdiff_t i ;
404
+ #pragma omp parallel for if(nIndices > TH_OMP_OVERHEAD_THRESHOLD) private(i)
405
+ for (i = 0 ; i < nIndices ; i ++ ) {
406
+ int64_t idx = index_data [i ];
407
+ if (idx < srcElements && idx >= - srcElements ) {
408
+ idx = THTensor_ (wrapLinearIndex )(idx , srcElements );
409
+ if (isContiguous ) {
410
+ dst_data [i ] = src_data [idx ];
411
+ } else {
412
+ dst_data [i ] = src_data [THTensor_ (dataOffset )(src , idx )];
413
+ }
414
+ } else {
415
+ THAtomicCompareAndSwapLong (& invalidIdx , -1 , idx );
408
416
}
409
417
}
410
418
419
+ if (invalidIdx >= 0 ) {
420
+ THTensor_ (checkLinearIndex )(invalidIdx , srcElements );
421
+ }
422
+
411
423
THLongTensor_free (index );
412
424
THTensor_ (freeCopyTo )(dst , r_ );
413
425
}
@@ -424,6 +436,7 @@ void THTensor_(put)(THTensor *tensor, THLongTensor *index, THTensor *src, int ac
424
436
int is_contiguous = THTensor_ (isContiguous )(tensor );
425
437
426
438
TH_TENSOR_APPLY2 (int64_t , index , real , src ,
439
+ THTensor_ (checkLinearIndex )(* index_data , numel );
427
440
int64_t linearIndex = THTensor_ (wrapLinearIndex )(* index_data , numel );
428
441
int64_t dataOffset = is_contiguous ? linearIndex : THTensor_ (dataOffset )(tensor , linearIndex );
429
442
if (accumulate ) {
0 commit comments