@@ -241,6 +241,7 @@ def get_cuda_paths():
241241 'libdevice' : _get_libdevice_paths (),
242242 'cudalib_dir' : _get_cudalib_dir (),
243243 'static_cudalib_dir' : _get_static_cudalib_dir (),
244+ 'include_dir' : _get_include_dir (),
244245 }
245246 # Cache result
246247 get_cuda_paths ._cached_result = d
@@ -256,3 +257,35 @@ def get_debian_pkg_libdevice():
256257 if not os .path .exists (pkg_libdevice_location ):
257258 return None
258259 return pkg_libdevice_location
260+
261+
262+ def get_conda_include_dir ():
263+ """
264+ Return the include directory in the current conda environment, if one
265+ is active and it exists.
266+ """
267+ conda_prefix = os .environ .get ('CONDA_PREFIX' )
268+ if conda_prefix :
269+ include_dir = os .path .join (conda_prefix , 'include' )
270+ if os .path .exists (include_dir ):
271+ return include_dir
272+ return None
273+
274+
275+ def get_system_include_dir ():
276+ """Return the system CUDA include directory, if it exists"""
277+ system_cuda_include = '/usr/local/cuda/include'
278+ if os .path .exists (system_cuda_include ):
279+ return system_cuda_include
280+ return None
281+
282+
283+ def _get_include_dir ():
284+ """Find the root include directory."""
285+ options = [
286+ ('Conda environment' , get_conda_include_dir ()),
287+ ('System' , get_system_include_dir ()),
288+ # TODO: add others
289+ ]
290+ by , include_dir = _find_valid_path (options )
291+ return _env_path_tuple (by , include_dir )
0 commit comments