2121import traceback
2222import asyncio
2323import pathlib
24+ import subprocess
25+ import tempfile
26+ import re
2427from itertools import product
2528from abc import ABCMeta , abstractmethod
2629from ctypes import (c_int , byref , c_size_t , c_char , c_char_p , addressof ,
3639from .drvapi import API_PROTOTYPES
3740from .drvapi import cu_occupancy_b2d_size , cu_stream_callback_pyobj , cu_uuid
3841from .mappings import FILE_EXTENSION_MAP
39- from .linkable_code import LinkableCode , LTOIR
42+ from .linkable_code import LinkableCode , LTOIR , Fatbin , Object
4043from numba .cuda .cudadrv import enums , drvapi , nvrtc
4144
4245USE_NV_BINDING = config .CUDA_USE_NVIDIA_BINDING
@@ -2710,12 +2713,25 @@ def add_file_guess_ext(self, path_or_code, ignore_nonlto=False):
27102713 "Don't know how to link file with extension "
27112714 f"{ ext } "
27122715 )
2713- if ignore_nonlto and kind != FILE_EXTENSION_MAP ["ltoir" ]:
2714- warnings .warn (
2715- f"Not adding { path_or_code } as it is not optimizable "
2716- "at link time, and `ignore_nonlto == True`."
2717- )
2718- return
2716+
2717+ if ignore_nonlto :
2718+ warn_and_return = False
2719+ if kind in (
2720+ FILE_EXTENSION_MAP ["fatbin" ], FILE_EXTENSION_MAP ["o" ]
2721+ ):
2722+ entry_types = inspect_obj_content (path_or_code )
2723+ if "nvvm" not in entry_types :
2724+ warn_and_return = True
2725+ elif kind != FILE_EXTENSION_MAP ["ltoir" ]:
2726+ warn_and_return = True
2727+
2728+ if warn_and_return :
2729+ warnings .warn (
2730+ f"Not adding { path_or_code } as it is not "
2731+ "optimizable at link time, and `ignore_nonlto == "
2732+ "True`."
2733+ )
2734+ return
27192735
27202736 self .add_file (path_or_code , kind )
27212737 return
@@ -2729,12 +2745,24 @@ def add_file_guess_ext(self, path_or_code, ignore_nonlto=False):
27292745 if path_or_code .kind == "cu" :
27302746 self .add_cu (path_or_code .data , path_or_code .name )
27312747 else :
2732- if ignore_nonlto and not isinstance (path_or_code .kind , LTOIR ):
2733- warnings .warn (
2734- f"Not adding { path_or_code .name } as it is not "
2735- "optimizable at link time, and `ignore_nonlto == True`."
2736- )
2737- return
2748+ if ignore_nonlto :
2749+ warn_and_return = False
2750+ if isinstance (path_or_code , (Fatbin , Object )):
2751+ with tempfile .NamedTemporaryFile ("w" ) as fp :
2752+ fp .write (path_or_code .data )
2753+ entry_types = inspect_obj_content (fp .name )
2754+ if "nvvm" not in entry_types :
2755+ warn_and_return = True
2756+ elif not isinstance (path_or_code , LTOIR ):
2757+ warn_and_return = True
2758+
2759+ if warn_and_return :
2760+ warnings .warn (
2761+ f"Not adding { path_or_code .name } as it is not "
2762+ "optimizable at link time, and `ignore_nonlto == "
2763+ "True`."
2764+ )
2765+ return
27382766
27392767 self .add_data (
27402768 path_or_code .data , path_or_code .kind , path_or_code .name
@@ -3411,3 +3439,16 @@ def get_version():
34113439 Return the driver version as a tuple of (major, minor)
34123440 """
34133441 return driver .get_version ()
3442+
3443+
3444+ def inspect_obj_content (objpath : str ):
3445+ code_types :set [str ] = set ()
3446+
3447+ out = subprocess .run (["cuobjdump" , objpath ], capture_output = True )
3448+ objtable = out .stdout .decode ()
3449+ entry_pattern = r"Fatbin (.*) code"
3450+ for line in objtable .split ("\n " ):
3451+ if match := re .match (entry_pattern , line ):
3452+ code_types .add (match .group (1 ))
3453+
3454+ return code_types
0 commit comments