Skip to content

Commit 5879098

Browse files
committed
add 3 more tests, augment API with single getters
1 parent 5433b54 commit 5879098

File tree

4 files changed

+141
-2
lines changed

4 files changed

+141
-2
lines changed

numba_cuda/numba/cuda/dispatcher.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,8 @@ def launch(self, args, griddim, blockdim, stream=0, sharedmem=0):
365365
rtsys.ensure_allocated(stream_handle)
366366
rtsys.set_memsys_to_module(cufunc.module, stream_handle)
367367
rtsys.ensure_initialized(stream_handle)
368-
rtsys.memsys_enable_stats(stream_handle)
368+
if config.CUDA_NRT_STATS:
369+
rtsys.memsys_enable_stats(stream_handle)
369370

370371
# Invoke kernel
371372
driver.launch_kernel(cufunc.handle,

numba_cuda/numba/cuda/runtime/memsys.cu

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,30 @@ extern "C" __global__ void NRT_MemSys_read(uint64_t *managed_memsys)
2525
managed_memsys[3] = TheMSys->stats.mi_free;
2626
}
2727

28+
extern "C" __global__ void NRT_MemSys_read_alloc(uint64_t *managed_memsys)
29+
{
30+
detail::check_memsys();
31+
managed_memsys[0] = TheMSys->stats.alloc;
32+
}
33+
34+
extern "C" __global__ void NRT_MemSys_read_free(uint64_t *managed_memsys)
35+
{
36+
detail::check_memsys();
37+
managed_memsys[0] = TheMSys->stats.free;
38+
}
39+
40+
extern "C" __global__ void NRT_MemSys_read_mi_alloc(uint64_t *managed_memsys)
41+
{
42+
detail::check_memsys();
43+
managed_memsys[0] = TheMSys->stats.mi_alloc;
44+
}
45+
46+
extern "C" __global__ void NRT_MemSys_read_mi_free(uint64_t *managed_memsys)
47+
{
48+
detail::check_memsys();
49+
managed_memsys[0] = TheMSys->stats.mi_free;
50+
}
51+
2852
extern "C" __global__ void NRT_MemSys_init(void)
2953
{
3054
detail::check_memsys();

numba_cuda/numba/cuda/runtime/nrt.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ def _copy_memsys_to_host(self, stream):
163163

164164
@_alloc_init_guard
165165
def get_allocation_stats(self, stream=None):
166+
enabled = self.memsys_stats_enabled(stream)
167+
if not enabled:
168+
raise RuntimeError("NRT stats are disabled.")
166169
memsys = self._copy_memsys_to_host(stream)
167170
return _nrt_mstats(
168171
alloc=memsys["alloc"],
@@ -171,6 +174,51 @@ def get_allocation_stats(self, stream=None):
171174
mi_free=memsys["mi_free"]
172175
)
173176

177+
@_alloc_init_guard
178+
def _get_single_stat(self, stat, stream=None):
179+
got = cuda.managed_array(1, np.uint64)
180+
self._single_thread_launch(
181+
self._memsys_module,
182+
stream,
183+
f"NRT_MemSys_read_{stat}",
184+
[got.device_ctypes_pointer]
185+
)
186+
187+
cuda.synchronize()
188+
return got[0]
189+
190+
@_alloc_init_guard
191+
def memsys_get_stats_alloc(self, stream=None):
192+
enabled = self.memsys_stats_enabled(stream)
193+
if not enabled:
194+
raise RuntimeError("NRT stats are disabled.")
195+
196+
return self._get_single_stat("alloc")
197+
198+
@_alloc_init_guard
199+
def memsys_get_stats_free(self, stream=None):
200+
enabled = self.memsys_stats_enabled(stream)
201+
if not enabled:
202+
raise RuntimeError("NRT stats are disabled.")
203+
204+
return self._get_single_stat("free")
205+
206+
@_alloc_init_guard
207+
def memsys_get_stats_mi_alloc(self, stream=None):
208+
enabled = self.memsys_stats_enabled(stream)
209+
if not enabled:
210+
raise RuntimeError("NRT stats are disabled.")
211+
212+
return self._get_single_stat("mi_alloc")
213+
214+
@_alloc_init_guard
215+
def memsys_get_stats_mi_free(self, stream=None):
216+
enabled = self.memsys_stats_enabled(stream)
217+
if not enabled:
218+
raise RuntimeError("NRT stats are disabled.")
219+
220+
return self._get_single_stat("mi_free")
221+
174222
def set_memsys_to_module(self, module, stream=None):
175223
if self._memsys is None:
176224
raise RuntimeError(
@@ -186,6 +234,7 @@ def set_memsys_to_module(self, module, stream=None):
186234

187235
@_alloc_init_guard
188236
def print_memsys(self, stream=None):
237+
"""Print the current statistics of memsys, for debugging purpose."""
189238
cuda.synchronize()
190239
self._single_thread_launch(
191240
self._memsys_module,

numba_cuda/numba/cuda/tests/nrt/test_nrt.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from unittest.mock import patch
77
from numba.cuda.testing import CUDATestCase
88

9-
from numba.cuda.tests.nrt.mock_numpy import cuda_empty
9+
from numba.cuda.tests.nrt.mock_numpy import cuda_empty, cuda_ones, cuda_arange
1010
from numba.tests.support import run_in_subprocess
1111

1212
from numba import cuda
@@ -160,6 +160,71 @@ def test_stats_env_var_default_off(self):
160160
env.pop('NUMBA_CUDA_NRT_STATS', None)
161161
self.check_env_var_off(env)
162162

163+
def test_stats_status_toggle(self):
164+
165+
@cuda.jit
166+
def foo():
167+
tmp = cuda_ones(3)
168+
arr = cuda_arange(5 * tmp[0]) # noqa: F841
169+
return None
170+
171+
# Switch on stats
172+
rtsys.memsys_enable_stats()
173+
# check the stats are on
174+
self.assertTrue(rtsys.memsys_stats_enabled())
175+
176+
for i in range(2):
177+
# capture the stats state
178+
stats_1 = rtsys.get_allocation_stats()
179+
# Switch off stats
180+
rtsys.memsys_disable_stats()
181+
# check the stats are off
182+
self.assertFalse(rtsys.memsys_stats_enabled())
183+
# run something that would move the counters were they enabled
184+
with patch('numba.config.CUDA_ENABLE_NRT', True, create=True):
185+
foo[1, 1]()
186+
# Switch on stats
187+
rtsys.memsys_enable_stats()
188+
# check the stats are on
189+
self.assertTrue(rtsys.memsys_stats_enabled())
190+
# capture the stats state (should not have changed)
191+
stats_2 = rtsys.get_allocation_stats()
192+
# run something that will move the counters
193+
with patch('numba.config.CUDA_ENABLE_NRT', True, create=True):
194+
foo[1, 1]()
195+
# capture the stats state (should have changed)
196+
stats_3 = rtsys.get_allocation_stats()
197+
# check stats_1 == stats_2
198+
self.assertEqual(stats_1, stats_2)
199+
# check stats_2 < stats_3
200+
self.assertLess(stats_2, stats_3)
201+
202+
def test_rtsys_stats_query_raises_exception_when_disabled(self):
203+
# Checks that the standard rtsys.get_allocation_stats() query raises
204+
# when stats counters are turned off.
205+
206+
rtsys.memsys_disable_stats()
207+
self.assertFalse(rtsys.memsys_stats_enabled())
208+
209+
with self.assertRaises(RuntimeError) as raises:
210+
rtsys.get_allocation_stats()
211+
212+
self.assertIn("NRT stats are disabled.", str(raises.exception))
213+
214+
def test_nrt_explicit_stats_query_raises_exception_when_disabled(self):
215+
# Checks the various memsys_get_stats functions raise if queried when
216+
# the stats counters are disabled.
217+
method_variations = ('alloc', 'free', 'mi_alloc', 'mi_free')
218+
for meth in method_variations:
219+
stats_func = getattr(rtsys, f'memsys_get_stats_{meth}')
220+
with self.subTest(stats_func=stats_func):
221+
# Turn stats off
222+
rtsys.memsys_disable_stats()
223+
self.assertFalse(rtsys.memsys_stats_enabled())
224+
with self.assertRaises(RuntimeError) as raises:
225+
stats_func()
226+
self.assertIn("NRT stats are disabled.", str(raises.exception))
227+
163228

164229
if __name__ == '__main__':
165230
unittest.main()

0 commit comments

Comments
 (0)