Skip to content

Commit 94ca655

Browse files
committed
Merge remote-tracking branch 'upstream/dev' into fix-int-namespace
2 parents 5a551f8 + 0d3193f commit 94ca655

File tree

11 files changed

+1277
-10
lines changed

11 files changed

+1277
-10
lines changed

include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh

Lines changed: 92 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <cuda/std/iterator>
3030
#include <cuda/std/type_traits>
3131
#include <thrust/execution_policy.h>
32+
#include <thrust/iterator/constant_iterator.h>
3233
#include <thrust/logical.h>
3334
#include <thrust/reduce.h>
3435
#if defined(CUCO_HAS_CUDA_BARRIER)
@@ -1084,8 +1085,16 @@ class open_addressing_ref_impl {
10841085
{
10851086
auto constexpr is_outer = false;
10861087
auto const n = cuco::detail::distance(input_probe_begin, input_probe_end); // TODO include
1087-
this->retrieve_impl<is_outer, BlockSize>(
1088-
block, input_probe_begin, n, output_probe, output_match, atomic_counter);
1088+
auto const always_true_stencil = thrust::constant_iterator<bool>(true);
1089+
auto const identity_predicate = cuda::std::identity{};
1090+
this->retrieve_impl<is_outer, BlockSize>(block,
1091+
input_probe_begin,
1092+
n,
1093+
always_true_stencil,
1094+
identity_predicate,
1095+
output_probe,
1096+
output_match,
1097+
atomic_counter);
10891098
}
10901099

10911100
/**
@@ -1133,8 +1142,73 @@ class open_addressing_ref_impl {
11331142
{
11341143
auto constexpr is_outer = true;
11351144
auto const n = cuco::detail::distance(input_probe_begin, input_probe_end); // TODO include
1145+
auto const always_true_stencil = thrust::constant_iterator<bool>(true);
1146+
auto const identity_predicate = cuda::std::identity{};
1147+
this->retrieve_impl<is_outer, BlockSize>(block,
1148+
input_probe_begin,
1149+
n,
1150+
always_true_stencil,
1151+
identity_predicate,
1152+
output_probe,
1153+
output_match,
1154+
atomic_counter);
1155+
}
1156+
1157+
/**
1158+
* @brief Retrieves all the slots corresponding to all keys in the range `[input_probe_begin,
1159+
* input_probe_end)` if `pred` of the corresponding stencil returns true.
1160+
*
1161+
* If key `k = *(first + i)` exists in the container and `pred( *(stencil + i) )` returns true,
1162+
* copies `k` to `output_probe` and associated slot contents to `output_match`,
1163+
* respectively. The output order is unspecified.
1164+
*
1165+
* Behavior is undefined if the size of the output range exceeds the number of retrieved slots.
1166+
* Use `count()` to determine the size of the output range.
1167+
*
1168+
* @tparam BlockSize Size of the thread block this operation is executed in
1169+
* @tparam InputProbeIt Device accessible input iterator
1170+
* @tparam StencilIt Device accessible random access iterator whose value_type is
1171+
* convertible to Predicate's argument type
1172+
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool`
1173+
* and argument type is convertible from `std::iterator_traits<StencilIt>::value_type`
1174+
* @tparam OutputProbeIt Device accessible input iterator whose `value_type` is
1175+
* convertible to the `InputProbeIt`'s `value_type`
1176+
* @tparam OutputMatchIt Device accessible input iterator whose `value_type` is
1177+
* convertible to the container's `value_type`
1178+
* @tparam AtomicCounter Integral atomic counter type that follows the same semantics as
1179+
* `cuda::(std::)atomic(_ref)`
1180+
*
1181+
* @param block Thread block this operation is executed in
1182+
* @param input_probe_begin Beginning of the input sequence of keys
1183+
* @param input_probe_end End of the input sequence of keys
1184+
* @param stencil Beginning of the stencil sequence
1185+
* @param pred Predicate to test on every element in the range `[stencil, stencil + n)`
1186+
* @param output_probe Beginning of the sequence of keys corresponding to matching elements in
1187+
* `output_match`
1188+
* @param output_match Beginning of the sequence of matching elements
1189+
* @param atomic_counter Atomic object of integral type that is used to count the
1190+
* number of output elements
1191+
*/
1192+
template <int BlockSize,
1193+
class InputProbeIt,
1194+
class StencilIt,
1195+
class Predicate,
1196+
class OutputProbeIt,
1197+
class OutputMatchIt,
1198+
class AtomicCounter>
1199+
__device__ void retrieve_if(cooperative_groups::thread_block const& block,
1200+
InputProbeIt input_probe_begin,
1201+
InputProbeIt input_probe_end,
1202+
StencilIt stencil,
1203+
Predicate pred,
1204+
OutputProbeIt output_probe,
1205+
OutputMatchIt output_match,
1206+
AtomicCounter& atomic_counter) const
1207+
{
1208+
auto constexpr is_outer = false;
1209+
auto const n = cuco::detail::distance(input_probe_begin, input_probe_end);
11361210
this->retrieve_impl<is_outer, BlockSize>(
1137-
block, input_probe_begin, n, output_probe, output_match, atomic_counter);
1211+
block, input_probe_begin, n, stencil, pred, output_probe, output_match, atomic_counter);
11381212
}
11391213

11401214
/**
@@ -1153,6 +1227,10 @@ class open_addressing_ref_impl {
11531227
* @tparam IsOuter Flag indicating if an inner or outer retrieve operation should be performed
11541228
* @tparam BlockSize Size of the thread block this operation is executed in
11551229
* @tparam InputProbeIt Device accessible input iterator
1230+
* @tparam StencilIt Device accessible random access iterator whose value_type is
1231+
* convertible to Predicate's argument type
1232+
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool`
1233+
* and argument type is convertible from `std::iterator_traits<StencilIt>::value_type`
11561234
* @tparam OutputProbeIt Device accessible input iterator whose `value_type` is
11571235
* convertible to the `InputProbeIt`'s `value_type`
11581236
* @tparam OutputMatchIt Device accessible input iterator whose `value_type` is
@@ -1161,8 +1239,10 @@ class open_addressing_ref_impl {
11611239
* `cuda::(std::)atomic(_ref)`
11621240
*
11631241
* @param block Thread block this operation is executed in
1164-
* @param input_probe_begin Beginning of the input sequence of keys
1165-
* @param input_probe_end End of the input sequence of keys
1242+
* @param input_probe Beginning of the input sequence of keys
1243+
* @param n Number of input keys
1244+
* @param stencil Beginning of the stencil sequence
1245+
* @param pred Predicate to test on every element in the range `[stencil, stencil + n)`
11661246
* @param output_probe Beginning of the sequence of keys corresponding to matching elements in
11671247
* `output_match`
11681248
* @param output_match Beginning of the sequence of matching elements
@@ -1172,12 +1252,16 @@ class open_addressing_ref_impl {
11721252
template <bool IsOuter,
11731253
int BlockSize,
11741254
class InputProbeIt,
1255+
class StencilIt,
1256+
class Predicate,
11751257
class OutputProbeIt,
11761258
class OutputMatchIt,
11771259
class AtomicCounter>
11781260
__device__ void retrieve_impl(cooperative_groups::thread_block const& block,
11791261
InputProbeIt input_probe,
11801262
cuco::detail::index_type n,
1263+
StencilIt stencil,
1264+
Predicate pred,
11811265
OutputProbeIt output_probe,
11821266
OutputMatchIt output_match,
11831267
AtomicCounter& atomic_counter) const
@@ -1228,15 +1312,16 @@ class open_addressing_ref_impl {
12281312
};
12291313

12301314
while (flushing_tile.any(idx < n)) {
1231-
bool active_flag = idx < n;
1315+
bool active_flag = idx < n and pred(*(stencil + idx));
12321316
auto const active_flushing_tile =
12331317
cg::binary_partition<flushing_tile_size>(flushing_tile, active_flag);
12341318

12351319
if (active_flag) {
12361320
// perform probing
12371321
// make sure the flushing_tile is converged at this point to get a coalesced load
12381322
auto const probe_key = *(input_probe + idx);
1239-
auto probing_iter = probing_scheme_.template make_iterator<bucket_size>(
1323+
1324+
auto probing_iter = probing_scheme_.template make_iterator<bucket_size>(
12401325
probing_tile, probe_key, storage_ref_.extent());
12411326
auto const init_idx = *probing_iter;
12421327

include/cuco/detail/static_map/static_map_ref.inl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1534,6 +1534,70 @@ class operator_impl<
15341534
ref_.impl_.template retrieve<BlockSize>(
15351535
block, input_probe_begin, input_probe_end, output_probe, output_match, atomic_counter);
15361536
}
1537+
1538+
/**
1539+
* @brief Retrieves all the slots corresponding to all keys in the range `[input_probe_begin,
1540+
* input_probe_end)` if `pred` of the corresponding stencil returns true.
1541+
*
1542+
* If key `k = *(first + i)` exists in the container and `pred( *(stencil + i) )` returns true,
1543+
* copies `k` to `output_probe` and associated slot content to `output_match`, respectively.
1544+
* The output order is unspecified.
1545+
*
1546+
* Behavior is undefined if the size of the output range exceeds the number of retrieved slots.
1547+
* Use `count()` to determine the size of the output range.
1548+
*
1549+
* @tparam BlockSize Size of the thread block this operation is executed in
1550+
* @tparam InputProbeIt Device accessible input iterator whose `value_type` is
1551+
* convertible to the container's `key_type`
1552+
* @tparam StencilIt Device accessible random access iterator whose value_type is
1553+
* convertible to Predicate's argument type
1554+
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool`
1555+
* and argument type is convertible from `std::iterator_traits<StencilIt>::value_type`
1556+
* @tparam OutputProbeIt Device accessible input iterator whose `value_type` is
1557+
* convertible to the container's `key_type`
1558+
* @tparam OutputMatchIt Device accessible input iterator whose `value_type` is
1559+
* convertible to the container's `value_type`
1560+
* @tparam AtomicCounter Atomic counter type that follows the same semantics as
1561+
* `cuda::atomic(_ref)`
1562+
*
1563+
* @param block Thread block this operation is executed in
1564+
* @param input_probe_begin Beginning of the input sequence of keys
1565+
* @param input_probe_end End of the input sequence of keys
1566+
* @param stencil Beginning of the stencil sequence
1567+
* @param pred Predicate to test on every element in the range `[stencil, stencil + n)`
1568+
* @param output_probe Beginning of the sequence of keys corresponding to matching elements in
1569+
* `output_match`
1570+
* @param output_match Beginning of the sequence of matching elements
1571+
* @param atomic_counter Counter that is used to determine the next free position in the output
1572+
* sequences
1573+
*/
1574+
template <int BlockSize,
1575+
class InputProbeIt,
1576+
class StencilIt,
1577+
class Predicate,
1578+
class OutputProbeIt,
1579+
class OutputMatchIt,
1580+
class AtomicCounter>
1581+
__device__ void retrieve_if(cooperative_groups::thread_block const& block,
1582+
InputProbeIt input_probe_begin,
1583+
InputProbeIt input_probe_end,
1584+
StencilIt stencil,
1585+
Predicate pred,
1586+
OutputProbeIt output_probe,
1587+
OutputMatchIt output_match,
1588+
AtomicCounter& atomic_counter) const
1589+
{
1590+
auto const& ref_ = static_cast<ref_type const&>(*this);
1591+
ref_.impl_.template retrieve_if<BlockSize>(block,
1592+
input_probe_begin,
1593+
input_probe_end,
1594+
stencil,
1595+
pred,
1596+
output_probe,
1597+
output_match,
1598+
atomic_counter);
1599+
}
15371600
};
1601+
15381602
} // namespace detail
15391603
} // namespace cuco

include/cuco/detail/static_multimap/static_multimap_ref.inl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,69 @@ class operator_impl<
845845
ref_.impl_.template retrieve<BlockSize>(
846846
block, input_probe_begin, input_probe_end, output_probe, output_match, atomic_counter);
847847
}
848+
849+
/**
850+
* @brief Retrieves all the slots corresponding to all keys in the range `[input_probe_begin,
851+
* input_probe_end)` if `pred` of the corresponding stencil returns true.
852+
*
853+
* If key `k = *(first + i)` exists in the container and `pred( *(stencil + i) )` returns true,
854+
* copies `k` to `output_probe` and associated slot content to `output_match`, respectively.
855+
* The output order is unspecified.
856+
*
857+
* Behavior is undefined if the size of the output range exceeds the number of retrieved slots.
858+
* Use `count()` to determine the size of the output range.
859+
*
860+
* @tparam BlockSize Size of the thread block this operation is executed in
861+
* @tparam InputProbeIt Device accessible input iterator whose `value_type` is
862+
* convertible to the container's `key_type`
863+
* @tparam StencilIt Device accessible random access iterator whose value_type is
864+
* convertible to Predicate's argument type
865+
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool`
866+
* and argument type is convertible from `std::iterator_traits<StencilIt>::value_type`
867+
* @tparam OutputProbeIt Device accessible input iterator whose `value_type` is
868+
* convertible to the container's `key_type`
869+
* @tparam OutputMatchIt Device accessible input iterator whose `value_type` is
870+
* convertible to the container's `value_type`
871+
* @tparam AtomicCounter Atomic counter type that follows the same semantics as
872+
* `cuda::atomic(_ref)`
873+
*
874+
* @param block Thread block this operation is executed in
875+
* @param input_probe_begin Beginning of the input sequence of keys
876+
* @param input_probe_end End of the input sequence of keys
877+
* @param stencil Beginning of the stencil sequence
878+
* @param pred Predicate to test on every element in the range `[stencil, stencil + n)`
879+
* @param output_probe Beginning of the sequence of keys corresponding to matching elements in
880+
* `output_match`
881+
* @param output_match Beginning of the sequence of matching elements
882+
* @param atomic_counter Counter that is used to determine the next free position in the output
883+
* sequences
884+
*/
885+
template <int BlockSize,
886+
class InputProbeIt,
887+
class StencilIt,
888+
class Predicate,
889+
class OutputProbeIt,
890+
class OutputMatchIt,
891+
class AtomicCounter>
892+
__device__ void retrieve_if(cooperative_groups::thread_block const& block,
893+
InputProbeIt input_probe_begin,
894+
InputProbeIt input_probe_end,
895+
StencilIt stencil,
896+
Predicate pred,
897+
OutputProbeIt output_probe,
898+
OutputMatchIt output_match,
899+
AtomicCounter& atomic_counter) const
900+
{
901+
auto const& ref_ = static_cast<ref_type const&>(*this);
902+
ref_.impl_.template retrieve_if<BlockSize>(block,
903+
input_probe_begin,
904+
input_probe_end,
905+
stencil,
906+
pred,
907+
output_probe,
908+
output_match,
909+
atomic_counter);
910+
}
848911
};
849912
} // namespace detail
850913
} // namespace cuco

include/cuco/detail/static_multiset/static_multiset_ref.inl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,69 @@ class operator_impl<
655655
ref_.impl_.template retrieve_outer<BlockSize>(
656656
block, input_probe_begin, input_probe_end, output_probe, output_match, atomic_counter);
657657
}
658+
659+
/**
660+
* @brief Retrieves all the slots corresponding to all keys in the range `[input_probe_begin,
661+
* input_probe_end)` if `pred` of the corresponding stencil returns true.
662+
*
663+
* If key `k = *(first + i)` exists in the container and `pred( *(stencil + i) )` returns true,
664+
* copies `k` to `output_probe` and associated slot content to `output_match`, respectively.
665+
* The output order is unspecified.
666+
*
667+
* Behavior is undefined if the size of the output range exceeds the number of retrieved slots.
668+
* Use `count()` to determine the size of the output range.
669+
*
670+
* @tparam BlockSize Size of the thread block this operation is executed in
671+
* @tparam InputProbeIt Device accessible input iterator whose `value_type` is
672+
* convertible to the container's `key_type`
673+
* @tparam StencilIt Device accessible random access iterator whose value_type is
674+
* convertible to Predicate's argument type
675+
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool`
676+
* and argument type is convertible from `std::iterator_traits<StencilIt>::value_type`
677+
* @tparam OutputProbeIt Device accessible input iterator whose `value_type` is
678+
* convertible to the container's `key_type`
679+
* @tparam OutputMatchIt Device accessible input iterator whose `value_type` is
680+
* convertible to the container's `value_type`
681+
* @tparam AtomicCounter Atomic counter type that follows the same semantics as
682+
* `cuda::atomic(_ref)`
683+
*
684+
* @param block Thread block this operation is executed in
685+
* @param input_probe_begin Beginning of the input sequence of keys
686+
* @param input_probe_end End of the input sequence of keys
687+
* @param stencil Beginning of the stencil sequence
688+
* @param pred Predicate to test on every element in the range `[stencil, stencil + n)`
689+
* @param output_probe Beginning of the sequence of keys corresponding to matching elements in
690+
* `output_match`
691+
* @param output_match Beginning of the sequence of matching elements
692+
* @param atomic_counter Counter that is used to determine the next free position in the output
693+
* sequences
694+
*/
695+
template <int BlockSize,
696+
class InputProbeIt,
697+
class StencilIt,
698+
class Predicate,
699+
class OutputProbeIt,
700+
class OutputMatchIt,
701+
class AtomicCounter>
702+
__device__ void retrieve_if(cooperative_groups::thread_block const& block,
703+
InputProbeIt input_probe_begin,
704+
InputProbeIt input_probe_end,
705+
StencilIt stencil,
706+
Predicate pred,
707+
OutputProbeIt output_probe,
708+
OutputMatchIt output_match,
709+
AtomicCounter& atomic_counter) const
710+
{
711+
auto const& ref_ = static_cast<ref_type const&>(*this);
712+
ref_.impl_.template retrieve_if<BlockSize>(block,
713+
input_probe_begin,
714+
input_probe_end,
715+
stencil,
716+
pred,
717+
output_probe,
718+
output_match,
719+
atomic_counter);
720+
}
658721
};
659722

660723
template <typename Key,

0 commit comments

Comments
 (0)