Skip to content

Commit 4ba8d90

Browse files
find extra includes dynamically
1 parent 327f908 commit 4ba8d90

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

numba_cuda/numba/cuda/cuda_paths.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ def get_cuda_paths():
241241
'libdevice': _get_libdevice_paths(),
242242
'cudalib_dir': _get_cudalib_dir(),
243243
'static_cudalib_dir': _get_static_cudalib_dir(),
244+
'include_dir': _get_include_dir(),
244245
}
245246
# Cache result
246247
get_cuda_paths._cached_result = d
@@ -256,3 +257,35 @@ def get_debian_pkg_libdevice():
256257
if not os.path.exists(pkg_libdevice_location):
257258
return None
258259
return pkg_libdevice_location
260+
261+
262+
def get_conda_include_dir():
263+
"""
264+
Return the include directory in the current conda environment, if one
265+
is active and it exists.
266+
"""
267+
conda_prefix = os.environ.get('CONDA_PREFIX')
268+
if conda_prefix:
269+
include_dir = os.path.join(conda_prefix, 'include')
270+
if os.path.exists(include_dir):
271+
return include_dir
272+
return None
273+
274+
275+
def get_system_include_dir():
276+
"""Return the system CUDA include directory, if it exists"""
277+
system_cuda_include = '/usr/local/cuda/include'
278+
if os.path.exists(system_cuda_include):
279+
return system_cuda_include
280+
return None
281+
282+
283+
def _get_include_dir():
284+
"""Find the root include directory."""
285+
options = [
286+
('Conda environment', get_conda_include_dir()),
287+
('System', get_system_include_dir()),
288+
# TODO: add others
289+
]
290+
by, include_dir = _find_valid_path(options)
291+
return _env_path_tuple(by, include_dir)

numba_cuda/numba/cuda/cudadrv/nvrtc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from numba.core import config
44
from numba.cuda.cudadrv.error import (NvrtcError, NvrtcCompilationError,
55
NvrtcSupportError)
6-
6+
from numba.cuda.cuda_paths import _get_include_dir
77
import functools
88
import os
99
import threading
@@ -237,7 +237,7 @@ def compile(src, name, cc):
237237

238238
cudadrv_path = os.path.dirname(os.path.abspath(__file__))
239239
numba_cuda_path = os.path.dirname(cudadrv_path)
240-
numba_include = f'-I{numba_cuda_path}'
240+
numba_include = f'-I{numba_cuda_path} -I{_get_include_dir()}'
241241
options = [arch, include, numba_include, '-rdc', 'true']
242242

243243
# Compile the program

0 commit comments

Comments
 (0)