Skip to content

Commit 7710372

Browse files
committed
remove all patches, use override_config; Only allocate memsys when NRT is enabled
1 parent 237dae4 commit 7710372

File tree

4 files changed

+58
-52
lines changed

4 files changed

+58
-52
lines changed

numba_cuda/numba/cuda/dispatcher.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -362,11 +362,14 @@ def launch(self, args, griddim, blockdim, stream=0, sharedmem=0):
362362

363363
stream_handle = stream and stream.handle or zero_stream
364364

365-
rtsys.ensure_allocated(stream_handle)
366-
rtsys.set_memsys_to_module(cufunc.module, stream_handle)
367-
rtsys.ensure_initialized(stream_handle)
368-
if config.CUDA_NRT_STATS:
369-
rtsys.memsys_enable_stats(stream_handle)
365+
if hasattr(self, "target_context") and self.target_context.enable_nrt:
366+
# If NRT is enabled, we also initialize the memsys. The statistics
367+
# are controlled by a different config setting `NRT_STATS`.
368+
rtsys.ensure_allocated(stream_handle)
369+
rtsys.set_memsys_to_module(cufunc.module, stream_handle)
370+
rtsys.ensure_initialized(stream_handle)
371+
if config.CUDA_NRT_STATS:
372+
rtsys.memsys_enable_stats(stream_handle)
370373

371374
# Invoke kernel
372375
driver.launch_kernel(cufunc.handle,

numba_cuda/numba/cuda/runtime/nrt.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ extern "C" __device__ void* NRT_Allocate(size_t size)
3333
{
3434
void* ptr = NULL;
3535
ptr = malloc(size);
36-
if (TheMSys->stats.enabled) { TheMSys->stats.alloc++; }
36+
if (TheMSys && TheMSys->stats.enabled) { TheMSys->stats.alloc++; }
3737
return ptr;
3838
}
3939

@@ -48,7 +48,7 @@ extern "C" __device__ void NRT_MemInfo_init(NRT_MemInfo* mi,
4848
mi->dtor_info = dtor_info;
4949
mi->data = data;
5050
mi->size = size;
51-
if (TheMSys->stats.enabled) { TheMSys->stats.mi_alloc++; }
51+
if (TheMSys && TheMSys->stats.enabled) { TheMSys->stats.mi_alloc++; }
5252
}
5353

5454
extern "C"
@@ -63,7 +63,7 @@ __device__ NRT_MemInfo* NRT_MemInfo_new(
6363
extern "C" __device__ void NRT_Free(void* ptr)
6464
{
6565
free(ptr);
66-
if (TheMSys->stats.enabled) { TheMSys->stats.free++; }
66+
if (TheMSys && TheMSys->stats.enabled) { TheMSys->stats.free++; }
6767
}
6868

6969
extern "C" __device__ void NRT_dealloc(NRT_MemInfo* mi)
@@ -74,7 +74,7 @@ extern "C" __device__ void NRT_dealloc(NRT_MemInfo* mi)
7474
extern "C" __device__ void NRT_MemInfo_destroy(NRT_MemInfo* mi)
7575
{
7676
NRT_dealloc(mi);
77-
if (TheMSys->stats.enabled) { TheMSys->stats.mi_free++; }
77+
if (TheMSys && TheMSys->stats.enabled) { TheMSys->stats.mi_free++; }
7878
}
7979

8080
extern "C" __device__ void NRT_MemInfo_call_dtor(NRT_MemInfo* mi)
@@ -151,7 +151,7 @@ extern "C" __device__ void* NRT_Allocate_External(size_t size) {
151151
ptr = malloc(size);
152152
//NRT_Debug(nrt_debug_print("NRT_Allocate_External bytes=%zu ptr=%p\n", size, ptr));
153153

154-
if (TheMSys->stats.enabled)
154+
if (TheMSys && TheMSys->stats.enabled)
155155
{
156156
TheMSys->stats.alloc++;
157157
}

numba_cuda/numba/cuda/tests/nrt/test_nrt.py

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,20 @@
33

44
import numpy as np
55
import unittest
6-
from unittest.mock import patch
76
from numba.cuda.testing import CUDATestCase
87

98
from numba.cuda.tests.nrt.mock_numpy import cuda_empty, cuda_ones, cuda_arange
10-
from numba.tests.support import run_in_subprocess
9+
from numba.tests.support import run_in_subprocess, override_config
1110

1211
from numba import cuda
1312
from numba.cuda.runtime.nrt import rtsys
1413

1514

1615
class TestNrtBasic(CUDATestCase):
16+
def run(self, result=None):
17+
with override_config("CUDA_ENABLE_NRT", True):
18+
super(TestNrtBasic, self).run(result)
19+
1720
def test_nrt_launches(self):
1821
@cuda.jit
1922
def f(x):
@@ -24,8 +27,7 @@ def g():
2427
x = cuda_empty(10, np.int64)
2528
f(x)
2629

27-
with patch('numba.config.CUDA_ENABLE_NRT', True, create=True):
28-
g[1,1]()
30+
g[1,1]()
2931
cuda.synchronize()
3032

3133
def test_nrt_ptx_contains_refcount(self):
@@ -38,8 +40,7 @@ def g():
3840
x = cuda_empty(10, np.int64)
3941
f(x)
4042

41-
with patch('numba.config.CUDA_ENABLE_NRT', True, create=True):
42-
g[1,1]()
43+
g[1,1]()
4344

4445
ptx = next(iter(g.inspect_asm().values()))
4546

@@ -72,8 +73,7 @@ def g(out_ary):
7273

7374
out_ary = np.zeros(1, dtype=np.int64)
7475

75-
with patch('numba.config.CUDA_ENABLE_NRT', True, create=True):
76-
g[1,1](out_ary)
76+
g[1,1](out_ary)
7777

7878
self.assertEqual(out_ary[0], 1)
7979

@@ -168,36 +168,35 @@ def foo():
168168
arr = cuda_arange(5 * tmp[0]) # noqa: F841
169169
return None
170170

171-
# Switch on stats
172-
rtsys.memsys_enable_stats()
173-
# check the stats are on
174-
self.assertTrue(rtsys.memsys_stats_enabled())
175-
176-
for i in range(2):
177-
# capture the stats state
178-
stats_1 = rtsys.get_allocation_stats()
179-
# Switch off stats
180-
rtsys.memsys_disable_stats()
181-
# check the stats are off
182-
self.assertFalse(rtsys.memsys_stats_enabled())
183-
# run something that would move the counters were they enabled
184-
with patch('numba.config.CUDA_ENABLE_NRT', True, create=True):
185-
foo[1, 1]()
171+
with override_config('CUDA_ENABLE_NRT', True):
186172
# Switch on stats
187173
rtsys.memsys_enable_stats()
188174
# check the stats are on
189175
self.assertTrue(rtsys.memsys_stats_enabled())
190-
# capture the stats state (should not have changed)
191-
stats_2 = rtsys.get_allocation_stats()
192-
# run something that will move the counters
193-
with patch('numba.config.CUDA_ENABLE_NRT', True, create=True):
176+
177+
for i in range(2):
178+
# capture the stats state
179+
stats_1 = rtsys.get_allocation_stats()
180+
# Switch off stats
181+
rtsys.memsys_disable_stats()
182+
# check the stats are off
183+
self.assertFalse(rtsys.memsys_stats_enabled())
184+
# run something that would move the counters were they enabled
185+
foo[1, 1]()
186+
# Switch on stats
187+
rtsys.memsys_enable_stats()
188+
# check the stats are on
189+
self.assertTrue(rtsys.memsys_stats_enabled())
190+
# capture the stats state (should not have changed)
191+
stats_2 = rtsys.get_allocation_stats()
192+
# run something that will move the counters
194193
foo[1, 1]()
195-
# capture the stats state (should have changed)
196-
stats_3 = rtsys.get_allocation_stats()
197-
# check stats_1 == stats_2
198-
self.assertEqual(stats_1, stats_2)
199-
# check stats_2 < stats_3
200-
self.assertLess(stats_2, stats_3)
194+
# capture the stats state (should have changed)
195+
stats_3 = rtsys.get_allocation_stats()
196+
# check stats_1 == stats_2
197+
self.assertEqual(stats_1, stats_2)
198+
# check stats_2 < stats_3
199+
self.assertLess(stats_2, stats_3)
201200

202201
def test_rtsys_stats_query_raises_exception_when_disabled(self):
203202
# Checks that the standard rtsys.get_allocation_stats() query raises

numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import gc
33
import numpy as np
44
import unittest
5-
from unittest.mock import patch
5+
from numba.tests.support import override_config
66
from numba.cuda.runtime import rtsys
77
from numba.cuda.tests.support import EnableNRTStatsMixin
88
from numba.cuda.testing import CUDATestCase
@@ -18,10 +18,18 @@ def setUp(self):
1818
gc.collect()
1919
super(TestNrtRefCt, self).setUp()
2020

21+
def tearDown(self):
22+
super(TestNrtRefCt, self).tearDown()
23+
24+
def run(self, result=None):
25+
with override_config("CUDA_ENABLE_NRT", True):
26+
super(TestNrtRefCt, self).run(result)
27+
2128
def test_no_return(self):
2229
"""
2330
Test issue #1291
2431
"""
32+
2533
n = 10
2634

2735
@cuda.jit
@@ -31,8 +39,7 @@ def kernel():
3139
return None
3240

3341
init_stats = rtsys.get_allocation_stats()
34-
with patch('numba.config.CUDA_ENABLE_NRT', True, create=True):
35-
kernel[1, 1]()
42+
kernel[1, 1]()
3643
cur_stats = rtsys.get_allocation_stats()
3744
self.assertEqual(cur_stats.alloc - init_stats.alloc, n)
3845
self.assertEqual(cur_stats.free - init_stats.free, n)
@@ -56,8 +63,7 @@ def g(n):
5663
return None
5764

5865
init_stats = rtsys.get_allocation_stats()
59-
with patch('numba.config.CUDA_ENABLE_NRT', True, create=True):
60-
g[1, 1](10)
66+
g[1, 1](10)
6167
cur_stats = rtsys.get_allocation_stats()
6268
self.assertEqual(cur_stats.alloc - init_stats.alloc, 1)
6369
self.assertEqual(cur_stats.free - init_stats.free, 1)
@@ -79,8 +85,7 @@ def if_with_allocation_and_initialization(arr1, test1):
7985
arr = np.random.random((5, 5)) # the values are not consumed
8086

8187
init_stats = rtsys.get_allocation_stats()
82-
with patch('numba.config.CUDA_ENABLE_NRT', True, create=True):
83-
if_with_allocation_and_initialization[1, 1](arr, False)
88+
if_with_allocation_and_initialization[1, 1](arr, False)
8489
cur_stats = rtsys.get_allocation_stats()
8590
self.assertEqual(cur_stats.alloc - init_stats.alloc,
8691
cur_stats.free - init_stats.free)
@@ -103,8 +108,7 @@ def f(arr):
103108
arr = np.ones((2, 2))
104109

105110
init_stats = rtsys.get_allocation_stats()
106-
with patch('numba.config.CUDA_ENABLE_NRT', True, create=True):
107-
f[1, 1](arr)
111+
f[1, 1](arr)
108112
cur_stats = rtsys.get_allocation_stats()
109113
self.assertEqual(cur_stats.alloc - init_stats.alloc,
110114
cur_stats.free - init_stats.free)

0 commit comments

Comments
 (0)