2929#include < cuda/std/iterator>
3030#include < cuda/std/type_traits>
3131#include < thrust/execution_policy.h>
32+ #include < thrust/iterator/constant_iterator.h>
3233#include < thrust/logical.h>
3334#include < thrust/reduce.h>
3435#if defined(CUCO_HAS_CUDA_BARRIER)
@@ -1084,8 +1085,16 @@ class open_addressing_ref_impl {
10841085 {
10851086 auto constexpr is_outer = false ;
10861087 auto const n = cuco::detail::distance (input_probe_begin, input_probe_end); // TODO include
1087- this ->retrieve_impl <is_outer, BlockSize>(
1088- block, input_probe_begin, n, output_probe, output_match, atomic_counter);
1088+ auto const always_true_stencil = thrust::constant_iterator<bool >(true );
1089+ auto const identity_predicate = cuda::std::identity{};
1090+ this ->retrieve_impl <is_outer, BlockSize>(block,
1091+ input_probe_begin,
1092+ n,
1093+ always_true_stencil,
1094+ identity_predicate,
1095+ output_probe,
1096+ output_match,
1097+ atomic_counter);
10891098 }
10901099
10911100 /* *
@@ -1133,8 +1142,73 @@ class open_addressing_ref_impl {
11331142 {
11341143 auto constexpr is_outer = true ;
11351144 auto const n = cuco::detail::distance (input_probe_begin, input_probe_end); // TODO include
1145+ auto const always_true_stencil = thrust::constant_iterator<bool >(true );
1146+ auto const identity_predicate = cuda::std::identity{};
1147+ this ->retrieve_impl <is_outer, BlockSize>(block,
1148+ input_probe_begin,
1149+ n,
1150+ always_true_stencil,
1151+ identity_predicate,
1152+ output_probe,
1153+ output_match,
1154+ atomic_counter);
1155+ }
1156+
1157+ /* *
1158+ * @brief Retrieves all the slots corresponding to all keys in the range `[input_probe_begin,
1159+ * input_probe_end)` if `pred` of the corresponding stencil returns true.
1160+ *
1161+ * If key `k = *(first + i)` exists in the container and `pred( *(stencil + i) )` returns true,
1162+ * copies `k` to `output_probe` and associated slot contents to `output_match`,
1163+ * respectively. The output order is unspecified.
1164+ *
1165+ * Behavior is undefined if the size of the output range exceeds the number of retrieved slots.
1166+ * Use `count()` to determine the size of the output range.
1167+ *
1168+ * @tparam BlockSize Size of the thread block this operation is executed in
1169+ * @tparam InputProbeIt Device accessible input iterator
1170+ * @tparam StencilIt Device accessible random access iterator whose value_type is
1171+ * convertible to Predicate's argument type
1172+ * @tparam Predicate Unary predicate callable whose return type must be convertible to `bool`
1173+ * and argument type is convertible from `std::iterator_traits<StencilIt>::value_type`
1174+ * @tparam OutputProbeIt Device accessible input iterator whose `value_type` is
1175+ * convertible to the `InputProbeIt`'s `value_type`
1176+ * @tparam OutputMatchIt Device accessible input iterator whose `value_type` is
1177+ * convertible to the container's `value_type`
1178+ * @tparam AtomicCounter Integral atomic counter type that follows the same semantics as
1179+ * `cuda::(std::)atomic(_ref)`
1180+ *
1181+ * @param block Thread block this operation is executed in
1182+ * @param input_probe_begin Beginning of the input sequence of keys
1183+ * @param input_probe_end End of the input sequence of keys
1184+ * @param stencil Beginning of the stencil sequence
1185+ * @param pred Predicate to test on every element in the range `[stencil, stencil + n)`
1186+ * @param output_probe Beginning of the sequence of keys corresponding to matching elements in
1187+ * `output_match`
1188+ * @param output_match Beginning of the sequence of matching elements
1189+ * @param atomic_counter Atomic object of integral type that is used to count the
1190+ * number of output elements
1191+ */
1192+ template <int BlockSize,
1193+ class InputProbeIt ,
1194+ class StencilIt ,
1195+ class Predicate ,
1196+ class OutputProbeIt ,
1197+ class OutputMatchIt ,
1198+ class AtomicCounter >
1199+ __device__ void retrieve_if (cooperative_groups::thread_block const & block,
1200+ InputProbeIt input_probe_begin,
1201+ InputProbeIt input_probe_end,
1202+ StencilIt stencil,
1203+ Predicate pred,
1204+ OutputProbeIt output_probe,
1205+ OutputMatchIt output_match,
1206+ AtomicCounter& atomic_counter) const
1207+ {
1208+ auto constexpr is_outer = false ;
1209+ auto const n = cuco::detail::distance (input_probe_begin, input_probe_end);
11361210 this ->retrieve_impl <is_outer, BlockSize>(
1137- block, input_probe_begin, n, output_probe, output_match, atomic_counter);
1211+ block, input_probe_begin, n, stencil, pred, output_probe, output_match, atomic_counter);
11381212 }
11391213
11401214 /* *
@@ -1153,6 +1227,10 @@ class open_addressing_ref_impl {
11531227 * @tparam IsOuter Flag indicating if an inner or outer retrieve operation should be performed
11541228 * @tparam BlockSize Size of the thread block this operation is executed in
11551229 * @tparam InputProbeIt Device accessible input iterator
1230+ * @tparam StencilIt Device accessible random access iterator whose value_type is
1231+ * convertible to Predicate's argument type
1232+ * @tparam Predicate Unary predicate callable whose return type must be convertible to `bool`
1233+ * and argument type is convertible from `std::iterator_traits<StencilIt>::value_type`
11561234 * @tparam OutputProbeIt Device accessible input iterator whose `value_type` is
11571235 * convertible to the `InputProbeIt`'s `value_type`
11581236 * @tparam OutputMatchIt Device accessible input iterator whose `value_type` is
@@ -1161,8 +1239,10 @@ class open_addressing_ref_impl {
11611239 * `cuda::(std::)atomic(_ref)`
11621240 *
11631241 * @param block Thread block this operation is executed in
1164- * @param input_probe_begin Beginning of the input sequence of keys
1165- * @param input_probe_end End of the input sequence of keys
1242+ * @param input_probe Beginning of the input sequence of keys
1243+ * @param n Number of input keys
1244+ * @param stencil Beginning of the stencil sequence
1245+ * @param pred Predicate to test on every element in the range `[stencil, stencil + n)`
11661246 * @param output_probe Beginning of the sequence of keys corresponding to matching elements in
11671247 * `output_match`
11681248 * @param output_match Beginning of the sequence of matching elements
@@ -1172,12 +1252,16 @@ class open_addressing_ref_impl {
11721252 template <bool IsOuter,
11731253 int BlockSize,
11741254 class InputProbeIt ,
1255+ class StencilIt ,
1256+ class Predicate ,
11751257 class OutputProbeIt ,
11761258 class OutputMatchIt ,
11771259 class AtomicCounter >
11781260 __device__ void retrieve_impl (cooperative_groups::thread_block const & block,
11791261 InputProbeIt input_probe,
11801262 cuco::detail::index_type n,
1263+ StencilIt stencil,
1264+ Predicate pred,
11811265 OutputProbeIt output_probe,
11821266 OutputMatchIt output_match,
11831267 AtomicCounter& atomic_counter) const
@@ -1228,15 +1312,16 @@ class open_addressing_ref_impl {
12281312 };
12291313
12301314 while (flushing_tile.any (idx < n)) {
1231- bool active_flag = idx < n;
1315+ bool active_flag = idx < n and pred (*(stencil + idx)) ;
12321316 auto const active_flushing_tile =
12331317 cg::binary_partition<flushing_tile_size>(flushing_tile, active_flag);
12341318
12351319 if (active_flag) {
12361320 // perform probing
12371321 // make sure the flushing_tile is converged at this point to get a coalesced load
12381322 auto const probe_key = *(input_probe + idx);
1239- auto probing_iter = probing_scheme_.template make_iterator <bucket_size>(
1323+
1324+ auto probing_iter = probing_scheme_.template make_iterator <bucket_size>(
12401325 probing_tile, probe_key, storage_ref_.extent ());
12411326 auto const init_idx = *probing_iter;
12421327
0 commit comments