|
2 | 2 | import re |
3 | 3 | import os |
4 | 4 | from collections import namedtuple |
| 5 | +import platform |
5 | 6 |
|
6 | 7 | from numba.core.config import IS_WIN32 |
7 | 8 | from numba.misc.findlib import find_lib, find_file |
@@ -259,14 +260,41 @@ def get_debian_pkg_libdevice(): |
259 | 260 | return pkg_libdevice_location |
260 | 261 |
|
261 | 262 |
|
| 263 | +def get_current_cuda_target_name(): |
| 264 | + """Determine conda's CTK target folder based on system and machine arch. |
| 265 | + |
| 266 | + CTK's conda package delivers headers based on its architecture type. For example, |
| 267 | + `x86_64` machine places header under `$CONDA_PREFIX/targets/x86_64-linux`, and |
| 268 | + `aarch64` places under `$CONDA_PREFIX/targets/sbsa-linux`. Read more about the |
| 269 | + nuances at cudart's conda feedstock: |
| 270 | + https://github.com/conda-forge/cuda-cudart-feedstock/blob/main/recipe/meta.yaml#L8-L11 # noqa: E501 |
| 271 | + """ |
| 272 | + system = platform.system() |
| 273 | + machine = platform.machine() |
| 274 | + |
| 275 | + if system == "Linux": |
| 276 | + arch_to_targets = { |
| 277 | + 'x86_64': 'x86_64-linux', |
| 278 | + 'aarch64': 'sbsa-linux' |
| 279 | + } |
| 280 | + return arch_to_targets.get(machine) |
| 281 | + |
| 282 | + return None |
| 283 | + |
262 | 284 | def get_conda_include_dir(): |
263 | 285 | """ |
264 | 286 | Return the include directory in the current conda environment, if one |
265 | 287 | is active and it exists. |
266 | 288 | """ |
267 | 289 | conda_prefix = os.environ.get('CONDA_PREFIX') |
| 290 | + target_name = get_current_cuda_target_name() |
| 291 | + |
268 | 292 | if conda_prefix: |
269 | | - include_dir = os.path.join(conda_prefix, 'include') |
| 293 | + if target_name: |
| 294 | + include_dir = os.path.join(conda_prefix, f'targets/{target_name}/include') |
| 295 | + else: |
| 296 | + # A fallback when target cannot determined, though usually it shouldn't. |
| 297 | + include_dir = os.path.join(conda_prefix, f'include') |
270 | 298 | if os.path.exists(include_dir): |
271 | 299 | return include_dir |
272 | 300 | return None |
|
0 commit comments