Skip to content

Commit bac11f6

Browse files
committed
Examine fatbin input that contains LTOIR
1 parent afcce87 commit bac11f6

File tree

3 files changed

+65
-13
lines changed

3 files changed

+65
-13
lines changed

numba_cuda/numba/cuda/cudadrv/driver.py

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
import traceback
2222
import asyncio
2323
import pathlib
24+
import subprocess
25+
import tempfile
26+
import re
2427
from itertools import product
2528
from abc import ABCMeta, abstractmethod
2629
from ctypes import (c_int, byref, c_size_t, c_char, c_char_p, addressof,
@@ -36,7 +39,7 @@
3639
from .drvapi import API_PROTOTYPES
3740
from .drvapi import cu_occupancy_b2d_size, cu_stream_callback_pyobj, cu_uuid
3841
from .mappings import FILE_EXTENSION_MAP
39-
from .linkable_code import LinkableCode, LTOIR
42+
from .linkable_code import LinkableCode, LTOIR, Fatbin, Object
4043
from numba.cuda.cudadrv import enums, drvapi, nvrtc
4144

4245
USE_NV_BINDING = config.CUDA_USE_NVIDIA_BINDING
@@ -2710,12 +2713,25 @@ def add_file_guess_ext(self, path_or_code, ignore_nonlto=False):
27102713
"Don't know how to link file with extension "
27112714
f"{ext}"
27122715
)
2713-
if ignore_nonlto and kind != FILE_EXTENSION_MAP["ltoir"]:
2714-
warnings.warn(
2715-
f"Not adding {path_or_code} as it is not optimizable "
2716-
"at link time, and `ignore_nonlto == True`."
2717-
)
2718-
return
2716+
2717+
if ignore_nonlto:
2718+
warn_and_return = False
2719+
if kind in (
2720+
FILE_EXTENSION_MAP["fatbin"], FILE_EXTENSION_MAP["o"]
2721+
):
2722+
entry_types = inspect_obj_content(path_or_code)
2723+
if "nvvm" not in entry_types:
2724+
warn_and_return = True
2725+
elif kind != FILE_EXTENSION_MAP["ltoir"]:
2726+
warn_and_return = True
2727+
2728+
if warn_and_return:
2729+
warnings.warn(
2730+
f"Not adding {path_or_code} as it is not "
2731+
"optimizable at link time, and `ignore_nonlto == "
2732+
"True`."
2733+
)
2734+
return
27192735

27202736
self.add_file(path_or_code, kind)
27212737
return
@@ -2729,12 +2745,24 @@ def add_file_guess_ext(self, path_or_code, ignore_nonlto=False):
27292745
if path_or_code.kind == "cu":
27302746
self.add_cu(path_or_code.data, path_or_code.name)
27312747
else:
2732-
if ignore_nonlto and not isinstance(path_or_code.kind, LTOIR):
2733-
warnings.warn(
2734-
f"Not adding {path_or_code.name} as it is not "
2735-
"optimizable at link time, and `ignore_nonlto == True`."
2736-
)
2737-
return
2748+
if ignore_nonlto:
2749+
warn_and_return = False
2750+
if isinstance(path_or_code, (Fatbin, Object)):
2751+
with tempfile.NamedTemporaryFile("w") as fp:
2752+
fp.write(path_or_code.data)
2753+
entry_types = inspect_obj_content(fp.name)
2754+
if "nvvm" not in entry_types:
2755+
warn_and_return = True
2756+
elif not isinstance(path_or_code, LTOIR):
2757+
warn_and_return = True
2758+
2759+
if warn_and_return:
2760+
warnings.warn(
2761+
f"Not adding {path_or_code.name} as it is not "
2762+
"optimizable at link time, and `ignore_nonlto == "
2763+
"True`."
2764+
)
2765+
return
27382766

27392767
self.add_data(
27402768
path_or_code.data, path_or_code.kind, path_or_code.name
@@ -3411,3 +3439,16 @@ def get_version():
34113439
Return the driver version as a tuple of (major, minor)
34123440
"""
34133441
return driver.get_version()
3442+
3443+
3444+
def inspect_obj_content(objpath: str):
3445+
code_types :set[str] = set()
3446+
3447+
out = subprocess.run(["cuobjdump", objpath], capture_output=True)
3448+
objtable = out.stdout.decode()
3449+
entry_pattern = r"Fatbin (.*) code"
3450+
for line in objtable.split("\n"):
3451+
if match := re.match(entry_pattern, line):
3452+
code_types.add(match.group(1))
3453+
3454+
return code_types

numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
test_device_functions_fatbin = os.path.join(
2828
TEST_BIN_DIR, "test_device_functions.fatbin"
2929
)
30+
test_device_functions_fatbin_multi = os.path.join(
31+
TEST_BIN_DIR, "test_device_functions_multi.fatbin"
32+
)
3033
test_device_functions_o = os.path.join(
3134
TEST_BIN_DIR, "test_device_functions.o"
3235
)
@@ -178,6 +181,7 @@ def test_nvjitlink_jit_with_linkable_code_lto_dump_assembly(self):
178181
files = [
179182
test_device_functions_cu,
180183
test_device_functions_ltoir,
184+
test_device_functions_fatbin_multi
181185
]
182186

183187
config.DUMP_ASSEMBLY = True

numba_cuda/numba/cuda/tests/test_binary_generation/Makefile

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,14 @@ endif
1414
# Gencode flags suitable for most tests
1515
GENCODE := -gencode arch=compute_$(GPU_CC),code=sm_$(GPU_CC)
1616

17+
MULTI_GENCODE := -gencode arch=compute_$(GPU_CC),code=[sm_$(GPU_CC),lto_$(GPU_CC)]
18+
1719
# Fatbin tests need to generate code for an additional compute capability
1820
FATBIN_GENCODE := $(GENCODE) -gencode arch=compute_$(ALT_CC),code=sm_$(ALT_CC)
1921

22+
# Fatbin that contains both LTO, SASS for multiple architectures
23+
MULTI_FATBIN_GENCODE := $(MULTI_GENCODE) -gencode arch=compute_$(ALT_CC),code=[sm_$(ALT_CC),lto_$(ALT_CC)]
24+
2025
# LTO-IR tests need to generate for the LTO "architecture" instead
2126
LTOIR_GENCODE := -gencode arch=lto_$(GPU_CC),code=lto_$(GPU_CC)
2227

@@ -30,6 +35,7 @@ PTX_FLAGS := $(GENCODE) -ptx
3035
OBJECT_FLAGS := $(GENCODE) -dc
3136
LIBRARY_FLAGS := $(GENCODE) -lib
3237
FATBIN_FLAGS := $(FATBIN_GENCODE) --fatbin
38+
MULTI_FATBIN_FLAGS := $(MULTI_FATBIN_GENCODE) --fatbin
3339
LTOIR_FLAGS := $(LTOIR_GENCODE) -dc
3440

3541
OUTPUT_DIR := ./
@@ -41,6 +47,7 @@ all:
4147
nvcc $(NVCC_FLAGS) $(CUBIN_FLAGS) -o $(OUTPUT_DIR)/undefined_extern.cubin undefined_extern.cu
4248
nvcc $(NVCC_FLAGS) $(CUBIN_FLAGS) -o $(OUTPUT_DIR)/test_device_functions.cubin test_device_functions.cu
4349
nvcc $(NVCC_FLAGS) $(FATBIN_FLAGS) -o $(OUTPUT_DIR)/test_device_functions.fatbin test_device_functions.cu
50+
nvcc $(NVCC_FLAGS) $(MULTI_FATBIN_FLAGS) -o $(OUTPUT_DIR)/test_device_functions_multi.fatbin test_device_functions.cu
4451
nvcc $(NVCC_FLAGS) $(PTX_FLAGS) -o $(OUTPUT_DIR)/test_device_functions.ptx test_device_functions.cu
4552
nvcc $(NVCC_FLAGS) $(OBJECT_FLAGS) -o $(OUTPUT_DIR)/test_device_functions.o test_device_functions.cu
4653
nvcc $(NVCC_FLAGS) $(LIBRARY_FLAGS) -o $(OUTPUT_DIR)/test_device_functions.a test_device_functions.cu

0 commit comments

Comments
 (0)