Skip to content

Commit 4f2bc2b

Browse files
reset files
1 parent b4ededf commit 4f2bc2b

File tree

3 files changed

+7
-27
lines changed

3 files changed

+7
-27
lines changed

numba_cuda/numba/cuda/cudadrv/libs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import os
1414
import sys
1515
import ctypes
16+
1617
from numba.misc.findlib import find_lib
1718
from numba.cuda.cuda_paths import get_cuda_paths
1819
from numba.cuda.cudadrv.driver import locate_driver_and_loader, load_driver
@@ -50,7 +51,7 @@ def get_cudalib(lib, static=False):
5051
loader's search mechanism.
5152
"""
5253
if lib == 'nvvm':
53-
return get_cuda_paths()['nvvm'].info
54+
return get_cuda_paths()['nvvm'].info or _dllnamepattern % 'nvvm'
5455
else:
5556
dir_type = 'static_cudalib_dir' if static else 'cudalib_dir'
5657
libdir = get_cuda_paths()[dir_type].info

numba_cuda/numba/cuda/cudadrv/nvrtc.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,10 @@
1-
from ctypes import (
2-
byref,
3-
c_char,
4-
c_char_p,
5-
c_int,
6-
c_size_t,
7-
c_void_p,
8-
POINTER,
9-
)
1+
from ctypes import byref, c_char, c_char_p, c_int, c_size_t, c_void_p, POINTER
102
from enum import IntEnum
113
from numba.cuda.cudadrv.error import (NvrtcError, NvrtcCompilationError,
124
NvrtcSupportError)
135
from numba.cuda.cuda_paths import get_cuda_paths
146
import functools
157
import os
16-
import sys
178
import threading
189
import warnings
1910

@@ -23,9 +14,6 @@
2314
# Result code
2415
nvrtc_result = c_int
2516

26-
PLATFORM_LINUX = sys.platform.startswith("linux")
27-
PLATFORM_WIN = sys.platform.startswith("win32")
28-
2917

3018
class NvrtcResult(IntEnum):
3119
NVRTC_SUCCESS = 0
@@ -43,7 +31,6 @@ class NvrtcResult(IntEnum):
4331

4432

4533
_nvrtc_lock = threading.Lock()
46-
_nvrtc_obj = []
4734

4835

4936
class NvrtcProgram:
@@ -123,9 +110,10 @@ class NVRTC:
123110
def __new__(cls):
124111
with _nvrtc_lock:
125112
if cls.__INSTANCE is None:
113+
from numba.cuda.cudadrv.libs import open_cudalib
126114
cls.__INSTANCE = inst = object.__new__(cls)
127115
try:
128-
lib = _nvrtc_obj[0]
116+
lib = open_cudalib('nvrtc')
129117
except OSError as e:
130118
cls.__INSTANCE = None
131119
raise NvrtcSupportError("NVRTC cannot be loaded") from e

numba_cuda/numba/cuda/cudadrv/nvvm.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,8 @@
55
import re
66
import sys
77
import warnings
8-
from ctypes import (
9-
c_void_p,
10-
c_int,
11-
POINTER,
12-
c_char_p,
13-
c_size_t,
14-
byref,
15-
c_char
16-
)
8+
from ctypes import (c_void_p, c_int, POINTER, c_char_p, c_size_t, byref,
9+
c_char)
1710

1811
import threading
1912

@@ -23,8 +16,6 @@
2316
from .libs import get_libdevice, open_libdevice, open_cudalib
2417
from numba.core import cgutils, config
2518

26-
PLATFORM_LINUX = sys.platform.startswith("linux")
27-
PLATFORM_WIN = sys.platform.startswith("win32")
2819

2920
logger = logging.getLogger(__name__)
3021

0 commit comments

Comments
 (0)