Skip to content

Commit 5d5a4c3

Browse files
committed
Propagate PyTorch CUDA flags via CMake target
1 parent c3b3825 commit 5d5a4c3

File tree

14 files changed

+226
-15
lines changed

14 files changed

+226
-15
lines changed

CMakeLists.txt

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,3 @@ endif()
1111
list(APPEND CMAKE_PREFIX_PATH "${CMAKE_CURRENT_SOURCE_DIR}/src/charonload")
1212

1313
find_package(charonload REQUIRED GLOBAL)
14-
15-
16-
set(MODIFIED_CMAKE_VARIABLES
17-
# Set by Caffe2/public/cuda.cmake as part of TorchConfig.cmake
18-
"CMAKE_CUDA_FLAGS"
19-
"CMAKE_CUDA_FLAGS_DEBUG"
20-
"CMAKE_CUDA_FLAGS_MINSIZEREL"
21-
"CMAKE_CUDA_FLAGS_RELEASE"
22-
"CMAKE_CUDA_FLAGS_RELWITHDEBINFO"
23-
)
24-
25-
foreach(var IN LISTS MODIFIED_CMAKE_VARIABLES)
26-
set(${var} ${${var}} PARENT_SCOPE)
27-
endforeach()

docs/src/pytorch_handling.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,16 @@ Set required PIC flag for linking.
4040

4141
::::
4242

43+
::::{grid-item-card}
44+
:link: pytorch_handling/cuda_flags
45+
:link-type: doc
46+
47+
**CUDA Flags**
48+
^^^
49+
Set required CUDA flags for parents and siblings.
50+
51+
::::
52+
4353
:::::
4454

4555

@@ -49,4 +59,5 @@ Set required PIC flag for linking.
4959
pytorch_handling/cpp_standard
5060
pytorch_handling/cpp11_abi
5161
pytorch_handling/position_independent_code
62+
pytorch_handling/cuda_flags
5263
```
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# CUDA Flags
2+
3+
In order to simplify writing CUDA kernels, the PyTorch C++ library enables several compiler flags:
4+
5+
- Using CUDA architectures of detected GPUs
6+
- Enabling `__host__ __device__` lambda functions, e.g., with thrust/CUB algorithms
7+
- Enabling relaxed `constexpr` rules to reuse, e.g., `std::clamp` in kernels and `__device__` functions
8+
- Suppressing some noisy warnings
9+
10+
However, the PyTorch C++ library provides these flags by modifying the (old-school) [``CUDA_NVCC_FLAGS``](<inv:cmake.org#module/FindCUDA>) variable. Although CMake will pick up the variable, the modifications are **only** visible in the directory (and subdirectory) scope(s) where PyTorch has been found by [``find_package``](<inv:cmake.org#command/find_package>). This may lead to compiler errors for depending targets in parent or sibling directories when finding PyTorch with the ``GLOBAL`` option enabled, as this promotes **only** the respective targets to all scopes but leaves the variables modifications in the calling scope.
11+
12+
13+
Charonload automatically detects the modified compile flags and attaches them as an `INTERFACE` property to the CUDA target of the PyTorch C++ library, such that they will be correctly propagated to any linking target.

src/charonload/cmake/charonload-config.cmake

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,25 @@ if(charonload_FIND_QUIETLY)
8181
set(CUDNN_FIND_QUIETLY 1)
8282
endif()
8383

84+
# Back up CUDA_NVCC_FLAGS for later restoring
85+
set(CHARONLOAD_CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS})
86+
8487
find_dependency(Torch)
8588

89+
list(POP_BACK CMAKE_MESSAGE_INDENT)
90+
8691
if(Torch_FOUND)
92+
# 1. CUDA flag patching
93+
if(NOT CHARONLOAD_CUDA_NVCC_FLAGS STREQUAL CUDA_NVCC_FLAGS AND TARGET torch_cuda)
94+
# Use modified CUDA_NVCC_FLAGS
95+
target_compile_options(torch_cuda INTERFACE $<$<COMPILE_LANGUAGE:CUDA>:${CUDA_NVCC_FLAGS}>)
96+
97+
# Restore CUDA_NVCC_FLAGS
98+
set(CUDA_NVCC_FLAGS ${CHARONLOAD_CUDA_NVCC_FLAGS})
99+
message(STATUS "Patched target \"torch_cuda\" with modified \"CUDA_NVCC_FLAGS\" settings and rolled back the variable modifications.")
100+
endif()
101+
102+
# 2. Python bindings library
87103
get_target_property(TORCH_LIBRARY_LOCATION torch LOCATION)
88104
get_filename_component(TORCH_LIB_SEARCH_PATH ${TORCH_LIBRARY_LOCATION} DIRECTORY)
89105

@@ -102,7 +118,8 @@ if(Torch_FOUND)
102118
endif()
103119
endif()
104120

105-
list(POP_BACK CMAKE_MESSAGE_INDENT)
121+
# Clean up backup variable
122+
unset(CHARONLOAD_CUDA_NVCC_FLAGS)
106123

107124

108125
include("${CMAKE_CURRENT_LIST_DIR}/torch/cxx_standard.cmake")

tests/data/torch_cuda/two_times_cuda.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
#include <ATen/ops/zeros_like.h>
33
#include <c10/cuda/CUDAException.h>
44

5+
#ifndef __CUDACC_EXTENDED_LAMBDA__
6+
#error "Modified CUDA_NVCC_FLAGS (extended lambda) from torch not correctly propagated"
7+
#endif
8+
59
template <class T>
610
__global__ void
711
two_times_kernel(const T* const input, T* const output, const std::size_t N)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
cmake_minimum_required(VERSION 3.27)
2+
3+
project(torch_cuda_subdirectory LANGUAGES CXX CUDA)
4+
5+
add_subdirectory(lib_dir)
6+
7+
# Must come AFTER lib_dir
8+
add_subdirectory(binding_dir)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# This short-circuits from GLOBAL find_package
2+
find_package(charonload)
3+
4+
if(charonload_FOUND)
5+
charonload_add_torch_library(${TORCH_EXTENSION_NAME} MODULE)
6+
7+
target_sources(${TORCH_EXTENSION_NAME} PRIVATE bindings.cpp three_times_cuda.cu)
8+
target_link_libraries(${TORCH_EXTENSION_NAME} PRIVATE torch_cuda_subdirectory)
9+
endif()
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#include <torch/python.h>
2+
3+
#include "three_times_cuda.h"
4+
#include "two_times_cuda.h"
5+
6+
using namespace pybind11::literals;
7+
8+
#define STRINGIFY_IMPL(x) #x
9+
#define STRINGIFY(a) STRINGIFY_IMPL(a)
10+
11+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
12+
{
13+
m.doc() = "A C++/CUDA extension module named \"" STRINGIFY(TORCH_EXTENSION_NAME) "\" that is built just-in-time.";
14+
15+
m.def("two_times", &two_times, "input"_a, R"(
16+
Multiply the given input tensor by a factor of 2 on the GPU using CUDA.
17+
18+
Parameters
19+
----------
20+
input
21+
A tensor with arbitrary shape and dtype.
22+
23+
Returns
24+
-------
25+
A new tensor with the same shape and dtype as ``input`` and where each value is multiplied by 2.
26+
)");
27+
28+
m.def("three_times", &three_times, "input"_a, R"(
29+
Multiply the given input tensor by a factor of 3 on the GPU using CUDA.
30+
31+
Parameters
32+
----------
33+
input
34+
A tensor with arbitrary shape and dtype.
35+
36+
Returns
37+
-------
38+
A new tensor with the same shape and dtype as ``input`` and where each value is multiplied by 3.
39+
)");
40+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#include <ATen/Dispatch.h>
2+
#include <ATen/ops/zeros_like.h>
3+
#include <c10/cuda/CUDAException.h>
4+
5+
#ifndef __CUDACC_EXTENDED_LAMBDA__
6+
#error "Modified CUDA_NVCC_FLAGS (extended lambda) from torch not correctly propagated"
7+
#endif
8+
9+
template <class T>
10+
__global__ void
11+
three_times_kernel(const T* const input, T* const output, const std::size_t N)
12+
{
13+
for (int i = blockDim.x * blockIdx.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x)
14+
{
15+
output[i] = T(3) * input[i];
16+
}
17+
}
18+
19+
at::Tensor
20+
three_times(const at::Tensor& input)
21+
{
22+
auto output = at::zeros_like(input);
23+
24+
AT_DISPATCH_ALL_TYPES(input.scalar_type(),
25+
"three_times_kernel",
26+
[&]()
27+
{
28+
const std::uint32_t block_size = 128;
29+
const std::uint32_t num_blocks = (input.numel() + block_size - 1) / block_size;
30+
three_times_kernel<<<num_blocks, block_size>>>(input.data_ptr<scalar_t>(),
31+
output.data_ptr<scalar_t>(),
32+
input.numel());
33+
C10_CUDA_KERNEL_LAUNCH_CHECK();
34+
});
35+
36+
return output;
37+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#pragma once
2+
3+
#include <ATen/core/Tensor.h>
4+
5+
at::Tensor
6+
three_times(const at::Tensor& input);

0 commit comments

Comments
 (0)