11# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22# SPDX-License-Identifier: BSD-2-Clause
33
4+ import concurrent .futures
45import multiprocessing as mp
6+ import os
57import itertools
6- import traceback
78import pickle
89
910import numpy as np
2122import unittest
2223
2324
24- def core_ipc_handle_test (the_work , result_queue ):
25- try :
26- arr = the_work ()
27- # Catch anything going wrong in the worker function
28- except : # noqa: E722
29- # FAILED. propagate the exception as a string
30- succ = False
31- out = traceback .format_exc ()
32- else :
33- # OK. send the ndarray back
34- succ = True
35- out = arr
36- result_queue .put ((succ , out ))
37-
38-
39- def base_ipc_handle_test (handle , size , result_queue ):
40- def the_work ():
41- dtype = np .dtype (np .intp )
42- with cuda .open_ipc_array (
43- handle , shape = size // dtype .itemsize , dtype = dtype
44- ) as darr :
45- # copy the data to host
46- return darr .copy_to_host ()
47-
48- core_ipc_handle_test (the_work , result_queue )
49-
50-
51- def serialize_ipc_handle_test (handle , result_queue ):
52- def the_work ():
53- dtype = np .dtype (np .intp )
54- darr = handle .open_array (
55- cuda .current_context (),
56- shape = handle .size // dtype .itemsize ,
57- dtype = dtype ,
58- )
25+ def base_ipc_handle_test (handle , size , parent_pid ):
26+ pid = os .getpid ()
27+ assert pid != parent_pid
28+ dtype = np .dtype (np .intp )
29+ with cuda .open_ipc_array (
30+ handle , shape = size // dtype .itemsize , dtype = dtype
31+ ) as darr :
5932 # copy the data to host
33+ return darr .copy_to_host ()
34+
35+
36+ def serialize_ipc_handle_test (handle , parent_pid ):
37+ pid = os .getpid ()
38+ assert pid != parent_pid
39+
40+ dtype = np .dtype (np .intp )
41+ darr = handle .open_array (
42+ cuda .current_context (),
43+ shape = handle .size // dtype .itemsize ,
44+ dtype = dtype ,
45+ )
46+ # copy the data to host
47+ arr = darr .copy_to_host ()
48+ handle .close ()
49+ return arr
50+
51+
52+ def ipc_array_test (ipcarr , parent_pid ):
53+ pid = os .getpid ()
54+ assert pid != parent_pid
55+ with ipcarr as darr :
6056 arr = darr .copy_to_host ()
61- handle .close ()
62- return arr
57+ try :
58+ # should fail to reopen
59+ with ipcarr :
60+ pass
61+ except ValueError as e :
62+ if str (e ) != "IpcHandle is already opened" :
63+ raise AssertionError ("invalid exception message" )
64+ else :
65+ raise AssertionError ("did not raise on reopen" )
66+ return arr
6367
64- core_ipc_handle_test (the_work , result_queue )
6568
69+ class CUDAIpcTestCase (CUDATestCase ):
70+ @classmethod
71+ def setUpClass (cls ) -> None :
72+ cls .exe = concurrent .futures .ProcessPoolExecutor (
73+ mp_context = mp .get_context ("spawn" )
74+ )
6675
67- def ipc_array_test (ipcarr , result_queue ):
68- try :
69- with ipcarr as darr :
70- arr = darr .copy_to_host ()
71- try :
72- # should fail to reopen
73- with ipcarr :
74- pass
75- except ValueError as e :
76- if str (e ) != "IpcHandle is already opened" :
77- raise AssertionError ("invalid exception message" )
78- else :
79- raise AssertionError ("did not raise on reopen" )
80- # Catch any exception so we can propagate it
81- except : # noqa: E722
82- # FAILED. propagate the exception as a string
83- succ = False
84- out = traceback .format_exc ()
85- else :
86- # OK. send the ndarray back
87- succ = True
88- out = arr
89- result_queue .put ((succ , out ))
76+ @classmethod
77+ def tearDownClass (cls ) -> None :
78+ cls .exe .shutdown ()
79+ del cls .exe
80+
81+ # def setUp(self) -> None:
82+ # self.exe = concurrent.futures.ProcessPoolExecutor(
83+ # mp_context=mp.get_context("spawn")
84+ # )
85+ #
86+ # def tearDown(self) -> None:
87+ # self.exe.shutdown(wait=True)
88+ # del self.exe
9089
9190
9291@linux_only
9392@skip_under_cuda_memcheck ("Hangs cuda-memcheck" )
9493@skip_on_cudasim ("Ipc not available in CUDASIM" )
9594@skip_on_arm ("CUDA IPC not supported on ARM in Numba" )
9695@skip_on_wsl2 ("CUDA IPC unreliable on WSL2; skipping IPC tests" )
97- class TestIpcMemory (CUDATestCase ):
96+ class TestIpcMemory (CUDAIpcTestCase ):
9897 def test_ipc_handle (self ):
9998 # prepare data for IPC
10099 arr = np .arange (10 , dtype = np .intp )
@@ -109,17 +108,11 @@ def test_ipc_handle(self):
109108 size = ipch .size
110109
111110 # spawn new process for testing
112- ctx = mp .get_context ("spawn" )
113- result_queue = ctx .Queue ()
114- args = (handle_bytes , size , result_queue )
115- proc = ctx .Process (target = base_ipc_handle_test , args = args )
116- proc .start ()
117- succ , out = result_queue .get ()
118- if not succ :
119- self .fail (out )
120- else :
121- np .testing .assert_equal (arr , out )
122- proc .join (3 )
111+ fut = self .exe .submit (
112+ base_ipc_handle_test , handle_bytes , size , parent_pid = os .getpid ()
113+ )
114+ out = fut .result (timeout = 3 )
115+ np .testing .assert_equal (arr , out )
123116
124117 def variants (self ):
125118 # Test with no slicing and various different slices
@@ -152,17 +145,11 @@ def check_ipc_handle_serialization(self, index_arg=None, foreign=False):
152145 self .assertEqual (ipch_recon .handle .reserved , ipch .handle .reserved )
153146
154147 # spawn new process for testing
155- ctx = mp .get_context ("spawn" )
156- result_queue = ctx .Queue ()
157- args = (ipch , result_queue )
158- proc = ctx .Process (target = serialize_ipc_handle_test , args = args )
159- proc .start ()
160- succ , out = result_queue .get ()
161- if not succ :
162- self .fail (out )
163- else :
164- np .testing .assert_equal (expect , out )
165- proc .join (3 )
148+ fut = self .exe .submit (
149+ serialize_ipc_handle_test , ipch , parent_pid = os .getpid ()
150+ )
151+ out = fut .result (timeout = 3 )
152+ np .testing .assert_equal (expect , out )
166153
167154 def test_ipc_handle_serialization (self ):
168155 for (
@@ -185,17 +172,9 @@ def check_ipc_array(self, index_arg=None, foreign=False):
185172 ipch = devarr .get_ipc_handle ()
186173
187174 # spawn new process for testing
188- ctx = mp .get_context ("spawn" )
189- result_queue = ctx .Queue ()
190- args = (ipch , result_queue )
191- proc = ctx .Process (target = ipc_array_test , args = args )
192- proc .start ()
193- succ , out = result_queue .get ()
194- if not succ :
195- self .fail (out )
196- else :
197- np .testing .assert_equal (expect , out )
198- proc .join (3 )
175+ fut = self .exe .submit (ipc_array_test , ipch , parent_pid = os .getpid ())
176+ out = fut .result (timeout = 3 )
177+ np .testing .assert_equal (expect , out )
199178
200179 def test_ipc_array (self ):
201180 for (
@@ -206,65 +185,51 @@ def test_ipc_array(self):
206185 self .check_ipc_array (index , foreign )
207186
208187
209- def staged_ipc_handle_test (handle , device_num , result_queue ):
210- def the_work ():
211- with cuda .gpus [device_num ]:
212- this_ctx = cuda .devices .get_context ()
213- deviceptr = handle .open_staged (this_ctx )
214- arrsize = handle .size // np .dtype (np .intp ).itemsize
215- hostarray = np .zeros (arrsize , dtype = np .intp )
216- cuda .driver .device_to_host (
217- hostarray ,
218- deviceptr ,
219- size = handle .size ,
220- )
221- handle .close ()
188+ def staged_ipc_handle_test (handle , device_num , parent_pid ):
189+ pid = os .getpid ()
190+ assert pid != parent_pid
191+ with cuda .gpus [device_num ]:
192+ this_ctx = cuda .devices .get_context ()
193+ deviceptr = handle .open_staged (this_ctx )
194+ arrsize = handle .size // np .dtype (np .intp ).itemsize
195+ hostarray = np .zeros (arrsize , dtype = np .intp )
196+ cuda .driver .device_to_host (
197+ hostarray ,
198+ deviceptr ,
199+ size = handle .size ,
200+ )
201+ handle .close ()
222202 return hostarray
223203
224- core_ipc_handle_test (the_work , result_queue )
225-
226-
227- def staged_ipc_array_test (ipcarr , device_num , result_queue ):
228- try :
229- with cuda .gpus [device_num ]:
230- with ipcarr as darr :
231- arr = darr .copy_to_host ()
232- try :
233- # should fail to reopen
234- with ipcarr :
235- pass
236- except ValueError as e :
237- if str (e ) != "IpcHandle is already opened" :
238- raise AssertionError ("invalid exception message" )
239- else :
240- raise AssertionError ("did not raise on reopen" )
241- # Catch any exception so we can propagate it
242- except : # noqa: E722
243- # FAILED. propagate the exception as a string
244- succ = False
245- out = traceback .format_exc ()
246- else :
247- # OK. send the ndarray back
248- succ = True
249- out = arr
250- result_queue .put ((succ , out ))
204+
205+ def staged_ipc_array_test (ipcarr , device_num , parent_pid ):
206+ pid = os .getpid ()
207+ assert pid != parent_pid
208+ with cuda .gpus [device_num ]:
209+ with ipcarr as darr :
210+ arr = darr .copy_to_host ()
211+ try :
212+ # should fail to reopen
213+ with ipcarr :
214+ pass
215+ except ValueError as e :
216+ if str (e ) != "IpcHandle is already opened" :
217+ raise AssertionError ("invalid exception message" )
218+ else :
219+ raise AssertionError ("did not raise on reopen" )
220+ return arr
251221
252222
253223@linux_only
254224@skip_under_cuda_memcheck ("Hangs cuda-memcheck" )
255225@skip_on_cudasim ("Ipc not available in CUDASIM" )
256226@skip_on_arm ("CUDA IPC not supported on ARM in Numba" )
257227@skip_on_wsl2 ("CUDA IPC unreliable on WSL2; skipping IPC tests" )
258- class TestIpcStaged (CUDATestCase ):
228+ class TestIpcStaged (CUDAIpcTestCase ):
259229 def test_staged (self ):
260230 # prepare data for IPC
261231 arr = np .arange (10 , dtype = np .intp )
262232 devarr = cuda .to_device (arr )
263-
264- # spawn new process for testing
265- mpctx = mp .get_context ("spawn" )
266- result_queue = mpctx .Queue ()
267-
268233 # create IPC handle
269234 ctx = cuda .current_context ()
270235 ipch = ctx .get_ipc_handle (devarr .gpu_data )
@@ -276,16 +241,15 @@ def test_staged(self):
276241 self .assertEqual (ipch_recon .size , ipch .size )
277242
278243 # Test on every CUDA devices
279- for device_num in range (len (cuda .gpus )):
280- args = (ipch , device_num , result_queue )
281- proc = mpctx .Process (target = staged_ipc_handle_test , args = args )
282- proc .start ()
283- succ , out = result_queue .get ()
284- proc .join (3 )
285- if not succ :
286- self .fail (out )
287- else :
288- np .testing .assert_equal (arr , out )
244+ futures = [
245+ self .exe .submit (
246+ staged_ipc_handle_test , ipch , device_num , parent_pid = os .getpid ()
247+ )
248+ for device_num in range (len (cuda .gpus ))
249+ ]
250+
251+ for fut in concurrent .futures .as_completed (futures , timeout = 3.0 ):
252+ np .testing .assert_equal (arr , fut .result ())
289253
290254 def test_ipc_array (self ):
291255 for device_num in range (len (cuda .gpus )):
@@ -295,17 +259,11 @@ def test_ipc_array(self):
295259 ipch = devarr .get_ipc_handle ()
296260
297261 # spawn new process for testing
298- ctx = mp .get_context ("spawn" )
299- result_queue = ctx .Queue ()
300- args = (ipch , device_num , result_queue )
301- proc = ctx .Process (target = staged_ipc_array_test , args = args )
302- proc .start ()
303- succ , out = result_queue .get ()
304- proc .join (3 )
305- if not succ :
306- self .fail (out )
307- else :
308- np .testing .assert_equal (arr , out )
262+ fut = self .exe .submit (
263+ staged_ipc_array_test , ipch , device_num , parent_pid = os .getpid ()
264+ )
265+ out = fut .result (timeout = 3 )
266+ np .testing .assert_equal (arr , out )
309267
310268
311269@windows_only
0 commit comments