2828#include < cuda/std/iterator>
2929#include < cuda/std/type_traits>
3030#include < thrust/execution_policy.h>
31+ #include < thrust/iterator/constant_iterator.h>
3132#include < thrust/logical.h>
3233#include < thrust/reduce.h>
3334#if defined(CUCO_HAS_CUDA_BARRIER)
@@ -1085,8 +1086,16 @@ class open_addressing_ref_impl {
10851086 {
10861087 auto constexpr is_outer = false ;
10871088 auto const n = cuco::detail::distance (input_probe_begin, input_probe_end); // TODO include
1088- this ->retrieve_impl <is_outer, BlockSize>(
1089- block, input_probe_begin, n, output_probe, output_match, atomic_counter);
1089+ auto const always_true_stencil = thrust::constant_iterator<bool >(true );
1090+ auto const identity_predicate = cuda::std::identity{};
1091+ this ->retrieve_impl <is_outer, BlockSize>(block,
1092+ input_probe_begin,
1093+ n,
1094+ always_true_stencil,
1095+ identity_predicate,
1096+ output_probe,
1097+ output_match,
1098+ atomic_counter);
10901099 }
10911100
10921101 /* *
@@ -1134,8 +1143,73 @@ class open_addressing_ref_impl {
11341143 {
11351144 auto constexpr is_outer = true ;
11361145 auto const n = cuco::detail::distance (input_probe_begin, input_probe_end); // TODO include
1146+ auto const always_true_stencil = thrust::constant_iterator<bool >(true );
1147+ auto const identity_predicate = cuda::std::identity{};
1148+ this ->retrieve_impl <is_outer, BlockSize>(block,
1149+ input_probe_begin,
1150+ n,
1151+ always_true_stencil,
1152+ identity_predicate,
1153+ output_probe,
1154+ output_match,
1155+ atomic_counter);
1156+ }
1157+
1158+ /* *
1159+ * @brief Retrieves all the slots corresponding to all keys in the range `[input_probe_begin,
1160+ * input_probe_end)` if `pred` of the corresponding stencil returns true.
1161+ *
1162+ * If key `k = *(first + i)` exists in the container and `pred( *(stencil + i) )` returns true,
1163+ * copies `k` to `output_probe` and associated slot contents to `output_match`,
1164+ * respectively. The output order is unspecified.
1165+ *
1166+ * Behavior is undefined if the size of the output range exceeds the number of retrieved slots.
1167+ * Use `count()` to determine the size of the output range.
1168+ *
1169+ * @tparam BlockSize Size of the thread block this operation is executed in
1170+ * @tparam InputProbeIt Device accessible input iterator
1171+ * @tparam StencilIt Device accessible random access iterator whose value_type is
1172+ * convertible to Predicate's argument type
1173+ * @tparam Predicate Unary predicate callable whose return type must be convertible to `bool`
1174+ * and argument type is convertible from `std::iterator_traits<StencilIt>::value_type`
1175+ * @tparam OutputProbeIt Device accessible input iterator whose `value_type` is
1176+ * convertible to the `InputProbeIt`'s `value_type`
1177+ * @tparam OutputMatchIt Device accessible input iterator whose `value_type` is
1178+ * convertible to the container's `value_type`
1179+ * @tparam AtomicCounter Integral atomic counter type that follows the same semantics as
1180+ * `cuda::(std::)atomic(_ref)`
1181+ *
1182+ * @param block Thread block this operation is executed in
1183+ * @param input_probe_begin Beginning of the input sequence of keys
1184+ * @param input_probe_end End of the input sequence of keys
1185+ * @param stencil Beginning of the stencil sequence
1186+ * @param pred Predicate to test on every element in the range `[stencil, stencil + n)`
1187+ * @param output_probe Beginning of the sequence of keys corresponding to matching elements in
1188+ * `output_match`
1189+ * @param output_match Beginning of the sequence of matching elements
1190+ * @param atomic_counter Atomic object of integral type that is used to count the
1191+ * number of output elements
1192+ */
1193+ template <int BlockSize,
1194+ class InputProbeIt ,
1195+ class StencilIt ,
1196+ class Predicate ,
1197+ class OutputProbeIt ,
1198+ class OutputMatchIt ,
1199+ class AtomicCounter >
1200+ __device__ void retrieve_if (cooperative_groups::thread_block const & block,
1201+ InputProbeIt input_probe_begin,
1202+ InputProbeIt input_probe_end,
1203+ StencilIt stencil,
1204+ Predicate pred,
1205+ OutputProbeIt output_probe,
1206+ OutputMatchIt output_match,
1207+ AtomicCounter& atomic_counter) const
1208+ {
1209+ auto constexpr is_outer = false ;
1210+ auto const n = cuco::detail::distance (input_probe_begin, input_probe_end);
11371211 this ->retrieve_impl <is_outer, BlockSize>(
1138- block, input_probe_begin, n, output_probe, output_match, atomic_counter);
1212+ block, input_probe_begin, n, stencil, pred, output_probe, output_match, atomic_counter);
11391213 }
11401214
11411215 /* *
@@ -1154,6 +1228,10 @@ class open_addressing_ref_impl {
11541228 * @tparam IsOuter Flag indicating if an inner or outer retrieve operation should be performed
11551229 * @tparam BlockSize Size of the thread block this operation is executed in
11561230 * @tparam InputProbeIt Device accessible input iterator
1231+ * @tparam StencilIt Device accessible random access iterator whose value_type is
1232+ * convertible to Predicate's argument type
1233+ * @tparam Predicate Unary predicate callable whose return type must be convertible to `bool`
1234+ * and argument type is convertible from `std::iterator_traits<StencilIt>::value_type`
11571235 * @tparam OutputProbeIt Device accessible input iterator whose `value_type` is
11581236 * convertible to the `InputProbeIt`'s `value_type`
11591237 * @tparam OutputMatchIt Device accessible input iterator whose `value_type` is
@@ -1162,8 +1240,10 @@ class open_addressing_ref_impl {
11621240 * `cuda::(std::)atomic(_ref)`
11631241 *
11641242 * @param block Thread block this operation is executed in
1165- * @param input_probe_begin Beginning of the input sequence of keys
1166- * @param input_probe_end End of the input sequence of keys
1243+ * @param input_probe Beginning of the input sequence of keys
1244+ * @param n Number of input keys
1245+ * @param stencil Beginning of the stencil sequence
1246+ * @param pred Predicate to test on every element in the range `[stencil, stencil + n)`
11671247 * @param output_probe Beginning of the sequence of keys corresponding to matching elements in
11681248 * `output_match`
11691249 * @param output_match Beginning of the sequence of matching elements
@@ -1173,12 +1253,16 @@ class open_addressing_ref_impl {
11731253 template <bool IsOuter,
11741254 int32_t BlockSize,
11751255 class InputProbeIt ,
1256+ class StencilIt ,
1257+ class Predicate ,
11761258 class OutputProbeIt ,
11771259 class OutputMatchIt ,
11781260 class AtomicCounter >
11791261 __device__ void retrieve_impl (cooperative_groups::thread_block const & block,
11801262 InputProbeIt input_probe,
11811263 cuco::detail::index_type n,
1264+ StencilIt stencil,
1265+ Predicate pred,
11821266 OutputProbeIt output_probe,
11831267 OutputMatchIt output_match,
11841268 AtomicCounter& atomic_counter) const
@@ -1229,15 +1313,16 @@ class open_addressing_ref_impl {
12291313 };
12301314
12311315 while (flushing_tile.any (idx < n)) {
1232- bool active_flag = idx < n;
1316+ bool active_flag = idx < n and pred (*(stencil + idx)) ;
12331317 auto const active_flushing_tile =
12341318 cg::binary_partition<flushing_tile_size>(flushing_tile, active_flag);
12351319
12361320 if (active_flag) {
12371321 // perform probing
12381322 // make sure the flushing_tile is converged at this point to get a coalesced load
12391323 auto const probe_key = *(input_probe + idx);
1240- auto probing_iter = probing_scheme_.template make_iterator <bucket_size>(
1324+
1325+ auto probing_iter = probing_scheme_.template make_iterator <bucket_size>(
12411326 probing_tile, probe_key, storage_ref_.extent ());
12421327 auto const init_idx = *probing_iter;
12431328
0 commit comments