Skip to content

Commit cd09270

Browse files
Add a type and operation enum to CUB (#6780)
And use it in the scan tunings to test it.
1 parent bcbe451 commit cd09270

File tree

2 files changed

+131
-57
lines changed

2 files changed

+131
-57
lines changed
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
3+
4+
#pragma once
5+
6+
#include <cub/config.cuh>
7+
8+
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
9+
# pragma GCC system_header
10+
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
11+
# pragma clang system_header
12+
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
13+
# pragma system_header
14+
#endif // no system header
15+
16+
#include <cuda/std/__functional/operations.h>
17+
#include <cuda/std/__type_traits/is_signed.h>
18+
19+
CUB_NAMESPACE_BEGIN
20+
21+
namespace detail
22+
{
23+
// copy of cccl_type_enum from cccl/c/types.h, which we cannot share, since CCCL.C's public interface does not depend on
24+
// libcu++
25+
enum class type_t
26+
{
27+
int8,
28+
int16,
29+
int32,
30+
int64,
31+
int128,
32+
uint8,
33+
uint16,
34+
uint32,
35+
uint64,
36+
uint128,
37+
float32,
38+
float64,
39+
other
40+
};
41+
42+
template <typename T>
43+
inline constexpr auto classify_type = type_t::other;
44+
45+
template <>
46+
inline constexpr auto classify_type<char> = ::cuda::std::is_signed_v<char> ? type_t::int8 : type_t::uint8;
47+
template <>
48+
inline constexpr auto classify_type<signed char> = type_t::int8;
49+
template <>
50+
inline constexpr auto classify_type<unsigned char> = type_t::uint8;
51+
52+
template <>
53+
inline constexpr auto classify_type<signed short> = type_t::int16;
54+
template <>
55+
inline constexpr auto classify_type<unsigned short> = type_t::uint16;
56+
57+
template <>
58+
inline constexpr auto classify_type<signed int> = type_t::int32;
59+
template <>
60+
inline constexpr auto classify_type<unsigned int> = type_t::uint32;
61+
62+
template <>
63+
inline constexpr auto classify_type<signed long> = sizeof(signed long) == 4 ? type_t::int32 : type_t::int64;
64+
template <>
65+
inline constexpr auto classify_type<unsigned long> = sizeof(unsigned long) == 4 ? type_t::uint32 : type_t::uint64;
66+
67+
template <>
68+
inline constexpr auto classify_type<signed long long> = type_t::int64;
69+
template <>
70+
inline constexpr auto classify_type<unsigned long long> = type_t::int64;
71+
72+
#if _CCCL_HAS_INT128()
73+
template <>
74+
inline constexpr auto classify_type<__int128_t> = type_t::int128;
75+
template <>
76+
inline constexpr auto classify_type<__uint128_t> = type_t::int128;
77+
#endif // _CCCL_HAS_INT128()
78+
79+
template <>
80+
inline constexpr auto classify_type<float> = type_t::float32;
81+
template <>
82+
inline constexpr auto classify_type<double> = type_t::float64;
83+
84+
// similar to cccl_op_kind_t from cccl/c/types.h
85+
enum class op_kind_t
86+
{
87+
plus,
88+
other
89+
};
90+
91+
template <typename T>
92+
inline constexpr auto classify_op = op_kind_t::other;
93+
94+
template <typename T>
95+
inline constexpr auto classify_op<::cuda::std::plus<T>> = op_kind_t::plus;
96+
} // namespace detail
97+
CUB_NAMESPACE_END

cub/cub/device/dispatch/tuning/tuning_scan.cuh

Lines changed: 34 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <cub/block/block_load.cuh>
1919
#include <cub/block/block_scan.cuh>
2020
#include <cub/block/block_store.cuh>
21+
#include <cub/device/dispatch/tuning/common.cuh>
2122
#include <cub/thread/thread_load.cuh>
2223
#include <cub/util_device.cuh>
2324
#include <cub/util_type.cuh>
@@ -46,11 +47,6 @@ enum class primitive_op
4647
no,
4748
yes
4849
};
49-
enum class op_type
50-
{
51-
plus,
52-
unknown
53-
};
5450
enum class offset_size
5551
{
5652
_4,
@@ -88,24 +84,6 @@ constexpr _CCCL_HOST_DEVICE primitive_op is_primitive_op()
8884
return basic_binary_op_t<ScanOpT>::value ? primitive_op::yes : primitive_op::no;
8985
}
9086

91-
template <typename Op>
92-
struct is_plus
93-
{
94-
static constexpr bool value = false;
95-
};
96-
97-
template <typename T>
98-
struct is_plus<::cuda::std::plus<T>>
99-
{
100-
static constexpr bool value = true;
101-
};
102-
103-
template <class ScanOpT>
104-
constexpr _CCCL_HOST_DEVICE op_type classify_op()
105-
{
106-
return is_plus<ScanOpT>::value ? op_type::plus : op_type::unknown;
107-
}
108-
10987
template <class ValueT>
11088
constexpr _CCCL_HOST_DEVICE value_size classify_value_size()
11189
{
@@ -139,14 +117,14 @@ constexpr _CCCL_HOST_DEVICE offset_size classify_offset_size()
139117
template <class ValueT,
140118
class AccumT,
141119
class OffsetT,
142-
op_type OpTypeT,
120+
op_kind_t OpTypeT,
143121
primitive_accum PrimitiveAccumulator = is_primitive_accum<AccumT>(),
144122
offset_size OffsetSize = classify_offset_size<OffsetT>(),
145123
value_size ValueSize = classify_value_size<ValueT>()>
146124
struct sm75_tuning;
147125

148126
template <class ValueT, class AccumT, class OffsetT>
149-
struct sm75_tuning<ValueT, AccumT, OffsetT, op_type::plus, primitive_accum::yes, offset_size::_8, value_size::_4>
127+
struct sm75_tuning<ValueT, AccumT, OffsetT, op_kind_t::plus, primitive_accum::yes, offset_size::_8, value_size::_4>
150128
{
151129
// ipt_7.tpb_128.ns_628.dcid_1.l2w_520.trp_1.ld_0
152130
static constexpr int threads = 128;
@@ -159,13 +137,10 @@ struct sm75_tuning<ValueT, AccumT, OffsetT, op_type::plus, primitive_accum::yes,
159137

160138
// Add sm89 tuning and verify it
161139

162-
template <class AccumT,
163-
primitive_op PrimitiveOp,
164-
primitive_accum PrimitiveAccumulator = is_primitive_accum<AccumT>(),
165-
accum_size AccumSize = classify_accum_size<AccumT>()>
140+
template <type_t AccumT, primitive_op PrimitiveOp, primitive_accum PrimitiveAccumulator, accum_size AccumSize>
166141
struct sm80_tuning;
167142

168-
template <class T>
143+
template <type_t T>
169144
struct sm80_tuning<T, primitive_op::yes, primitive_accum::yes, accum_size::_1>
170145
{
171146
static constexpr int threads = 320;
@@ -175,7 +150,7 @@ struct sm80_tuning<T, primitive_op::yes, primitive_accum::yes, accum_size::_1>
175150
static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE;
176151
};
177152

178-
template <class T>
153+
template <type_t T>
179154
struct sm80_tuning<T, primitive_op::yes, primitive_accum::yes, accum_size::_2>
180155
{
181156
static constexpr int threads = 352;
@@ -185,7 +160,7 @@ struct sm80_tuning<T, primitive_op::yes, primitive_accum::yes, accum_size::_2>
185160
static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE;
186161
};
187162

188-
template <class T>
163+
template <type_t T>
189164
struct sm80_tuning<T, primitive_op::yes, primitive_accum::yes, accum_size::_4>
190165
{
191166
static constexpr int threads = 320;
@@ -195,7 +170,7 @@ struct sm80_tuning<T, primitive_op::yes, primitive_accum::yes, accum_size::_4>
195170
static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE;
196171
};
197172

198-
template <class T>
173+
template <type_t T>
199174
struct sm80_tuning<T, primitive_op::yes, primitive_accum::yes, accum_size::_8>
200175
{
201176
static constexpr int threads = 288;
@@ -206,7 +181,7 @@ struct sm80_tuning<T, primitive_op::yes, primitive_accum::yes, accum_size::_8>
206181
};
207182

208183
template <>
209-
struct sm80_tuning<float, primitive_op::yes, primitive_accum::yes, accum_size::_4>
184+
struct sm80_tuning<type_t::float32, primitive_op::yes, primitive_accum::yes, accum_size::_4>
210185
{
211186
static constexpr int threads = 288;
212187
static constexpr int items = 8;
@@ -216,7 +191,7 @@ struct sm80_tuning<float, primitive_op::yes, primitive_accum::yes, accum_size::_
216191
};
217192

218193
template <>
219-
struct sm80_tuning<double, primitive_op::yes, primitive_accum::yes, accum_size::_8>
194+
struct sm80_tuning<type_t::float64, primitive_op::yes, primitive_accum::yes, accum_size::_8>
220195
{
221196
static constexpr int threads = 384;
222197
static constexpr int items = 12;
@@ -227,7 +202,7 @@ struct sm80_tuning<double, primitive_op::yes, primitive_accum::yes, accum_size::
227202

228203
#if _CCCL_HAS_INT128()
229204
template <>
230-
struct sm80_tuning<__int128_t, primitive_op::yes, primitive_accum::no, accum_size::_16>
205+
struct sm80_tuning<type_t::int128, primitive_op::yes, primitive_accum::no, accum_size::_16>
231206
{
232207
static constexpr int threads = 640;
233208
static constexpr int items = 24;
@@ -237,8 +212,8 @@ struct sm80_tuning<__int128_t, primitive_op::yes, primitive_accum::no, accum_siz
237212
};
238213

239214
template <>
240-
struct sm80_tuning<__uint128_t, primitive_op::yes, primitive_accum::no, accum_size::_16>
241-
: sm80_tuning<__int128_t, primitive_op::yes, primitive_accum::no, accum_size::_16>
215+
struct sm80_tuning<type_t::uint128, primitive_op::yes, primitive_accum::no, accum_size::_16>
216+
: sm80_tuning<type_t::int128, primitive_op::yes, primitive_accum::no, accum_size::_16>
242217
{};
243218
#endif
244219

@@ -283,15 +258,15 @@ struct sm90_tuning<__uint128_t, primitive_op::yes, primitive_accum::no, accum_si
283258
template <class ValueT,
284259
class AccumT,
285260
class OffsetT,
286-
op_type OpTypeT,
261+
op_kind_t OpTypeT,
287262
primitive_accum PrimitiveAccumulator = is_primitive_accum<AccumT>(),
288263
offset_size OffsetSize = classify_offset_size<OffsetT>(),
289264
value_size ValueSize = classify_value_size<ValueT>()>
290265
struct sm100_tuning;
291266

292267
// sum
293268
template <class ValueT, class AccumT, class OffsetT>
294-
struct sm100_tuning<ValueT, AccumT, OffsetT, op_type::plus, primitive_accum::yes, offset_size::_4, value_size::_1>
269+
struct sm100_tuning<ValueT, AccumT, OffsetT, op_kind_t::plus, primitive_accum::yes, offset_size::_4, value_size::_1>
295270
{
296271
// ipt_18.tpb_512.ns_768.dcid_7.l2w_820.trp_1.ld_0 1.188818 1.005682 1.173041 1.305288
297272
static constexpr int items = 18;
@@ -303,7 +278,7 @@ struct sm100_tuning<ValueT, AccumT, OffsetT, op_type::plus, primitive_accum::yes
303278
};
304279

305280
template <class ValueT, class AccumT, class OffsetT>
306-
struct sm100_tuning<ValueT, AccumT, OffsetT, op_type::plus, primitive_accum::yes, offset_size::_8, value_size::_1>
281+
struct sm100_tuning<ValueT, AccumT, OffsetT, op_kind_t::plus, primitive_accum::yes, offset_size::_8, value_size::_1>
307282
{
308283
// ipt_14.tpb_384.ns_228.dcid_7.l2w_775.trp_1.ld_1 1.107210 1.000000 1.100637 1.307692
309284
static constexpr int items = 14;
@@ -315,7 +290,7 @@ struct sm100_tuning<ValueT, AccumT, OffsetT, op_type::plus, primitive_accum::yes
315290
};
316291

317292
template <class ValueT, class AccumT, class OffsetT>
318-
struct sm100_tuning<ValueT, AccumT, OffsetT, op_type::plus, primitive_accum::yes, offset_size::_4, value_size::_2>
293+
struct sm100_tuning<ValueT, AccumT, OffsetT, op_kind_t::plus, primitive_accum::yes, offset_size::_4, value_size::_2>
319294
{
320295
// ipt_13.tpb_512.ns_1384.dcid_7.l2w_720.trp_1.ld_0 1.128443 1.002841 1.119688 1.307692
321296
static constexpr int items = 13;
@@ -331,7 +306,7 @@ struct sm100_tuning<ValueT, AccumT, OffsetT, op_type::plus, primitive_accum::yes
331306
// struct sm100_tuning<ValueT,
332307
// AccumT,
333308
// OffsetT,
334-
// op_type::plus,
309+
// op_kind_t::plus,
335310
// primitive_value::yes,
336311
// primitive_accum::yes,
337312
// offset_size::_8,
@@ -347,7 +322,7 @@ struct sm100_tuning<ValueT, AccumT, OffsetT, op_type::plus, primitive_accum::yes
347322
// };
348323

349324
template <class ValueT, class AccumT, class OffsetT>
350-
struct sm100_tuning<ValueT, AccumT, OffsetT, op_type::plus, primitive_accum::yes, offset_size::_4, value_size::_4>
325+
struct sm100_tuning<ValueT, AccumT, OffsetT, op_kind_t::plus, primitive_accum::yes, offset_size::_4, value_size::_4>
351326
{
352327
// ipt_22.tpb_384.ns_1904.dcid_6.l2w_830.trp_1.ld_0 1.148442 0.997167 1.139902 1.462651
353328
static constexpr int items = 22;
@@ -359,7 +334,7 @@ struct sm100_tuning<ValueT, AccumT, OffsetT, op_type::plus, primitive_accum::yes
359334
};
360335

361336
template <class ValueT, class AccumT, class OffsetT>
362-
struct sm100_tuning<ValueT, AccumT, OffsetT, op_type::plus, primitive_accum::yes, offset_size::_8, value_size::_4>
337+
struct sm100_tuning<ValueT, AccumT, OffsetT, op_kind_t::plus, primitive_accum::yes, offset_size::_8, value_size::_4>
363338
{
364339
// ipt_19.tpb_416.ns_956.dcid_7.l2w_550.trp_1.ld_1 1.146142 0.994350 1.137459 1.455636
365340
static constexpr int items = 19;
@@ -371,7 +346,7 @@ struct sm100_tuning<ValueT, AccumT, OffsetT, op_type::plus, primitive_accum::yes
371346
};
372347

373348
template <class ValueT, class AccumT, class OffsetT>
374-
struct sm100_tuning<ValueT, AccumT, OffsetT, op_type::plus, primitive_accum::yes, offset_size::_4, value_size::_8>
349+
struct sm100_tuning<ValueT, AccumT, OffsetT, op_kind_t::plus, primitive_accum::yes, offset_size::_4, value_size::_8>
375350
{
376351
// ipt_23.tpb_416.ns_772.dcid_5.l2w_710.trp_1.ld_0 1.089468 1.015581 1.085630 1.264583
377352
static constexpr int items = 23;
@@ -383,7 +358,7 @@ struct sm100_tuning<ValueT, AccumT, OffsetT, op_type::plus, primitive_accum::yes
383358
};
384359

385360
template <class ValueT, class AccumT, class OffsetT>
386-
struct sm100_tuning<ValueT, AccumT, OffsetT, op_type::plus, primitive_accum::yes, offset_size::_8, value_size::_8>
361+
struct sm100_tuning<ValueT, AccumT, OffsetT, op_kind_t::plus, primitive_accum::yes, offset_size::_8, value_size::_8>
387362
{
388363
// ipt_22.tpb_320.ns_328.dcid_2.l2w_965.trp_1.ld_0 1.080133 1.000000 1.075577 1.248963
389364
static constexpr int items = 22;
@@ -395,19 +370,19 @@ struct sm100_tuning<ValueT, AccumT, OffsetT, op_type::plus, primitive_accum::yes
395370
};
396371

397372
// todo(gonidelis): Add tunings for i128, float and double.
398-
// template <class OffsetT> struct sm100_tuning<float, OffsetT, op_type::plus, primitive_accum::yes, offset_size::_8,
373+
// template <class OffsetT> struct sm100_tuning<float, OffsetT, op_kind_t::plus, primitive_accum::yes, offset_size::_8,
399374
// accum_size::_4>;
400375
// Default explicitly so it doesn't pick up the sm100<I64, I64> tuning.
401376
template <class AccumT, class OffsetT>
402-
struct sm100_tuning<double, AccumT, OffsetT, op_type::plus, primitive_accum::yes, offset_size::_8, value_size::_8>
377+
struct sm100_tuning<double, AccumT, OffsetT, op_kind_t::plus, primitive_accum::yes, offset_size::_8, value_size::_8>
403378
: sm90_tuning<double, primitive_op::yes, primitive_accum::yes, accum_size::_8>
404379
{};
405380

406381
#if _CCCL_HAS_INT128()
407-
// template <class OffsetT> struct sm100_tuning<__int128_t, OffsetT, op_type::plus, primitive_accum::no,
382+
// template <class OffsetT> struct sm100_tuning<__int128_t, OffsetT, op_kind_t::plus, primitive_accum::no,
408383
// offset_size::_8, accum_size::_16> : tuning<576, 21, 860, 630> {}; template <class OffsetT> struct
409-
// sm100_tuning<__uint128_t, OffsetT, op_type::plus, primitive_accum::no, offset_size::_8, accum_size::_16>
410-
// : sm100_tuning<__int128_t, OffsetT, op_type::plus, primitive_accum::no, offset_size::_8, accum_size::_16>
384+
// sm100_tuning<__uint128_t, OffsetT, op_kind_t::plus, primitive_accum::no, offset_size::_8, accum_size::_16>
385+
// : sm100_tuning<__int128_t, OffsetT, op_kind_t::plus, primitive_accum::no, offset_size::_8, accum_size::_16>
411386
// {};
412387
#endif
413388

@@ -537,13 +512,16 @@ struct policy_hub
537512
_CCCL_HOST_DEVICE static auto select_agent_policy750(long) -> typename Policy600::ScanPolicyT;
538513

539514
using ScanPolicyT =
540-
decltype(select_agent_policy750<sm75_tuning<InputValueT, AccumT, OffsetT, classify_op<ScanOpT>()>, InputValueT>(
541-
0));
515+
decltype(select_agent_policy750<sm75_tuning<InputValueT, AccumT, OffsetT, classify_op<ScanOpT>>, InputValueT>(0));
542516
};
543517

544518
struct Policy800 : ChainedPolicy<800, Policy800, Policy750>
545519
{
546-
using ScanPolicyT = decltype(select_agent_policy<sm80_tuning<AccumT, is_primitive_op<ScanOpT>()>>(0));
520+
using ScanPolicyT =
521+
decltype(select_agent_policy<sm80_tuning<classify_type<AccumT>,
522+
is_primitive_op<ScanOpT>(),
523+
is_primitive_accum<AccumT>(),
524+
classify_accum_size<AccumT>()>>(0));
547525
};
548526

549527
struct Policy860
@@ -583,8 +561,7 @@ struct policy_hub
583561
_CCCL_HOST_DEVICE static auto select_agent_policy100(long) -> typename Policy900::ScanPolicyT;
584562

585563
using ScanPolicyT =
586-
decltype(select_agent_policy100<sm100_tuning<InputValueT, AccumT, OffsetT, classify_op<ScanOpT>()>, InputValueT>(
587-
0));
564+
decltype(select_agent_policy100<sm100_tuning<InputValueT, AccumT, OffsetT, classify_op<ScanOpT>>, InputValueT>(0));
588565
};
589566

590567
using MaxPolicy = Policy1000;

0 commit comments

Comments
 (0)