Skip to content

Commit caf1558

Browse files
authored
Avoid Using Device Compute Capability for Linker Class (#429)
When debugging with @gmarkall, we discovered that the linker class is using the current device's compute capability for cuda source compilation. Because `compile_ptx` can be used to compile target code for a different architecture from current device. This PR pass through the CC used to construct the linker class to the cuda source compilation. This avoids CC inconsistencies between compile and linking. --------- Co-authored-by: Michael Wang <[email protected]>
1 parent 6b49be0 commit caf1558

File tree

2 files changed

+20
-9
lines changed

2 files changed

+20
-9
lines changed

numba_cuda/numba/cuda/cudadrv/driver.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2836,10 +2836,7 @@ def add_ptx(self, ptx, name):
28362836
def add_cu(self, cu, name):
28372837
"""Add CUDA source in a string to the link. The name of the source
28382838
file should be specified in `name`."""
2839-
with driver.get_active_context() as ac:
2840-
dev = driver.get_device(ac.devnum)
2841-
cc = dev.compute_capability
2842-
ptx, log = nvrtc.compile(cu, name, cc)
2839+
ptx, log = nvrtc.compile(cu, name, self.cc)
28432840

28442841
if config.DUMP_ASSEMBLY:
28452842
print(("ASSEMBLY %s" % name).center(80, "-"))
@@ -3003,10 +3000,7 @@ def add_ptx(self, ptx, name="<cudapy-ptx>"):
30033000
self._object_codes.append(obj)
30043001

30053002
def add_cu(self, cu, name="<cudapy-cu>"):
3006-
with driver.get_active_context() as ac:
3007-
dev = driver.get_device(ac.devnum)
3008-
cc = dev.compute_capability
3009-
obj, log = nvrtc.compile(cu, name, cc, ltoir=self.lto)
3003+
obj, log = nvrtc.compile(cu, name, self.cc, ltoir=self.lto)
30103004

30113005
if not self.lto and config.DUMP_ASSEMBLY:
30123006
print(("ASSEMBLY %s" % name).center(80, "-"))
@@ -3117,6 +3111,7 @@ def __init__(self, max_registers=0, lineinfo=False, cc=None):
31173111
if lineinfo:
31183112
options[enums.CU_JIT_GENERATE_LINE_INFO] = c_void_p(1)
31193113

3114+
self.cc = cc
31203115
if cc is None:
31213116
# No option value is needed, but we need something as a placeholder
31223117
options[enums.CU_JIT_TARGET_FROM_CUCONTEXT] = 1

numba_cuda/numba/cuda/tests/cudadrv/test_linker.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
import warnings
66
from numba import config
77
from numba.cuda.testing import unittest
8-
from numba.cuda.testing import skip_on_cudasim, skip_if_cuda_includes_missing
8+
from numba.cuda.testing import (
9+
skip_on_cudasim,
10+
skip_if_cuda_includes_missing,
11+
skip_if_nvjitlink_missing,
12+
)
913
from numba.cuda.testing import CUDATestCase, test_data_dir
1014
from numba.cuda.cudadrv.driver import CudaAPIError, _Linker, LinkerError
1115
from numba.cuda import require_context
@@ -329,6 +333,18 @@ def test_get_local_mem_per_specialized(self):
329333
calc_size = np.dtype(np.float64).itemsize * LMEM_SIZE
330334
self.assertGreaterEqual(local_mem_size, calc_size)
331335

336+
@skip_if_nvjitlink_missing("nvJitLink not installed or new enough (>12.3)")
337+
def test_link_for_different_cc(self):
338+
linker = _Linker.new(cc=(7, 5), lto=True)
339+
code = """
340+
__device__ int foo(int x) {
341+
return x + 1;
342+
}
343+
"""
344+
linker.add_cu(code, "foo")
345+
ptx = linker.get_linked_ptx().decode()
346+
assert "target sm_75" in ptx
347+
332348

333349
if __name__ == "__main__":
334350
unittest.main()

0 commit comments

Comments
 (0)