Skip to content

Commit 111bbaf

Browse files
committed
Use make factory to create probing iterators
1 parent edeafaa commit 111bbaf

File tree

5 files changed

+55
-47
lines changed

5 files changed

+55
-47
lines changed

include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ class open_addressing_ref_impl {
379379
auto const val = this->heterogeneous_value(value);
380380
auto const key = this->extract_key(val);
381381

382-
auto probing_iter = probing_scheme_.operator()<bucket_size>(key, storage_ref_.extent());
382+
auto probing_iter = probing_scheme_.make_iterator<bucket_size>(key, storage_ref_.extent());
383383
auto const init_idx = *probing_iter;
384384

385385
while (true) {
@@ -428,9 +428,10 @@ class open_addressing_ref_impl {
428428
__device__ bool insert(cooperative_groups::thread_block_tile<cg_size> const& group,
429429
Value const& value) noexcept
430430
{
431-
auto const val = this->heterogeneous_value(value);
432-
auto const key = this->extract_key(val);
433-
auto probing_iter = probing_scheme_.operator()<bucket_size>(group, key, storage_ref_.extent());
431+
auto const val = this->heterogeneous_value(value);
432+
auto const key = this->extract_key(val);
433+
auto probing_iter =
434+
probing_scheme_.make_iterator<bucket_size>(group, key, storage_ref_.extent());
434435
auto const init_idx = *probing_iter;
435436

436437
while (true) {
@@ -523,7 +524,7 @@ class open_addressing_ref_impl {
523524

524525
auto const val = this->heterogeneous_value(value);
525526
auto const key = this->extract_key(val);
526-
auto probing_iter = probing_scheme_.operator()<bucket_size>(key, storage_ref_.extent());
527+
auto probing_iter = probing_scheme_.make_iterator<bucket_size>(key, storage_ref_.extent());
527528
auto const init_idx = *probing_iter;
528529

529530
while (true) {
@@ -594,9 +595,10 @@ class open_addressing_ref_impl {
594595
"insert_and_find is not supported for pair types larger than 8 bytes on pre-Volta GPUs.");
595596
#endif
596597

597-
auto const val = this->heterogeneous_value(value);
598-
auto const key = this->extract_key(val);
599-
auto probing_iter = probing_scheme_.operator()<bucket_size>(group, key, storage_ref_.extent());
598+
auto const val = this->heterogeneous_value(value);
599+
auto const key = this->extract_key(val);
600+
auto probing_iter =
601+
probing_scheme_.make_iterator<bucket_size>(group, key, storage_ref_.extent());
600602
auto const init_idx = *probing_iter;
601603

602604
while (true) {
@@ -683,7 +685,7 @@ class open_addressing_ref_impl {
683685
{
684686
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
685687

686-
auto probing_iter = probing_scheme_.operator()<bucket_size>(key, storage_ref_.extent());
688+
auto probing_iter = probing_scheme_.make_iterator<bucket_size>(key, storage_ref_.extent());
687689
auto const init_idx = *probing_iter;
688690

689691
while (true) {
@@ -726,7 +728,8 @@ class open_addressing_ref_impl {
726728
__device__ bool erase(cooperative_groups::thread_block_tile<cg_size> const& group,
727729
ProbeKey const& key) noexcept
728730
{
729-
auto probing_iter = probing_scheme_.operator()<bucket_size>(group, key, storage_ref_.extent());
731+
auto probing_iter =
732+
probing_scheme_.make_iterator<bucket_size>(group, key, storage_ref_.extent());
730733
auto const init_idx = *probing_iter;
731734

732735
while (true) {
@@ -783,7 +786,7 @@ class open_addressing_ref_impl {
783786
[[nodiscard]] __device__ bool contains(ProbeKey const& key) const noexcept
784787
{
785788
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
786-
auto probing_iter = probing_scheme_.operator()<bucket_size>(key, storage_ref_.extent());
789+
auto probing_iter = probing_scheme_.make_iterator<bucket_size>(key, storage_ref_.extent());
787790
auto const init_idx = *probing_iter;
788791

789792
while (true) {
@@ -820,7 +823,8 @@ class open_addressing_ref_impl {
820823
[[nodiscard]] __device__ bool contains(
821824
cooperative_groups::thread_block_tile<cg_size> const& group, ProbeKey const& key) const noexcept
822825
{
823-
auto probing_iter = probing_scheme_.operator()<bucket_size>(group, key, storage_ref_.extent());
826+
auto probing_iter =
827+
probing_scheme_.make_iterator<bucket_size>(group, key, storage_ref_.extent());
824828
auto const init_idx = *probing_iter;
825829

826830
while (true) {
@@ -859,7 +863,7 @@ class open_addressing_ref_impl {
859863
[[nodiscard]] __device__ iterator find(ProbeKey const& key) const noexcept
860864
{
861865
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
862-
auto probing_iter = probing_scheme_.operator()<bucket_size>(key, storage_ref_.extent());
866+
auto probing_iter = probing_scheme_.make_iterator<bucket_size>(key, storage_ref_.extent());
863867
auto const init_idx = *probing_iter;
864868

865869
while (true) {
@@ -900,7 +904,8 @@ class open_addressing_ref_impl {
900904
[[nodiscard]] __device__ iterator find(
901905
cooperative_groups::thread_block_tile<cg_size> const& group, ProbeKey const& key) const noexcept
902906
{
903-
auto probing_iter = probing_scheme_.operator()<bucket_size>(group, key, storage_ref_.extent());
907+
auto probing_iter =
908+
probing_scheme_.make_iterator<bucket_size>(group, key, storage_ref_.extent());
904909
auto const init_idx = *probing_iter;
905910

906911
while (true) {
@@ -949,7 +954,7 @@ class open_addressing_ref_impl {
949954
if constexpr (not allows_duplicates) {
950955
return static_cast<size_type>(this->contains(key));
951956
} else {
952-
auto probing_iter = probing_scheme_.operator()<bucket_size>(key, storage_ref_.extent());
957+
auto probing_iter = probing_scheme_.make_iterator<bucket_size>(key, storage_ref_.extent());
953958
auto const init_idx = *probing_iter;
954959
size_type count = 0;
955960

@@ -993,7 +998,8 @@ class open_addressing_ref_impl {
993998
[[nodiscard]] __device__ size_type count(
994999
cooperative_groups::thread_block_tile<cg_size> const& group, ProbeKey const& key) const noexcept
9951000
{
996-
auto probing_iter = probing_scheme_.operator()<bucket_size>(group, key, storage_ref_.extent());
1001+
auto probing_iter =
1002+
probing_scheme_.make_iterator<bucket_size>(group, key, storage_ref_.extent());
9971003
auto const init_idx = *probing_iter;
9981004
size_type count = 0;
9991005

@@ -1216,8 +1222,8 @@ class open_addressing_ref_impl {
12161222
// perform probing
12171223
// make sure the flushing_tile is converged at this point to get a coalesced load
12181224
auto const probe_key = *(input_probe + idx);
1219-
auto probing_iter =
1220-
probing_scheme_.operator()<bucket_size>(probing_tile, probe_key, storage_ref_.extent());
1225+
auto probing_iter = probing_scheme_.make_iterator<bucket_size>(
1226+
probing_tile, probe_key, storage_ref_.extent());
12211227
auto const init_idx = *probing_iter;
12221228

12231229
bool running = true;
@@ -1348,7 +1354,7 @@ class open_addressing_ref_impl {
13481354
__device__ void for_each(ProbeKey const& key, CallbackOp&& callback_op) const noexcept
13491355
{
13501356
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
1351-
auto probing_iter = probing_scheme_.operator()<bucket_size>(key, storage_ref_.extent());
1357+
auto probing_iter = probing_scheme_.make_iterator<bucket_size>(key, storage_ref_.extent());
13521358
auto const init_idx = *probing_iter;
13531359

13541360
while (true) {
@@ -1397,7 +1403,8 @@ class open_addressing_ref_impl {
13971403
ProbeKey const& key,
13981404
CallbackOp&& callback_op) const noexcept
13991405
{
1400-
auto probing_iter = probing_scheme_.operator()<bucket_size>(group, key, storage_ref_.extent());
1406+
auto probing_iter =
1407+
probing_scheme_.make_iterator<bucket_size>(group, key, storage_ref_.extent());
14011408
auto const init_idx = *probing_iter;
14021409
bool empty = false;
14031410

@@ -1461,7 +1468,8 @@ class open_addressing_ref_impl {
14611468
CallbackOp&& callback_op,
14621469
SyncOp&& sync_op) const noexcept
14631470
{
1464-
auto probing_iter = probing_scheme_.operator()<bucket_size>(group, key, storage_ref_.extent());
1471+
auto probing_iter =
1472+
probing_scheme_.make_iterator<bucket_size>(group, key, storage_ref_.extent());
14651473
auto const init_idx = *probing_iter;
14661474
bool empty = false;
14671475

include/cuco/detail/probing_scheme/probing_scheme_impl.inl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ __host__ __device__ constexpr auto linear_probing<CGSize, Hash>::rebind_hash_fun
106106

107107
template <int32_t CGSize, typename Hash>
108108
template <int32_t BucketSize, typename ProbeKey, typename Extent>
109-
__host__ __device__ constexpr auto linear_probing<CGSize, Hash>::operator()(
109+
__host__ __device__ constexpr auto linear_probing<CGSize, Hash>::make_iterator(
110110
ProbeKey const& probe_key, Extent upper_bound) const noexcept
111111
{
112112
using size_type = typename Extent::value_type;
@@ -117,7 +117,7 @@ __host__ __device__ constexpr auto linear_probing<CGSize, Hash>::operator()(
117117

118118
template <int32_t CGSize, typename Hash>
119119
template <int32_t BucketSize, typename ProbeKey, typename Extent>
120-
__host__ __device__ constexpr auto linear_probing<CGSize, Hash>::operator()(
120+
__host__ __device__ constexpr auto linear_probing<CGSize, Hash>::make_iterator(
121121
cooperative_groups::thread_block_tile<cg_size> const& g,
122122
ProbeKey const& probe_key,
123123
Extent upper_bound) const noexcept
@@ -167,7 +167,7 @@ __host__ __device__ constexpr auto double_hashing<CGSize, Hash1, Hash2>::rebind_
167167

168168
template <int32_t CGSize, typename Hash1, typename Hash2>
169169
template <int32_t BucketSize, typename ProbeKey, typename Extent>
170-
__host__ __device__ constexpr auto double_hashing<CGSize, Hash1, Hash2>::operator()(
170+
__host__ __device__ constexpr auto double_hashing<CGSize, Hash1, Hash2>::make_iterator(
171171
ProbeKey const& probe_key, Extent upper_bound) const noexcept
172172
{
173173
using size_type = typename Extent::value_type;
@@ -183,7 +183,7 @@ __host__ __device__ constexpr auto double_hashing<CGSize, Hash1, Hash2>::operato
183183

184184
template <int32_t CGSize, typename Hash1, typename Hash2>
185185
template <int32_t BucketSize, typename ProbeKey, typename Extent>
186-
__host__ __device__ constexpr auto double_hashing<CGSize, Hash1, Hash2>::operator()(
186+
__host__ __device__ constexpr auto double_hashing<CGSize, Hash1, Hash2>::make_iterator(
187187
cooperative_groups::thread_block_tile<cg_size> const& g,
188188
ProbeKey const& probe_key,
189189
Extent upper_bound) const noexcept

include/cuco/detail/static_map/static_map_ref.inl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -517,8 +517,8 @@ class operator_impl<
517517
auto const key = ref_.impl_.extract_key(val);
518518
auto const probing_scheme = ref_.impl_.probing_scheme();
519519
auto storage_ref = ref_.impl_.storage_ref();
520-
auto probing_iter = probing_scheme.operator()<bucket_size>(key, storage_ref.extent());
521-
auto const init_idx = *probing_iter;
520+
auto probing_iter = probing_scheme.make_iterator<bucket_size>(key, storage_ref.extent());
521+
auto const init_idx = *probing_iter;
522522

523523
while (true) {
524524
auto const bucket_slots = storage_ref[*probing_iter];
@@ -565,7 +565,7 @@ class operator_impl<
565565
auto const key = ref_.impl_.extract_key(val);
566566
auto const probing_scheme = ref_.impl_.probing_scheme();
567567
auto storage_ref = ref_.impl_.storage_ref();
568-
auto probing_iter = probing_scheme.operator()<bucket_size>(group, key, storage_ref.extent());
568+
auto probing_iter = probing_scheme.make_iterator<bucket_size>(group, key, storage_ref.extent());
569569
auto const init_idx = *probing_iter;
570570

571571
while (true) {
@@ -883,9 +883,9 @@ class operator_impl<
883883
auto const key = ref_.impl_.extract_key(val);
884884
auto const probing_scheme = ref_.impl_.probing_scheme();
885885
auto storage_ref = ref_.impl_.storage_ref();
886-
auto probing_iter = probing_scheme.operator()<bucket_size>(key, storage_ref.extent());
887-
auto const init_idx = *probing_iter;
888-
auto const empty_value = ref_.empty_value_sentinel();
886+
auto probing_iter = probing_scheme.make_iterator<bucket_size>(key, storage_ref.extent());
887+
auto const init_idx = *probing_iter;
888+
auto const empty_value = ref_.empty_value_sentinel();
889889

890890
// wait for payload only when init != sentinel and insert strategy is not `packed_cas`
891891
auto constexpr wait_for_payload = (not UseDirectApply) and (sizeof(value_type) > 8);
@@ -959,8 +959,8 @@ class operator_impl<
959959
auto const key = ref_.impl_.extract_key(val);
960960
auto const probing_scheme = ref_.impl_.probing_scheme();
961961
auto storage_ref = ref_.impl_.storage_ref();
962-
auto probing_iter = probing_scheme.operator()<bucket_size>(group, key, storage_ref.extent());
963-
auto const init_idx = *probing_iter;
962+
auto probing_iter = probing_scheme.make_iterator<bucket_size>(group, key, storage_ref.extent());
963+
auto const init_idx = *probing_iter;
964964
auto const empty_value = ref_.empty_value_sentinel();
965965

966966
// wait for payload only when init != sentinel and insert strategy is not `packed_cas`

include/cuco/probing_scheme.cuh

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class linear_probing : private detail::probing_scheme_base<CGSize> {
6666
NewHash const& hash) const noexcept;
6767

6868
/**
69-
* @brief Operator to return a probing iterator
69+
* @brief Returns a probing iterator
7070
*
7171
* @tparam BucketSize Size of the bucket
7272
* @tparam ProbeKey Type of probing key
@@ -77,11 +77,11 @@ class linear_probing : private detail::probing_scheme_base<CGSize> {
7777
* @return An iterator whose value_type is convertible to slot index type
7878
*/
7979
template <int32_t BucketSize, typename ProbeKey, typename Extent>
80-
__host__ __device__ constexpr auto operator()(ProbeKey const& probe_key,
81-
Extent upper_bound) const noexcept;
80+
__host__ __device__ constexpr auto make_iterator(ProbeKey const& probe_key,
81+
Extent upper_bound) const noexcept;
8282

8383
/**
84-
* @brief Operator to return a CG-based probing iterator
84+
* @brief Returns a CG-based probing iterator
8585
*
8686
* @tparam BucketSize Size of the bucket
8787
* @tparam ProbeKey Type of probing key
@@ -93,7 +93,7 @@ class linear_probing : private detail::probing_scheme_base<CGSize> {
9393
* @return An iterator whose value_type is convertible to slot index type
9494
*/
9595
template <int32_t BucketSize, typename ProbeKey, typename Extent>
96-
__host__ __device__ constexpr auto operator()(
96+
__host__ __device__ constexpr auto make_iterator(
9797
cooperative_groups::thread_block_tile<cg_size> const& g,
9898
ProbeKey const& probe_key,
9999
Extent upper_bound) const noexcept;
@@ -163,7 +163,7 @@ class double_hashing : private detail::probing_scheme_base<CGSize> {
163163
[[nodiscard]] __host__ __device__ constexpr auto rebind_hash_function(NewHash const& hash) const;
164164

165165
/**
166-
* @brief Operator to return a probing iterator
166+
* @brief Returns a probing iterator
167167
*
168168
* @tparam BucketSize Size of the bucket
169169
* @tparam ProbeKey Type of probing key
@@ -174,11 +174,11 @@ class double_hashing : private detail::probing_scheme_base<CGSize> {
174174
* @return An iterator whose value_type is convertible to slot index type
175175
*/
176176
template <int32_t BucketSize, typename ProbeKey, typename Extent>
177-
__host__ __device__ constexpr auto operator()(ProbeKey const& probe_key,
178-
Extent upper_bound) const noexcept;
177+
__host__ __device__ constexpr auto make_iterator(ProbeKey const& probe_key,
178+
Extent upper_bound) const noexcept;
179179

180180
/**
181-
* @brief Operator to return a CG-based probing iterator
181+
* @brief Returns a CG-based probing iterator
182182
*
183183
* @tparam BucketSize Size of the bucket
184184
* @tparam ProbeKey Type of probing key
@@ -190,7 +190,7 @@ class double_hashing : private detail::probing_scheme_base<CGSize> {
190190
* @return An iterator whose value_type is convertible to slot index type
191191
*/
192192
template <int32_t BucketSize, typename ProbeKey, typename Extent>
193-
__host__ __device__ constexpr auto operator()(
193+
__host__ __device__ constexpr auto make_iterator(
194194
cooperative_groups::thread_block_tile<cg_size> const& g,
195195
ProbeKey const& probe_key,
196196
Extent upper_bound) const noexcept;

tests/utility/probing_scheme_test.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@ __global__ void generate_scalar_probing_sequence(Key key,
4444
auto probing_scheme = ProbingScheme{};
4545

4646
if (tid == 0) {
47-
auto iter = probing_scheme.operator()<BucketSize>(key, upper_bound);
47+
auto iter = probing_scheme.make_iterator<BucketSize>(key, upper_bound);
4848

4949
for (size_t i = 0; i < seq_length; ++i) {
5050
out_seq[i] = *iter;
51-
iter++;
51+
++iter;
5252
}
5353
}
5454
}
@@ -68,11 +68,11 @@ __global__ void generate_cg_probing_sequence(Key key,
6868
auto const tile =
6969
cooperative_groups::tiled_partition<cg_size>(cooperative_groups::this_thread_block());
7070

71-
auto iter = probing_scheme.operator()<BucketSize>(tile, key, upper_bound);
71+
auto iter = probing_scheme.make_iterator<BucketSize>(tile, key, upper_bound);
7272

7373
for (size_t i = tile.thread_rank(); i < seq_length; ++i) {
7474
out_seq[i] = *iter;
75-
iter++;
75+
++iter;
7676
}
7777
}
7878
}

0 commit comments

Comments
 (0)