Skip to content

Commit 2b9f630

Browse files
committed
Provide a way to specify a stream
1 parent 96195de commit 2b9f630

File tree

6 files changed

+195
-15
lines changed

6 files changed

+195
-15
lines changed

libcudacxx/include/cuda/std/__execution/policy.h

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
#include <cuda/__stream/stream_ref.h>
2424
#include <cuda/std/__bit/has_single_bit.h>
25+
#include <cuda/std/__execution/stream_policy.h>
26+
#include <cuda/std/__fwd/policy.h>
2527
#include <cuda/std/cstdint>
2628

2729
#include <cuda/std/__cccl/prologue.h>
@@ -98,13 +100,6 @@ struct __execution_policy_base
98100
}
99101
#endif // _CCCL_STD_VER <= 2017
100102

101-
#if _CCCL_HAS_CTK() && !_CCCL_COMPILER(NVRTC)
102-
[[nodiscard]] _CCCL_HOST_API static ::cuda::stream_ref get_stream() noexcept
103-
{
104-
return ::cuda::stream_ref{cudaStreamPerThread};
105-
}
106-
#endif // _CCCL_HAS_CTK() && !_CCCL_COMPILER(NVRTC)
107-
108103
//! @brief Tag that identifies this and all derived classes as a CCCL execution policy
109104
static constexpr uint32_t __cccl_policy_ = _Policy;
110105

@@ -152,6 +147,19 @@ struct __execution_policy_base
152147
constexpr uint32_t __direction_mask{0xFF00FFFF};
153148
return (_Policy & __direction_mask) & (static_cast<uint32_t>(__pol) << 16);
154149
}
150+
151+
#if _CCCL_HAS_CTK() && !_CCCL_COMPILER(NVRTC)
152+
[[nodiscard]] _CCCL_HOST_API static ::cuda::stream_ref get_stream() noexcept
153+
{
154+
return ::cuda::stream_ref{cudaStreamPerThread};
155+
}
156+
157+
[[nodiscard]] _CCCL_HOST_API static __execution_policy_stream<__execution_policy_base>
158+
on(::cuda::stream_ref __stream) noexcept
159+
{
160+
return __execution_policy_stream<__execution_policy_base>{__stream};
161+
}
162+
#endif // _CCCL_HAS_CTK() && !_CCCL_COMPILER(NVRTC)
155163
};
156164

157165
using sequenced_policy = __execution_policy_base<static_cast<uint32_t>(__execution_policy::__sequenced)>;
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#ifndef _CUDA_STD___EXECUTION_STREAM_POLICY_H
11+
#define _CUDA_STD___EXECUTION_STREAM_POLICY_H
12+
13+
#include <cuda/std/detail/__config>
14+
15+
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
16+
# pragma GCC system_header
17+
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
18+
# pragma clang system_header
19+
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
20+
# pragma system_header
21+
#endif // no system header
22+
23+
#if _CCCL_HAS_CTK() && !_CCCL_COMPILER(NVRTC)
24+
25+
# include <cuda/__stream/stream_ref.h>
26+
# include <cuda/std/__bit/has_single_bit.h>
27+
# include <cuda/std/__fwd/policy.h>
28+
# include <cuda/std/cstdint>
29+
30+
# include <cuda/std/__cccl/prologue.h>
31+
32+
_CCCL_BEGIN_NAMESPACE_EXECUTION
33+
34+
//! @brief Wrapper around an execution policy to store a stream
35+
template <class _Policy>
36+
class __execution_policy_stream : public _Policy
37+
{
38+
private:
39+
::cuda::stream_ref __stream_;
40+
41+
public:
42+
_CCCL_HOST_API __execution_policy_stream(::cuda::stream_ref __stream) noexcept
43+
: __stream_(__stream)
44+
{}
45+
46+
[[nodiscard]] _CCCL_HOST_API ::cuda::stream_ref get_stream() const noexcept
47+
{
48+
return __stream_;
49+
}
50+
};
51+
52+
_CCCL_END_NAMESPACE_EXECUTION
53+
54+
# include <cuda/std/__cccl/epilogue.h>
55+
56+
#endif // _CCCL_HAS_CTK() && !_CCCL_COMPILER(NVRTC)
57+
58+
#endif // _CUDA_STD___EXECUTION_STREAM_POLICY_H
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#ifndef _CUDA_STD___FWD_POLICY_H
11+
#define _CUDA_STD___FWD_POLICY_H
12+
13+
#include <cuda/std/detail/__config>
14+
15+
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
16+
# pragma GCC system_header
17+
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
18+
# pragma clang system_header
19+
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
20+
# pragma system_header
21+
#endif // no system header
22+
23+
#include <cuda/std/cstdint>
24+
25+
#include <cuda/std/__cccl/prologue.h>
26+
27+
_CCCL_BEGIN_NAMESPACE_EXECUTION
28+
29+
template <uint32_t _Policy>
30+
struct __execution_policy_base;
31+
32+
_CCCL_END_NAMESPACE_EXECUTION
33+
34+
#include <cuda/std/__cccl/epilogue.h>
35+
36+
#endif // _CUDA_STD___FWD_POLICY_H
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of libcu++, the C++ Standard Library for your entire system,
4+
// under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
8+
//
9+
//===----------------------------------------------------------------------===//
10+
11+
// UNSUPPORTED: nvrtc
12+
13+
#include <cuda/std/execution>
14+
#include <cuda/std/type_traits>
15+
#include <cuda/stream>
16+
17+
template <class Policy>
18+
void test(Policy pol)
19+
{
20+
{ // Ensure that the plain policy returns a well defined stream
21+
cuda::stream_ref expected_stream{cudaStreamPerThread};
22+
assert(cuda::get_stream(pol) == expected_stream);
23+
}
24+
25+
{ // Ensure that we can attach a stream to an execution policy
26+
cuda::stream stream{cuda::device_ref{0}};
27+
auto pol_with_stream = pol.on(stream);
28+
assert(cuda::get_stream(pol_with_stream) == stream);
29+
30+
static_assert(noexcept(pol.on(stream)));
31+
static_assert(cuda::std::is_base_of_v<Policy, decltype(pol_with_stream)>);
32+
static_assert(cuda::std::is_execution_policy_v<decltype(pol_with_stream)>);
33+
}
34+
}
35+
36+
void test()
37+
{
38+
test(cuda::std::execution::seq);
39+
test(cuda::std::execution::par);
40+
test(cuda::std::execution::unseq);
41+
test(cuda::std::execution::par_unseq);
42+
}
43+
44+
int main(int, char**)
45+
{
46+
NV_IF_TARGET(NV_IS_HOST, (test();))
47+
48+
return 0;
49+
}

libcudacxx/test/libcudacxx/std/algorithms/alg.nonmodifying/alg.for_each/pstl_for_each.cu

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <cuda/std/__pstl/for_each.h>
2323
#include <cuda/std/execution>
2424
#include <cuda/std/functional>
25+
#include <cuda/stream>
2526

2627
#include <testing.cuh>
2728
#include <utility.cuh>
@@ -41,9 +42,23 @@ struct mark_present_for_each
4142

4243
C2H_TEST("cuda::std::for_each", "[parallel algorithm]")
4344
{
44-
thrust::device_vector<bool> res(size, false);
45-
mark_present_for_each fn{thrust::raw_pointer_cast(res.data())};
45+
SECTION("with default stream")
46+
{
47+
thrust::device_vector<bool> res(size, false);
48+
mark_present_for_each fn{thrust::raw_pointer_cast(res.data())};
49+
50+
cuda::std::for_each(cuda::std::execution::par_unseq, cuda::counting_iterator{0}, cuda::counting_iterator{size}, fn);
51+
CHECK(thrust::all_of(res.begin(), res.end(), cuda::std::identity{}));
52+
}
4653

47-
cuda::std::for_each(cuda::std::execution::par_unseq, cuda::counting_iterator{0}, cuda::counting_iterator{size}, fn);
48-
CHECK(thrust::all_of(res.begin(), res.end(), cuda::std::identity{}));
54+
SECTION("with unique stream")
55+
{
56+
::cuda::stream stream{::cuda::device_ref{0}};
57+
thrust::device_vector<bool> res(size, false);
58+
mark_present_for_each fn{thrust::raw_pointer_cast(res.data())};
59+
60+
cuda::std::for_each(
61+
cuda::std::execution::par_unseq.on(stream), cuda::counting_iterator{0}, cuda::counting_iterator{size}, fn);
62+
CHECK(thrust::all_of(res.begin(), res.end(), cuda::std::identity{}));
63+
}
4964
}

libcudacxx/test/libcudacxx/std/algorithms/alg.nonmodifying/alg.for_each/pstl_for_each_n.cu

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <cuda/std/__pstl/for_each_n.h>
2525
#include <cuda/std/execution>
2626
#include <cuda/std/functional>
27+
#include <cuda/stream>
2728

2829
#include <testing.cuh>
2930
#include <utility.cuh>
@@ -43,9 +44,22 @@ struct mark_present_for_each
4344

4445
C2H_TEST("cuda::std::for_each_n", "[parallel algorithm]")
4546
{
46-
thrust::device_vector<bool> res(size, false);
47-
mark_present_for_each fn{thrust::raw_pointer_cast(res.data())};
47+
SECTION("with default stream")
48+
{
49+
thrust::device_vector<bool> res(size, false);
50+
mark_present_for_each fn{thrust::raw_pointer_cast(res.data())};
51+
52+
cuda::std::for_each_n(cuda::std::execution::par_unseq, cuda::counting_iterator{0}, size, fn);
53+
CHECK(thrust::all_of(res.begin(), res.end(), cuda::std::identity{}));
54+
}
4855

49-
cuda::std::for_each_n(cuda::std::execution::par_unseq, cuda::counting_iterator{0}, size, fn);
50-
CHECK(thrust::all_of(res.begin(), res.end(), cuda::std::identity{}));
56+
SECTION("with unique stream")
57+
{
58+
::cuda::stream stream{::cuda::device_ref{0}};
59+
thrust::device_vector<bool> res(size, false);
60+
mark_present_for_each fn{thrust::raw_pointer_cast(res.data())};
61+
62+
cuda::std::for_each_n(cuda::std::execution::par_unseq.on(stream), cuda::counting_iterator{0}, size, fn);
63+
CHECK(thrust::all_of(res.begin(), res.end(), cuda::std::identity{}));
64+
}
5165
}

0 commit comments

Comments
 (0)