22# SPDX-License-Identifier: BSD-2-Clause
33
44import numbers
5- from ctypes import byref
65import weakref
76
87from numba import cuda
@@ -31,9 +30,6 @@ def test_gpus_cudevice_indexing(self):
3130 device_ids = [device .id for device in cuda .list_devices ()]
3231 for device_id in device_ids :
3332 with cuda .gpus [device_id ]:
34- # Check that the device is an integer if not using the CUDA
35- # Python bindings, otherwise it's a CUdevice object
36- assert isinstance (device_id , int ) != driver .USE_NV_BINDING
3733 self .assertEqual (cuda .gpus .current .id , device_id )
3834
3935
@@ -91,14 +87,9 @@ def tearDown(self):
9187 def test_attached_primary (self , extra_work = lambda : None ):
9288 # Emulate primary context creation by 3rd party
9389 the_driver = driver .driver
94- if driver .USE_NV_BINDING :
95- dev = driver .binding .CUdevice (0 )
96- binding_hctx = the_driver .cuDevicePrimaryCtxRetain (dev )
97- hctx = driver .drvapi .cu_context (int (binding_hctx ))
98- else :
99- dev = 0
100- hctx = driver .drvapi .cu_context ()
101- the_driver .cuDevicePrimaryCtxRetain (byref (hctx ), dev )
90+ dev = driver .binding .CUdevice (0 )
91+ binding_hctx = the_driver .cuDevicePrimaryCtxRetain (dev )
92+ hctx = driver .drvapi .cu_context (int (binding_hctx ))
10293 try :
10394 ctx = driver .Context (weakref .proxy (self ), hctx )
10495 ctx .push ()
@@ -115,33 +106,29 @@ def test_attached_primary(self, extra_work=lambda: None):
115106 def test_attached_non_primary (self ):
116107 # Emulate non-primary context creation by 3rd party
117108 the_driver = driver .driver
118- if driver .USE_NV_BINDING :
119- flags = 0
120- dev = driver .binding .CUdevice (0 )
121-
122- result , version = driver .binding .cuDriverGetVersion ()
123- self .assertEqual (
124- result ,
125- driver .binding .CUresult .CUDA_SUCCESS ,
126- "Error getting CUDA driver version" ,
127- )
128-
129- # CUDA 13's cuCtxCreate has an optional parameter prepended. The
130- # version of cuCtxCreate in use depends on the cuda.bindings major
131- # version rather than the installed driver version on the machine
132- # we're running on.
133- from cuda import bindings
134-
135- bindings_version = int (bindings .__version__ .split ("." )[0 ])
136- if bindings_version in (11 , 12 ):
137- args = (flags , dev )
138- else :
139- args = (None , flags , dev )
140-
141- hctx = the_driver .cuCtxCreate (* args )
109+ flags = 0
110+ dev = driver .binding .CUdevice (0 )
111+
112+ result , version = driver .binding .cuDriverGetVersion ()
113+ self .assertEqual (
114+ result ,
115+ driver .binding .CUresult .CUDA_SUCCESS ,
116+ "Error getting CUDA driver version" ,
117+ )
118+
119+ # CUDA 13's cuCtxCreate has an optional parameter prepended. The
120+ # version of cuCtxCreate in use depends on the cuda.bindings major
121+ # version rather than the installed driver version on the machine
122+ # we're running on.
123+ from cuda import bindings
124+
125+ bindings_version = int (bindings .__version__ .split ("." )[0 ])
126+ if bindings_version in (11 , 12 ):
127+ args = (flags , dev )
142128 else :
143- hctx = driver .drvapi .cu_context ()
144- the_driver .cuCtxCreate (byref (hctx ), 0 , 0 )
129+ args = (None , flags , dev )
130+
131+ hctx = the_driver .cuCtxCreate (* args )
145132 try :
146133 cuda .current_context ()
147134 except RuntimeError as e :
0 commit comments