22# SPDX-License-Identifier: BSD-2-Clause
33
44import numbers
5- from ctypes import byref
65import weakref
76
87from numba import cuda
@@ -31,9 +30,7 @@ def test_gpus_cudevice_indexing(self):
3130 device_ids = [device .id for device in cuda .list_devices ()]
3231 for device_id in device_ids :
3332 with cuda .gpus [device_id ]:
34- # Check that the device is an integer if not using the CUDA
35- # Python bindings, otherwise it's a CUdevice object
36- assert isinstance (device_id , int ) != driver .USE_NV_BINDING
33+ assert not isinstance (device_id , int )
3734 self .assertEqual (cuda .gpus .current .id , device_id )
3835
3936
@@ -91,14 +88,9 @@ def tearDown(self):
9188 def test_attached_primary (self , extra_work = lambda : None ):
9289 # Emulate primary context creation by 3rd party
9390 the_driver = driver .driver
94- if driver .USE_NV_BINDING :
95- dev = driver .binding .CUdevice (0 )
96- binding_hctx = the_driver .cuDevicePrimaryCtxRetain (dev )
97- hctx = driver .drvapi .cu_context (int (binding_hctx ))
98- else :
99- dev = 0
100- hctx = driver .drvapi .cu_context ()
101- the_driver .cuDevicePrimaryCtxRetain (byref (hctx ), dev )
91+ dev = driver .binding .CUdevice (0 )
92+ binding_hctx = the_driver .cuDevicePrimaryCtxRetain (dev )
93+ hctx = driver .drvapi .cu_context (int (binding_hctx ))
10294 try :
10395 ctx = driver .Context (weakref .proxy (self ), hctx )
10496 ctx .push ()
@@ -115,33 +107,29 @@ def test_attached_primary(self, extra_work=lambda: None):
115107 def test_attached_non_primary (self ):
116108 # Emulate non-primary context creation by 3rd party
117109 the_driver = driver .driver
118- if driver .USE_NV_BINDING :
119- flags = 0
120- dev = driver .binding .CUdevice (0 )
121-
122- result , version = driver .binding .cuDriverGetVersion ()
123- self .assertEqual (
124- result ,
125- driver .binding .CUresult .CUDA_SUCCESS ,
126- "Error getting CUDA driver version" ,
127- )
128-
129- # CUDA 13's cuCtxCreate has an optional parameter prepended. The
130- # version of cuCtxCreate in use depends on the cuda.bindings major
131- # version rather than the installed driver version on the machine
132- # we're running on.
133- from cuda import bindings
134-
135- bindings_version = int (bindings .__version__ .split ("." )[0 ])
136- if bindings_version in (11 , 12 ):
137- args = (flags , dev )
138- else :
139- args = (None , flags , dev )
140-
141- hctx = the_driver .cuCtxCreate (* args )
110+ flags = 0
111+ dev = driver .binding .CUdevice (0 )
112+
113+ result , version = driver .binding .cuDriverGetVersion ()
114+ self .assertEqual (
115+ result ,
116+ driver .binding .CUresult .CUDA_SUCCESS ,
117+ "Error getting CUDA driver version" ,
118+ )
119+
120+ # CUDA 13's cuCtxCreate has an optional parameter prepended. The
121+ # version of cuCtxCreate in use depends on the cuda.bindings major
122+ # version rather than the installed driver version on the machine
123+ # we're running on.
124+ from cuda import bindings
125+
126+ bindings_version = int (bindings .__version__ .split ("." )[0 ])
127+ if bindings_version in (11 , 12 ):
128+ args = (flags , dev )
142129 else :
143- hctx = driver .drvapi .cu_context ()
144- the_driver .cuCtxCreate (byref (hctx ), 0 , 0 )
130+ args = (None , flags , dev )
131+
132+ hctx = the_driver .cuCtxCreate (* args )
145133 try :
146134 cuda .current_context ()
147135 except RuntimeError as e :
0 commit comments