Skip to content

Commit f54ed97

Browse files
committed
Refactor our execution policies
We currently tag our execution policies with just an enumeration for that represents the standard execution policies. However, this is not sufficient for our use cases, because we also want to pass along the execution backend and the memory_direction of the algorithm. This changes our policies so that they take an unsigned integer instead of an enumeration and then adds facilities to set and get the respective properties
1 parent 5b84c17 commit f54ed97

File tree

5 files changed

+110
-71
lines changed

5 files changed

+110
-71
lines changed

cudax/include/cuda/experimental/__execution/policy.cuh

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -44,48 +44,48 @@ struct any_execution_policy
4444

4545
_CCCL_HIDE_FROM_ABI any_execution_policy() = default;
4646

47-
template <__execution_policy _Policy>
48-
_CCCL_HOST_API constexpr any_execution_policy(::cuda::std::execution::__policy<_Policy>) noexcept
49-
: value(_Policy)
47+
template <uint32_t _Policy>
48+
_CCCL_HOST_API constexpr any_execution_policy(::cuda::std::execution::__execution_policy_base<_Policy>) noexcept
49+
: value(value_type{_Policy})
5050
{}
5151

5252
_CCCL_HOST_API constexpr operator __execution_policy() const noexcept
5353
{
5454
return value;
5555
}
5656

57-
_CCCL_HOST_API constexpr auto operator()() const noexcept -> __execution_policy
57+
_CCCL_HOST_API constexpr auto operator()() const noexcept -> value_type
5858
{
5959
return value;
6060
}
6161

62-
template <__execution_policy _Policy>
62+
template <uint32_t _Policy>
6363
[[nodiscard]] _CCCL_HOST_API friend constexpr bool
64-
operator==(const any_execution_policy& pol, const ::cuda::std::execution::__policy<_Policy>&) noexcept
64+
operator==(const any_execution_policy& pol, const ::cuda::std::execution::__execution_policy_base<_Policy>&) noexcept
6565
{
66-
return pol.value == _Policy;
66+
return pol.value == value_type{_Policy};
6767
}
6868

6969
#if _CCCL_STD_VER <= 2017
70-
template <__execution_policy _Policy>
70+
template <uint32_t _Policy>
7171
[[nodiscard]] _CCCL_HOST_API friend constexpr bool
72-
operator==(const ::cuda::std::execution::__policy<_Policy>&, const any_execution_policy& pol) noexcept
72+
operator==(const ::cuda::std::execution::__execution_policy_base<_Policy>&, const any_execution_policy& pol) noexcept
7373
{
74-
return pol.value == _Policy;
74+
return pol.value == value_type{_Policy};
7575
}
7676

77-
template <__execution_policy _Policy>
77+
template <uint32_t _Policy>
7878
[[nodiscard]] _CCCL_HOST_API friend constexpr bool
79-
operator!=(const any_execution_policy& pol, const ::cuda::std::execution::__policy<_Policy>&) noexcept
79+
operator!=(const any_execution_policy& pol, const ::cuda::std::execution::__execution_policy_base<_Policy>&) noexcept
8080
{
81-
return pol.value != _Policy;
81+
return pol.value != value_type{_Policy};
8282
}
8383

84-
template <__execution_policy _Policy>
84+
template <uint32_t _Policy>
8585
[[nodiscard]] _CCCL_HOST_API friend constexpr bool
86-
operator!=(const ::cuda::std::execution::__policy<_Policy>&, const any_execution_policy& pol)
86+
operator!=(const ::cuda::std::execution::__execution_policy_base<_Policy>&, const any_execution_policy& pol)
8787
{
88-
return pol.value != _Policy;
88+
return pol.value != value_type{_Policy};
8989
}
9090
#endif // _CCCL_STD_VER <= 2017
9191

libcudacxx/include/cuda/std/__execution/policy.h

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@
2020
# pragma system_header
2121
#endif // no system header
2222

23-
#include <cuda/std/__type_traits/underlying_type.h>
2423
#include <cuda/std/cstdint>
2524

2625
#include <cuda/std/__cccl/prologue.h>
2726

2827
_CCCL_BEGIN_NAMESPACE_CUDA_STD_EXECUTION
2928

30-
enum class __execution_policy : uint32_t
29+
//! @brief Enumerates the standard execution policies
30+
enum class __execution_policy : uint8_t
3131
{
3232
__invalid_execution_policy = 0,
3333
__sequenced = 1 << 0,
@@ -36,49 +36,73 @@ enum class __execution_policy : uint32_t
3636
__parallel_unsequenced = __execution_policy::__parallel | __execution_policy::__unsequenced,
3737
};
3838

39-
[[nodiscard]] _CCCL_API constexpr bool
40-
__satisfies_execution_policy(__execution_policy __lhs, __execution_policy __rhs) noexcept
39+
//! @brief Enumerates the different backends we support
40+
//! @note Not an enum class because a user might specify multiple backends
41+
enum __execution_backend : uint8_t
4142
{
42-
return (static_cast<uint32_t>(__lhs) & static_cast<uint32_t>(__rhs)) != 0;
43-
}
43+
// The backends we provide
44+
__none = 0,
45+
#if _CCCL_HAS_BACKEND_CUDA()
46+
__cuda = 1 << 1,
47+
#endif // _CCCL_HAS_BACKEND_CUDA()
48+
#if _CCCL_HAS_BACKEND_OMP()
49+
__omp = 1 << 2,
50+
#endif // _CCCL_HAS_BACKEND_OMP()
51+
#if _CCCL_HAS_BACKEND_TBB()
52+
__tbb = 1 << 3,
53+
#endif // _CCCL_HAS_BACKEND_TBB()
54+
};
4455

45-
template <__execution_policy _Policy>
46-
struct __policy
56+
//! @brief Base class for our execution policies.
57+
//! It takes an untagged uint32_t because we want to be able to store 3 different enumerations in it.
58+
template <uint32_t _Policy>
59+
struct __execution_policy_base
4760
{
48-
template <__execution_policy _OtherPolicy>
49-
[[nodiscard]] _CCCL_API friend constexpr bool operator==(const __policy&, const __policy<_OtherPolicy>&) noexcept
61+
template <uint32_t _OtherPolicy>
62+
[[nodiscard]] _CCCL_API friend constexpr bool
63+
operator==(const __execution_policy_base&, const __execution_policy_base<_OtherPolicy>&) noexcept
5064
{
51-
using __underlying_t = underlying_type_t<__execution_policy>;
52-
return (static_cast<__underlying_t>(_Policy) == static_cast<__underlying_t>(_OtherPolicy));
65+
return _Policy == _OtherPolicy;
5366
}
5467

5568
#if _CCCL_STD_VER <= 2017
56-
template <__execution_policy _OtherPolicy>
57-
[[nodiscard]] _CCCL_API friend constexpr bool operator!=(const __policy&, const __policy<_OtherPolicy>&) noexcept
69+
template <uint32_t _OtherPolicy>
70+
[[nodiscard]] _CCCL_API friend constexpr bool
71+
operator!=(const __execution_policy_base&, const __execution_policy_base<_OtherPolicy>&) noexcept
5872
{
59-
using __underlying_t = underlying_type_t<__execution_policy>;
60-
return (static_cast<__underlying_t>(_Policy) != static_cast<__underlying_t>(_OtherPolicy));
73+
return _Policy != _OtherPolicy;
6174
}
6275
#endif // _CCCL_STD_VER <= 2017
6376

64-
static constexpr __execution_policy __policy_ = _Policy;
65-
};
77+
//! @brief Tag that identifies this and all derived classes as a CCCL execution policy
78+
static constexpr uint32_t __cccl_policy_ = _Policy;
79+
80+
//! @brief Extracts the execution policy from the stored _Policy
81+
[[nodiscard]] _CCCL_API static constexpr __execution_policy __get_policy() noexcept
82+
{
83+
constexpr uint32_t __policy_mask{0x000000FF};
84+
return __execution_policy{_Policy & __policy_mask};
85+
}
6686

67-
struct sequenced_policy : public __policy<__execution_policy::__sequenced>
68-
{};
87+
//! @brief Extracts the execution backend from the stored _Policy
88+
[[nodiscard]] _CCCL_API static constexpr __execution_backend __get_backend() noexcept
89+
{
90+
constexpr uint32_t __backend_mask{0x0000FF00};
91+
return __execution_backend{(_Policy & __backend_mask) >> 8};
92+
}
93+
};
6994

95+
using sequenced_policy = __execution_policy_base<static_cast<uint32_t>(__execution_policy::__sequenced)>;
7096
_CCCL_GLOBAL_CONSTANT sequenced_policy seq{};
7197

72-
struct parallel_policy : public __policy<__execution_policy::__parallel>
73-
{};
98+
using parallel_policy = __execution_policy_base<static_cast<uint32_t>(__execution_policy::__parallel)>;
7499
_CCCL_GLOBAL_CONSTANT parallel_policy par{};
75100

76-
struct parallel_unsequenced_policy : public __policy<__execution_policy::__parallel_unsequenced>
77-
{};
101+
using parallel_unsequenced_policy =
102+
__execution_policy_base<static_cast<uint32_t>(__execution_policy::__parallel_unsequenced)>;
78103
_CCCL_GLOBAL_CONSTANT parallel_unsequenced_policy par_unseq{};
79104

80-
struct unsequenced_policy : public __policy<__execution_policy::__unsequenced>
81-
{};
105+
using unsequenced_policy = __execution_policy_base<static_cast<uint32_t>(__execution_policy::__unsequenced)>;
82106
_CCCL_GLOBAL_CONSTANT unsequenced_policy unseq{};
83107

84108
_CCCL_END_NAMESPACE_CUDA_STD_EXECUTION
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of libcu++, the C++ Standard Library for your entire system,
4+
// under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
8+
//
9+
//===----------------------------------------------------------------------===//
10+
11+
#ifndef _CUDA_STD___INTERNAL_PSTL_CONFIG_H
12+
#define _CUDA_STD___INTERNAL_PSTL_CONFIG_H
13+
14+
#include <cuda/std/detail/__config>
15+
16+
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
17+
# pragma GCC system_header
18+
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
19+
# pragma clang system_header
20+
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
21+
# pragma system_header
22+
#endif // no system header
23+
24+
#include <cuda/std/__cccl/prologue.h>
25+
26+
#define _CCCL_HAS_BACKEND_CUDA() 0
27+
#define _CCCL_HAS_BACKEND_OMP() 0
28+
#define _CCCL_HAS_BACKEND_TBB() 0
29+
30+
#include <cuda/std/__cccl/epilogue.h>
31+
32+
#endif // _CUDA_STD___INTERNAL_PSTL_CONFIG_H

libcudacxx/include/cuda/std/__type_traits/is_execution_policy.h

Lines changed: 11 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -28,51 +28,33 @@
2828

2929
_CCCL_BEGIN_NAMESPACE_CUDA_STD
3030

31-
template <class>
31+
template <class, class = void>
3232
inline constexpr bool is_execution_policy_v = false;
3333

34-
// Ensure we ignore cv qualifiers
35-
template <class _Tp>
36-
inline constexpr bool is_execution_policy_v<const _Tp> = is_execution_policy_v<_Tp>;
37-
38-
template <class _Tp>
39-
inline constexpr bool is_execution_policy_v<volatile _Tp> = is_execution_policy_v<_Tp>;
40-
41-
template <class _Tp>
42-
inline constexpr bool is_execution_policy_v<const volatile _Tp> = is_execution_policy_v<_Tp>;
43-
44-
// Explicitly mark our execution policies as such
45-
template <>
46-
inline constexpr bool is_execution_policy_v<::cuda::std::execution::sequenced_policy> = true;
47-
48-
template <>
49-
inline constexpr bool is_execution_policy_v<::cuda::std::execution::parallel_policy> = true;
50-
51-
template <>
52-
inline constexpr bool is_execution_policy_v<::cuda::std::execution::parallel_unsequenced_policy> = true;
53-
54-
template <>
55-
inline constexpr bool is_execution_policy_v<::cuda::std::execution::unsequenced_policy> = true;
34+
template <class _Policy>
35+
inline constexpr bool is_execution_policy_v<_Policy, void_t<decltype(_Policy::__cccl_policy_)>> = true;
5636

5737
template <class _Tp>
5838
struct _CCCL_NO_SPECIALIZATIONS is_execution_policy : bool_constant<is_execution_policy_v<_Tp>>
5939
{};
6040

6141
// Detect parallel policies
62-
template <class, class = void>
42+
template <class _Policy, bool = is_execution_policy_v<_Policy>>
6343
inline constexpr bool __is_parallel_execution_policy_v = false;
6444

6545
template <class _Policy>
66-
inline constexpr bool __is_parallel_execution_policy_v<_Policy, void_t<decltype(_Policy::__policy_)>> =
67-
__satisfies_execution_policy(_Policy::__policy_, ::cuda::std::execution::__execution_policy::__parallel);
46+
inline constexpr bool __is_parallel_execution_policy_v<_Policy, true> =
47+
_Policy::__get_policy() == ::cuda::std::execution::__execution_policy::__parallel
48+
|| _Policy::__get_policy() == ::cuda::std::execution::__execution_policy::__parallel_unsequenced;
6849

6950
// Detect unsequenced policies
70-
template <class, class = void>
51+
template <class _Policy, bool = is_execution_policy_v<_Policy>>
7152
inline constexpr bool __is_unsequenced_execution_policy_v = false;
7253

7354
template <class _Policy>
74-
inline constexpr bool __is_unsequenced_execution_policy_v<_Policy, void_t<decltype(_Policy::__policy_)>> =
75-
__satisfies_execution_policy(_Policy::__policy_, ::cuda::std::execution::__execution_policy::__unsequenced);
55+
inline constexpr bool __is_unsequenced_execution_policy_v<_Policy, true> =
56+
_Policy::__get_policy() == ::cuda::std::execution::__execution_policy::__unsequenced
57+
|| _Policy::__get_policy() == ::cuda::std::execution::__execution_policy::__parallel_unsequenced;
7658

7759
_CCCL_END_NAMESPACE_CUDA_STD
7860

libcudacxx/include/cuda/std/detail/__config

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <cuda/std/__internal/cpp_dialect.h>
1616
#include <cuda/std/__internal/features.h>
1717
#include <cuda/std/__internal/namespaces.h>
18+
#include <cuda/std/__internal/pstl_config.h>
1819
#include <cuda/std/__internal/thread_api.h>
1920
#include <cuda/std/__internal/version.h>
2021

0 commit comments

Comments
 (0)