Skip to content

Commit 5e531c8

Browse files
committed
test: speed up ipc tests by ~6.5x
1 parent 6c52fc0 commit 5e531c8

File tree

1 file changed

+121
-163
lines changed

1 file changed

+121
-163
lines changed

numba_cuda/numba/cuda/tests/cudapy/test_ipc.py

Lines changed: 121 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: BSD-2-Clause
33

4+
import concurrent.futures
45
import multiprocessing as mp
6+
import os
57
import itertools
6-
import traceback
78
import pickle
89

910
import numpy as np
@@ -21,80 +22,78 @@
2122
import unittest
2223

2324

24-
def core_ipc_handle_test(the_work, result_queue):
25-
try:
26-
arr = the_work()
27-
# Catch anything going wrong in the worker function
28-
except: # noqa: E722
29-
# FAILED. propagate the exception as a string
30-
succ = False
31-
out = traceback.format_exc()
32-
else:
33-
# OK. send the ndarray back
34-
succ = True
35-
out = arr
36-
result_queue.put((succ, out))
37-
38-
39-
def base_ipc_handle_test(handle, size, result_queue):
40-
def the_work():
41-
dtype = np.dtype(np.intp)
42-
with cuda.open_ipc_array(
43-
handle, shape=size // dtype.itemsize, dtype=dtype
44-
) as darr:
45-
# copy the data to host
46-
return darr.copy_to_host()
47-
48-
core_ipc_handle_test(the_work, result_queue)
49-
50-
51-
def serialize_ipc_handle_test(handle, result_queue):
52-
def the_work():
53-
dtype = np.dtype(np.intp)
54-
darr = handle.open_array(
55-
cuda.current_context(),
56-
shape=handle.size // dtype.itemsize,
57-
dtype=dtype,
58-
)
25+
def base_ipc_handle_test(handle, size, parent_pid):
26+
pid = os.getpid()
27+
assert pid != parent_pid
28+
dtype = np.dtype(np.intp)
29+
with cuda.open_ipc_array(
30+
handle, shape=size // dtype.itemsize, dtype=dtype
31+
) as darr:
5932
# copy the data to host
33+
return darr.copy_to_host()
34+
35+
36+
def serialize_ipc_handle_test(handle, parent_pid):
37+
pid = os.getpid()
38+
assert pid != parent_pid
39+
40+
dtype = np.dtype(np.intp)
41+
darr = handle.open_array(
42+
cuda.current_context(),
43+
shape=handle.size // dtype.itemsize,
44+
dtype=dtype,
45+
)
46+
# copy the data to host
47+
arr = darr.copy_to_host()
48+
handle.close()
49+
return arr
50+
51+
52+
def ipc_array_test(ipcarr, parent_pid):
53+
pid = os.getpid()
54+
assert pid != parent_pid
55+
with ipcarr as darr:
6056
arr = darr.copy_to_host()
61-
handle.close()
62-
return arr
57+
try:
58+
# should fail to reopen
59+
with ipcarr:
60+
pass
61+
except ValueError as e:
62+
if str(e) != "IpcHandle is already opened":
63+
raise AssertionError("invalid exception message")
64+
else:
65+
raise AssertionError("did not raise on reopen")
66+
return arr
6367

64-
core_ipc_handle_test(the_work, result_queue)
6568

69+
class CUDAIpcTestCase(CUDATestCase):
70+
@classmethod
71+
def setUpClass(cls) -> None:
72+
cls.exe = concurrent.futures.ProcessPoolExecutor(
73+
mp_context=mp.get_context("spawn")
74+
)
6675

67-
def ipc_array_test(ipcarr, result_queue):
68-
try:
69-
with ipcarr as darr:
70-
arr = darr.copy_to_host()
71-
try:
72-
# should fail to reopen
73-
with ipcarr:
74-
pass
75-
except ValueError as e:
76-
if str(e) != "IpcHandle is already opened":
77-
raise AssertionError("invalid exception message")
78-
else:
79-
raise AssertionError("did not raise on reopen")
80-
# Catch any exception so we can propagate it
81-
except: # noqa: E722
82-
# FAILED. propagate the exception as a string
83-
succ = False
84-
out = traceback.format_exc()
85-
else:
86-
# OK. send the ndarray back
87-
succ = True
88-
out = arr
89-
result_queue.put((succ, out))
76+
@classmethod
77+
def tearDownClass(cls) -> None:
78+
cls.exe.shutdown()
79+
del cls.exe
80+
81+
# def setUp(self) -> None:
82+
# self.exe = concurrent.futures.ProcessPoolExecutor(
83+
# mp_context=mp.get_context("spawn")
84+
# )
85+
#
86+
# def tearDown(self) -> None:
87+
# self.exe.shutdown(wait=True)
88+
# del self.exe
9089

9190

9291
@linux_only
9392
@skip_under_cuda_memcheck("Hangs cuda-memcheck")
9493
@skip_on_cudasim("Ipc not available in CUDASIM")
9594
@skip_on_arm("CUDA IPC not supported on ARM in Numba")
9695
@skip_on_wsl2("CUDA IPC unreliable on WSL2; skipping IPC tests")
97-
class TestIpcMemory(CUDATestCase):
96+
class TestIpcMemory(CUDAIpcTestCase):
9897
def test_ipc_handle(self):
9998
# prepare data for IPC
10099
arr = np.arange(10, dtype=np.intp)
@@ -109,17 +108,11 @@ def test_ipc_handle(self):
109108
size = ipch.size
110109

111110
# spawn new process for testing
112-
ctx = mp.get_context("spawn")
113-
result_queue = ctx.Queue()
114-
args = (handle_bytes, size, result_queue)
115-
proc = ctx.Process(target=base_ipc_handle_test, args=args)
116-
proc.start()
117-
succ, out = result_queue.get()
118-
if not succ:
119-
self.fail(out)
120-
else:
121-
np.testing.assert_equal(arr, out)
122-
proc.join(3)
111+
fut = self.exe.submit(
112+
base_ipc_handle_test, handle_bytes, size, parent_pid=os.getpid()
113+
)
114+
out = fut.result(timeout=3)
115+
np.testing.assert_equal(arr, out)
123116

124117
def variants(self):
125118
# Test with no slicing and various different slices
@@ -152,17 +145,11 @@ def check_ipc_handle_serialization(self, index_arg=None, foreign=False):
152145
self.assertEqual(ipch_recon.handle.reserved, ipch.handle.reserved)
153146

154147
# spawn new process for testing
155-
ctx = mp.get_context("spawn")
156-
result_queue = ctx.Queue()
157-
args = (ipch, result_queue)
158-
proc = ctx.Process(target=serialize_ipc_handle_test, args=args)
159-
proc.start()
160-
succ, out = result_queue.get()
161-
if not succ:
162-
self.fail(out)
163-
else:
164-
np.testing.assert_equal(expect, out)
165-
proc.join(3)
148+
fut = self.exe.submit(
149+
serialize_ipc_handle_test, ipch, parent_pid=os.getpid()
150+
)
151+
out = fut.result(timeout=3)
152+
np.testing.assert_equal(expect, out)
166153

167154
def test_ipc_handle_serialization(self):
168155
for (
@@ -185,17 +172,9 @@ def check_ipc_array(self, index_arg=None, foreign=False):
185172
ipch = devarr.get_ipc_handle()
186173

187174
# spawn new process for testing
188-
ctx = mp.get_context("spawn")
189-
result_queue = ctx.Queue()
190-
args = (ipch, result_queue)
191-
proc = ctx.Process(target=ipc_array_test, args=args)
192-
proc.start()
193-
succ, out = result_queue.get()
194-
if not succ:
195-
self.fail(out)
196-
else:
197-
np.testing.assert_equal(expect, out)
198-
proc.join(3)
175+
fut = self.exe.submit(ipc_array_test, ipch, parent_pid=os.getpid())
176+
out = fut.result(timeout=3)
177+
np.testing.assert_equal(expect, out)
199178

200179
def test_ipc_array(self):
201180
for (
@@ -206,65 +185,51 @@ def test_ipc_array(self):
206185
self.check_ipc_array(index, foreign)
207186

208187

209-
def staged_ipc_handle_test(handle, device_num, result_queue):
210-
def the_work():
211-
with cuda.gpus[device_num]:
212-
this_ctx = cuda.devices.get_context()
213-
deviceptr = handle.open_staged(this_ctx)
214-
arrsize = handle.size // np.dtype(np.intp).itemsize
215-
hostarray = np.zeros(arrsize, dtype=np.intp)
216-
cuda.driver.device_to_host(
217-
hostarray,
218-
deviceptr,
219-
size=handle.size,
220-
)
221-
handle.close()
188+
def staged_ipc_handle_test(handle, device_num, parent_pid):
189+
pid = os.getpid()
190+
assert pid != parent_pid
191+
with cuda.gpus[device_num]:
192+
this_ctx = cuda.devices.get_context()
193+
deviceptr = handle.open_staged(this_ctx)
194+
arrsize = handle.size // np.dtype(np.intp).itemsize
195+
hostarray = np.zeros(arrsize, dtype=np.intp)
196+
cuda.driver.device_to_host(
197+
hostarray,
198+
deviceptr,
199+
size=handle.size,
200+
)
201+
handle.close()
222202
return hostarray
223203

224-
core_ipc_handle_test(the_work, result_queue)
225-
226-
227-
def staged_ipc_array_test(ipcarr, device_num, result_queue):
228-
try:
229-
with cuda.gpus[device_num]:
230-
with ipcarr as darr:
231-
arr = darr.copy_to_host()
232-
try:
233-
# should fail to reopen
234-
with ipcarr:
235-
pass
236-
except ValueError as e:
237-
if str(e) != "IpcHandle is already opened":
238-
raise AssertionError("invalid exception message")
239-
else:
240-
raise AssertionError("did not raise on reopen")
241-
# Catch any exception so we can propagate it
242-
except: # noqa: E722
243-
# FAILED. propagate the exception as a string
244-
succ = False
245-
out = traceback.format_exc()
246-
else:
247-
# OK. send the ndarray back
248-
succ = True
249-
out = arr
250-
result_queue.put((succ, out))
204+
205+
def staged_ipc_array_test(ipcarr, device_num, parent_pid):
206+
pid = os.getpid()
207+
assert pid != parent_pid
208+
with cuda.gpus[device_num]:
209+
with ipcarr as darr:
210+
arr = darr.copy_to_host()
211+
try:
212+
# should fail to reopen
213+
with ipcarr:
214+
pass
215+
except ValueError as e:
216+
if str(e) != "IpcHandle is already opened":
217+
raise AssertionError("invalid exception message")
218+
else:
219+
raise AssertionError("did not raise on reopen")
220+
return arr
251221

252222

253223
@linux_only
254224
@skip_under_cuda_memcheck("Hangs cuda-memcheck")
255225
@skip_on_cudasim("Ipc not available in CUDASIM")
256226
@skip_on_arm("CUDA IPC not supported on ARM in Numba")
257227
@skip_on_wsl2("CUDA IPC unreliable on WSL2; skipping IPC tests")
258-
class TestIpcStaged(CUDATestCase):
228+
class TestIpcStaged(CUDAIpcTestCase):
259229
def test_staged(self):
260230
# prepare data for IPC
261231
arr = np.arange(10, dtype=np.intp)
262232
devarr = cuda.to_device(arr)
263-
264-
# spawn new process for testing
265-
mpctx = mp.get_context("spawn")
266-
result_queue = mpctx.Queue()
267-
268233
# create IPC handle
269234
ctx = cuda.current_context()
270235
ipch = ctx.get_ipc_handle(devarr.gpu_data)
@@ -276,16 +241,15 @@ def test_staged(self):
276241
self.assertEqual(ipch_recon.size, ipch.size)
277242

278243
# Test on every CUDA devices
279-
for device_num in range(len(cuda.gpus)):
280-
args = (ipch, device_num, result_queue)
281-
proc = mpctx.Process(target=staged_ipc_handle_test, args=args)
282-
proc.start()
283-
succ, out = result_queue.get()
284-
proc.join(3)
285-
if not succ:
286-
self.fail(out)
287-
else:
288-
np.testing.assert_equal(arr, out)
244+
futures = [
245+
self.exe.submit(
246+
staged_ipc_handle_test, ipch, device_num, parent_pid=os.getpid()
247+
)
248+
for device_num in range(len(cuda.gpus))
249+
]
250+
251+
for fut in concurrent.futures.as_completed(futures, timeout=3.0):
252+
np.testing.assert_equal(arr, fut.result())
289253

290254
def test_ipc_array(self):
291255
for device_num in range(len(cuda.gpus)):
@@ -295,17 +259,11 @@ def test_ipc_array(self):
295259
ipch = devarr.get_ipc_handle()
296260

297261
# spawn new process for testing
298-
ctx = mp.get_context("spawn")
299-
result_queue = ctx.Queue()
300-
args = (ipch, device_num, result_queue)
301-
proc = ctx.Process(target=staged_ipc_array_test, args=args)
302-
proc.start()
303-
succ, out = result_queue.get()
304-
proc.join(3)
305-
if not succ:
306-
self.fail(out)
307-
else:
308-
np.testing.assert_equal(arr, out)
262+
fut = self.exe.submit(
263+
staged_ipc_array_test, ipch, device_num, parent_pid=os.getpid()
264+
)
265+
out = fut.result(timeout=3)
266+
np.testing.assert_equal(arr, out)
309267

310268

311269
@windows_only

0 commit comments

Comments
 (0)