Skip to content

Commit 010e85f

Browse files
authored
Fix NRT stats for cuda python (#168)
* Fix NRT stats when cuda-python bindings in use The NRT stats implementation did not take the difference in ctypes pointers between the ctypes and cuda-python bindings into account. * NRT tests: Add a test for single stats This functionality was previously not exercised by the test suite.
1 parent 91a9c66 commit 010e85f

File tree

2 files changed

+56
-5
lines changed

2 files changed

+56
-5
lines changed

numba_cuda/numba/cuda/runtime/nrt.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
from numba import cuda, config
77
from numba.core.runtime.nrt import _nrt_mstats
8-
from numba.cuda.cudadrv.driver import Linker, driver, launch_kernel
8+
from numba.cuda.cudadrv.driver import (Linker, driver, launch_kernel,
9+
USE_NV_BINDING)
910
from numba.cuda.cudadrv import devices
1011
from numba.cuda.api import get_current_device
1112
from numba.cuda.utils import _readenv
@@ -128,6 +129,18 @@ def _single_thread_launch(self, module, stream, name, params=()):
128129
cooperative=False
129130
)
130131

132+
def _ctypes_pointer(self, array):
133+
"""
134+
Given an array, return a ctypes pointer to the data suitable for
135+
passing to ``launch_kernel``.
136+
"""
137+
ptr = array.device_ctypes_pointer
138+
139+
if USE_NV_BINDING:
140+
ptr = ctypes.c_void_p(int(ptr))
141+
142+
return ptr
143+
131144
def ensure_initialized(self, stream=None):
132145
"""
133146
If memsys is not initialized, initialize memsys
@@ -174,12 +187,13 @@ def memsys_stats_enabled(self, stream=None):
174187
context
175188
"""
176189
enabled_ar = cuda.managed_array(1, np.uint8)
190+
enabled_ptr = self._ctypes_pointer(enabled_ar)
177191

178192
self._single_thread_launch(
179193
self._memsys_module,
180194
stream,
181195
"NRT_MemSys_stats_enabled",
182-
(enabled_ar.device_ctypes_pointer,)
196+
(enabled_ptr,)
183197
)
184198

185199
cuda.synchronize()
@@ -198,12 +212,13 @@ def _copy_memsys_to_host(self, stream):
198212
])
199213

200214
stats_for_read = cuda.managed_array(1, dt)
215+
stats_ptr = self._ctypes_pointer(stats_for_read)
201216

202217
self._single_thread_launch(
203218
self._memsys_module,
204219
stream,
205220
"NRT_MemSys_read",
206-
[stats_for_read.device_ctypes_pointer]
221+
[stats_ptr]
207222
)
208223
cuda.synchronize()
209224

@@ -231,11 +246,13 @@ def _get_single_stat(self, stat, stream=None):
231246
Get a single stat from the memsys
232247
"""
233248
got = cuda.managed_array(1, np.uint64)
249+
got_ptr = self._ctypes_pointer(got)
250+
234251
self._single_thread_launch(
235252
self._memsys_module,
236253
stream,
237254
f"NRT_MemSys_read_{stat}",
238-
[got.device_ctypes_pointer]
255+
[got_ptr]
239256
)
240257

241258
cuda.synchronize()
@@ -294,11 +311,13 @@ def set_memsys_to_module(self, module, stream=None):
294311
raise RuntimeError(
295312
"Please allocate NRT Memsys first before setting to module.")
296313

314+
memsys_ptr = self._ctypes_pointer(self._memsys)
315+
297316
self._single_thread_launch(
298317
module,
299318
stream,
300319
"NRT_MemSys_set",
301-
[self._memsys.device_ctypes_pointer,]
320+
[memsys_ptr]
302321
)
303322

304323
@_alloc_init_guard

numba_cuda/numba/cuda/tests/nrt/test_nrt.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,38 @@ def test_nrt_explicit_stats_query_raises_exception_when_disabled(self):
229229
stats_func()
230230
self.assertIn("NRT stats are disabled.", str(raises.exception))
231231

232+
def test_read_one_stat(self):
233+
@cuda.jit
234+
def foo():
235+
tmp = np.ones(3)
236+
arr = np.arange(5 * tmp[0]) # noqa: F841
237+
return None
238+
239+
with (
240+
override_config('CUDA_ENABLE_NRT', True),
241+
override_config('CUDA_NRT_STATS', True)
242+
):
243+
244+
# Switch on stats
245+
rtsys.memsys_enable_stats()
246+
247+
# Launch the kernel a couple of times to increase stats
248+
foo[1, 1]()
249+
foo[1, 1]()
250+
251+
# Get stats struct and individual stats
252+
stats = rtsys.get_allocation_stats()
253+
stats_alloc = rtsys.memsys_get_stats_alloc()
254+
stats_mi_alloc = rtsys.memsys_get_stats_mi_alloc()
255+
stats_free = rtsys.memsys_get_stats_free()
256+
stats_mi_free = rtsys.memsys_get_stats_mi_free()
257+
258+
# Check individual stats match stats struct
259+
self.assertEqual(stats.alloc, stats_alloc)
260+
self.assertEqual(stats.mi_alloc, stats_mi_alloc)
261+
self.assertEqual(stats.free, stats_free)
262+
self.assertEqual(stats.mi_free, stats_mi_free)
263+
232264

233265
if __name__ == '__main__':
234266
unittest.main()

0 commit comments

Comments
 (0)