|
24 | 24 | #include <cuda/__driver/driver_api.h> |
25 | 25 | #include <cuda/__stream/stream_ref.h> |
26 | 26 | #include <cuda/std/__functional/reference_wrapper.h> |
| 27 | +#include <cuda/std/__tuple_dir/apply.h> |
| 28 | +#include <cuda/std/__tuple_dir/tuple.h> |
27 | 29 | #include <cuda/std/__type_traits/decay.h> |
| 30 | +#include <cuda/std/__type_traits/is_function.h> |
28 | 31 | #include <cuda/std/__type_traits/is_move_constructible.h> |
29 | | -#include <cuda/std/__utility/forward.h> |
30 | | -#include <cuda/std/tuple> |
| 32 | +#include <cuda/std/__type_traits/is_pointer.h> |
| 33 | +#include <cuda/std/__utility/move.h> |
31 | 34 |
|
32 | 35 | #include <cuda/std/__cccl/prologue.h> |
33 | 36 |
|
34 | 37 | _CCCL_BEGIN_NAMESPACE_CUDA |
35 | 38 |
|
| 39 | +template <class _Callable> |
| 40 | +_CCCL_HOST_API inline void CUDA_CB __host_func_launcher(void* __callable_ptr) |
| 41 | +{ |
| 42 | + (*static_cast<_Callable*>(__callable_ptr))(); |
| 43 | +} |
| 44 | + |
36 | 45 | template <class _Callable, class... _Args> |
37 | 46 | struct __stream_callback_data |
38 | 47 | { |
@@ -73,37 +82,29 @@ _CCCL_HOST_API void host_launch(stream_ref __stream, _Callable __callable, _Args |
73 | 82 | static_assert((::cuda::std::is_move_constructible_v<_Args> && ...), |
74 | 83 | "All callback arguments must be move constructible"); |
75 | 84 |
|
76 | | - using _CallbackData = __stream_callback_data<_Callable, _Args...>; |
77 | | - _CallbackData* __callback_data_ptr = new _CallbackData{::cuda::std::move(__callable), {::cuda::std::move(__args)...}}; |
| 85 | + constexpr auto __has_args = sizeof...(_Args) > 0; |
78 | 86 |
|
79 | | - // We use the callback here to have it execute even on stream error, because it needs to free the above allocation |
80 | | - ::cuda::__driver::__streamAddCallback(__stream.get(), __stream_callback_launcher<_CallbackData>, __callback_data_ptr); |
81 | | -} |
| 87 | + if constexpr (!__has_args && ::cuda::std::is_function_v<_Callable> && ::cuda::std::is_pointer_v<_Callable>) |
| 88 | + { |
| 89 | + ::cuda::__driver::__launchHostFunc(__stream.get(), ::cuda::__host_func_launcher<_Callable>, __callable); |
| 90 | + } |
| 91 | + else if constexpr (!__has_args && ::cuda::std::__is_cuda_std_reference_wrapper_v<_Callable>) |
| 92 | + { |
| 93 | + ::cuda::__driver::__launchHostFunc( |
| 94 | + __stream.get(), ::cuda::__host_func_launcher<typename _Callable::type>, ::cuda::std::addressof(__callable.get())); |
| 95 | + } |
| 96 | + else |
| 97 | + { |
| 98 | + using _CallbackData = __stream_callback_data<_Callable, _Args...>; |
| 99 | + _CallbackData* __callback_data_ptr = |
| 100 | + new _CallbackData{::cuda::std::move(__callable), {::cuda::std::move(__args)...}}; |
82 | 101 |
|
83 | | -template <class _Callable> |
84 | | -_CCCL_HOST_API inline void CUDA_CB __host_func_launcher(void* __callable_ptr) |
85 | | -{ |
86 | | - (*static_cast<_Callable*>(__callable_ptr))(); |
| 102 | + // We use the callback here to have it execute even on stream error, because it needs to free the above allocation |
| 103 | + ::cuda::__driver::__streamAddCallback( |
| 104 | + __stream.get(), ::cuda::__stream_callback_launcher<_CallbackData>, __callback_data_ptr); |
| 105 | + } |
87 | 106 | } |
88 | 107 |
|
89 | | -//! @brief Launches a host callable to be executed in stream order on the provided stream |
90 | | -//! |
91 | | -//! Callable will be called using the supplied reference. If the callable was destroyed |
92 | | -//! or moved by the time it is asynchronously called the behavior is undefined. |
93 | | -//! |
94 | | -//! Callable must not call any APIs from cuda, thrust or cub namespaces. |
95 | | -//! It must not call into CUDA Runtime or Driver APIs. It also can't depend on another |
96 | | -//! thread that might block on any asynchronous CUDA work. |
97 | | -//! |
98 | | -//! @param __stream Stream to launch the host function on |
99 | | -//! @param __callable A reference to a host function or callable object to call in stream order |
100 | | -template <class _Callable> |
101 | | -_CCCL_HOST_API void host_launch(stream_ref __stream, ::cuda::std::reference_wrapper<_Callable> __callable) |
102 | | -{ |
103 | | - static_assert(::cuda::std::is_invocable_v<_Callable>, "Callable in reference_wrapper can't take any arguments"); |
104 | | - ::cuda::__driver::__launchHostFunc( |
105 | | - __stream.get(), __host_func_launcher<_Callable>, ::cuda::std::addressof(__callable.get())); |
106 | | -} |
107 | 108 | _CCCL_END_NAMESPACE_CUDA |
108 | 109 |
|
109 | 110 | #include <cuda/std/__cccl/epilogue.h> |
|
0 commit comments