1818#include < cub/block/block_load.cuh>
1919#include < cub/block/block_scan.cuh>
2020#include < cub/block/block_store.cuh>
21+ #include < cub/device/dispatch/tuning/common.cuh>
2122#include < cub/thread/thread_load.cuh>
2223#include < cub/util_device.cuh>
2324#include < cub/util_type.cuh>
@@ -46,11 +47,6 @@ enum class primitive_op
4647 no,
4748 yes
4849};
49- enum class op_type
50- {
51- plus,
52- unknown
53- };
5450enum class offset_size
5551{
5652 _4,
@@ -88,24 +84,6 @@ constexpr _CCCL_HOST_DEVICE primitive_op is_primitive_op()
8884 return basic_binary_op_t <ScanOpT>::value ? primitive_op::yes : primitive_op::no;
8985}
9086
91- template <typename Op>
92- struct is_plus
93- {
94- static constexpr bool value = false ;
95- };
96-
97- template <typename T>
98- struct is_plus <::cuda::std::plus<T>>
99- {
100- static constexpr bool value = true ;
101- };
102-
103- template <class ScanOpT >
104- constexpr _CCCL_HOST_DEVICE op_type classify_op ()
105- {
106- return is_plus<ScanOpT>::value ? op_type::plus : op_type::unknown;
107- }
108-
10987template <class ValueT >
11088constexpr _CCCL_HOST_DEVICE value_size classify_value_size ()
11189{
@@ -139,14 +117,14 @@ constexpr _CCCL_HOST_DEVICE offset_size classify_offset_size()
139117template <class ValueT ,
140118 class AccumT ,
141119 class OffsetT ,
142- op_type OpTypeT,
120+ op_kind_t OpTypeT,
143121 primitive_accum PrimitiveAccumulator = is_primitive_accum<AccumT>(),
144122 offset_size OffsetSize = classify_offset_size<OffsetT>(),
145123 value_size ValueSize = classify_value_size<ValueT>()>
146124struct sm75_tuning ;
147125
148126template <class ValueT , class AccumT , class OffsetT >
149- struct sm75_tuning <ValueT, AccumT, OffsetT, op_type ::plus, primitive_accum::yes, offset_size::_8, value_size::_4>
127+ struct sm75_tuning <ValueT, AccumT, OffsetT, op_kind_t ::plus, primitive_accum::yes, offset_size::_8, value_size::_4>
150128{
151129 // ipt_7.tpb_128.ns_628.dcid_1.l2w_520.trp_1.ld_0
152130 static constexpr int threads = 128 ;
@@ -159,13 +137,10 @@ struct sm75_tuning<ValueT, AccumT, OffsetT, op_type::plus, primitive_accum::yes,
159137
160138// Add sm89 tuning and verify it
161139
162- template <class AccumT ,
163- primitive_op PrimitiveOp,
164- primitive_accum PrimitiveAccumulator = is_primitive_accum<AccumT>(),
165- accum_size AccumSize = classify_accum_size<AccumT>()>
140+ template <type_t AccumT, primitive_op PrimitiveOp, primitive_accum PrimitiveAccumulator, accum_size AccumSize>
166141struct sm80_tuning ;
167142
168- template <class T >
143+ template <type_t T>
169144struct sm80_tuning <T, primitive_op::yes, primitive_accum::yes, accum_size::_1>
170145{
171146 static constexpr int threads = 320 ;
@@ -175,7 +150,7 @@ struct sm80_tuning<T, primitive_op::yes, primitive_accum::yes, accum_size::_1>
175150 static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE;
176151};
177152
178- template <class T >
153+ template <type_t T>
179154struct sm80_tuning <T, primitive_op::yes, primitive_accum::yes, accum_size::_2>
180155{
181156 static constexpr int threads = 352 ;
@@ -185,7 +160,7 @@ struct sm80_tuning<T, primitive_op::yes, primitive_accum::yes, accum_size::_2>
185160 static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE;
186161};
187162
188- template <class T >
163+ template <type_t T>
189164struct sm80_tuning <T, primitive_op::yes, primitive_accum::yes, accum_size::_4>
190165{
191166 static constexpr int threads = 320 ;
@@ -195,7 +170,7 @@ struct sm80_tuning<T, primitive_op::yes, primitive_accum::yes, accum_size::_4>
195170 static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE;
196171};
197172
198- template <class T >
173+ template <type_t T>
199174struct sm80_tuning <T, primitive_op::yes, primitive_accum::yes, accum_size::_8>
200175{
201176 static constexpr int threads = 288 ;
@@ -206,7 +181,7 @@ struct sm80_tuning<T, primitive_op::yes, primitive_accum::yes, accum_size::_8>
206181};
207182
208183template <>
209- struct sm80_tuning <float , primitive_op::yes, primitive_accum::yes, accum_size::_4>
184+ struct sm80_tuning <type_t ::float32 , primitive_op::yes, primitive_accum::yes, accum_size::_4>
210185{
211186 static constexpr int threads = 288 ;
212187 static constexpr int items = 8 ;
@@ -216,7 +191,7 @@ struct sm80_tuning<float, primitive_op::yes, primitive_accum::yes, accum_size::_
216191};
217192
218193template <>
219- struct sm80_tuning <double , primitive_op::yes, primitive_accum::yes, accum_size::_8>
194+ struct sm80_tuning <type_t ::float64 , primitive_op::yes, primitive_accum::yes, accum_size::_8>
220195{
221196 static constexpr int threads = 384 ;
222197 static constexpr int items = 12 ;
@@ -227,7 +202,7 @@ struct sm80_tuning<double, primitive_op::yes, primitive_accum::yes, accum_size::
227202
228203#if _CCCL_HAS_INT128()
229204template <>
230- struct sm80_tuning <__int128_t , primitive_op::yes, primitive_accum::no, accum_size::_16>
205+ struct sm80_tuning <type_t ::int128 , primitive_op::yes, primitive_accum::no, accum_size::_16>
231206{
232207 static constexpr int threads = 640 ;
233208 static constexpr int items = 24 ;
@@ -237,8 +212,8 @@ struct sm80_tuning<__int128_t, primitive_op::yes, primitive_accum::no, accum_siz
237212};
238213
239214template <>
240- struct sm80_tuning <__uint128_t , primitive_op::yes, primitive_accum::no, accum_size::_16>
241- : sm80_tuning<__int128_t , primitive_op::yes, primitive_accum::no, accum_size::_16>
215+ struct sm80_tuning <type_t ::uint128 , primitive_op::yes, primitive_accum::no, accum_size::_16>
216+ : sm80_tuning<type_t ::int128 , primitive_op::yes, primitive_accum::no, accum_size::_16>
242217{};
243218#endif
244219
@@ -283,15 +258,15 @@ struct sm90_tuning<__uint128_t, primitive_op::yes, primitive_accum::no, accum_si
283258template <class ValueT ,
284259 class AccumT ,
285260 class OffsetT ,
286- op_type OpTypeT,
261+ op_kind_t OpTypeT,
287262 primitive_accum PrimitiveAccumulator = is_primitive_accum<AccumT>(),
288263 offset_size OffsetSize = classify_offset_size<OffsetT>(),
289264 value_size ValueSize = classify_value_size<ValueT>()>
290265struct sm100_tuning ;
291266
292267// sum
293268template <class ValueT , class AccumT , class OffsetT >
294- struct sm100_tuning <ValueT, AccumT, OffsetT, op_type ::plus, primitive_accum::yes, offset_size::_4, value_size::_1>
269+ struct sm100_tuning <ValueT, AccumT, OffsetT, op_kind_t ::plus, primitive_accum::yes, offset_size::_4, value_size::_1>
295270{
296271 // ipt_18.tpb_512.ns_768.dcid_7.l2w_820.trp_1.ld_0 1.188818 1.005682 1.173041 1.305288
297272 static constexpr int items = 18 ;
@@ -303,7 +278,7 @@ struct sm100_tuning<ValueT, AccumT, OffsetT, op_type::plus, primitive_accum::yes
303278};
304279
305280template <class ValueT , class AccumT , class OffsetT >
306- struct sm100_tuning <ValueT, AccumT, OffsetT, op_type ::plus, primitive_accum::yes, offset_size::_8, value_size::_1>
281+ struct sm100_tuning <ValueT, AccumT, OffsetT, op_kind_t ::plus, primitive_accum::yes, offset_size::_8, value_size::_1>
307282{
308283 // ipt_14.tpb_384.ns_228.dcid_7.l2w_775.trp_1.ld_1 1.107210 1.000000 1.100637 1.307692
309284 static constexpr int items = 14 ;
@@ -315,7 +290,7 @@ struct sm100_tuning<ValueT, AccumT, OffsetT, op_type::plus, primitive_accum::yes
315290};
316291
317292template <class ValueT , class AccumT , class OffsetT >
318- struct sm100_tuning <ValueT, AccumT, OffsetT, op_type ::plus, primitive_accum::yes, offset_size::_4, value_size::_2>
293+ struct sm100_tuning <ValueT, AccumT, OffsetT, op_kind_t ::plus, primitive_accum::yes, offset_size::_4, value_size::_2>
319294{
320295 // ipt_13.tpb_512.ns_1384.dcid_7.l2w_720.trp_1.ld_0 1.128443 1.002841 1.119688 1.307692
321296 static constexpr int items = 13 ;
@@ -331,7 +306,7 @@ struct sm100_tuning<ValueT, AccumT, OffsetT, op_type::plus, primitive_accum::yes
331306// struct sm100_tuning<ValueT,
332307// AccumT,
333308// OffsetT,
334- // op_type ::plus,
309+ // op_kind_t ::plus,
335310// primitive_value::yes,
336311// primitive_accum::yes,
337312// offset_size::_8,
@@ -347,7 +322,7 @@ struct sm100_tuning<ValueT, AccumT, OffsetT, op_type::plus, primitive_accum::yes
347322// };
348323
349324template <class ValueT , class AccumT , class OffsetT >
350- struct sm100_tuning <ValueT, AccumT, OffsetT, op_type ::plus, primitive_accum::yes, offset_size::_4, value_size::_4>
325+ struct sm100_tuning <ValueT, AccumT, OffsetT, op_kind_t ::plus, primitive_accum::yes, offset_size::_4, value_size::_4>
351326{
352327 // ipt_22.tpb_384.ns_1904.dcid_6.l2w_830.trp_1.ld_0 1.148442 0.997167 1.139902 1.462651
353328 static constexpr int items = 22 ;
@@ -359,7 +334,7 @@ struct sm100_tuning<ValueT, AccumT, OffsetT, op_type::plus, primitive_accum::yes
359334};
360335
361336template <class ValueT , class AccumT , class OffsetT >
362- struct sm100_tuning <ValueT, AccumT, OffsetT, op_type ::plus, primitive_accum::yes, offset_size::_8, value_size::_4>
337+ struct sm100_tuning <ValueT, AccumT, OffsetT, op_kind_t ::plus, primitive_accum::yes, offset_size::_8, value_size::_4>
363338{
364339 // ipt_19.tpb_416.ns_956.dcid_7.l2w_550.trp_1.ld_1 1.146142 0.994350 1.137459 1.455636
365340 static constexpr int items = 19 ;
@@ -371,7 +346,7 @@ struct sm100_tuning<ValueT, AccumT, OffsetT, op_type::plus, primitive_accum::yes
371346};
372347
373348template <class ValueT , class AccumT , class OffsetT >
374- struct sm100_tuning <ValueT, AccumT, OffsetT, op_type ::plus, primitive_accum::yes, offset_size::_4, value_size::_8>
349+ struct sm100_tuning <ValueT, AccumT, OffsetT, op_kind_t ::plus, primitive_accum::yes, offset_size::_4, value_size::_8>
375350{
376351 // ipt_23.tpb_416.ns_772.dcid_5.l2w_710.trp_1.ld_0 1.089468 1.015581 1.085630 1.264583
377352 static constexpr int items = 23 ;
@@ -383,7 +358,7 @@ struct sm100_tuning<ValueT, AccumT, OffsetT, op_type::plus, primitive_accum::yes
383358};
384359
385360template <class ValueT , class AccumT , class OffsetT >
386- struct sm100_tuning <ValueT, AccumT, OffsetT, op_type ::plus, primitive_accum::yes, offset_size::_8, value_size::_8>
361+ struct sm100_tuning <ValueT, AccumT, OffsetT, op_kind_t ::plus, primitive_accum::yes, offset_size::_8, value_size::_8>
387362{
388363 // ipt_22.tpb_320.ns_328.dcid_2.l2w_965.trp_1.ld_0 1.080133 1.000000 1.075577 1.248963
389364 static constexpr int items = 22 ;
@@ -395,19 +370,19 @@ struct sm100_tuning<ValueT, AccumT, OffsetT, op_type::plus, primitive_accum::yes
395370};
396371
397372// todo(gonidelis): Add tunings for i128, float and double.
398- // template <class OffsetT> struct sm100_tuning<float, OffsetT, op_type ::plus, primitive_accum::yes, offset_size::_8,
373+ // template <class OffsetT> struct sm100_tuning<float, OffsetT, op_kind_t ::plus, primitive_accum::yes, offset_size::_8,
399374// accum_size::_4>;
400375// Default explicitly so it doesn't pick up the sm100<I64, I64> tuning.
401376template <class AccumT , class OffsetT >
402- struct sm100_tuning <double , AccumT, OffsetT, op_type ::plus, primitive_accum::yes, offset_size::_8, value_size::_8>
377+ struct sm100_tuning <double , AccumT, OffsetT, op_kind_t ::plus, primitive_accum::yes, offset_size::_8, value_size::_8>
403378 : sm90_tuning<double , primitive_op::yes, primitive_accum::yes, accum_size::_8>
404379{};
405380
406381#if _CCCL_HAS_INT128()
407- // template <class OffsetT> struct sm100_tuning<__int128_t, OffsetT, op_type ::plus, primitive_accum::no,
382+ // template <class OffsetT> struct sm100_tuning<__int128_t, OffsetT, op_kind_t ::plus, primitive_accum::no,
408383// offset_size::_8, accum_size::_16> : tuning<576, 21, 860, 630> {}; template <class OffsetT> struct
409- // sm100_tuning<__uint128_t, OffsetT, op_type ::plus, primitive_accum::no, offset_size::_8, accum_size::_16>
410- // : sm100_tuning<__int128_t, OffsetT, op_type ::plus, primitive_accum::no, offset_size::_8, accum_size::_16>
384+ // sm100_tuning<__uint128_t, OffsetT, op_kind_t ::plus, primitive_accum::no, offset_size::_8, accum_size::_16>
385+ // : sm100_tuning<__int128_t, OffsetT, op_kind_t ::plus, primitive_accum::no, offset_size::_8, accum_size::_16>
411386// {};
412387#endif
413388
@@ -537,13 +512,16 @@ struct policy_hub
537512 _CCCL_HOST_DEVICE static auto select_agent_policy750 (long ) -> typename Policy600::ScanPolicyT;
538513
539514 using ScanPolicyT =
540- decltype (select_agent_policy750<sm75_tuning<InputValueT, AccumT, OffsetT, classify_op<ScanOpT>()>, InputValueT>(
541- 0 ));
515+ decltype (select_agent_policy750<sm75_tuning<InputValueT, AccumT, OffsetT, classify_op<ScanOpT>>, InputValueT>(0 ));
542516 };
543517
544518 struct Policy800 : ChainedPolicy<800 , Policy800, Policy750>
545519 {
546- using ScanPolicyT = decltype (select_agent_policy<sm80_tuning<AccumT, is_primitive_op<ScanOpT>()>>(0 ));
520+ using ScanPolicyT =
521+ decltype (select_agent_policy<sm80_tuning<classify_type<AccumT>,
522+ is_primitive_op<ScanOpT>(),
523+ is_primitive_accum<AccumT>(),
524+ classify_accum_size<AccumT>()>>(0 ));
547525 };
548526
549527 struct Policy860
@@ -583,8 +561,7 @@ struct policy_hub
583561 _CCCL_HOST_DEVICE static auto select_agent_policy100 (long ) -> typename Policy900::ScanPolicyT;
584562
585563 using ScanPolicyT =
586- decltype (select_agent_policy100<sm100_tuning<InputValueT, AccumT, OffsetT, classify_op<ScanOpT>()>, InputValueT>(
587- 0 ));
564+ decltype (select_agent_policy100<sm100_tuning<InputValueT, AccumT, OffsetT, classify_op<ScanOpT>>, InputValueT>(0 ));
588565 };
589566
590567 using MaxPolicy = Policy1000;
0 commit comments