Skip to content

Commit 99cab49

Browse files
Remove some unnecessary uses of ContextResettingTestCase (#507)
Some cases that I think look safe to remove; a few more complicated ones still remain. Testing on CI for now - seems to pass locally for me. --------- Co-authored-by: brandon-b-miller <[email protected]>
1 parent c83f379 commit 99cab49

14 files changed

+46
-59
lines changed

numba_cuda/numba/cuda/testing.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,7 @@
3232
@pytest.mark.usefixtures("initialize_from_pytest_config")
3333
class CUDATestCase(TestCase):
3434
"""
35-
For tests that use a CUDA device. Test methods in a CUDATestCase must not
36-
be run out of module order, because the ContextResettingTestCase may reset
37-
the context and destroy resources used by a normal CUDATestCase if any of
38-
its tests are run between tests from a CUDATestCase.
35+
For tests that use a CUDA device.
3936
4037
Methods assertFileCheckAsm and assertFileCheckLLVM will inspect a
4138
CUDADispatcher and assert that the compilation artifacts match the
@@ -187,21 +184,6 @@ def assertFileCheckMatches(
187184
)
188185

189186

190-
class ContextResettingTestCase(CUDATestCase):
191-
"""
192-
For tests where the context needs to be reset after each test. Typically
193-
these inspect or modify parts of the context that would usually be expected
194-
to be internal implementation details (such as the state of allocations and
195-
deallocations, etc.).
196-
"""
197-
198-
def tearDown(self):
199-
super().tearDown()
200-
from numba.cuda.cudadrv.devices import reset
201-
202-
reset()
203-
204-
205187
def skip_on_cudasim(reason):
206188
"""Skip this test if running on the CUDA simulator"""
207189
return unittest.skipIf(config.ENABLE_CUDASIM, reason)

numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,7 @@ class TestContextStack(CUDATestCase):
1414
def setUp(self):
1515
super().setUp()
1616
# Reset before testing
17-
cuda.close()
18-
19-
def test_gpus_current(self):
20-
self.assertIs(cuda.gpus.current, None)
21-
with cuda.gpus[0]:
22-
self.assertEqual(int(cuda.gpus.current.id), 0)
17+
cuda.current_context().reset()
2318

2419
def test_gpus_len(self):
2520
self.assertGreater(len(cuda.gpus), 0)
@@ -45,7 +40,7 @@ def test_gpus_cudevice_indexing(self):
4540
class TestContextAPI(CUDATestCase):
4641
def tearDown(self):
4742
super().tearDown()
48-
cuda.close()
43+
cuda.current_context().reset()
4944

5045
def test_context_memory(self):
5146
try:
@@ -91,7 +86,7 @@ def switch_gpu():
9186
class Test3rdPartyContext(CUDATestCase):
9287
def tearDown(self):
9388
super().tearDown()
94-
cuda.close()
89+
cuda.current_context().reset()
9590

9691
def test_attached_primary(self, extra_work=lambda: None):
9792
# Emulate primary context creation by 3rd party

numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,18 @@
66
import numpy as np
77

88
from numba.cuda.cudadrv import driver, drvapi, devices
9-
from numba.cuda.testing import unittest, ContextResettingTestCase
9+
from numba.cuda.testing import unittest, CUDATestCase
1010
from numba.cuda.testing import skip_on_cudasim
1111

1212

1313
@skip_on_cudasim("CUDA Memory API unsupported in the simulator")
14-
class TestCudaMemory(ContextResettingTestCase):
14+
class TestCudaMemory(CUDATestCase):
1515
def setUp(self):
1616
super().setUp()
1717
self.context = devices.get_context()
1818

1919
def tearDown(self):
20+
self.context.reset()
2021
del self.context
2122
super(TestCudaMemory, self).tearDown()
2223

@@ -107,7 +108,7 @@ def dtor():
107108
self.assertEqual(dtor_invoked[0], 2)
108109

109110

110-
class TestCudaMemoryFunctions(ContextResettingTestCase):
111+
class TestCudaMemoryFunctions(CUDATestCase):
111112
def setUp(self):
112113
super().setUp()
113114
self.context = devices.get_context()
@@ -153,7 +154,7 @@ def test_d2d(self):
153154

154155

155156
@skip_on_cudasim("CUDA Memory API unsupported in the simulator")
156-
class TestMVExtent(ContextResettingTestCase):
157+
class TestMVExtent(CUDATestCase):
157158
def test_c_contiguous_array(self):
158159
ary = np.arange(100)
159160
arysz = ary.dtype.itemsize * ary.size

numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,14 +112,16 @@ class TestDeviceOnlyEMMPlugin(CUDATestCase):
112112
def setUp(self):
113113
super().setUp()
114114
# Always start afresh with a new context and memory manager
115-
cuda.close()
116-
cuda.set_memory_manager(DeviceOnlyEMMPlugin)
115+
ctx = cuda.current_context()
116+
ctx.reset()
117+
self._initial_memory_manager = ctx.memory_manager
118+
ctx.memory_manager = DeviceOnlyEMMPlugin(context=ctx)
117119

118120
def tearDown(self):
119121
super().tearDown()
120-
# Unset the memory manager for subsequent tests
121-
cuda.close()
122-
cuda.cudadrv.driver._memory_manager = None
122+
ctx = cuda.current_context()
123+
ctx.reset()
124+
ctx.memory_manager = self._initial_memory_manager
123125

124126
def test_memalloc(self):
125127
mgr = cuda.current_context().memory_manager
@@ -129,6 +131,7 @@ def test_memalloc(self):
129131
arr_1 = np.arange(10)
130132
d_arr_1 = cuda.device_array_like(arr_1)
131133
self.assertTrue(mgr.memalloc_called)
134+
132135
self.assertEqual(mgr.count, 1)
133136
self.assertEqual(mgr.allocations[1], arr_1.nbytes)
134137

numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
import numpy as np
55
from numba.cuda.cudadrv import driver
66
from numba import cuda
7-
from numba.cuda.testing import unittest, ContextResettingTestCase
7+
from numba.cuda.testing import unittest, CUDATestCase
88

99

10-
class TestHostAlloc(ContextResettingTestCase):
10+
class TestHostAlloc(CUDATestCase):
11+
def tearDown(self):
12+
cuda.current_context().reset()
13+
1114
def test_host_alloc_driver(self):
1215
n = 32
1316
mem = cuda.current_context().memhostalloc(n, mapped=True)

numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
from llvmlite import ir
55

66
from numba.cuda.cudadrv import nvvm
7-
from numba.cuda.testing import unittest, ContextResettingTestCase
7+
from numba.cuda.testing import unittest, CUDATestCase
88
from numba.cuda.testing import skip_on_cudasim
99

1010

1111
@skip_on_cudasim("Inline PTX cannot be used in the simulator")
12-
class TestCudaInlineAsm(ContextResettingTestCase):
12+
class TestCudaInlineAsm(CUDATestCase):
1313
def test_inline_rsqrt(self):
1414
mod = ir.Module(__name__)
1515
mod.triple = "nvptx64-nvidia-cuda"

numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,19 @@
55
from ctypes import byref, c_size_t
66
from numba.cuda.cudadrv.driver import device_memset, driver, USE_NV_BINDING
77
from numba import cuda
8-
from numba.cuda.testing import unittest, ContextResettingTestCase
8+
from numba.cuda.testing import unittest, CUDATestCase
99
from numba.cuda.testing import skip_on_cudasim, skip_on_arm
1010
from numba.cuda.tests.support import linux_only
1111

1212

1313
@skip_on_cudasim("CUDA Driver API unsupported in the simulator")
1414
@linux_only
1515
@skip_on_arm("Managed Alloc support is experimental/untested on ARM")
16-
class TestManagedAlloc(ContextResettingTestCase):
16+
class TestManagedAlloc(CUDATestCase):
17+
def tearDown(self):
18+
super().tearDown()
19+
cuda.current_context().reset()
20+
1721
def get_total_gpu_memory(self):
1822
# We use a driver function to directly get the total GPU memory because
1923
# an EMM plugin may report something different (or not implement

numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from numba.cuda.cudadrv.linkable_code import CUSource
1212
from numba.cuda.testing import (
1313
CUDATestCase,
14-
ContextResettingTestCase,
1514
skip_on_cudasim,
1615
)
1716

@@ -42,7 +41,7 @@ def get_hashable_handle_value(handle):
4241

4342

4443
@skip_on_cudasim("Module loading not implemented in the simulator")
45-
class TestModuleCallbacksBasic(ContextResettingTestCase):
44+
class TestModuleCallbacksBasic(CUDATestCase):
4645
def test_basic(self):
4746
counter = 0
4847

numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
import platform
66

77
from numba import cuda
8-
from numba.cuda.testing import unittest, ContextResettingTestCase
8+
from numba.cuda.testing import unittest, CUDATestCase
99

1010

11-
class TestPinned(ContextResettingTestCase):
11+
class TestPinned(CUDATestCase):
1212
def _run_copies(self, A):
1313
A0 = np.copy(A)
1414

numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
# SPDX-License-Identifier: BSD-2-Clause
33

44
import unittest
5-
from numba.cuda.testing import ContextResettingTestCase
5+
from numba.cuda.testing import CUDATestCase
66
from numba import cuda
77
from numba.cuda.testing import skip_on_cudasim
88

99

1010
@skip_on_cudasim("CUDA Profiler unsupported in the simulator")
11-
class TestProfiler(ContextResettingTestCase):
11+
class TestProfiler(CUDATestCase):
1212
def test_profiling(self):
1313
with cuda.profiling():
1414
a = cuda.device_array(10)

0 commit comments

Comments
 (0)