Skip to content

Commit d2f71cb

Browse files
authored
make CuDNN finders respect library major version (pytorch#5399)
1 parent 40d79e4 commit d2f71cb

File tree

3 files changed

+28
-12
lines changed

3 files changed

+28
-12
lines changed

aten/cmake/FindCuDNN.cmake

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,21 @@ include(FindPackageHandleStandardArgs)
1515

1616
set(CUDNN_ROOT_DIR "" CACHE PATH "Folder contains NVIDIA cuDNN")
1717

18-
find_path(CUDNN_INCLUDE_DIR cudnn.h
18+
if($ENV{CUDNN_INCLUDE_DIR})
19+
SET(CUDNN_INCLUDE_DIR $ENV{CUDNN_INCLUDE_DIR})
20+
else($ENV{CUDNN_INCLUDE_DIR})
21+
find_path(CUDNN_INCLUDE_DIR cudnn.h
1922
HINTS ${CUDNN_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR}
2023
PATH_SUFFIXES cuda/include include)
24+
endif($ENV{CUDNN_INCLUDE_DIR})
2125

22-
find_library(CUDNN_LIBRARY cudnn
26+
if($ENV{CUDNN_LIBRARY})
27+
SET(CUDNN_LIBRARY $ENV{CUDNN_LIBRARY})
28+
else($ENV{CUDNN_LIBRARY})
29+
find_library(CUDNN_LIBRARY cudnn
2330
HINTS ${CUDNN_LIB_DIR} ${CUDNN_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR}
2431
PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64)
32+
endif($ENV{CUDNN_LIBRARY})
2533

2634
find_package_handle_standard_args(
2735
CUDNN DEFAULT_MSG CUDNN_INCLUDE_DIR CUDNN_LIBRARY)

setup.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -709,8 +709,7 @@ def run(self):
709709
"torch/csrc/cuda/python_nccl.cpp",
710710
]
711711
if WITH_CUDNN:
712-
main_libraries += ['cudnn']
713-
library_dirs.insert(0, CUDNN_LIB_DIR)
712+
main_libraries += [CUDNN_LIBRARY]
714713
# NOTE: these are at the front, in case there's another cuDNN in CUDA path
715714
include_dirs.insert(0, CUDNN_INCLUDE_DIR)
716715
if not IS_WINDOWS:

tools/setup_helpers/cudnn.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,22 @@
3535
if IS_CONDA:
3636
lib_paths.append(os.path.join(CONDA_DIR, 'lib'))
3737
include_paths.append(os.path.join(CONDA_DIR, 'include'))
38+
for path in include_paths:
39+
if path is None or not os.path.exists(path):
40+
continue
41+
include_file_path = os.path.join(path, 'cudnn.h')
42+
if os.path.exists(include_file_path):
43+
CUDNN_INCLUDE_DIR = path
44+
CUDNN_INCLUDE_VERSION = -1
45+
with open(include_file_path) as f:
46+
for line in f:
47+
if "#define CUDNN_MAJOR" in line:
48+
CUDNN_INCLUDE_VERSION = int(line.split()[-1])
49+
break
50+
if CUDNN_INCLUDE_VERSION == -1:
51+
raise AssertionError("Could not find #define CUDNN_MAJOR in " + include_file_path)
52+
break
53+
3854
for path in lib_paths:
3955
if path is None or not os.path.exists(path):
4056
continue
@@ -45,18 +61,11 @@
4561
CUDNN_LIB_DIR = path
4662
break
4763
else:
48-
libraries = sorted(glob.glob(os.path.join(path, 'libcudnn*')))
64+
libraries = sorted(glob.glob(os.path.join(path, 'libcudnn*' + str(CUDNN_INCLUDE_VERSION) + "*")))
4965
if libraries:
5066
CUDNN_LIBRARY = libraries[0]
5167
CUDNN_LIB_DIR = path
5268
break
53-
for path in include_paths:
54-
if path is None or not os.path.exists(path):
55-
continue
56-
if os.path.exists((os.path.join(path, 'cudnn.h'))):
57-
CUDNN_INCLUDE_DIR = path
58-
break
59-
6069
# Specifying the library directly will overwrite the lib directory
6170
library = os.getenv('CUDNN_LIBRARY')
6271
if library is not None and os.path.exists(library):

0 commit comments

Comments
 (0)