|
2 | 2 | from numba.cuda.testing import skip_on_cudasim |
3 | 3 | from numba.cuda.testing import CUDATestCase |
4 | 4 | from numba.cuda.cudadrv.driver import PyNvJitLinker |
| 5 | +from numba.cuda import get_current_device |
| 6 | + |
| 7 | +from numba import cuda |
| 8 | +from numba import config |
| 9 | +from numba.tests.support import run_in_subprocess, override_config |
5 | 10 |
|
6 | 11 | import itertools |
7 | 12 | import os |
8 | 13 | import io |
9 | 14 | import contextlib |
10 | 15 | import warnings |
11 | 16 |
|
12 | | -from numba.cuda import get_current_device |
13 | | -from numba import cuda |
14 | | -from numba import config |
15 | 17 |
|
16 | 18 | TEST_BIN_DIR = os.getenv("NUMBA_CUDA_TEST_BIN_DIR") |
17 | 19 | if TEST_BIN_DIR: |
@@ -251,5 +253,61 @@ def kernel(): |
251 | 253 | pass |
252 | 254 |
|
253 | 255 |
|
| 256 | +class TestLinkerUsage(CUDATestCase): |
| 257 | + """Test that whether pynvjitlink can be enabled by both environment variable |
| 258 | + and modification of config at runtime. |
| 259 | + """ |
| 260 | + def test_linker_enabled_envvar(self): |
| 261 | + # Linkable code is only supported via pynvjitlink |
| 262 | + src = """if 1: |
| 263 | + import os |
| 264 | + from numba import cuda |
| 265 | +
|
| 266 | + TEST_BIN_DIR = os.getenv("NUMBA_CUDA_TEST_BIN_DIR") |
| 267 | + if TEST_BIN_DIR: |
| 268 | + test_device_functions_cubin = os.path.join( |
| 269 | + TEST_BIN_DIR, "test_device_functions.cubin" |
| 270 | + ) |
| 271 | + print(TEST_BIN_DIR) |
| 272 | + files = ( |
| 273 | + test_device_functions_cubin, |
| 274 | + ) |
| 275 | + for lto in [True, False]: |
| 276 | + for file in files: |
| 277 | + sig = "uint32(uint32, uint32)" |
| 278 | + add_from_numba = cuda.declare_device("add_from_numba", sig) |
| 279 | +
|
| 280 | + @cuda.jit(link=[file], lto=lto) |
| 281 | + def kernel(result): |
| 282 | + result[0] = add_from_numba(1, 2) |
| 283 | +
|
| 284 | + result = cuda.device_array(1) |
| 285 | + kernel[1, 1](result) |
| 286 | + assert result[0] == 3 |
| 287 | + """ |
| 288 | + env = os.environ.copy() |
| 289 | + env['NUMBA_CUDA_ENABLE_PYNVJITLINK'] = "1" |
| 290 | + print(env['NUMBA_CUDA_TEST_BIN_DIR']) |
| 291 | + run_in_subprocess(src, env=env) |
| 292 | + |
| 293 | + def test_linker_enabled_config(self): |
| 294 | + with override_config("CUDA_ENABLE_PYNVJITLINK", True): |
| 295 | + files = ( |
| 296 | + test_device_functions_cubin, |
| 297 | + ) |
| 298 | + for lto in [True, False]: |
| 299 | + for file in files: |
| 300 | + sig = "uint32(uint32, uint32)" |
| 301 | + add_from_numba = cuda.declare_device("add_from_numba", sig) |
| 302 | + |
| 303 | + @cuda.jit(link=[file], lto=lto) |
| 304 | + def kernel(result): |
| 305 | + result[0] = add_from_numba(1, 2) |
| 306 | + |
| 307 | + result = cuda.device_array(1) |
| 308 | + kernel[1, 1](result) |
| 309 | + assert result[0] == 3 |
| 310 | + |
| 311 | + |
254 | 312 | if __name__ == "__main__": |
255 | 313 | unittest.main() |
0 commit comments