Skip to content

Commit b5b1286

Browse files
authored
Merge branch 'dev' into fix-docs
2 parents ffe9f98 + 0d3193f commit b5b1286

File tree

11 files changed

+1278
-11
lines changed

11 files changed

+1278
-11
lines changed

include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh

Lines changed: 92 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <cuda/std/iterator>
2929
#include <cuda/std/type_traits>
3030
#include <thrust/execution_policy.h>
31+
#include <thrust/iterator/constant_iterator.h>
3132
#include <thrust/logical.h>
3233
#include <thrust/reduce.h>
3334
#if defined(CUCO_HAS_CUDA_BARRIER)
@@ -1085,8 +1086,16 @@ class open_addressing_ref_impl {
10851086
{
10861087
auto constexpr is_outer = false;
10871088
auto const n = cuco::detail::distance(input_probe_begin, input_probe_end); // TODO include
1088-
this->retrieve_impl<is_outer, BlockSize>(
1089-
block, input_probe_begin, n, output_probe, output_match, atomic_counter);
1089+
auto const always_true_stencil = thrust::constant_iterator<bool>(true);
1090+
auto const identity_predicate = cuda::std::identity{};
1091+
this->retrieve_impl<is_outer, BlockSize>(block,
1092+
input_probe_begin,
1093+
n,
1094+
always_true_stencil,
1095+
identity_predicate,
1096+
output_probe,
1097+
output_match,
1098+
atomic_counter);
10901099
}
10911100

10921101
/**
@@ -1134,8 +1143,73 @@ class open_addressing_ref_impl {
11341143
{
11351144
auto constexpr is_outer = true;
11361145
auto const n = cuco::detail::distance(input_probe_begin, input_probe_end); // TODO include
1146+
auto const always_true_stencil = thrust::constant_iterator<bool>(true);
1147+
auto const identity_predicate = cuda::std::identity{};
1148+
this->retrieve_impl<is_outer, BlockSize>(block,
1149+
input_probe_begin,
1150+
n,
1151+
always_true_stencil,
1152+
identity_predicate,
1153+
output_probe,
1154+
output_match,
1155+
atomic_counter);
1156+
}
1157+
1158+
/**
1159+
* @brief Retrieves all the slots corresponding to all keys in the range `[input_probe_begin,
1160+
* input_probe_end)` if `pred` of the corresponding stencil returns true.
1161+
*
1162+
* If key `k = *(first + i)` exists in the container and `pred( *(stencil + i) )` returns true,
1163+
* copies `k` to `output_probe` and associated slot contents to `output_match`,
1164+
* respectively. The output order is unspecified.
1165+
*
1166+
* Behavior is undefined if the size of the output range exceeds the number of retrieved slots.
1167+
* Use `count()` to determine the size of the output range.
1168+
*
1169+
* @tparam BlockSize Size of the thread block this operation is executed in
1170+
* @tparam InputProbeIt Device accessible input iterator
1171+
* @tparam StencilIt Device accessible random access iterator whose value_type is
1172+
* convertible to Predicate's argument type
1173+
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool`
1174+
* and argument type is convertible from `std::iterator_traits<StencilIt>::value_type`
1175+
* @tparam OutputProbeIt Device accessible input iterator whose `value_type` is
1176+
* convertible to the `InputProbeIt`'s `value_type`
1177+
* @tparam OutputMatchIt Device accessible input iterator whose `value_type` is
1178+
* convertible to the container's `value_type`
1179+
* @tparam AtomicCounter Integral atomic counter type that follows the same semantics as
1180+
* `cuda::(std::)atomic(_ref)`
1181+
*
1182+
* @param block Thread block this operation is executed in
1183+
* @param input_probe_begin Beginning of the input sequence of keys
1184+
* @param input_probe_end End of the input sequence of keys
1185+
* @param stencil Beginning of the stencil sequence
1186+
* @param pred Predicate to test on every element in the range `[stencil, stencil + n)`
1187+
* @param output_probe Beginning of the sequence of keys corresponding to matching elements in
1188+
* `output_match`
1189+
* @param output_match Beginning of the sequence of matching elements
1190+
* @param atomic_counter Atomic object of integral type that is used to count the
1191+
* number of output elements
1192+
*/
1193+
template <int BlockSize,
1194+
class InputProbeIt,
1195+
class StencilIt,
1196+
class Predicate,
1197+
class OutputProbeIt,
1198+
class OutputMatchIt,
1199+
class AtomicCounter>
1200+
__device__ void retrieve_if(cooperative_groups::thread_block const& block,
1201+
InputProbeIt input_probe_begin,
1202+
InputProbeIt input_probe_end,
1203+
StencilIt stencil,
1204+
Predicate pred,
1205+
OutputProbeIt output_probe,
1206+
OutputMatchIt output_match,
1207+
AtomicCounter& atomic_counter) const
1208+
{
1209+
auto constexpr is_outer = false;
1210+
auto const n = cuco::detail::distance(input_probe_begin, input_probe_end);
11371211
this->retrieve_impl<is_outer, BlockSize>(
1138-
block, input_probe_begin, n, output_probe, output_match, atomic_counter);
1212+
block, input_probe_begin, n, stencil, pred, output_probe, output_match, atomic_counter);
11391213
}
11401214

11411215
/**
@@ -1154,6 +1228,10 @@ class open_addressing_ref_impl {
11541228
* @tparam IsOuter Flag indicating if an inner or outer retrieve operation should be performed
11551229
* @tparam BlockSize Size of the thread block this operation is executed in
11561230
* @tparam InputProbeIt Device accessible input iterator
1231+
* @tparam StencilIt Device accessible random access iterator whose value_type is
1232+
* convertible to Predicate's argument type
1233+
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool`
1234+
* and argument type is convertible from `std::iterator_traits<StencilIt>::value_type`
11571235
* @tparam OutputProbeIt Device accessible input iterator whose `value_type` is
11581236
* convertible to the `InputProbeIt`'s `value_type`
11591237
* @tparam OutputMatchIt Device accessible input iterator whose `value_type` is
@@ -1162,8 +1240,10 @@ class open_addressing_ref_impl {
11621240
* `cuda::(std::)atomic(_ref)`
11631241
*
11641242
* @param block Thread block this operation is executed in
1165-
* @param input_probe_begin Beginning of the input sequence of keys
1166-
* @param input_probe_end End of the input sequence of keys
1243+
* @param input_probe Beginning of the input sequence of keys
1244+
* @param n Number of input keys
1245+
* @param stencil Beginning of the stencil sequence
1246+
* @param pred Predicate to test on every element in the range `[stencil, stencil + n)`
11671247
* @param output_probe Beginning of the sequence of keys corresponding to matching elements in
11681248
* `output_match`
11691249
* @param output_match Beginning of the sequence of matching elements
@@ -1173,12 +1253,16 @@ class open_addressing_ref_impl {
11731253
template <bool IsOuter,
11741254
int32_t BlockSize,
11751255
class InputProbeIt,
1256+
class StencilIt,
1257+
class Predicate,
11761258
class OutputProbeIt,
11771259
class OutputMatchIt,
11781260
class AtomicCounter>
11791261
__device__ void retrieve_impl(cooperative_groups::thread_block const& block,
11801262
InputProbeIt input_probe,
11811263
cuco::detail::index_type n,
1264+
StencilIt stencil,
1265+
Predicate pred,
11821266
OutputProbeIt output_probe,
11831267
OutputMatchIt output_match,
11841268
AtomicCounter& atomic_counter) const
@@ -1229,15 +1313,16 @@ class open_addressing_ref_impl {
12291313
};
12301314

12311315
while (flushing_tile.any(idx < n)) {
1232-
bool active_flag = idx < n;
1316+
bool active_flag = idx < n and pred(*(stencil + idx));
12331317
auto const active_flushing_tile =
12341318
cg::binary_partition<flushing_tile_size>(flushing_tile, active_flag);
12351319

12361320
if (active_flag) {
12371321
// perform probing
12381322
// make sure the flushing_tile is converged at this point to get a coalesced load
12391323
auto const probe_key = *(input_probe + idx);
1240-
auto probing_iter = probing_scheme_.template make_iterator<bucket_size>(
1324+
1325+
auto probing_iter = probing_scheme_.template make_iterator<bucket_size>(
12411326
probing_tile, probe_key, storage_ref_.extent());
12421327
auto const init_idx = *probing_iter;
12431328

include/cuco/detail/static_map/static_map_ref.inl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1534,6 +1534,70 @@ class operator_impl<
15341534
ref_.impl_.template retrieve<BlockSize>(
15351535
block, input_probe_begin, input_probe_end, output_probe, output_match, atomic_counter);
15361536
}
1537+
1538+
/**
1539+
* @brief Retrieves all the slots corresponding to all keys in the range `[input_probe_begin,
1540+
* input_probe_end)` if `pred` of the corresponding stencil returns true.
1541+
*
1542+
* If key `k = *(first + i)` exists in the container and `pred( *(stencil + i) )` returns true,
1543+
* copies `k` to `output_probe` and associated slot content to `output_match`, respectively.
1544+
* The output order is unspecified.
1545+
*
1546+
* Behavior is undefined if the size of the output range exceeds the number of retrieved slots.
1547+
* Use `count()` to determine the size of the output range.
1548+
*
1549+
* @tparam BlockSize Size of the thread block this operation is executed in
1550+
* @tparam InputProbeIt Device accessible input iterator whose `value_type` is
1551+
* convertible to the container's `key_type`
1552+
* @tparam StencilIt Device accessible random access iterator whose value_type is
1553+
* convertible to Predicate's argument type
1554+
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool`
1555+
* and argument type is convertible from `std::iterator_traits<StencilIt>::value_type`
1556+
* @tparam OutputProbeIt Device accessible input iterator whose `value_type` is
1557+
* convertible to the container's `key_type`
1558+
* @tparam OutputMatchIt Device accessible input iterator whose `value_type` is
1559+
* convertible to the container's `value_type`
1560+
* @tparam AtomicCounter Atomic counter type that follows the same semantics as
1561+
* `cuda::atomic(_ref)`
1562+
*
1563+
* @param block Thread block this operation is executed in
1564+
* @param input_probe_begin Beginning of the input sequence of keys
1565+
* @param input_probe_end End of the input sequence of keys
1566+
* @param stencil Beginning of the stencil sequence
1567+
* @param pred Predicate to test on every element in the range `[stencil, stencil + n)`
1568+
* @param output_probe Beginning of the sequence of keys corresponding to matching elements in
1569+
* `output_match`
1570+
* @param output_match Beginning of the sequence of matching elements
1571+
* @param atomic_counter Counter that is used to determine the next free position in the output
1572+
* sequences
1573+
*/
1574+
template <int BlockSize,
1575+
class InputProbeIt,
1576+
class StencilIt,
1577+
class Predicate,
1578+
class OutputProbeIt,
1579+
class OutputMatchIt,
1580+
class AtomicCounter>
1581+
__device__ void retrieve_if(cooperative_groups::thread_block const& block,
1582+
InputProbeIt input_probe_begin,
1583+
InputProbeIt input_probe_end,
1584+
StencilIt stencil,
1585+
Predicate pred,
1586+
OutputProbeIt output_probe,
1587+
OutputMatchIt output_match,
1588+
AtomicCounter& atomic_counter) const
1589+
{
1590+
auto const& ref_ = static_cast<ref_type const&>(*this);
1591+
ref_.impl_.template retrieve_if<BlockSize>(block,
1592+
input_probe_begin,
1593+
input_probe_end,
1594+
stencil,
1595+
pred,
1596+
output_probe,
1597+
output_match,
1598+
atomic_counter);
1599+
}
15371600
};
1601+
15381602
} // namespace detail
15391603
} // namespace cuco

include/cuco/detail/static_multimap/static_multimap_ref.inl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,69 @@ class operator_impl<
845845
ref_.impl_.template retrieve<BlockSize>(
846846
block, input_probe_begin, input_probe_end, output_probe, output_match, atomic_counter);
847847
}
848+
849+
/**
850+
* @brief Retrieves all the slots corresponding to all keys in the range `[input_probe_begin,
851+
* input_probe_end)` if `pred` of the corresponding stencil returns true.
852+
*
853+
* If key `k = *(first + i)` exists in the container and `pred( *(stencil + i) )` returns true,
854+
* copies `k` to `output_probe` and associated slot content to `output_match`, respectively.
855+
* The output order is unspecified.
856+
*
857+
* Behavior is undefined if the size of the output range exceeds the number of retrieved slots.
858+
* Use `count()` to determine the size of the output range.
859+
*
860+
* @tparam BlockSize Size of the thread block this operation is executed in
861+
* @tparam InputProbeIt Device accessible input iterator whose `value_type` is
862+
* convertible to the container's `key_type`
863+
* @tparam StencilIt Device accessible random access iterator whose value_type is
864+
* convertible to Predicate's argument type
865+
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool`
866+
* and argument type is convertible from `std::iterator_traits<StencilIt>::value_type`
867+
* @tparam OutputProbeIt Device accessible input iterator whose `value_type` is
868+
* convertible to the container's `key_type`
869+
* @tparam OutputMatchIt Device accessible input iterator whose `value_type` is
870+
* convertible to the container's `value_type`
871+
* @tparam AtomicCounter Atomic counter type that follows the same semantics as
872+
* `cuda::atomic(_ref)`
873+
*
874+
* @param block Thread block this operation is executed in
875+
* @param input_probe_begin Beginning of the input sequence of keys
876+
* @param input_probe_end End of the input sequence of keys
877+
* @param stencil Beginning of the stencil sequence
878+
* @param pred Predicate to test on every element in the range `[stencil, stencil + n)`
879+
* @param output_probe Beginning of the sequence of keys corresponding to matching elements in
880+
* `output_match`
881+
* @param output_match Beginning of the sequence of matching elements
882+
* @param atomic_counter Counter that is used to determine the next free position in the output
883+
* sequences
884+
*/
885+
template <int BlockSize,
886+
class InputProbeIt,
887+
class StencilIt,
888+
class Predicate,
889+
class OutputProbeIt,
890+
class OutputMatchIt,
891+
class AtomicCounter>
892+
__device__ void retrieve_if(cooperative_groups::thread_block const& block,
893+
InputProbeIt input_probe_begin,
894+
InputProbeIt input_probe_end,
895+
StencilIt stencil,
896+
Predicate pred,
897+
OutputProbeIt output_probe,
898+
OutputMatchIt output_match,
899+
AtomicCounter& atomic_counter) const
900+
{
901+
auto const& ref_ = static_cast<ref_type const&>(*this);
902+
ref_.impl_.template retrieve_if<BlockSize>(block,
903+
input_probe_begin,
904+
input_probe_end,
905+
stencil,
906+
pred,
907+
output_probe,
908+
output_match,
909+
atomic_counter);
910+
}
848911
};
849912
} // namespace detail
850913
} // namespace cuco

include/cuco/detail/static_multiset/static_multiset_ref.inl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,69 @@ class operator_impl<
655655
ref_.impl_.template retrieve_outer<BlockSize>(
656656
block, input_probe_begin, input_probe_end, output_probe, output_match, atomic_counter);
657657
}
658+
659+
/**
660+
* @brief Retrieves all the slots corresponding to all keys in the range `[input_probe_begin,
661+
* input_probe_end)` if `pred` of the corresponding stencil returns true.
662+
*
663+
* If key `k = *(first + i)` exists in the container and `pred( *(stencil + i) )` returns true,
664+
* copies `k` to `output_probe` and associated slot content to `output_match`, respectively.
665+
* The output order is unspecified.
666+
*
667+
* Behavior is undefined if the size of the output range exceeds the number of retrieved slots.
668+
* Use `count()` to determine the size of the output range.
669+
*
670+
* @tparam BlockSize Size of the thread block this operation is executed in
671+
* @tparam InputProbeIt Device accessible input iterator whose `value_type` is
672+
* convertible to the container's `key_type`
673+
* @tparam StencilIt Device accessible random access iterator whose value_type is
674+
* convertible to Predicate's argument type
675+
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool`
676+
* and argument type is convertible from `std::iterator_traits<StencilIt>::value_type`
677+
* @tparam OutputProbeIt Device accessible input iterator whose `value_type` is
678+
* convertible to the container's `key_type`
679+
* @tparam OutputMatchIt Device accessible input iterator whose `value_type` is
680+
* convertible to the container's `value_type`
681+
* @tparam AtomicCounter Atomic counter type that follows the same semantics as
682+
* `cuda::atomic(_ref)`
683+
*
684+
* @param block Thread block this operation is executed in
685+
* @param input_probe_begin Beginning of the input sequence of keys
686+
* @param input_probe_end End of the input sequence of keys
687+
* @param stencil Beginning of the stencil sequence
688+
* @param pred Predicate to test on every element in the range `[stencil, stencil + n)`
689+
* @param output_probe Beginning of the sequence of keys corresponding to matching elements in
690+
* `output_match`
691+
* @param output_match Beginning of the sequence of matching elements
692+
* @param atomic_counter Counter that is used to determine the next free position in the output
693+
* sequences
694+
*/
695+
template <int BlockSize,
696+
class InputProbeIt,
697+
class StencilIt,
698+
class Predicate,
699+
class OutputProbeIt,
700+
class OutputMatchIt,
701+
class AtomicCounter>
702+
__device__ void retrieve_if(cooperative_groups::thread_block const& block,
703+
InputProbeIt input_probe_begin,
704+
InputProbeIt input_probe_end,
705+
StencilIt stencil,
706+
Predicate pred,
707+
OutputProbeIt output_probe,
708+
OutputMatchIt output_match,
709+
AtomicCounter& atomic_counter) const
710+
{
711+
auto const& ref_ = static_cast<ref_type const&>(*this);
712+
ref_.impl_.template retrieve_if<BlockSize>(block,
713+
input_probe_begin,
714+
input_probe_end,
715+
stencil,
716+
pred,
717+
output_probe,
718+
output_match,
719+
atomic_counter);
720+
}
658721
};
659722

660723
template <typename Key,

0 commit comments

Comments
 (0)