Skip to content

Commit f882ea1

Browse files
Merge pull request #1 from isVoid/numba-cuda-runtime
Determine conda include path based on machine kind
2 parents 5833c87 + c5b7df4 commit f882ea1

File tree

2 files changed

+34
-5
lines changed

2 files changed

+34
-5
lines changed

numba_cuda/numba/cuda/cuda_paths.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import re
33
import os
44
from collections import namedtuple
5+
import platform
56

67
from numba.core.config import IS_WIN32
78
from numba.misc.findlib import find_lib, find_file
@@ -259,14 +260,41 @@ def get_debian_pkg_libdevice():
259260
return pkg_libdevice_location
260261

261262

263+
def get_current_cuda_target_name():
264+
"""Determine conda's CTK target folder based on system and machine arch.
265+
266+
CTK's conda package delivers headers based on its architecture type. For example,
267+
`x86_64` machine places header under `$CONDA_PREFIX/targets/x86_64-linux`, and
268+
`aarch64` places under `$CONDA_PREFIX/targets/sbsa-linux`. Read more about the
269+
nuances at cudart's conda feedstock:
270+
https://github.com/conda-forge/cuda-cudart-feedstock/blob/main/recipe/meta.yaml#L8-L11 # noqa: E501
271+
"""
272+
system = platform.system()
273+
machine = platform.machine()
274+
275+
if system == "Linux":
276+
arch_to_targets = {
277+
'x86_64': 'x86_64-linux',
278+
'aarch64': 'sbsa-linux'
279+
}
280+
return arch_to_targets.get(machine)
281+
282+
return None
283+
262284
def get_conda_include_dir():
263285
"""
264286
Return the include directory in the current conda environment, if one
265287
is active and it exists.
266288
"""
267289
conda_prefix = os.environ.get('CONDA_PREFIX')
290+
target_name = get_current_cuda_target_name()
291+
268292
if conda_prefix:
269-
include_dir = os.path.join(conda_prefix, 'include')
293+
if target_name:
294+
include_dir = os.path.join(conda_prefix, f'targets/{target_name}/include')
295+
else:
296+
# A fallback when target cannot determined, though usually it shouldn't.
297+
include_dir = os.path.join(conda_prefix, f'include')
270298
if os.path.exists(include_dir):
271299
return include_dir
272300
return None

numba_cuda/numba/cuda/cudadrv/nvrtc.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from numba.core import config
44
from numba.cuda.cudadrv.error import (NvrtcError, NvrtcCompilationError,
55
NvrtcSupportError)
6-
from numba.cuda.cuda_paths import _get_include_dir
6+
from numba.cuda.cuda_paths import get_cuda_paths
77
import functools
88
import os
99
import threading
@@ -233,12 +233,13 @@ def compile(src, name, cc):
233233
# being optimized away.
234234
major, minor = cc
235235
arch = f'--gpu-architecture=compute_{major}{minor}'
236-
include = f'-I{config.CUDA_INCLUDE_PATH}'
236+
237+
cuda_include = f"-I{get_cuda_paths()['include_dir'].info}"
237238

238239
cudadrv_path = os.path.dirname(os.path.abspath(__file__))
239240
numba_cuda_path = os.path.dirname(cudadrv_path)
240-
numba_include = f'-I{numba_cuda_path} -I{_get_include_dir()}'
241-
options = [arch, include, numba_include, '-rdc', 'true']
241+
numba_include = f'-I{numba_cuda_path}'
242+
options = [arch, cuda_include, numba_include, '-rdc', 'true']
242243

243244
# Compile the program
244245
compile_error = nvrtc.compile_program(program, options)

0 commit comments

Comments
 (0)