11# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22# SPDX-License-Identifier: BSD-2-Clause
33
4+ import pytest
5+ import concurrent .futures
46import multiprocessing as mp
7+ import os
58import itertools
6- import traceback
79import pickle
810
911import numpy as np
2123import unittest
2224
2325
24- def core_ipc_handle_test (the_work , result_queue ):
25- try :
26- arr = the_work ()
27- # Catch anything going wrong in the worker function
28- except : # noqa: E722
29- # FAILED. propagate the exception as a string
30- succ = False
31- out = traceback .format_exc ()
32- else :
33- # OK. send the ndarray back
34- succ = True
35- out = arr
36- result_queue .put ((succ , out ))
37-
38-
39- def base_ipc_handle_test (handle , size , result_queue ):
40- def the_work ():
41- dtype = np .dtype (np .intp )
42- with cuda .open_ipc_array (
43- handle , shape = size // dtype .itemsize , dtype = dtype
44- ) as darr :
45- # copy the data to host
46- return darr .copy_to_host ()
47-
48- core_ipc_handle_test (the_work , result_queue )
49-
50-
51- def serialize_ipc_handle_test (handle , result_queue ):
52- def the_work ():
53- dtype = np .dtype (np .intp )
54- darr = handle .open_array (
55- cuda .current_context (),
56- shape = handle .size // dtype .itemsize ,
57- dtype = dtype ,
58- )
26+ def base_ipc_handle_test (handle , size , parent_pid ):
27+ pid = os .getpid ()
28+ assert pid != parent_pid
29+ dtype = np .dtype (np .intp )
30+ with cuda .open_ipc_array (
31+ handle , shape = size // dtype .itemsize , dtype = dtype
32+ ) as darr :
5933 # copy the data to host
34+ return darr .copy_to_host ()
35+
36+
37+ def serialize_ipc_handle_test (handle , parent_pid ):
38+ pid = os .getpid ()
39+ assert pid != parent_pid
40+
41+ dtype = np .dtype (np .intp )
42+ darr = handle .open_array (
43+ cuda .current_context (),
44+ shape = handle .size // dtype .itemsize ,
45+ dtype = dtype ,
46+ )
47+ # copy the data to host
48+ arr = darr .copy_to_host ()
49+ handle .close ()
50+ return arr
51+
52+
53+ def ipc_array_test (ipcarr , parent_pid ):
54+ pid = os .getpid ()
55+ assert pid != parent_pid
56+ with ipcarr as darr :
6057 arr = darr .copy_to_host ()
61- handle .close ()
62- return arr
58+ with pytest .raises (ValueError , match = "IpcHandle is already opened" ):
59+ with ipcarr :
60+ pass
61+ return arr
6362
64- core_ipc_handle_test (the_work , result_queue )
6563
64+ class CUDAIpcTestCase (CUDATestCase ):
65+ @classmethod
66+ def setUpClass (cls ) -> None :
67+ cls .exe = concurrent .futures .ProcessPoolExecutor (
68+ mp_context = mp .get_context ("spawn" )
69+ )
6670
67- def ipc_array_test (ipcarr , result_queue ):
68- try :
69- with ipcarr as darr :
70- arr = darr .copy_to_host ()
71- try :
72- # should fail to reopen
73- with ipcarr :
74- pass
75- except ValueError as e :
76- if str (e ) != "IpcHandle is already opened" :
77- raise AssertionError ("invalid exception message" )
78- else :
79- raise AssertionError ("did not raise on reopen" )
80- # Catch any exception so we can propagate it
81- except : # noqa: E722
82- # FAILED. propagate the exception as a string
83- succ = False
84- out = traceback .format_exc ()
85- else :
86- # OK. send the ndarray back
87- succ = True
88- out = arr
89- result_queue .put ((succ , out ))
71+ @classmethod
72+ def tearDownClass (cls ) -> None :
73+ cls .exe .shutdown ()
74+ del cls .exe
9075
9176
9277@linux_only
9378@skip_under_cuda_memcheck ("Hangs cuda-memcheck" )
9479@skip_on_cudasim ("Ipc not available in CUDASIM" )
9580@skip_on_arm ("CUDA IPC not supported on ARM in Numba" )
9681@skip_on_wsl2 ("CUDA IPC unreliable on WSL2; skipping IPC tests" )
97- class TestIpcMemory (CUDATestCase ):
82+ class TestIpcMemory (CUDAIpcTestCase ):
9883 def test_ipc_handle (self ):
9984 # prepare data for IPC
10085 arr = np .arange (10 , dtype = np .intp )
@@ -109,17 +94,11 @@ def test_ipc_handle(self):
10994 size = ipch .size
11095
11196 # spawn new process for testing
112- ctx = mp .get_context ("spawn" )
113- result_queue = ctx .Queue ()
114- args = (handle_bytes , size , result_queue )
115- proc = ctx .Process (target = base_ipc_handle_test , args = args )
116- proc .start ()
117- succ , out = result_queue .get ()
118- if not succ :
119- self .fail (out )
120- else :
121- np .testing .assert_equal (arr , out )
122- proc .join (3 )
97+ fut = self .exe .submit (
98+ base_ipc_handle_test , handle_bytes , size , parent_pid = os .getpid ()
99+ )
100+ out = fut .result (timeout = 3 )
101+ np .testing .assert_equal (arr , out )
123102
124103 def variants (self ):
125104 # Test with no slicing and various different slices
@@ -152,17 +131,11 @@ def check_ipc_handle_serialization(self, index_arg=None, foreign=False):
152131 self .assertEqual (ipch_recon .handle .reserved , ipch .handle .reserved )
153132
154133 # spawn new process for testing
155- ctx = mp .get_context ("spawn" )
156- result_queue = ctx .Queue ()
157- args = (ipch , result_queue )
158- proc = ctx .Process (target = serialize_ipc_handle_test , args = args )
159- proc .start ()
160- succ , out = result_queue .get ()
161- if not succ :
162- self .fail (out )
163- else :
164- np .testing .assert_equal (expect , out )
165- proc .join (3 )
134+ fut = self .exe .submit (
135+ serialize_ipc_handle_test , ipch , parent_pid = os .getpid ()
136+ )
137+ out = fut .result (timeout = 3 )
138+ np .testing .assert_equal (expect , out )
166139
167140 def test_ipc_handle_serialization (self ):
168141 for (
@@ -185,17 +158,9 @@ def check_ipc_array(self, index_arg=None, foreign=False):
185158 ipch = devarr .get_ipc_handle ()
186159
187160 # spawn new process for testing
188- ctx = mp .get_context ("spawn" )
189- result_queue = ctx .Queue ()
190- args = (ipch , result_queue )
191- proc = ctx .Process (target = ipc_array_test , args = args )
192- proc .start ()
193- succ , out = result_queue .get ()
194- if not succ :
195- self .fail (out )
196- else :
197- np .testing .assert_equal (expect , out )
198- proc .join (3 )
161+ fut = self .exe .submit (ipc_array_test , ipch , parent_pid = os .getpid ())
162+ out = fut .result (timeout = 3 )
163+ np .testing .assert_equal (expect , out )
199164
200165 def test_ipc_array (self ):
201166 for (
@@ -206,65 +171,45 @@ def test_ipc_array(self):
206171 self .check_ipc_array (index , foreign )
207172
208173
209- def staged_ipc_handle_test (handle , device_num , result_queue ):
210- def the_work ():
211- with cuda .gpus [device_num ]:
212- this_ctx = cuda .devices .get_context ()
213- deviceptr = handle .open_staged (this_ctx )
214- arrsize = handle .size // np .dtype (np .intp ).itemsize
215- hostarray = np .zeros (arrsize , dtype = np .intp )
216- cuda .driver .device_to_host (
217- hostarray ,
218- deviceptr ,
219- size = handle .size ,
220- )
221- handle .close ()
174+ def staged_ipc_handle_test (handle , device_num , parent_pid ):
175+ pid = os .getpid ()
176+ assert pid != parent_pid
177+ with cuda .gpus [device_num ]:
178+ this_ctx = cuda .devices .get_context ()
179+ deviceptr = handle .open_staged (this_ctx )
180+ arrsize = handle .size // np .dtype (np .intp ).itemsize
181+ hostarray = np .zeros (arrsize , dtype = np .intp )
182+ cuda .driver .device_to_host (
183+ hostarray ,
184+ deviceptr ,
185+ size = handle .size ,
186+ )
187+ handle .close ()
222188 return hostarray
223189
224- core_ipc_handle_test (the_work , result_queue )
225-
226-
227- def staged_ipc_array_test (ipcarr , device_num , result_queue ):
228- try :
229- with cuda .gpus [device_num ]:
230- with ipcarr as darr :
231- arr = darr .copy_to_host ()
232- try :
233- # should fail to reopen
234- with ipcarr :
235- pass
236- except ValueError as e :
237- if str (e ) != "IpcHandle is already opened" :
238- raise AssertionError ("invalid exception message" )
239- else :
240- raise AssertionError ("did not raise on reopen" )
241- # Catch any exception so we can propagate it
242- except : # noqa: E722
243- # FAILED. propagate the exception as a string
244- succ = False
245- out = traceback .format_exc ()
246- else :
247- # OK. send the ndarray back
248- succ = True
249- out = arr
250- result_queue .put ((succ , out ))
190+
191+ def staged_ipc_array_test (ipcarr , device_num , parent_pid ):
192+ pid = os .getpid ()
193+ assert pid != parent_pid
194+ with cuda .gpus [device_num ]:
195+ with ipcarr as darr :
196+ arr = darr .copy_to_host ()
197+ with pytest .raises (ValueError , match = "IpcHandle is already opened" ):
198+ with ipcarr :
199+ pass
200+ return arr
251201
252202
253203@linux_only
254204@skip_under_cuda_memcheck ("Hangs cuda-memcheck" )
255205@skip_on_cudasim ("Ipc not available in CUDASIM" )
256206@skip_on_arm ("CUDA IPC not supported on ARM in Numba" )
257207@skip_on_wsl2 ("CUDA IPC unreliable on WSL2; skipping IPC tests" )
258- class TestIpcStaged (CUDATestCase ):
208+ class TestIpcStaged (CUDAIpcTestCase ):
259209 def test_staged (self ):
260210 # prepare data for IPC
261211 arr = np .arange (10 , dtype = np .intp )
262212 devarr = cuda .to_device (arr )
263-
264- # spawn new process for testing
265- mpctx = mp .get_context ("spawn" )
266- result_queue = mpctx .Queue ()
267-
268213 # create IPC handle
269214 ctx = cuda .current_context ()
270215 ipch = ctx .get_ipc_handle (devarr .gpu_data )
@@ -276,16 +221,16 @@ def test_staged(self):
276221 self .assertEqual (ipch_recon .size , ipch .size )
277222
278223 # Test on every CUDA devices
279- for device_num in range ( len (cuda .gpus )):
280- args = ( ipch , device_num , result_queue )
281- proc = mpctx . Process ( target = staged_ipc_handle_test , args = args )
282- proc . start ()
283- succ , out = result_queue . get ( )
284- proc . join ( 3 )
285- if not succ :
286- self . fail ( out )
287- else :
288- np .testing .assert_equal (arr , out )
224+ ngpus = len (cuda .gpus )
225+ futures = [
226+ self . exe . submit (
227+ staged_ipc_handle_test , ipch , device_num , parent_pid = os . getpid ()
228+ )
229+ for device_num in range ( ngpus )
230+ ]
231+
232+ for fut in concurrent . futures . as_completed ( futures , timeout = 3 * ngpus ) :
233+ np .testing .assert_equal (arr , fut . result () )
289234
290235 def test_ipc_array (self ):
291236 for device_num in range (len (cuda .gpus )):
@@ -295,17 +240,11 @@ def test_ipc_array(self):
295240 ipch = devarr .get_ipc_handle ()
296241
297242 # spawn new process for testing
298- ctx = mp .get_context ("spawn" )
299- result_queue = ctx .Queue ()
300- args = (ipch , device_num , result_queue )
301- proc = ctx .Process (target = staged_ipc_array_test , args = args )
302- proc .start ()
303- succ , out = result_queue .get ()
304- proc .join (3 )
305- if not succ :
306- self .fail (out )
307- else :
308- np .testing .assert_equal (arr , out )
243+ fut = self .exe .submit (
244+ staged_ipc_array_test , ipch , device_num , parent_pid = os .getpid ()
245+ )
246+ out = fut .result (timeout = 3 )
247+ np .testing .assert_equal (arr , out )
309248
310249
311250@windows_only
0 commit comments