Skip to content

Commit 8b64eb2

Browse files
committed
Simplify cuda::host_launch API
1 parent b5071bc commit 8b64eb2

File tree

2 files changed

+89
-31
lines changed

2 files changed

+89
-31
lines changed

libcudacxx/include/cuda/__launch/host_launch.h

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,24 @@
2424
#include <cuda/__driver/driver_api.h>
2525
#include <cuda/__stream/stream_ref.h>
2626
#include <cuda/std/__functional/reference_wrapper.h>
27+
#include <cuda/std/__tuple_dir/apply.h>
28+
#include <cuda/std/__tuple_dir/tuple.h>
2729
#include <cuda/std/__type_traits/decay.h>
30+
#include <cuda/std/__type_traits/is_function.h>
2831
#include <cuda/std/__type_traits/is_move_constructible.h>
29-
#include <cuda/std/__utility/forward.h>
30-
#include <cuda/std/tuple>
32+
#include <cuda/std/__type_traits/is_pointer.h>
33+
#include <cuda/std/__utility/move.h>
3134

3235
#include <cuda/std/__cccl/prologue.h>
3336

3437
_CCCL_BEGIN_NAMESPACE_CUDA
3538

39+
template <class _Callable>
40+
_CCCL_HOST_API inline void CUDA_CB __host_func_launcher(void* __callable_ptr)
41+
{
42+
(*static_cast<_Callable*>(__callable_ptr))();
43+
}
44+
3645
template <class _Callable, class... _Args>
3746
struct __stream_callback_data
3847
{
@@ -73,37 +82,29 @@ _CCCL_HOST_API void host_launch(stream_ref __stream, _Callable __callable, _Args
7382
static_assert((::cuda::std::is_move_constructible_v<_Args> && ...),
7483
"All callback arguments must be move constructible");
7584

76-
using _CallbackData = __stream_callback_data<_Callable, _Args...>;
77-
_CallbackData* __callback_data_ptr = new _CallbackData{::cuda::std::move(__callable), {::cuda::std::move(__args)...}};
85+
constexpr auto __has_args = sizeof...(_Args) > 0;
7886

79-
// We use the callback here to have it execute even on stream error, because it needs to free the above allocation
80-
::cuda::__driver::__streamAddCallback(__stream.get(), __stream_callback_launcher<_CallbackData>, __callback_data_ptr);
81-
}
87+
if constexpr (!__has_args && ::cuda::std::is_function_v<_Callable> && ::cuda::std::is_pointer_v<_Callable>)
88+
{
89+
::cuda::__driver::__launchHostFunc(__stream.get(), ::cuda::__host_func_launcher<_Callable>, __callable);
90+
}
91+
else if constexpr (!__has_args && ::cuda::std::__is_cuda_std_reference_wrapper_v<_Callable>)
92+
{
93+
::cuda::__driver::__launchHostFunc(
94+
__stream.get(), ::cuda::__host_func_launcher<typename _Callable::type>, ::cuda::std::addressof(__callable.get()));
95+
}
96+
else
97+
{
98+
using _CallbackData = __stream_callback_data<_Callable, _Args...>;
99+
_CallbackData* __callback_data_ptr =
100+
new _CallbackData{::cuda::std::move(__callable), {::cuda::std::move(__args)...}};
82101

83-
template <class _Callable>
84-
_CCCL_HOST_API inline void CUDA_CB __host_func_launcher(void* __callable_ptr)
85-
{
86-
(*static_cast<_Callable*>(__callable_ptr))();
102+
// We use the callback here to have it execute even on stream error, because it needs to free the above allocation
103+
::cuda::__driver::__streamAddCallback(
104+
__stream.get(), ::cuda::__stream_callback_launcher<_CallbackData>, __callback_data_ptr);
105+
}
87106
}
88107

89-
//! @brief Launches a host callable to be executed in stream order on the provided stream
90-
//!
91-
//! Callable will be called using the supplied reference. If the callable was destroyed
92-
//! or moved by the time it is asynchronously called the behavior is undefined.
93-
//!
94-
//! Callable must not call any APIs from cuda, thrust or cub namespaces.
95-
//! It must not call into CUDA Runtime or Driver APIs. It also can't depend on another
96-
//! thread that might block on any asynchronous CUDA work.
97-
//!
98-
//! @param __stream Stream to launch the host function on
99-
//! @param __callable A reference to a host function or callable object to call in stream order
100-
template <class _Callable>
101-
_CCCL_HOST_API void host_launch(stream_ref __stream, ::cuda::std::reference_wrapper<_Callable> __callable)
102-
{
103-
static_assert(::cuda::std::is_invocable_v<_Callable>, "Callable in reference_wrapper can't take any arguments");
104-
::cuda::__driver::__launchHostFunc(
105-
__stream.get(), __host_func_launcher<_Callable>, ::cuda::std::addressof(__callable.get()));
106-
}
107108
_CCCL_END_NAMESPACE_CUDA
108109

109110
#include <cuda/std/__cccl/epilogue.h>

libcudacxx/test/libcudacxx/cuda/ccclrt/launch/host_launch.cu

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,21 @@ void unblock_and_wait_stream(cuda::stream_ref stream, cuda::atomic<int>& atomic)
3232
atomic = 0;
3333
}
3434

35+
bool ordinary_function_run_proof = false;
36+
37+
template <class Ret, class... Args>
38+
Ret ordinary_function(Args...)
39+
{
40+
ordinary_function_run_proof = true;
41+
return (Ret) 0;
42+
}
43+
44+
[[nodiscard]] int nodiscard_ordinary_function()
45+
{
46+
ordinary_function_run_proof = true;
47+
return 0;
48+
}
49+
3550
void launch_local_lambda(cuda::stream_ref stream, int& set, int set_to)
3651
{
3752
auto lambda = [&set, set_to]() {
@@ -111,9 +126,51 @@ C2H_CCCLRT_TEST("Host launch", "")
111126
cuda::device_ref device{0};
112127
device.init();
113128

114-
cuda::atomic<int> atomic = 0;
115129
cuda::stream stream{device};
116-
int i = 0;
130+
131+
SECTION("Ordinary function without arguments returning void")
132+
{
133+
CCCLRT_REQUIRE(ordinary_function_run_proof == false);
134+
135+
cuda::host_launch(stream, ordinary_function<void>);
136+
137+
stream.sync();
138+
CCCLRT_REQUIRE(ordinary_function_run_proof == true);
139+
ordinary_function_run_proof = false;
140+
}
141+
SECTION("Ordinary function without arguments returning int")
142+
{
143+
CCCLRT_REQUIRE(ordinary_function_run_proof == false);
144+
145+
cuda::host_launch(stream, ordinary_function<int>);
146+
147+
stream.sync();
148+
CCCLRT_REQUIRE(ordinary_function_run_proof == true);
149+
ordinary_function_run_proof = false;
150+
}
151+
SECTION("Ordinary function with arguments returning void")
152+
{
153+
CCCLRT_REQUIRE(ordinary_function_run_proof == false);
154+
155+
cuda::host_launch(stream, ordinary_function<int, char, double>, 'c', 1.0);
156+
157+
stream.sync();
158+
CCCLRT_REQUIRE(ordinary_function_run_proof == true);
159+
ordinary_function_run_proof = false;
160+
}
161+
SECTION("Nodiscard ordinary function")
162+
{
163+
CCCLRT_REQUIRE(ordinary_function_run_proof == false);
164+
165+
cuda::host_launch(stream, nodiscard_ordinary_function);
166+
167+
stream.sync();
168+
CCCLRT_REQUIRE(ordinary_function_run_proof == true);
169+
ordinary_function_run_proof = false;
170+
}
171+
172+
cuda::atomic<int> atomic = 0;
173+
int i = 0;
117174

118175
auto set_lambda = [&](int set) {
119176
i = set;

0 commit comments

Comments
 (0)