33
44import numpy as np
55import unittest
6- from unittest .mock import patch
76from numba .cuda .testing import CUDATestCase
87
98from numba .cuda .tests .nrt .mock_numpy import cuda_empty , cuda_ones , cuda_arange
10- from numba .tests .support import run_in_subprocess
9+ from numba .tests .support import run_in_subprocess , override_config
1110
1211from numba import cuda
1312from numba .cuda .runtime .nrt import rtsys
1413
1514
1615class TestNrtBasic (CUDATestCase ):
16+ def run (self , result = None ):
17+ with override_config ("CUDA_ENABLE_NRT" , True ):
18+ super (TestNrtBasic , self ).run (result )
19+
1720 def test_nrt_launches (self ):
1821 @cuda .jit
1922 def f (x ):
@@ -24,8 +27,7 @@ def g():
2427 x = cuda_empty (10 , np .int64 )
2528 f (x )
2629
27- with patch ('numba.config.CUDA_ENABLE_NRT' , True , create = True ):
28- g [1 ,1 ]()
30+ g [1 ,1 ]()
2931 cuda .synchronize ()
3032
3133 def test_nrt_ptx_contains_refcount (self ):
@@ -38,8 +40,7 @@ def g():
3840 x = cuda_empty (10 , np .int64 )
3941 f (x )
4042
41- with patch ('numba.config.CUDA_ENABLE_NRT' , True , create = True ):
42- g [1 ,1 ]()
43+ g [1 ,1 ]()
4344
4445 ptx = next (iter (g .inspect_asm ().values ()))
4546
@@ -72,8 +73,7 @@ def g(out_ary):
7273
7374 out_ary = np .zeros (1 , dtype = np .int64 )
7475
75- with patch ('numba.config.CUDA_ENABLE_NRT' , True , create = True ):
76- g [1 ,1 ](out_ary )
76+ g [1 ,1 ](out_ary )
7777
7878 self .assertEqual (out_ary [0 ], 1 )
7979
@@ -168,36 +168,35 @@ def foo():
168168 arr = cuda_arange (5 * tmp [0 ]) # noqa: F841
169169 return None
170170
171- # Switch on stats
172- rtsys .memsys_enable_stats ()
173- # check the stats are on
174- self .assertTrue (rtsys .memsys_stats_enabled ())
175-
176- for i in range (2 ):
177- # capture the stats state
178- stats_1 = rtsys .get_allocation_stats ()
179- # Switch off stats
180- rtsys .memsys_disable_stats ()
181- # check the stats are off
182- self .assertFalse (rtsys .memsys_stats_enabled ())
183- # run something that would move the counters were they enabled
184- with patch ('numba.config.CUDA_ENABLE_NRT' , True , create = True ):
185- foo [1 , 1 ]()
171+ with override_config ('CUDA_ENABLE_NRT' , True ):
186172 # Switch on stats
187173 rtsys .memsys_enable_stats ()
188174 # check the stats are on
189175 self .assertTrue (rtsys .memsys_stats_enabled ())
190- # capture the stats state (should not have changed)
191- stats_2 = rtsys .get_allocation_stats ()
192- # run something that will move the counters
193- with patch ('numba.config.CUDA_ENABLE_NRT' , True , create = True ):
176+
177+ for i in range (2 ):
178+ # capture the stats state
179+ stats_1 = rtsys .get_allocation_stats ()
180+ # Switch off stats
181+ rtsys .memsys_disable_stats ()
182+ # check the stats are off
183+ self .assertFalse (rtsys .memsys_stats_enabled ())
184+ # run something that would move the counters were they enabled
185+ foo [1 , 1 ]()
186+ # Switch on stats
187+ rtsys .memsys_enable_stats ()
188+ # check the stats are on
189+ self .assertTrue (rtsys .memsys_stats_enabled ())
190+ # capture the stats state (should not have changed)
191+ stats_2 = rtsys .get_allocation_stats ()
192+ # run something that will move the counters
194193 foo [1 , 1 ]()
195- # capture the stats state (should have changed)
196- stats_3 = rtsys .get_allocation_stats ()
197- # check stats_1 == stats_2
198- self .assertEqual (stats_1 , stats_2 )
199- # check stats_2 < stats_3
200- self .assertLess (stats_2 , stats_3 )
194+ # capture the stats state (should have changed)
195+ stats_3 = rtsys .get_allocation_stats ()
196+ # check stats_1 == stats_2
197+ self .assertEqual (stats_1 , stats_2 )
198+ # check stats_2 < stats_3
199+ self .assertLess (stats_2 , stats_3 )
201200
202201 def test_rtsys_stats_query_raises_exception_when_disabled (self ):
203202 # Checks that the standard rtsys.get_allocation_stats() query raises
0 commit comments