Skip to content

Commit fc0080c

Browse files
committed
Implement execution policies
This implements the standard execution policies, mostly copying the design from `cuda::experimental::execution` I slightly adopted it to be extensible for the case that we want to extend the policies with policies
1 parent cf69aec commit fc0080c

File tree

11 files changed

+360
-134
lines changed

11 files changed

+360
-134
lines changed

cudax/include/cuda/experimental/__execution/bulk.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ struct _CCCL_TYPE_VISIBILITY_DEFAULT __bulk_t
270270
[[nodiscard]] _CCCL_TRIVIAL_API auto operator()(_Policy __policy, _Shape __shape, _Fn __fn) const
271271
{
272272
static_assert(::cuda::std::integral<_Shape>);
273-
static_assert(is_execution_policy_v<_Policy>);
273+
static_assert(::cuda::std::is_execution_policy_v<_Policy>);
274274
using __closure_t = typename _BulkTag::template __closure_t<_Policy, _Shape, _Fn>;
275275
return __closure_t{{__policy, __shape, static_cast<_Fn&&>(__fn)}};
276276
}

cudax/include/cuda/experimental/__execution/policy.cuh

Lines changed: 46 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -23,57 +23,20 @@
2323

2424
#include <cuda/std/__concepts/concept_macros.h>
2525
#include <cuda/std/__execution/env.h>
26+
#include <cuda/std/__execution/policy.h>
2627
#include <cuda/std/__type_traits/is_convertible.h>
28+
#include <cuda/std/__type_traits/is_execution_policy.h>
2729

2830
#include <cuda/experimental/__execution/prologue.cuh>
2931

30-
namespace cuda::experimental
32+
namespace cuda::experimental::execution
3133
{
32-
namespace execution
33-
{
34-
struct sequenced_policy;
35-
struct parallel_policy;
36-
struct parallel_unsequenced_policy;
37-
struct unsequenced_policy;
38-
struct any_execution_policy;
39-
} // namespace execution
40-
41-
// execution policy type trait
42-
template <class _Ty>
43-
inline constexpr bool is_execution_policy_v = false;
44-
45-
template <>
46-
inline constexpr bool is_execution_policy_v<execution::sequenced_policy> = true;
47-
48-
template <>
49-
inline constexpr bool is_execution_policy_v<execution::parallel_policy> = true;
50-
51-
template <>
52-
inline constexpr bool is_execution_policy_v<execution::parallel_unsequenced_policy> = true;
53-
54-
template <>
55-
inline constexpr bool is_execution_policy_v<execution::unsequenced_policy> = true;
5634

57-
template <>
58-
inline constexpr bool is_execution_policy_v<execution::any_execution_policy> = true;
59-
60-
template <class _Ty>
61-
struct is_execution_policy : ::cuda::std::bool_constant<is_execution_policy_v<_Ty>>
62-
{};
63-
64-
namespace execution
65-
{
66-
enum class __execution_policy
67-
{
68-
invalid_execution_policy,
69-
sequenced,
70-
parallel,
71-
parallel_unsequenced,
72-
unsequenced,
73-
};
74-
75-
template <__execution_policy _Policy>
76-
struct __policy;
35+
using ::cuda::std::execution::__execution_policy;
36+
using ::cuda::std::execution::par;
37+
using ::cuda::std::execution::par_unseq;
38+
using ::cuda::std::execution::seq;
39+
using ::cuda::std::execution::unseq;
7740

7841
struct any_execution_policy
7942
{
@@ -83,8 +46,8 @@ struct any_execution_policy
8346
_CCCL_HIDE_FROM_ABI any_execution_policy() = default;
8447

8548
template <__execution_policy _Policy>
86-
_CCCL_HOST_API constexpr any_execution_policy(__policy<_Policy> __pol) noexcept
87-
: value(__pol)
49+
_CCCL_HOST_API constexpr any_execution_policy(::cuda::std::execution::__policy<_Policy>) noexcept
50+
: value(_Policy)
8851
{}
8952

9053
_CCCL_HOST_API constexpr operator __execution_policy() const noexcept
@@ -97,37 +60,38 @@ struct any_execution_policy
9760
return value;
9861
}
9962

100-
__execution_policy value = __execution_policy::invalid_execution_policy;
101-
};
102-
103-
template <__execution_policy _Policy>
104-
struct _CCCL_DECLSPEC_EMPTY_BASES __policy : ::cuda::std::integral_constant<__execution_policy, _Policy>
105-
{};
106-
107-
struct sequenced_policy : __policy<__execution_policy::sequenced>
108-
{};
109-
110-
struct parallel_policy : __policy<__execution_policy::parallel>
111-
{};
112-
113-
struct parallel_unsequenced_policy : __policy<__execution_policy::parallel_unsequenced>
114-
{};
63+
template <__execution_policy _Policy>
64+
[[nodiscard]] _CCCL_HOST_API friend constexpr bool
65+
operator==(const any_execution_policy& pol, const ::cuda::std::execution::__policy<_Policy>&) noexcept
66+
{
67+
return pol.value == _Policy;
68+
}
11569

116-
struct unsequenced_policy : __policy<__execution_policy::unsequenced>
117-
{};
70+
#if _CCCL_STD_VER <= 2017
71+
template <__execution_policy _Policy>
72+
[[nodiscard]] _CCCL_HOST_API friend constexpr bool
73+
operator==(const ::cuda::std::execution::__policy<_Policy>&, const any_execution_policy& pol) noexcept
74+
{
75+
return pol.value == _Policy;
76+
}
11877

119-
_CCCL_GLOBAL_CONSTANT sequenced_policy seq{};
120-
_CCCL_GLOBAL_CONSTANT parallel_policy par{};
121-
_CCCL_GLOBAL_CONSTANT parallel_unsequenced_policy par_unseq{};
122-
_CCCL_GLOBAL_CONSTANT unsequenced_policy unseq{};
78+
template <__execution_policy _Policy>
79+
[[nodiscard]] _CCCL_HOST_API friend constexpr bool
80+
operator!=(const any_execution_policy& pol, const ::cuda::std::execution::__policy<_Policy>&) noexcept
81+
{
82+
return pol.value != _Policy;
83+
}
12384

124-
template <__execution_policy _Policy>
125-
inline constexpr bool __is_parallel_execution_policy =
126-
_Policy == __execution_policy::parallel || _Policy == __execution_policy::parallel_unsequenced;
85+
template <__execution_policy _Policy>
86+
[[nodiscard]] _CCCL_HOST_API friend constexpr bool
87+
operator!=(const ::cuda::std::execution::__policy<_Policy>&, const any_execution_policy& pol)
88+
{
89+
return pol.value != _Policy;
90+
}
91+
#endif // _CCCL_STD_VER <= 2017
12792

128-
template <__execution_policy _Policy>
129-
inline constexpr bool __is_unsequenced_execution_policy =
130-
_Policy == __execution_policy::unsequenced || _Policy == __execution_policy::parallel_unsequenced;
93+
__execution_policy value = __execution_policy::__invalid_execution_policy;
94+
};
13195

13296
struct get_execution_policy_t;
13397

@@ -161,8 +125,14 @@ struct get_execution_policy_t
161125

162126
_CCCL_GLOBAL_CONSTANT get_execution_policy_t get_execution_policy{};
163127

164-
} // namespace execution
165-
} // namespace cuda::experimental
128+
} // namespace cuda::experimental::execution
129+
130+
_CCCL_BEGIN_NAMESPACE_CUDA_STD
131+
132+
template <>
133+
inline constexpr bool is_execution_policy_v<::cuda::experimental::execution::any_execution_policy> = true;
134+
135+
_CCCL_END_NAMESPACE_CUDA_STD
166136

167137
#include <cuda/experimental/__execution/epilogue.cuh>
168138

cudax/test/execution/env.cu

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
//
99
//===----------------------------------------------------------------------===//
1010

11+
#include <cuda/std/execution>
1112
#include <cuda/std/type_traits>
1213

1314
#include <cuda/experimental/container.cuh>
@@ -84,9 +85,9 @@ C2H_TEST("env_t is constructible from an any_resource", "[execution][env]")
8485
SECTION("Passing an any_resource, a stream and a policy")
8586
{
8687
cudax::stream stream{cuda::device_ref{0}};
87-
env_t env{mr, stream, cudax::execution::par_unseq};
88+
env_t env{mr, stream, cuda::std::execution::par_unseq};
8889
CHECK(env.query(cuda::get_stream) == stream);
89-
CHECK((env.query(cudax::execution::get_execution_policy) == cudax::execution::par_unseq));
90+
CHECK((env.query(cudax::execution::get_execution_policy) == cuda::std::execution::par_unseq));
9091
CHECK(env.query(cuda::mr::get_memory_resource) == mr);
9192
}
9293
}
@@ -115,9 +116,10 @@ C2H_TEST("env_t is constructible from an any_resource passed as an rvalue", "[ex
115116
SECTION("Passing an any_resource, a stream and a policy")
116117
{
117118
cudax::stream stream{cuda::device_ref{0}};
118-
env_t env{cudax::any_resource<cuda::mr::device_accessible>{test_resource{}}, stream, cudax::execution::par_unseq};
119+
env_t env{
120+
cudax::any_resource<cuda::mr::device_accessible>{test_resource{}}, stream, cuda::std::execution::par_unseq};
119121
CHECK(env.query(cuda::get_stream) == stream);
120-
CHECK(env.query(cudax::execution::get_execution_policy) == cudax::execution::par_unseq);
122+
CHECK(env.query(cudax::execution::get_execution_policy) == cuda::std::execution::par_unseq);
121123
CHECK(env.query(cuda::mr::get_memory_resource)
122124
== cudax::any_resource<cuda::mr::device_accessible>{test_resource{}});
123125
}
@@ -147,9 +149,9 @@ C2H_TEST("env_t is constructible from a resource", "[execution][env]")
147149
SECTION("Passing an any_resource, a stream and a policy")
148150
{
149151
cudax::stream stream{cuda::device_ref{0}};
150-
env_t env{mr, stream, cudax::execution::par_unseq};
152+
env_t env{mr, stream, cuda::std::execution::par_unseq};
151153
CHECK(env.query(cuda::get_stream) == stream);
152-
CHECK(env.query(cudax::execution::get_execution_policy) == cudax::execution::par_unseq);
154+
CHECK(env.query(cudax::execution::get_execution_policy) == cuda::std::execution::par_unseq);
153155
CHECK(env.query(cuda::mr::get_memory_resource) == mr);
154156
}
155157
}
@@ -176,9 +178,9 @@ C2H_TEST("env_t is constructible from a resource passed as an rvalue", "[executi
176178
SECTION("Passing an any_resource, a stream and a policy")
177179
{
178180
cudax::stream stream{cuda::device_ref{0}};
179-
env_t env{test_resource{}, stream, cudax::execution::par_unseq};
181+
env_t env{test_resource{}, stream, cuda::std::execution::par_unseq};
180182
CHECK(env.query(cuda::get_stream) == stream);
181-
CHECK(env.query(cudax::execution::get_execution_policy) == cudax::execution::par_unseq);
183+
CHECK(env.query(cudax::execution::get_execution_policy) == cuda::std::execution::par_unseq);
182184
CHECK(env.query(cuda::mr::get_memory_resource) == test_resource{});
183185
}
184186
}
@@ -187,7 +189,7 @@ struct some_env_t
187189
{
188190
test_resource res_{};
189191
cudax::stream stream_{cuda::device_ref{0}};
190-
cudax::execution::any_execution_policy policy_ = cudax::execution::par_unseq;
192+
cudax::execution::any_execution_policy policy_ = cuda::std::execution::par_unseq;
191193

192194
const test_resource& query(cuda::mr::get_memory_resource_t) const noexcept
193195
{
@@ -218,7 +220,7 @@ struct bad_env_t
218220
{
219221
test_resource res_{};
220222
cudax::stream stream_{cuda::device_ref{0}};
221-
cudax::execution::any_execution_policy policy_ = cudax::execution::par_unseq;
223+
cudax::execution::any_execution_policy policy_ = cuda::std::execution::par_unseq;
222224

223225
template <bool Enable = WithResource, cuda::std::enable_if_t<Enable, int> = 0>
224226
const test_resource& query(cuda::mr::get_memory_resource_t) const noexcept

cudax/test/execution/policies/policies.cu

Lines changed: 6 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
//
99
//===----------------------------------------------------------------------===//
1010

11+
#include <cuda/std/execution>
1112
#include <cuda/std/type_traits>
1213

1314
#include <cuda/experimental/execution.cuh>
@@ -21,54 +22,13 @@ using is_same = cuda::std::is_same<cuda::std::remove_cvref_t<T>, U>;
2122

2223
C2H_TEST("Execution policies", "[execution][policies]")
2324
{
24-
namespace execution = cuda::experimental::execution;
25+
namespace execution = cuda::std::execution;
2526
SECTION("Individual options")
2627
{
27-
execution::any_execution_policy pol = execution::seq;
28-
pol = execution::par;
29-
pol = execution::par_unseq;
30-
pol = execution::unseq;
28+
cudax::execution::any_execution_policy pol = execution::seq;
29+
pol = execution::par;
30+
pol = execution::par_unseq;
31+
pol = execution::unseq;
3132
CHECK(pol == execution::unseq);
3233
}
33-
34-
SECTION("Global instances")
35-
{
36-
STATIC_CHECK(execution::seq == execution::seq);
37-
STATIC_CHECK(execution::par == execution::par);
38-
STATIC_CHECK(execution::par_unseq == execution::par_unseq);
39-
STATIC_CHECK(execution::unseq == execution::unseq);
40-
41-
STATIC_CHECK_FALSE(execution::seq != execution::seq);
42-
STATIC_CHECK_FALSE(execution::par != execution::par);
43-
STATIC_CHECK_FALSE(execution::par_unseq != execution::par_unseq);
44-
STATIC_CHECK_FALSE(execution::unseq != execution::unseq);
45-
46-
STATIC_CHECK_FALSE(execution::seq == execution::unseq);
47-
STATIC_CHECK_FALSE(execution::par == execution::seq);
48-
STATIC_CHECK_FALSE(execution::par_unseq == execution::par);
49-
STATIC_CHECK_FALSE(execution::unseq == execution::par_unseq);
50-
51-
STATIC_CHECK(execution::seq != execution::unseq);
52-
STATIC_CHECK(execution::par != execution::seq);
53-
STATIC_CHECK(execution::par_unseq != execution::par);
54-
STATIC_CHECK(execution::unseq != execution::par_unseq);
55-
}
56-
57-
SECTION("is_parallel_execution_policy")
58-
{
59-
using execution::__is_parallel_execution_policy;
60-
STATIC_CHECK(!__is_parallel_execution_policy<execution::seq>);
61-
STATIC_CHECK(__is_parallel_execution_policy<execution::par>);
62-
STATIC_CHECK(__is_parallel_execution_policy<execution::par_unseq>);
63-
STATIC_CHECK(!__is_parallel_execution_policy<execution::unseq>);
64-
}
65-
66-
SECTION("is_unsequenced_execution_policy")
67-
{
68-
using execution::__is_unsequenced_execution_policy;
69-
STATIC_CHECK(!__is_unsequenced_execution_policy<execution::seq>);
70-
STATIC_CHECK(!__is_unsequenced_execution_policy<execution::par>);
71-
STATIC_CHECK(__is_unsequenced_execution_policy<execution::par_unseq>);
72-
STATIC_CHECK(__is_unsequenced_execution_policy<execution::unseq>);
73-
}
7434
}

0 commit comments

Comments
 (0)