Skip to content

Commit 7f63eab

Browse files
authored
test: speed up ipc tests by ~6.5x (#527)
This PR speeds up multiprocessing-based IPC tests by using the concurrent futures interface. This is faster because processes are reused from a pool as opposed to creating all the machinery to run a thing in new process in every single test or `subTest` invocation. While this is less isolated, I would argue that it's a much more common use case to use pools and/or the concurrent futures interface (it's much less noisy, for example).
1 parent 46cda77 commit 7f63eab

File tree

1 file changed

+102
-163
lines changed

1 file changed

+102
-163
lines changed

numba_cuda/numba/cuda/tests/cudapy/test_ipc.py

Lines changed: 102 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: BSD-2-Clause
33

4+
import pytest
5+
import concurrent.futures
46
import multiprocessing as mp
7+
import os
58
import itertools
6-
import traceback
79
import pickle
810

911
import numpy as np
@@ -21,80 +23,63 @@
2123
import unittest
2224

2325

24-
def core_ipc_handle_test(the_work, result_queue):
25-
try:
26-
arr = the_work()
27-
# Catch anything going wrong in the worker function
28-
except: # noqa: E722
29-
# FAILED. propagate the exception as a string
30-
succ = False
31-
out = traceback.format_exc()
32-
else:
33-
# OK. send the ndarray back
34-
succ = True
35-
out = arr
36-
result_queue.put((succ, out))
37-
38-
39-
def base_ipc_handle_test(handle, size, result_queue):
40-
def the_work():
41-
dtype = np.dtype(np.intp)
42-
with cuda.open_ipc_array(
43-
handle, shape=size // dtype.itemsize, dtype=dtype
44-
) as darr:
45-
# copy the data to host
46-
return darr.copy_to_host()
47-
48-
core_ipc_handle_test(the_work, result_queue)
49-
50-
51-
def serialize_ipc_handle_test(handle, result_queue):
52-
def the_work():
53-
dtype = np.dtype(np.intp)
54-
darr = handle.open_array(
55-
cuda.current_context(),
56-
shape=handle.size // dtype.itemsize,
57-
dtype=dtype,
58-
)
26+
def base_ipc_handle_test(handle, size, parent_pid):
27+
pid = os.getpid()
28+
assert pid != parent_pid
29+
dtype = np.dtype(np.intp)
30+
with cuda.open_ipc_array(
31+
handle, shape=size // dtype.itemsize, dtype=dtype
32+
) as darr:
5933
# copy the data to host
34+
return darr.copy_to_host()
35+
36+
37+
def serialize_ipc_handle_test(handle, parent_pid):
38+
pid = os.getpid()
39+
assert pid != parent_pid
40+
41+
dtype = np.dtype(np.intp)
42+
darr = handle.open_array(
43+
cuda.current_context(),
44+
shape=handle.size // dtype.itemsize,
45+
dtype=dtype,
46+
)
47+
# copy the data to host
48+
arr = darr.copy_to_host()
49+
handle.close()
50+
return arr
51+
52+
53+
def ipc_array_test(ipcarr, parent_pid):
54+
pid = os.getpid()
55+
assert pid != parent_pid
56+
with ipcarr as darr:
6057
arr = darr.copy_to_host()
61-
handle.close()
62-
return arr
58+
with pytest.raises(ValueError, match="IpcHandle is already opened"):
59+
with ipcarr:
60+
pass
61+
return arr
6362

64-
core_ipc_handle_test(the_work, result_queue)
6563

64+
class CUDAIpcTestCase(CUDATestCase):
65+
@classmethod
66+
def setUpClass(cls) -> None:
67+
cls.exe = concurrent.futures.ProcessPoolExecutor(
68+
mp_context=mp.get_context("spawn")
69+
)
6670

67-
def ipc_array_test(ipcarr, result_queue):
68-
try:
69-
with ipcarr as darr:
70-
arr = darr.copy_to_host()
71-
try:
72-
# should fail to reopen
73-
with ipcarr:
74-
pass
75-
except ValueError as e:
76-
if str(e) != "IpcHandle is already opened":
77-
raise AssertionError("invalid exception message")
78-
else:
79-
raise AssertionError("did not raise on reopen")
80-
# Catch any exception so we can propagate it
81-
except: # noqa: E722
82-
# FAILED. propagate the exception as a string
83-
succ = False
84-
out = traceback.format_exc()
85-
else:
86-
# OK. send the ndarray back
87-
succ = True
88-
out = arr
89-
result_queue.put((succ, out))
71+
@classmethod
72+
def tearDownClass(cls) -> None:
73+
cls.exe.shutdown()
74+
del cls.exe
9075

9176

9277
@linux_only
9378
@skip_under_cuda_memcheck("Hangs cuda-memcheck")
9479
@skip_on_cudasim("Ipc not available in CUDASIM")
9580
@skip_on_arm("CUDA IPC not supported on ARM in Numba")
9681
@skip_on_wsl2("CUDA IPC unreliable on WSL2; skipping IPC tests")
97-
class TestIpcMemory(CUDATestCase):
82+
class TestIpcMemory(CUDAIpcTestCase):
9883
def test_ipc_handle(self):
9984
# prepare data for IPC
10085
arr = np.arange(10, dtype=np.intp)
@@ -109,17 +94,11 @@ def test_ipc_handle(self):
10994
size = ipch.size
11095

11196
# spawn new process for testing
112-
ctx = mp.get_context("spawn")
113-
result_queue = ctx.Queue()
114-
args = (handle_bytes, size, result_queue)
115-
proc = ctx.Process(target=base_ipc_handle_test, args=args)
116-
proc.start()
117-
succ, out = result_queue.get()
118-
if not succ:
119-
self.fail(out)
120-
else:
121-
np.testing.assert_equal(arr, out)
122-
proc.join(3)
97+
fut = self.exe.submit(
98+
base_ipc_handle_test, handle_bytes, size, parent_pid=os.getpid()
99+
)
100+
out = fut.result(timeout=3)
101+
np.testing.assert_equal(arr, out)
123102

124103
def variants(self):
125104
# Test with no slicing and various different slices
@@ -152,17 +131,11 @@ def check_ipc_handle_serialization(self, index_arg=None, foreign=False):
152131
self.assertEqual(ipch_recon.handle.reserved, ipch.handle.reserved)
153132

154133
# spawn new process for testing
155-
ctx = mp.get_context("spawn")
156-
result_queue = ctx.Queue()
157-
args = (ipch, result_queue)
158-
proc = ctx.Process(target=serialize_ipc_handle_test, args=args)
159-
proc.start()
160-
succ, out = result_queue.get()
161-
if not succ:
162-
self.fail(out)
163-
else:
164-
np.testing.assert_equal(expect, out)
165-
proc.join(3)
134+
fut = self.exe.submit(
135+
serialize_ipc_handle_test, ipch, parent_pid=os.getpid()
136+
)
137+
out = fut.result(timeout=3)
138+
np.testing.assert_equal(expect, out)
166139

167140
def test_ipc_handle_serialization(self):
168141
for (
@@ -185,17 +158,9 @@ def check_ipc_array(self, index_arg=None, foreign=False):
185158
ipch = devarr.get_ipc_handle()
186159

187160
# spawn new process for testing
188-
ctx = mp.get_context("spawn")
189-
result_queue = ctx.Queue()
190-
args = (ipch, result_queue)
191-
proc = ctx.Process(target=ipc_array_test, args=args)
192-
proc.start()
193-
succ, out = result_queue.get()
194-
if not succ:
195-
self.fail(out)
196-
else:
197-
np.testing.assert_equal(expect, out)
198-
proc.join(3)
161+
fut = self.exe.submit(ipc_array_test, ipch, parent_pid=os.getpid())
162+
out = fut.result(timeout=3)
163+
np.testing.assert_equal(expect, out)
199164

200165
def test_ipc_array(self):
201166
for (
@@ -206,65 +171,45 @@ def test_ipc_array(self):
206171
self.check_ipc_array(index, foreign)
207172

208173

209-
def staged_ipc_handle_test(handle, device_num, result_queue):
210-
def the_work():
211-
with cuda.gpus[device_num]:
212-
this_ctx = cuda.devices.get_context()
213-
deviceptr = handle.open_staged(this_ctx)
214-
arrsize = handle.size // np.dtype(np.intp).itemsize
215-
hostarray = np.zeros(arrsize, dtype=np.intp)
216-
cuda.driver.device_to_host(
217-
hostarray,
218-
deviceptr,
219-
size=handle.size,
220-
)
221-
handle.close()
174+
def staged_ipc_handle_test(handle, device_num, parent_pid):
175+
pid = os.getpid()
176+
assert pid != parent_pid
177+
with cuda.gpus[device_num]:
178+
this_ctx = cuda.devices.get_context()
179+
deviceptr = handle.open_staged(this_ctx)
180+
arrsize = handle.size // np.dtype(np.intp).itemsize
181+
hostarray = np.zeros(arrsize, dtype=np.intp)
182+
cuda.driver.device_to_host(
183+
hostarray,
184+
deviceptr,
185+
size=handle.size,
186+
)
187+
handle.close()
222188
return hostarray
223189

224-
core_ipc_handle_test(the_work, result_queue)
225-
226-
227-
def staged_ipc_array_test(ipcarr, device_num, result_queue):
228-
try:
229-
with cuda.gpus[device_num]:
230-
with ipcarr as darr:
231-
arr = darr.copy_to_host()
232-
try:
233-
# should fail to reopen
234-
with ipcarr:
235-
pass
236-
except ValueError as e:
237-
if str(e) != "IpcHandle is already opened":
238-
raise AssertionError("invalid exception message")
239-
else:
240-
raise AssertionError("did not raise on reopen")
241-
# Catch any exception so we can propagate it
242-
except: # noqa: E722
243-
# FAILED. propagate the exception as a string
244-
succ = False
245-
out = traceback.format_exc()
246-
else:
247-
# OK. send the ndarray back
248-
succ = True
249-
out = arr
250-
result_queue.put((succ, out))
190+
191+
def staged_ipc_array_test(ipcarr, device_num, parent_pid):
192+
pid = os.getpid()
193+
assert pid != parent_pid
194+
with cuda.gpus[device_num]:
195+
with ipcarr as darr:
196+
arr = darr.copy_to_host()
197+
with pytest.raises(ValueError, match="IpcHandle is already opened"):
198+
with ipcarr:
199+
pass
200+
return arr
251201

252202

253203
@linux_only
254204
@skip_under_cuda_memcheck("Hangs cuda-memcheck")
255205
@skip_on_cudasim("Ipc not available in CUDASIM")
256206
@skip_on_arm("CUDA IPC not supported on ARM in Numba")
257207
@skip_on_wsl2("CUDA IPC unreliable on WSL2; skipping IPC tests")
258-
class TestIpcStaged(CUDATestCase):
208+
class TestIpcStaged(CUDAIpcTestCase):
259209
def test_staged(self):
260210
# prepare data for IPC
261211
arr = np.arange(10, dtype=np.intp)
262212
devarr = cuda.to_device(arr)
263-
264-
# spawn new process for testing
265-
mpctx = mp.get_context("spawn")
266-
result_queue = mpctx.Queue()
267-
268213
# create IPC handle
269214
ctx = cuda.current_context()
270215
ipch = ctx.get_ipc_handle(devarr.gpu_data)
@@ -276,16 +221,16 @@ def test_staged(self):
276221
self.assertEqual(ipch_recon.size, ipch.size)
277222

278223
# Test on every CUDA devices
279-
for device_num in range(len(cuda.gpus)):
280-
args = (ipch, device_num, result_queue)
281-
proc = mpctx.Process(target=staged_ipc_handle_test, args=args)
282-
proc.start()
283-
succ, out = result_queue.get()
284-
proc.join(3)
285-
if not succ:
286-
self.fail(out)
287-
else:
288-
np.testing.assert_equal(arr, out)
224+
ngpus = len(cuda.gpus)
225+
futures = [
226+
self.exe.submit(
227+
staged_ipc_handle_test, ipch, device_num, parent_pid=os.getpid()
228+
)
229+
for device_num in range(ngpus)
230+
]
231+
232+
for fut in concurrent.futures.as_completed(futures, timeout=3 * ngpus):
233+
np.testing.assert_equal(arr, fut.result())
289234

290235
def test_ipc_array(self):
291236
for device_num in range(len(cuda.gpus)):
@@ -295,17 +240,11 @@ def test_ipc_array(self):
295240
ipch = devarr.get_ipc_handle()
296241

297242
# spawn new process for testing
298-
ctx = mp.get_context("spawn")
299-
result_queue = ctx.Queue()
300-
args = (ipch, device_num, result_queue)
301-
proc = ctx.Process(target=staged_ipc_array_test, args=args)
302-
proc.start()
303-
succ, out = result_queue.get()
304-
proc.join(3)
305-
if not succ:
306-
self.fail(out)
307-
else:
308-
np.testing.assert_equal(arr, out)
243+
fut = self.exe.submit(
244+
staged_ipc_array_test, ipch, device_num, parent_pid=os.getpid()
245+
)
246+
out = fut.result(timeout=3)
247+
np.testing.assert_equal(arr, out)
309248

310249

311250
@windows_only

0 commit comments

Comments
 (0)