Skip to content

Commit f4d1a80

Browse files
committed
Explicitly control the use of stream in tests with NRT libraries
1 parent 06f5e53 commit f4d1a80

File tree

2 files changed

+24
-18
lines changed

2 files changed

+24
-18
lines changed

numba_cuda/numba/cuda/runtime/nrt.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,16 @@ def allocate(self, stream):
5858
self.set_memsys_to_module(self._memsys_module, stream=stream)
5959

6060
def _single_thread_launch(self, module, stream, name, params=()):
61+
if stream is None:
62+
stream = cuda.default_stream()
63+
6164
func = module.get_function(name)
6265
launch_kernel(
6366
func.handle,
6467
1, 1, 1,
6568
1, 1, 1,
6669
0,
67-
stream,
70+
stream.handle,
6871
params,
6972
cooperative=False
7073
)
@@ -92,7 +95,7 @@ def memsys_stats_disabled(self, stream):
9295
self._single_thread_launch(
9396
self._memsys_module, stream, "NRT_MemSys_disable")
9497

95-
def _copy_memsys_to_host(self, stream=0):
98+
def _copy_memsys_to_host(self, stream):
9699
self.ensure_allocate(stream)
97100
self.ensure_initialize(stream)
98101

@@ -116,7 +119,7 @@ def _copy_memsys_to_host(self, stream=0):
116119

117120
return stats_for_read[0]
118121

119-
def get_allocation_stats(self, stream=0):
122+
def get_allocation_stats(self, stream):
120123
memsys = self._copy_memsys_to_host(stream)
121124
return _nrt_mstats(
122125
alloc=memsys["alloc"],

numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from numba.cuda.runtime import rtsys
77
from numba.tests.support import EnableNRTStatsMixin
88
from numba.cuda.testing import CUDATestCase
9-
109
from numba.cuda.tests.nrt.mock_numpy import cuda_empty, cuda_empty_like
1110

1211
from numba import cuda
@@ -25,17 +24,17 @@ def test_no_return(self):
2524
"""
2625
n = 10
2726

28-
@cuda.jit(debug=True)
27+
@cuda.jit
2928
def kernel():
3029
for i in range(n):
3130
temp = cuda_empty(2, np.float64) # noqa: F841
3231
return None
3332

34-
init_stats = rtsys.get_allocation_stats()
35-
33+
stream = cuda.default_stream()
34+
init_stats = rtsys.get_allocation_stats(stream)
3635
with patch('numba.config.CUDA_ENABLE_NRT', True, create=True):
37-
kernel[1,1]()
38-
cur_stats = rtsys.get_allocation_stats()
36+
kernel[1, 1, stream]()
37+
cur_stats = rtsys.get_allocation_stats(stream)
3938
self.assertEqual(cur_stats.alloc - init_stats.alloc, n)
4039
self.assertEqual(cur_stats.free - init_stats.free, n)
4140

@@ -57,10 +56,11 @@ def g(n):
5756

5857
return None
5958

60-
init_stats = rtsys.get_allocation_stats()
59+
stream = cuda.default_stream()
60+
init_stats = rtsys.get_allocation_stats(stream)
6161
with patch('numba.config.CUDA_ENABLE_NRT', True, create=True):
62-
g[1, 1](10)
63-
cur_stats = rtsys.get_allocation_stats()
62+
g[1, 1, stream](10)
63+
cur_stats = rtsys.get_allocation_stats(stream)
6464
self.assertEqual(cur_stats.alloc - init_stats.alloc, 1)
6565
self.assertEqual(cur_stats.free - init_stats.free, 1)
6666

@@ -80,10 +80,11 @@ def if_with_allocation_and_initialization(arr1, test1):
8080

8181
arr = np.random.random((5, 5)) # the values are not consumed
8282

83-
init_stats = rtsys.get_allocation_stats()
83+
stream = cuda.default_stream()
84+
init_stats = rtsys.get_allocation_stats(stream)
8485
with patch('numba.config.CUDA_ENABLE_NRT', True, create=True):
85-
if_with_allocation_and_initialization[1, 1](arr, False)
86-
cur_stats = rtsys.get_allocation_stats()
86+
if_with_allocation_and_initialization[1, 1, stream](arr, False)
87+
cur_stats = rtsys.get_allocation_stats(stream)
8788
self.assertEqual(cur_stats.alloc - init_stats.alloc,
8889
cur_stats.free - init_stats.free)
8990

@@ -103,10 +104,12 @@ def f(arr):
103104
res += t[i]
104105

105106
arr = np.ones((2, 2))
106-
init_stats = rtsys.get_allocation_stats()
107+
108+
stream = cuda.default_stream()
109+
init_stats = rtsys.get_allocation_stats(stream)
107110
with patch('numba.config.CUDA_ENABLE_NRT', True, create=True):
108-
f[1, 1](arr)
109-
cur_stats = rtsys.get_allocation_stats()
111+
f[1, 1, stream](arr)
112+
cur_stats = rtsys.get_allocation_stats(stream)
110113
self.assertEqual(cur_stats.alloc - init_stats.alloc,
111114
cur_stats.free - init_stats.free)
112115

0 commit comments

Comments
 (0)