@@ -81,22 +81,32 @@ if(charonload_FIND_QUIETLY)
8181 set (CUDNN_FIND_QUIETLY 1)
8282endif ()
8383
84- # Back up CUDA_NVCC_FLAGS for later restoring
84+ # Back up CUDA_NVCC_FLAGS and CMAKE_CUDA_FLAGS for later restoring
8585set (CHARONLOAD_CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} )
86+ set (CHARONLOAD_CMAKE_CUDA_FLAGS ${CMAKE_CUDA_FLAGS} )
8687
8788find_dependency(Torch)
8889
8990list (POP_BACK CMAKE_MESSAGE_INDENT)
9091
9192if (Torch_FOUND)
9293 # 1. CUDA flag patching
93- if (NOT CHARONLOAD_CUDA_NVCC_FLAGS STREQUAL CUDA_NVCC_FLAGS AND TARGET torch_cuda)
94+ message (STATUS "${CUDA_NVCC_FLAGS} " )
95+ message (STATUS "${CHARONLOAD_CUDA_NVCC_FLAGS} " )
96+ if ((NOT CHARONLOAD_CUDA_NVCC_FLAGS STREQUAL CUDA_NVCC_FLAGS OR NOT CHARONLOAD_CUDA_NVCC_FLAGS STREQUAL CUDA_NVCC_FLAGS) AND TARGET torch_cuda)
97+ # Extract modified flags
98+ string (REPLACE "${CHARONLOAD_CUDA_NVCC_FLAGS} " "" CHARONLOAD_CUDA_NVCC_FLAGS_MODIFIED "${CUDA_NVCC_FLAGS} " )
99+ string (REPLACE ";" " " CHARONLOAD_CUDA_NVCC_FLAGS_MODIFIED "${CHARONLOAD_CUDA_NVCC_FLAGS_MODIFIED} " )
100+ string (STRIP "${CHARONLOAD_CUDA_NVCC_FLAGS_MODIFIED} " CHARONLOAD_CUDA_NVCC_FLAGS_MODIFIED)
101+
94102 # Use modified CUDA_NVCC_FLAGS
95- target_compile_options (torch_cuda INTERFACE $<$<COMPILE_LANGUAGE:CUDA>:${CUDA_NVCC_FLAGS} >)
103+ target_compile_options (torch_cuda INTERFACE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:${CHARONLOAD_CUDA_NVCC_FLAGS_MODIFIED} >" )
104+ unset (CHARONLOAD_CUDA_NVCC_FLAGS_MODIFIED)
96105
97- # Restore CUDA_NVCC_FLAGS
106+ # Restore CUDA_NVCC_FLAGS and CMAKE_CUDA_FLAGS
98107 set (CUDA_NVCC_FLAGS ${CHARONLOAD_CUDA_NVCC_FLAGS} )
99- message (STATUS "Patched target \" torch_cuda\" with modified \" CUDA_NVCC_FLAGS\" settings and rolled back the variable modifications." )
108+ set (CMAKE_CUDA_FLAGS ${CHARONLOAD_CMAKE_CUDA_FLAGS} )
109+ message (STATUS "Patched target \" torch_cuda\" with modified \" CUDA_NVCC_FLAGS\" /\" CMAKE_CUDA_FLAGS\" settings and rolled back the variable modifications." )
100110 endif ()
101111
102112 # 2. Python bindings library
@@ -120,6 +130,7 @@ endif()
120130
121131# Clean up backup variable
122132unset (CHARONLOAD_CUDA_NVCC_FLAGS)
133+ unset (CHARONLOAD_CMAKE_CUDA_FLAGS)
123134
124135
125136include ("${CMAKE_CURRENT_LIST_DIR} /torch/cxx_standard.cmake" )
0 commit comments