Skip to content

Commit f1f8377

Browse files
committed
add another test from CPU target
1 parent a7d2887 commit f1f8377

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

numba_cuda/numba/cuda/tests/nrt/mock_numpy.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12

23
from numba.core import errors, types
34
from numba.core.extending import overload
@@ -20,6 +21,10 @@ def cuda_empty(shape, dtype):
2021
pass
2122

2223

24+
def cuda_empty_like(arr):
25+
pass
26+
27+
2328
@overload(cuda_empty)
2429
def ol_cuda_empty(shape, dtype):
2530
_check_const_str_dtype("empty", dtype)
@@ -40,3 +45,15 @@ def impl(shape, dtype):
4045
else:
4146
msg = f"Cannot parse input types to function np.empty({shape}, {dtype})"
4247
raise errors.TypingError(msg)
48+
49+
50+
@overload(cuda_empty_like)
51+
def ol_cuda_empty_like(a, dtype=None):
52+
_check_const_str_dtype("zeros_like", dtype)
53+
54+
# NumPy uses 'a' as the arg name for the array-like
55+
def impl(a, dtype=None):
56+
arr = np.empty_like(a, dtype=dtype)
57+
arr._zero_fill()
58+
return arr
59+
return impl

numba_cuda/numba/cuda/tests/nrt/test_nrt.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from numba.tests.support import EnableNRTStatsMixin
88
from numba.cuda.testing import CUDATestCase
99

10-
from .mock_numpy import cuda_empty
10+
from numba.cuda.tests.nrt.mock_numpy import cuda_empty, cuda_empty_like
1111

1212
from numba import cuda
1313

@@ -70,6 +70,29 @@ def g(n):
7070
self.assertEqual(cur_stats.alloc - init_stats.alloc, 1)
7171
self.assertEqual(cur_stats.free - init_stats.free, 1)
7272

73+
def test_invalid_computation_of_lifetime(self):
74+
"""
75+
Test issue #1573
76+
"""
77+
@cuda.jit
78+
def if_with_allocation_and_initialization(arr1, test1):
79+
tmp_arr = cuda_empty_like(arr1)
80+
81+
for i in range(tmp_arr.shape[0]):
82+
pass
83+
84+
if test1:
85+
cuda_empty_like(arr1)
86+
87+
arr = np.random.random((5, 5)) # the values are not consumed
88+
89+
init_stats = rtsys.get_allocation_stats()
90+
with patch('numba.config.CUDA_ENABLE_NRT', True, create=True):
91+
if_with_allocation_and_initialization[1, 1](arr, False)
92+
cur_stats = rtsys.get_allocation_stats()
93+
self.assertEqual(cur_stats.alloc - init_stats.alloc,
94+
cur_stats.free - init_stats.free)
95+
7396

7497
class TestNrtBasic(CUDATestCase):
7598
def test_nrt_launches(self):

0 commit comments

Comments
 (0)