Skip to content

Commit 5433b54

Browse files
committed
add 3 tests in TestNRTStatistics
1 parent 8494547 commit 5433b54

File tree

1 file changed

+87
-0
lines changed

1 file changed

+87
-0
lines changed

numba_cuda/numba/cuda/tests/nrt/test_nrt.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import re
2+
import os
3+
24
import numpy as np
35
import unittest
46
from unittest.mock import patch
57
from numba.cuda.testing import CUDATestCase
68

79
from numba.cuda.tests.nrt.mock_numpy import cuda_empty
10+
from numba.tests.support import run_in_subprocess
811

912
from numba import cuda
13+
from numba.cuda.runtime.nrt import rtsys
1014

1115

1216
class TestNrtBasic(CUDATestCase):
@@ -74,5 +78,88 @@ def g(out_ary):
7478
self.assertEqual(out_ary[0], 1)
7579

7680

81+
class TestNrtStatistics(CUDATestCase):
82+
83+
def setUp(self):
84+
self._stream = cuda.default_stream()
85+
# Store the current stats state
86+
self.__stats_state = rtsys.memsys_stats_enabled(self._stream)
87+
88+
def tearDown(self):
89+
# Set stats state back to whatever it was before the test ran
90+
if self.__stats_state:
91+
rtsys.memsys_enable_stats(self._stream)
92+
else:
93+
rtsys.memsys_disable_stats(self._stream)
94+
95+
def test_stats_env_var_explicit_on(self):
96+
# Checks that explicitly turning the stats on via the env var works.
97+
src = """if 1:
98+
from numba import cuda
99+
from unittest.mock import patch
100+
from numba.cuda.runtime import rtsys
101+
from numba.cuda.tests.nrt.mock_numpy import cuda_arange
102+
103+
@cuda.jit
104+
def foo():
105+
x = cuda_arange(10)[0]
106+
107+
# initialize the NRT before use
108+
rtsys.initialize()
109+
assert rtsys.memsys_stats_enabled()
110+
orig_stats = rtsys.get_allocation_stats()
111+
foo[1, 1]()
112+
new_stats = rtsys.get_allocation_stats()
113+
total_alloc = new_stats.alloc - orig_stats.alloc
114+
total_free = new_stats.free - orig_stats.free
115+
total_mi_alloc = new_stats.mi_alloc - orig_stats.mi_alloc
116+
total_mi_free = new_stats.mi_free - orig_stats.mi_free
117+
118+
expected = 1
119+
assert total_alloc == expected
120+
assert total_free == expected
121+
assert total_mi_alloc == expected
122+
assert total_mi_free == expected
123+
"""
124+
125+
# Check env var explicitly being set works
126+
env = os.environ.copy()
127+
env['NUMBA_CUDA_NRT_STATS'] = "1"
128+
env['NUMBA_CUDA_ENABLE_NRT'] = "1"
129+
run_in_subprocess(src, env=env)
130+
131+
def check_env_var_off(self, env):
132+
133+
src = """if 1:
134+
from numba import cuda
135+
import numpy as np
136+
from numba.cuda.runtime import rtsys
137+
138+
@cuda.jit
139+
def foo():
140+
arr = np.arange(10)[0]
141+
142+
assert rtsys.memsys_stats_enabled() == False
143+
try:
144+
rtsys.get_allocation_stats()
145+
except RuntimeError as e:
146+
assert "NRT stats are disabled." in str(e)
147+
"""
148+
run_in_subprocess(src, env=env)
149+
150+
def test_stats_env_var_explicit_off(self):
151+
# Checks that explicitly turning the stats off via the env var works.
152+
env = os.environ.copy()
153+
env['NUMBA_CUDA_NRT_STATS'] = "0"
154+
self.check_env_var_off(env)
155+
156+
def test_stats_env_var_default_off(self):
157+
# Checks that the env var not being set is the same as "off", i.e.
158+
# default for Numba is off.
159+
env = os.environ.copy()
160+
env.pop('NUMBA_CUDA_NRT_STATS', None)
161+
self.check_env_var_off(env)
162+
163+
77164
if __name__ == '__main__':
78165
unittest.main()

0 commit comments

Comments
 (0)