Skip to content

Commit f04a474

Browse files
committed
Fix propagation of PyTorch CUDA flags for multiple archs
1 parent 07363bd commit f04a474

File tree

2 files changed

+54
-5
lines changed

2 files changed

+54
-5
lines changed

src/charonload/cmake/charonload-config.cmake

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,22 +81,32 @@ if(charonload_FIND_QUIETLY)
8181
set(CUDNN_FIND_QUIETLY 1)
8282
endif()
8383

84-
# Back up CUDA_NVCC_FLAGS for later restoring
84+
# Back up CUDA_NVCC_FLAGS and CMAKE_CUDA_FLAGS for later restoring
8585
set(CHARONLOAD_CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS})
86+
set(CHARONLOAD_CMAKE_CUDA_FLAGS ${CMAKE_CUDA_FLAGS})
8687

8788
find_dependency(Torch)
8889

8990
list(POP_BACK CMAKE_MESSAGE_INDENT)
9091

9192
if(Torch_FOUND)
9293
# 1. CUDA flag patching
93-
if(NOT CHARONLOAD_CUDA_NVCC_FLAGS STREQUAL CUDA_NVCC_FLAGS AND TARGET torch_cuda)
94+
message(STATUS "${CUDA_NVCC_FLAGS}")
95+
message(STATUS "${CHARONLOAD_CUDA_NVCC_FLAGS}")
96+
if((NOT CHARONLOAD_CUDA_NVCC_FLAGS STREQUAL CUDA_NVCC_FLAGS OR NOT CHARONLOAD_CUDA_NVCC_FLAGS STREQUAL CUDA_NVCC_FLAGS) AND TARGET torch_cuda)
97+
# Extract modified flags
98+
string(REPLACE "${CHARONLOAD_CUDA_NVCC_FLAGS}" "" CHARONLOAD_CUDA_NVCC_FLAGS_MODIFIED "${CUDA_NVCC_FLAGS}")
99+
string(REPLACE ";" " " CHARONLOAD_CUDA_NVCC_FLAGS_MODIFIED "${CHARONLOAD_CUDA_NVCC_FLAGS_MODIFIED}")
100+
string(STRIP "${CHARONLOAD_CUDA_NVCC_FLAGS_MODIFIED}" CHARONLOAD_CUDA_NVCC_FLAGS_MODIFIED)
101+
94102
# Use modified CUDA_NVCC_FLAGS
95-
target_compile_options(torch_cuda INTERFACE $<$<COMPILE_LANGUAGE:CUDA>:${CUDA_NVCC_FLAGS}>)
103+
target_compile_options(torch_cuda INTERFACE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:${CHARONLOAD_CUDA_NVCC_FLAGS_MODIFIED}>")
104+
unset(CHARONLOAD_CUDA_NVCC_FLAGS_MODIFIED)
96105

97-
# Restore CUDA_NVCC_FLAGS
106+
# Restore CUDA_NVCC_FLAGS and CMAKE_CUDA_FLAGS
98107
set(CUDA_NVCC_FLAGS ${CHARONLOAD_CUDA_NVCC_FLAGS})
99-
message(STATUS "Patched target \"torch_cuda\" with modified \"CUDA_NVCC_FLAGS\" settings and rolled back the variable modifications.")
108+
set(CMAKE_CUDA_FLAGS ${CHARONLOAD_CMAKE_CUDA_FLAGS})
109+
message(STATUS "Patched target \"torch_cuda\" with modified \"CUDA_NVCC_FLAGS\"/\"CMAKE_CUDA_FLAGS\" settings and rolled back the variable modifications.")
100110
endif()
101111

102112
# 2. Python bindings library
@@ -120,6 +130,7 @@ endif()
120130

121131
# Clean up backup variable
122132
unset(CHARONLOAD_CUDA_NVCC_FLAGS)
133+
unset(CHARONLOAD_CMAKE_CUDA_FLAGS)
123134

124135

125136
include("${CMAKE_CURRENT_LIST_DIR}/torch/cxx_standard.cmake")

tests/test_finder.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import importlib.util
66
import io
77
import multiprocessing
8+
import os
89
import pathlib
910
import platform
1011
import shutil
@@ -87,6 +88,43 @@ def test_torch_cuda(shared_datadir: pathlib.Path, tmp_path: pathlib.Path) -> Non
8788
assert torch.equal(t_output, 2 * t_input)
8889

8990

91+
def _torch_cuda_custom_archs(shared_datadir: pathlib.Path, tmp_path: pathlib.Path) -> None:
92+
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0 8.6" # Use 2 different archs
93+
94+
project_directory = shared_datadir / "torch_cuda"
95+
build_directory = tmp_path / "build"
96+
97+
charonload.module_config["test_torch_cuda_custom_archs"] = charonload.Config(
98+
project_directory,
99+
build_directory,
100+
stubs_directory=VSCODE_STUBS_DIRECTORY,
101+
)
102+
103+
import test_torch_cuda_custom_archs as test_torch
104+
105+
t_input = torch.randint(0, 10, size=(3, 3, 3), dtype=torch.float, device="cuda")
106+
t_output = test_torch.two_times(t_input)
107+
108+
assert t_output.device == t_input.device
109+
assert t_output.shape == t_input.shape
110+
assert torch.equal(t_output, 2 * t_input)
111+
112+
113+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
114+
def test_torch_cuda_custom_archs(shared_datadir: pathlib.Path, tmp_path: pathlib.Path) -> None:
115+
p = multiprocessing.get_context("spawn").Process(
116+
target=_torch_cuda_custom_archs,
117+
args=(
118+
shared_datadir,
119+
tmp_path,
120+
),
121+
)
122+
123+
p.start()
124+
p.join()
125+
assert p.exitcode == 0
126+
127+
90128
def test_torch_common_static(shared_datadir: pathlib.Path, tmp_path: pathlib.Path) -> None:
91129
project_directory = shared_datadir / "torch_common_static"
92130
build_directory = tmp_path / "build"

0 commit comments

Comments
 (0)