Skip to content

Commit 9c76abf

Browse files
committed
test: refactor process-based tests to use concurrent futures in order to simplify tests
1 parent 3ebbe29 commit 9c76abf

File tree

5 files changed

+93
-163
lines changed

5 files changed

+93
-163
lines changed

numba_cuda/numba/cuda/tests/cudadrv/test_init.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: BSD-2-Clause
33

4+
import concurrent.futures
45
import multiprocessing as mp
56
import os
67

@@ -20,88 +21,83 @@ def cuInit_raising(arg):
2021
# not assigned until we attempt to initialize - mock.patch.object cannot locate
2122
# the non-existent original method, and so fails. Instead we patch
2223
# driver.cuInit with our raising version prior to any attempt to initialize.
23-
def cuInit_raising_test(result_queue):
24+
def cuInit_raising_test():
2425
driver.cuInit = cuInit_raising
2526

26-
success = False
27-
msg = None
28-
2927
try:
3028
# A CUDA operation that forces initialization of the device
3129
cuda.device_array(1)
3230
except CudaSupportError as e:
3331
success = True
3432
msg = e.msg
33+
else:
34+
success = False
35+
msg = None
3536

36-
result_queue.put((success, msg))
37+
return success, msg
3738

3839

3940
# Similar to cuInit_raising_test above, but for testing that the string
4041
# returned by cuda_error() is as expected.
41-
def initialization_error_test(result_queue):
42+
def initialization_error_test():
4243
driver.cuInit = cuInit_raising
4344

44-
success = False
45-
msg = None
46-
4745
try:
4846
# A CUDA operation that forces initialization of the device
4947
cuda.device_array(1)
5048
except CudaSupportError:
5149
success = True
50+
else:
51+
success = False
5252

53-
msg = cuda.cuda_error()
54-
result_queue.put((success, msg))
53+
return success, cuda.cuda_error()
5554

5655

5756
# For testing the path where Driver.__init__() catches a CudaSupportError
58-
def cuda_disabled_test(result_queue):
59-
success = False
60-
msg = None
61-
57+
def cuda_disabled_test():
6258
try:
6359
# A CUDA operation that forces initialization of the device
6460
cuda.device_array(1)
6561
except CudaSupportError as e:
6662
success = True
6763
msg = e.msg
64+
else:
65+
success = False
66+
msg = None
6867

69-
result_queue.put((success, msg))
68+
return success, msg
7069

7170

7271
# Similar to cuda_disabled_test, but checks cuda.cuda_error() instead of the
7372
# exception raised on initialization
74-
def cuda_disabled_error_test(result_queue):
75-
success = False
76-
msg = None
77-
73+
def cuda_disabled_error_test():
7874
try:
7975
# A CUDA operation that forces initialization of the device
8076
cuda.device_array(1)
8177
except CudaSupportError:
8278
success = True
79+
else:
80+
success = False
8381

84-
msg = cuda.cuda_error()
85-
result_queue.put((success, msg))
82+
return success, cuda.cuda_error()
8683

8784

8885
@skip_on_cudasim("CUDA Simulator does not initialize driver")
8986
class TestInit(CUDATestCase):
9087
def _test_init_failure(self, target, expected):
9188
# Run the initialization failure test in a separate subprocess
92-
ctx = mp.get_context("spawn")
93-
result_queue = ctx.Queue()
94-
proc = ctx.Process(target=target, args=(result_queue,))
95-
proc.start()
96-
proc.join(30) # should complete within 30s
97-
success, msg = result_queue.get()
89+
with concurrent.futures.ProcessPoolExecutor(
90+
mp_context=mp.get_context("spawn")
91+
) as exe:
92+
# should complete within 30s
93+
success, msg = exe.submit(target).result(timeout=30)
9894

9995
# Ensure the child process raised an exception during initialization
10096
# before checking the message
10197
if not success:
102-
self.fail("CudaSupportError not raised")
98+
assert "CudaSupportError not raised" in msg
10399

104-
self.assertIn(expected, msg)
100+
assert expected in msg
105101

106102
def test_init_failure_raising(self):
107103
expected = "Error at driver init: CUDA_ERROR_UNKNOWN (999)"

numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,18 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: BSD-2-Clause
33

4+
import concurrent.futures
45
import multiprocessing
56
import os
67
from numba.cuda.testing import unittest
78

89

9-
def set_visible_devices_and_check(q):
10-
try:
11-
from numba import cuda
12-
import os
10+
def set_visible_devices_and_check():
11+
from numba import cuda
12+
import os
1313

14-
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
15-
q.put(len(cuda.gpus.lst))
16-
except: # noqa: E722
17-
# Sentinel value for error executing test code
18-
q.put(-1)
14+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
15+
return len(cuda.gpus.lst)
1916

2017

2118
class TestVisibleDevices(unittest.TestCase):
@@ -38,22 +35,13 @@ def test_visible_devices_set_after_import(self):
3835
msg = "Cannot test when CUDA_VISIBLE_DEVICES already set"
3936
self.skipTest(msg)
4037

41-
ctx = multiprocessing.get_context("spawn")
42-
q = ctx.Queue()
43-
p = ctx.Process(target=set_visible_devices_and_check, args=(q,))
44-
p.start()
45-
try:
46-
visible_gpu_count = q.get()
47-
finally:
48-
p.join()
49-
50-
# Make an obvious distinction between an error running the test code
51-
# and an incorrect number of GPUs in the list
52-
msg = "Error running set_visible_devices_and_check"
53-
self.assertNotEqual(visible_gpu_count, -1, msg=msg)
54-
55-
# The actual check that we see only one GPU
56-
self.assertEqual(visible_gpu_count, 1)
38+
with concurrent.futures.ProcessPoolExecutor(
39+
mp_context=multiprocessing.get_context("spawn")
40+
) as exe:
41+
future = exe.submit(set_visible_devices_and_check)
42+
43+
visible_gpu_count = future.result()
44+
assert visible_gpu_count == 1
5745

5846

5947
if __name__ == "__main__":

numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,47 +3,35 @@
33

44
import os
55
import multiprocessing as mp
6+
import pytest
7+
import concurrent.futures
68

79
import numpy as np
810

911
from numba import cuda
1012
from numba.cuda.testing import skip_on_cudasim, CUDATestCase
13+
from numba.cuda.cudadrv.error import CudaDriverError
1114
import unittest
1215

13-
has_mp_get_context = hasattr(mp, "get_context")
14-
is_unix = os.name == "posix"
15-
16-
17-
def fork_test(q):
18-
from numba.cuda.cudadrv.error import CudaDriverError
19-
20-
try:
21-
cuda.to_device(np.arange(1))
22-
except CudaDriverError as e:
23-
q.put(e)
24-
else:
25-
q.put(None)
26-
2716

2817
@skip_on_cudasim("disabled for cudasim")
2918
class TestMultiprocessing(CUDATestCase):
30-
@unittest.skipUnless(has_mp_get_context, "requires mp.get_context")
31-
@unittest.skipUnless(is_unix, "requires Unix")
19+
@unittest.skipUnless(hasattr(mp, "get_context"), "requires mp.get_context")
20+
@unittest.skipUnless(os.name == "posix", "requires Unix")
3221
def test_fork(self):
3322
"""
3423
Test fork detection.
3524
"""
3625
cuda.current_context() # force cuda initialize
37-
# fork in process that also uses CUDA
38-
ctx = mp.get_context("fork")
39-
q = ctx.Queue()
40-
proc = ctx.Process(target=fork_test, args=[q])
41-
proc.start()
42-
exc = q.get()
43-
proc.join()
44-
# there should be an exception raised in the child process
45-
self.assertIsNotNone(exc)
46-
self.assertIn("CUDA initialized before forking", str(exc))
26+
with concurrent.futures.ProcessPoolExecutor(
27+
mp_context=mp.get_context("fork")
28+
) as exe:
29+
future = exe.submit(cuda.to_device, np.arange(1))
30+
31+
with pytest.raises(
32+
CudaDriverError, match="CUDA initialized before forking"
33+
):
34+
future.result()
4735

4836

4937
if __name__ == "__main__":

numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py

Lines changed: 12 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: BSD-2-Clause
33

4-
import traceback
54
import threading
65
import multiprocessing
76
import numpy as np
@@ -13,12 +12,7 @@
1312
)
1413
import unittest
1514

16-
try:
17-
from concurrent.futures import ThreadPoolExecutor
18-
except ImportError:
19-
has_concurrent_futures = False
20-
else:
21-
has_concurrent_futures = True
15+
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
2216

2317

2418
has_mp_get_context = hasattr(multiprocessing, "get_context")
@@ -41,52 +35,34 @@ def use_foo(x):
4135
np.testing.assert_equal(ary, expected)
4236

4337

44-
def spawn_process_entry(q):
45-
try:
46-
check_concurrent_compiling()
47-
# Catch anything that goes wrong in the threads
48-
except: # noqa: E722
49-
msg = traceback.format_exc()
50-
q.put("\n".join(["", "=" * 80, msg]))
51-
else:
52-
q.put(None)
53-
54-
5538
@skip_under_cuda_memcheck("Hangs cuda-memcheck")
5639
@skip_on_cudasim("disabled for cudasim")
5740
class TestMultiThreadCompiling(CUDATestCase):
58-
@unittest.skipIf(not has_concurrent_futures, "no concurrent.futures")
5941
def test_concurrent_compiling(self):
6042
check_concurrent_compiling()
6143

6244
@unittest.skipIf(not has_mp_get_context, "no multiprocessing.get_context")
6345
def test_spawn_concurrent_compilation(self):
6446
# force CUDA context init
6547
cuda.get_current_device()
66-
# use "spawn" to avoid inheriting the CUDA context
67-
ctx = multiprocessing.get_context("spawn")
68-
69-
q = ctx.Queue()
70-
p = ctx.Process(target=spawn_process_entry, args=(q,))
71-
p.start()
72-
try:
73-
err = q.get()
74-
finally:
75-
p.join()
76-
if err is not None:
77-
raise AssertionError(err)
78-
self.assertEqual(p.exitcode, 0, "test failed in child process")
48+
49+
with ProcessPoolExecutor(
50+
# use "spawn" to avoid inheriting the CUDA context
51+
mp_context=multiprocessing.get_context("spawn")
52+
) as exe:
53+
future = exe.submit(check_concurrent_compiling)
54+
future.result()
7955

8056
def test_invalid_context_error_with_d2h(self):
8157
def d2h(arr, out):
8258
out[:] = arr.copy_to_host()
8359

8460
arr = np.arange(1, 4)
8561
out = np.zeros_like(arr)
86-
darr = cuda.to_device(arr)
87-
th = threading.Thread(target=d2h, args=[darr, out])
88-
th.start()
89-
th.join()
62+
63+
with ThreadPoolExecutor() as exe:
64+
exe.submit(d2h, cuda.to_device(arr), out)
65+
9066
np.testing.assert_equal(arr, out)
9167

9268
def test_invalid_context_error_with_d2d(self):

0 commit comments

Comments
 (0)