44from numba .cuda .cudadrv .driver import PyNvJitLinker
55
66import itertools
7+ import os
78from numba .cuda import get_current_device
89from numba import cuda
910from numba import config
1011
11-
12- @unittest .skipIf (not config .ENABLE_PYNVJITLINK , "pynvjitlink not enabled" )
12+ TEST_BIN_DIR = os .getenv ("NUMBA_CUDA_TEST_BIN_DIR" )
13+ test_device_functions_a = os .path .join (
14+ TEST_BIN_DIR , "test_device_functions.a"
15+ )
16+ test_device_functions_cubin = os .path .join (
17+ TEST_BIN_DIR , "test_device_functions.cubin"
18+ )
19+ test_device_functions_cu = os .path .join (
20+ TEST_BIN_DIR , "test_device_functions.cu"
21+ )
22+ test_device_functions_fatbin = os .path .join (
23+ TEST_BIN_DIR , "test_device_functions.fatbin"
24+ )
25+ test_device_functions_o = os .path .join (
26+ TEST_BIN_DIR , "test_device_functions.o"
27+ )
28+ test_device_functions_ptx = os .path .join (
29+ TEST_BIN_DIR , "test_device_functions.ptx"
30+ )
31+ test_device_functions_ltoir = os .path .join (
32+ TEST_BIN_DIR , "test_device_functions.ltoir"
33+ )
34+
35+
36+ @unittest .skipIf (
37+ not config .ENABLE_PYNVJITLINK or not TEST_BIN_DIR ,
38+ "pynvjitlink not enabled"
39+ )
1340@skip_on_cudasim ("Linking unsupported in the simulator" )
1441class TestLinker (CUDATestCase ):
1542 _NUMBA_NVIDIA_BINDING_0_ENV = {"NUMBA_CUDA_USE_NVIDIA_BINDING" : "0" }
@@ -91,12 +118,12 @@ def test_nvjitlink_ptx_compile_options(self):
91118
92119 def test_nvjitlink_add_file_guess_ext_linkable_code (self ):
93120 files = (
94- "test_device_functions.a" ,
95- "test_device_functions.cubin" ,
96- "test_device_functions.cu" ,
97- "test_device_functions.fatbin" ,
98- "test_device_functions.o" ,
99- "test_device_functions.ptx" ,
121+ test_device_functions_a ,
122+ test_device_functions_cubin ,
123+ test_device_functions_cu ,
124+ test_device_functions_fatbin ,
125+ test_device_functions_o ,
126+ test_device_functions_ptx ,
100127 )
101128 for file in files :
102129 with self .subTest (file = file ):
@@ -106,7 +133,7 @@ def test_nvjitlink_add_file_guess_ext_linkable_code(self):
106133 patched_linker .add_file_guess_ext (file )
107134
108135 def test_nvjitlink_test_add_file_guess_ext_invalid_input (self ):
109- with open ("test_device_functions.cubin" , "rb" ) as f :
136+ with open (test_device_functions_cubin , "rb" ) as f :
110137 content = f .read ()
111138
112139 patched_linker = PyNvJitLinker (
@@ -121,12 +148,12 @@ def test_nvjitlink_test_add_file_guess_ext_invalid_input(self):
121148
122149 def test_nvjitlink_jit_with_linkable_code (self ):
123150 files = (
124- "test_device_functions.a" ,
125- "test_device_functions.cubin" ,
126- "test_device_functions.cu" ,
127- "test_device_functions.fatbin" ,
128- "test_device_functions.o" ,
129- "test_device_functions.ptx" ,
151+ test_device_functions_a ,
152+ test_device_functions_cubin ,
153+ test_device_functions_cu ,
154+ test_device_functions_fatbin ,
155+ test_device_functions_o ,
156+ test_device_functions_ptx ,
130157 )
131158 for file in files :
132159 with self .subTest (file = file ):
@@ -142,7 +169,7 @@ def kernel(result):
142169 assert result [0 ] == 3
143170
144171 def test_nvjitlink_jit_with_linkable_code_lto (self ):
145- file = "test_device_functions.ltoir"
172+ file = test_device_functions_ltoir
146173
147174 sig = "uint32(uint32, uint32)"
148175 add_from_numba = cuda .declare_device ("add_from_numba" , sig )
@@ -156,7 +183,7 @@ def kernel(result):
156183 assert result [0 ] == 3
157184
158185 def test_nvjitlink_jit_with_invalid_linkable_code (self ):
159- with open ("test_device_functions.cubin" , "rb" ) as f :
186+ with open (test_device_functions_cubin , "rb" ) as f :
160187 content = f .read ()
161188 with self .assertRaisesRegex (
162189 TypeError , "Expected path to file or a LinkableCode"
0 commit comments