Skip to content

Commit 5273e4a

Browse files
committed
make cuda nrt test mixin
1 parent 5879098 commit 5273e4a

File tree

2 files changed

+24
-17
lines changed

2 files changed

+24
-17
lines changed

numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import unittest
55
from unittest.mock import patch
66
from numba.cuda.runtime import rtsys
7-
from numba.tests.support import EnableNRTStatsMixin
7+
from numba.cuda.tests.support import EnableNRTStatsMixin
88
from numba.cuda.testing import CUDATestCase
99
from numba.cuda.tests.nrt.mock_numpy import cuda_empty, cuda_empty_like
1010

@@ -30,11 +30,10 @@ def kernel():
3030
temp = cuda_empty(2, np.float64) # noqa: F841
3131
return None
3232

33-
stream = cuda.default_stream()
34-
init_stats = rtsys.get_allocation_stats(stream)
33+
init_stats = rtsys.get_allocation_stats()
3534
with patch('numba.config.CUDA_ENABLE_NRT', True, create=True):
36-
kernel[1, 1, stream]()
37-
cur_stats = rtsys.get_allocation_stats(stream)
35+
kernel[1, 1]()
36+
cur_stats = rtsys.get_allocation_stats()
3837
self.assertEqual(cur_stats.alloc - init_stats.alloc, n)
3938
self.assertEqual(cur_stats.free - init_stats.free, n)
4039

@@ -56,11 +55,10 @@ def g(n):
5655

5756
return None
5857

59-
stream = cuda.default_stream()
60-
init_stats = rtsys.get_allocation_stats(stream)
58+
init_stats = rtsys.get_allocation_stats()
6159
with patch('numba.config.CUDA_ENABLE_NRT', True, create=True):
62-
g[1, 1, stream](10)
63-
cur_stats = rtsys.get_allocation_stats(stream)
60+
g[1, 1](10)
61+
cur_stats = rtsys.get_allocation_stats()
6462
self.assertEqual(cur_stats.alloc - init_stats.alloc, 1)
6563
self.assertEqual(cur_stats.free - init_stats.free, 1)
6664

@@ -80,11 +78,10 @@ def if_with_allocation_and_initialization(arr1, test1):
8078

8179
arr = np.random.random((5, 5)) # the values are not consumed
8280

83-
stream = cuda.default_stream()
84-
init_stats = rtsys.get_allocation_stats(stream)
81+
init_stats = rtsys.get_allocation_stats()
8582
with patch('numba.config.CUDA_ENABLE_NRT', True, create=True):
86-
if_with_allocation_and_initialization[1, 1, stream](arr, False)
87-
cur_stats = rtsys.get_allocation_stats(stream)
83+
if_with_allocation_and_initialization[1, 1](arr, False)
84+
cur_stats = rtsys.get_allocation_stats()
8885
self.assertEqual(cur_stats.alloc - init_stats.alloc,
8986
cur_stats.free - init_stats.free)
9087

@@ -105,11 +102,10 @@ def f(arr):
105102

106103
arr = np.ones((2, 2))
107104

108-
stream = cuda.default_stream()
109-
init_stats = rtsys.get_allocation_stats(stream)
105+
init_stats = rtsys.get_allocation_stats()
110106
with patch('numba.config.CUDA_ENABLE_NRT', True, create=True):
111-
f[1, 1, stream](arr)
112-
cur_stats = rtsys.get_allocation_stats(stream)
107+
f[1, 1](arr)
108+
cur_stats = rtsys.get_allocation_stats()
113109
self.assertEqual(cur_stats.alloc - init_stats.alloc,
114110
cur_stats.free - init_stats.free)
115111

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from numba.cuda.runtime.nrt import rtsys
2+
3+
4+
class EnableNRTStatsMixin(object):
5+
"""Mixin to enable the NRT statistics counters."""
6+
7+
def setUp(self):
8+
rtsys.memsys_enable_stats()
9+
10+
def tearDown(self):
11+
rtsys.memsys_disable_stats()

0 commit comments

Comments
 (0)