|
47 | 47 | except ImportError: |
48 | 48 | NvJitLinker, NvJitLinkError = None, None |
49 | 49 |
|
| 50 | +from cuda.bindings.nvjitlink import nvJitLinkError |
| 51 | +from cuda.bindings import nvjitlink as _nvjitlink |
| 52 | + |
50 | 53 | USE_NV_BINDING = config.CUDA_USE_NVIDIA_BINDING |
51 | 54 |
|
52 | 55 | if USE_NV_BINDING: |
@@ -2594,7 +2597,7 @@ def new(cls, |
2594 | 2597 | "Enabling pynvjitlink requires CUDA 12." |
2595 | 2598 | ) |
2596 | 2599 | if config.CUDA_ENABLE_PYNVJITLINK: |
2597 | | - linker = PyNvJitLinker |
| 2600 | + linker = PyNvJitLinkerNew |
2598 | 2601 |
|
2599 | 2602 | elif config.CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY: |
2600 | 2603 | linker = MVCLinker |
@@ -3024,6 +3027,143 @@ def complete(self): |
3024 | 3027 | return bytes(np.ctypeslib.as_array(cubin_ptr, shape=(size,))) |
3025 | 3028 |
|
3026 | 3029 |
|
| 3030 | +class PyNvJitLinkerNew(Linker): |
| 3031 | + def __init__( |
| 3032 | + self, |
| 3033 | + max_registers=None, |
| 3034 | + lineinfo=False, |
| 3035 | + cc=None, |
| 3036 | + lto=False, |
| 3037 | + additional_flags=None, |
| 3038 | + ): |
| 3039 | + |
| 3040 | + if cc is None: |
| 3041 | + raise RuntimeError("NvJitLink requires CC to be specified") |
| 3042 | + if not any(isinstance(cc, t) for t in [list, tuple]): |
| 3043 | + raise TypeError("`cc` must be a list or tuple of length 2") |
| 3044 | + |
| 3045 | + sm_ver = f"{cc[0] * 10 + cc[1]}" |
| 3046 | + arch = f"-arch=sm_{sm_ver}" |
| 3047 | + options = [arch] |
| 3048 | + if max_registers: |
| 3049 | + options.append(f"-maxrregcount={max_registers}") |
| 3050 | + if lineinfo: |
| 3051 | + options.append("-lineinfo") |
| 3052 | + if lto: |
| 3053 | + options.append("-lto") |
| 3054 | + if additional_flags is not None: |
| 3055 | + options.extend(additional_flags) |
| 3056 | + |
| 3057 | + self.handle = _nvjitlink.create(len(options), options) |
| 3058 | + self.lto = lto |
| 3059 | + self.options = options |
| 3060 | + |
| 3061 | + self._info_log = None |
| 3062 | + self._error_log = None |
| 3063 | + self._complete = False |
| 3064 | + |
| 3065 | + @property |
| 3066 | + def info_log(self): |
| 3067 | + return self._info_log |
| 3068 | + |
| 3069 | + @property |
| 3070 | + def error_log(self): |
| 3071 | + return self._error_log |
| 3072 | + |
| 3073 | + def add_data(self, input_type, data, name): |
| 3074 | + if self._complete: |
| 3075 | + raise nvJitLinkError("Cannot add data to already-completeted link") |
| 3076 | + |
| 3077 | + try: |
| 3078 | + _nvjitlink.add_data( |
| 3079 | + self.handle, |
| 3080 | + input_type.value, |
| 3081 | + data, |
| 3082 | + len(data), |
| 3083 | + name |
| 3084 | + ) |
| 3085 | + except RuntimeError as e: |
| 3086 | + log_size = _nvjitlink.get_error_log_size(self.handle) |
| 3087 | + log = bytearray(log_size) |
| 3088 | + self._info_log = _nvjitlink.get_error_log(self.handle, log) |
| 3089 | + |
| 3090 | + log_size = _nvjitlink.get_info_log_size(self.handle) |
| 3091 | + log = bytearray(log_size) |
| 3092 | + self._error_log = _nvjitlink.get_info_log(self.handle, log) |
| 3093 | + raise nvJitLinkError(f"{e}\n{self.error_log}") |
| 3094 | + |
| 3095 | + def add_cubin(self, cubin, name=None): |
| 3096 | + name = name or "unnamed-cubin" |
| 3097 | + self.add_data(_nvjitlink.InputType.CUBIN, cubin, name) |
| 3098 | + |
| 3099 | + def add_ptx(self, ptx, name=None): |
| 3100 | + name = name or "unnamed-ptx" |
| 3101 | + self.add_data(_nvjitlink.InputType.PTX, ptx, name) |
| 3102 | + |
| 3103 | + def add_ltoir(self, ltoir, name=None): |
| 3104 | + name = name or "unnamed-ltoir" |
| 3105 | + self.add_data(_nvjitlink.InputType.LTOIR, ltoir, name) |
| 3106 | + |
| 3107 | + def add_object(self, object_, name=None): |
| 3108 | + name = name or "unnamed-object" |
| 3109 | + self.add_data(_nvjitlink.InputType.OBJECT, object_, name) |
| 3110 | + |
| 3111 | + def add_fatbin(self, fatbin, name=None): |
| 3112 | + name = name or "unnamed-fatbin" |
| 3113 | + self.add_data(_nvjitlink.InputType.FATBIN, fatbin, name) |
| 3114 | + |
| 3115 | + def add_library(self, library, name=None): |
| 3116 | + self.add_data(_nvjitlink.InputType.LIBRARY, library, name) |
| 3117 | + |
| 3118 | + def add_file(self, path, kind): |
| 3119 | + try: |
| 3120 | + with open(path, "rb") as f: |
| 3121 | + data = f.read() |
| 3122 | + except FileNotFoundError: |
| 3123 | + raise LinkerError(f"{path} not found") |
| 3124 | + |
| 3125 | + name = pathlib.Path(path).name |
| 3126 | + self.add_data(data, kind, name) |
| 3127 | + |
| 3128 | + def get_linked_cubin(self): |
| 3129 | + try: |
| 3130 | + _nvjitlink.complete(self.handle) |
| 3131 | + self._complete = True |
| 3132 | + size = _nvjitlink.get_linked_cubin_size(self.handle) |
| 3133 | + cubin = bytearray(size) |
| 3134 | + |
| 3135 | + _nvjitlink.get_linked_cubin(self.handle, cubin) |
| 3136 | + return cubin |
| 3137 | + |
| 3138 | + except nvJitLinkError as e: |
| 3139 | + size = _nvjitlink.get_error_log_size(self.handle) |
| 3140 | + log = bytearray(size) |
| 3141 | + self._error_log = _nvjitlink.get_error_log(self.handle, log) |
| 3142 | + raise nvJitLinkError(f"{e}\n{self.error_log}") |
| 3143 | + finally: |
| 3144 | + size = _nvjitlink.get_info_log_size(self.handle) |
| 3145 | + log = bytearray(size) |
| 3146 | + self._info_log = _nvjitlink.get_info_log(self.handle, log) |
| 3147 | + |
| 3148 | + def get_linked_ptx(self): |
| 3149 | + try: |
| 3150 | + _nvjitlink.complete(self.handle) |
| 3151 | + self._complete = True |
| 3152 | + size = _nvjitlink.get_linked_ptx_size(self.handle) |
| 3153 | + return _nvjitlink.get_linked_ptx(self.handle, size) |
| 3154 | + except RuntimeError as e: |
| 3155 | + self._error_log = _nvjitlink.get_error_log(self.handle) |
| 3156 | + raise NvJitLinkError(f"{e}\n{self.error_log}") |
| 3157 | + finally: |
| 3158 | + self._info_log = _nvjitlink.get_info_log(self.handle) |
| 3159 | + |
| 3160 | + def complete(self): |
| 3161 | + try: |
| 3162 | + return self.get_linked_cubin() |
| 3163 | + except NvJitLinkError as e: |
| 3164 | + raise LinkerError from e |
| 3165 | + |
| 3166 | + |
3027 | 3167 | class PyNvJitLinker(Linker): |
3028 | 3168 | def __init__( |
3029 | 3169 | self, |
|
0 commit comments