Skip to content

Commit 6db4762

Browse files
vitor1001pre-commit-ci[bot]PointKernel
authored
Several fixes to make cuCo compile with LLVM: (#733)
- add `template` keyword to help parsing. - do not mark as `constexpr` functions that cannot be evaluated at compile-time. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Yunsong Wang <[email protected]>
1 parent 4b19502 commit 6db4762

File tree

20 files changed

+157
-139
lines changed

20 files changed

+157
-139
lines changed

include/cuco/detail/bloom_filter/bloom_filter_impl.cuh

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -363,44 +363,44 @@ class bloom_filter_impl {
363363
// const;
364364

365365
template <class InputIt, class OutputIt>
366-
__host__ constexpr void contains(InputIt first,
367-
InputIt last,
368-
OutputIt output_begin,
369-
cuda::stream_ref stream) const
366+
__host__ void contains(InputIt first,
367+
InputIt last,
368+
OutputIt output_begin,
369+
cuda::stream_ref stream) const
370370
{
371371
this->contains_async(first, last, output_begin, stream);
372372
stream.wait();
373373
}
374374

375375
template <class InputIt, class OutputIt>
376-
__host__ constexpr void contains_async(InputIt first,
377-
InputIt last,
378-
OutputIt output_begin,
379-
cuda::stream_ref stream) const noexcept
376+
__host__ void contains_async(InputIt first,
377+
InputIt last,
378+
OutputIt output_begin,
379+
cuda::stream_ref stream) const noexcept
380380
{
381381
auto const always_true = thrust::constant_iterator<bool>{true};
382382
this->contains_if_async(first, last, always_true, cuda::std::identity{}, output_begin, stream);
383383
}
384384

385385
template <class InputIt, class StencilIt, class Predicate, class OutputIt>
386-
__host__ constexpr void contains_if(InputIt first,
387-
InputIt last,
388-
StencilIt stencil,
389-
Predicate pred,
390-
OutputIt output_begin,
391-
cuda::stream_ref stream) const
386+
__host__ void contains_if(InputIt first,
387+
InputIt last,
388+
StencilIt stencil,
389+
Predicate pred,
390+
OutputIt output_begin,
391+
cuda::stream_ref stream) const
392392
{
393393
this->contains_if_async(first, last, stencil, pred, output_begin, stream);
394394
stream.wait();
395395
}
396396

397397
template <class InputIt, class StencilIt, class Predicate, class OutputIt>
398-
__host__ constexpr void contains_if_async(InputIt first,
399-
InputIt last,
400-
StencilIt stencil,
401-
Predicate pred,
402-
OutputIt output_begin,
403-
cuda::stream_ref stream) const noexcept
398+
__host__ void contains_if_async(InputIt first,
399+
InputIt last,
400+
StencilIt stencil,
401+
Predicate pred,
402+
OutputIt output_begin,
403+
cuda::stream_ref stream) const noexcept
404404
{
405405
auto const num_keys = cuco::detail::distance(first, last);
406406
if (num_keys == 0) { return; }

include/cuco/detail/hyperloglog/kernels.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ CUCO_KERNEL void add_shmem_vectorized(typename RefType::value_type const* first,
4242
{
4343
using value_type = typename RefType::value_type;
4444
using vector_type = cuda::std::array<value_type, VectorSize>;
45-
using local_ref_type = typename RefType::with_scope<cuda::thread_scope_block>;
45+
using local_ref_type = typename RefType::template with_scope<cuda::thread_scope_block>;
4646

4747
// Base address of dynamic shared memory is guaranteed to be aligned to at least 16 bytes which is
4848
// sufficient for this purpose
@@ -91,7 +91,7 @@ CUCO_KERNEL void add_shmem_vectorized(typename RefType::value_type const* first,
9191
template <class InputIt, class RefType>
9292
CUCO_KERNEL void add_shmem(InputIt first, cuco::detail::index_type n, RefType ref)
9393
{
94-
using local_ref_type = typename RefType::with_scope<cuda::thread_scope_block>;
94+
using local_ref_type = typename RefType::template with_scope<cuda::thread_scope_block>;
9595

9696
// TODO assert alignment
9797
extern __shared__ cuda::std::byte local_sketch[];

include/cuco/detail/open_addressing/kernels.cuh

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -691,19 +691,19 @@ CUCO_KERNEL void retrieve(InputProbeIt input_probe,
691691

692692
if (block_begin_offset < block_end_offset) {
693693
if constexpr (IsOuter) {
694-
ref.retrieve_outer<BlockSize>(block,
695-
input_probe + block_begin_offset,
696-
input_probe + block_end_offset,
697-
output_probe,
698-
output_match,
699-
atomic_counter);
694+
ref.template retrieve_outer<BlockSize>(block,
695+
input_probe + block_begin_offset,
696+
input_probe + block_end_offset,
697+
output_probe,
698+
output_match,
699+
atomic_counter);
700700
} else {
701-
ref.retrieve<BlockSize>(block,
702-
input_probe + block_begin_offset,
703-
input_probe + block_end_offset,
704-
output_probe,
705-
output_match,
706-
atomic_counter);
701+
ref.template retrieve<BlockSize>(block,
702+
input_probe + block_begin_offset,
703+
input_probe + block_end_offset,
704+
output_probe,
705+
output_match,
706+
atomic_counter);
707707
}
708708
}
709709
}

include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -522,17 +522,18 @@ class open_addressing_ref_impl {
522522
"insert_and_find is not supported for pair types larger than 8 bytes on pre-Volta GPUs.");
523523
#endif
524524

525-
auto const val = this->heterogeneous_value(value);
526-
auto const key = this->extract_key(val);
527-
auto probing_iter = probing_scheme_.make_iterator<bucket_size>(key, storage_ref_.extent());
525+
auto const val = this->heterogeneous_value(value);
526+
auto const key = this->extract_key(val);
527+
auto probing_iter =
528+
probing_scheme_.template make_iterator<bucket_size>(key, storage_ref_.extent());
528529
auto const init_idx = *probing_iter;
529530

530531
while (true) {
531532
auto const bucket_slots = storage_ref_[*probing_iter];
532533

533534
for (auto i = 0; i < bucket_size; ++i) {
534-
auto const eq_res =
535-
this->predicate_.operator()<is_insert::YES>(key, this->extract_key(bucket_slots[i]));
535+
auto const eq_res = this->predicate_.template operator()<is_insert::YES>(
536+
key, this->extract_key(bucket_slots[i]));
536537
auto* slot_ptr = this->get_slot_ptr(*probing_iter, i);
537538

538539
// If the key is already in the container, return false
@@ -598,7 +599,7 @@ class open_addressing_ref_impl {
598599
auto const val = this->heterogeneous_value(value);
599600
auto const key = this->extract_key(val);
600601
auto probing_iter =
601-
probing_scheme_.make_iterator<bucket_size>(group, key, storage_ref_.extent());
602+
probing_scheme_.template make_iterator<bucket_size>(group, key, storage_ref_.extent());
602603
auto const init_idx = *probing_iter;
603604

604605
while (true) {
@@ -607,8 +608,8 @@ class open_addressing_ref_impl {
607608
auto const [state, intra_bucket_index] = [&]() {
608609
auto res = detail::equal_result::UNEQUAL;
609610
for (auto i = 0; i < bucket_size; ++i) {
610-
res =
611-
this->predicate_.operator()<is_insert::YES>(key, this->extract_key(bucket_slots[i]));
611+
res = this->predicate_.template operator()<is_insert::YES>(
612+
key, this->extract_key(bucket_slots[i]));
612613
if (res != detail::equal_result::UNEQUAL) { return bucket_probing_results{res, i}; }
613614
}
614615
// returns dummy index `-1` for UNEQUAL
@@ -685,15 +686,16 @@ class open_addressing_ref_impl {
685686
{
686687
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
687688

688-
auto probing_iter = probing_scheme_.make_iterator<bucket_size>(key, storage_ref_.extent());
689+
auto probing_iter =
690+
probing_scheme_.template make_iterator<bucket_size>(key, storage_ref_.extent());
689691
auto const init_idx = *probing_iter;
690692

691693
while (true) {
692694
auto const bucket_slots = storage_ref_[*probing_iter];
693695

694696
for (auto& slot_content : bucket_slots) {
695697
auto const eq_res =
696-
this->predicate_.operator()<is_insert::NO>(key, this->extract_key(slot_content));
698+
this->predicate_.template operator()<is_insert::NO>(key, this->extract_key(slot_content));
697699

698700
// Key doesn't exist, return false
699701
if (eq_res == detail::equal_result::EMPTY) { return false; }
@@ -729,7 +731,7 @@ class open_addressing_ref_impl {
729731
ProbeKey const& key) noexcept
730732
{
731733
auto probing_iter =
732-
probing_scheme_.make_iterator<bucket_size>(group, key, storage_ref_.extent());
734+
probing_scheme_.template make_iterator<bucket_size>(group, key, storage_ref_.extent());
733735
auto const init_idx = *probing_iter;
734736

735737
while (true) {
@@ -738,7 +740,8 @@ class open_addressing_ref_impl {
738740
auto const [state, intra_bucket_index] = [&]() {
739741
auto res = detail::equal_result::UNEQUAL;
740742
for (auto i = 0; i < bucket_size; ++i) {
741-
res = this->predicate_.operator()<is_insert::NO>(key, this->extract_key(bucket_slots[i]));
743+
res = this->predicate_.template operator()<is_insert::NO>(
744+
key, this->extract_key(bucket_slots[i]));
742745
if (res != detail::equal_result::UNEQUAL) { return bucket_probing_results{res, i}; }
743746
}
744747
// returns dummy index `-1` for UNEQUAL
@@ -824,7 +827,7 @@ class open_addressing_ref_impl {
824827
cooperative_groups::thread_block_tile<cg_size> const& group, ProbeKey const& key) const noexcept
825828
{
826829
auto probing_iter =
827-
probing_scheme_.make_iterator<bucket_size>(group, key, storage_ref_.extent());
830+
probing_scheme_.template make_iterator<bucket_size>(group, key, storage_ref_.extent());
828831
auto const init_idx = *probing_iter;
829832

830833
while (true) {
@@ -833,7 +836,8 @@ class open_addressing_ref_impl {
833836
auto const state = [&]() {
834837
auto res = detail::equal_result::UNEQUAL;
835838
for (auto i = 0; i < bucket_size; ++i) {
836-
res = this->predicate_.operator()<is_insert::NO>(key, this->extract_key(bucket_slots[i]));
839+
res = this->predicate_.template operator()<is_insert::NO>(
840+
key, this->extract_key(bucket_slots[i]));
837841
if (res != detail::equal_result::UNEQUAL) { return res; }
838842
}
839843
return res;
@@ -863,16 +867,17 @@ class open_addressing_ref_impl {
863867
[[nodiscard]] __device__ iterator find(ProbeKey const& key) const noexcept
864868
{
865869
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
866-
auto probing_iter = probing_scheme_.make_iterator<bucket_size>(key, storage_ref_.extent());
870+
auto probing_iter =
871+
probing_scheme_.template make_iterator<bucket_size>(key, storage_ref_.extent());
867872
auto const init_idx = *probing_iter;
868873

869874
while (true) {
870875
// TODO atomic_ref::load if insert operator is present
871876
auto const bucket_slots = storage_ref_[*probing_iter];
872877

873878
for (auto i = 0; i < bucket_size; ++i) {
874-
switch (
875-
this->predicate_.operator()<is_insert::NO>(key, this->extract_key(bucket_slots[i]))) {
879+
switch (this->predicate_.template operator()<is_insert::NO>(
880+
key, this->extract_key(bucket_slots[i]))) {
876881
case detail::equal_result::EMPTY: {
877882
return this->end();
878883
}
@@ -905,7 +910,7 @@ class open_addressing_ref_impl {
905910
cooperative_groups::thread_block_tile<cg_size> const& group, ProbeKey const& key) const noexcept
906911
{
907912
auto probing_iter =
908-
probing_scheme_.make_iterator<bucket_size>(group, key, storage_ref_.extent());
913+
probing_scheme_.template make_iterator<bucket_size>(group, key, storage_ref_.extent());
909914
auto const init_idx = *probing_iter;
910915

911916
while (true) {
@@ -914,7 +919,8 @@ class open_addressing_ref_impl {
914919
auto const [state, intra_bucket_index] = [&]() {
915920
auto res = detail::equal_result::UNEQUAL;
916921
for (auto i = 0; i < bucket_size; ++i) {
917-
res = this->predicate_.operator()<is_insert::NO>(key, this->extract_key(bucket_slots[i]));
922+
res = this->predicate_.template operator()<is_insert::NO>(
923+
key, this->extract_key(bucket_slots[i]));
918924
if (res != detail::equal_result::UNEQUAL) { return bucket_probing_results{res, i}; }
919925
}
920926
// returns dummy index `-1` for UNEQUAL
@@ -954,7 +960,8 @@ class open_addressing_ref_impl {
954960
if constexpr (not allows_duplicates) {
955961
return static_cast<size_type>(this->contains(key));
956962
} else {
957-
auto probing_iter = probing_scheme_.make_iterator<bucket_size>(key, storage_ref_.extent());
963+
auto probing_iter =
964+
probing_scheme_.template make_iterator<bucket_size>(key, storage_ref_.extent());
958965
auto const init_idx = *probing_iter;
959966
size_type count = 0;
960967

@@ -999,7 +1006,7 @@ class open_addressing_ref_impl {
9991006
cooperative_groups::thread_block_tile<cg_size> const& group, ProbeKey const& key) const noexcept
10001007
{
10011008
auto probing_iter =
1002-
probing_scheme_.make_iterator<bucket_size>(group, key, storage_ref_.extent());
1009+
probing_scheme_.template make_iterator<bucket_size>(group, key, storage_ref_.extent());
10031010
auto const init_idx = *probing_iter;
10041011
size_type count = 0;
10051012

@@ -1222,7 +1229,7 @@ class open_addressing_ref_impl {
12221229
// perform probing
12231230
// make sure the flushing_tile is converged at this point to get a coalesced load
12241231
auto const probe_key = *(input_probe + idx);
1225-
auto probing_iter = probing_scheme_.make_iterator<bucket_size>(
1232+
auto probing_iter = probing_scheme_.template make_iterator<bucket_size>(
12261233
probing_tile, probe_key, storage_ref_.extent());
12271234
auto const init_idx = *probing_iter;
12281235

@@ -1242,7 +1249,7 @@ class open_addressing_ref_impl {
12421249
equals[i] = false;
12431250
if (running) {
12441251
// inspect slot content
1245-
switch (this->predicate_.operator()<is_insert::NO>(
1252+
switch (this->predicate_.template operator()<is_insert::NO>(
12461253
probe_key, this->extract_key(bucket_slots[i]))) {
12471254
case detail::equal_result::EMPTY: {
12481255
running = false;
@@ -1354,16 +1361,17 @@ class open_addressing_ref_impl {
13541361
__device__ void for_each(ProbeKey const& key, CallbackOp&& callback_op) const noexcept
13551362
{
13561363
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
1357-
auto probing_iter = probing_scheme_.make_iterator<bucket_size>(key, storage_ref_.extent());
1364+
auto probing_iter =
1365+
probing_scheme_.template make_iterator<bucket_size>(key, storage_ref_.extent());
13581366
auto const init_idx = *probing_iter;
13591367

13601368
while (true) {
13611369
// TODO atomic_ref::load if insert operator is present
13621370
auto const bucket_slots = this->storage_ref_[*probing_iter];
13631371

13641372
for (int32_t i = 0; i < bucket_size; ++i) {
1365-
switch (
1366-
this->predicate_.operator()<is_insert::NO>(key, this->extract_key(bucket_slots[i]))) {
1373+
switch (this->predicate_.template operator()<is_insert::NO>(
1374+
key, this->extract_key(bucket_slots[i]))) {
13671375
case detail::equal_result::EMPTY: {
13681376
return;
13691377
}
@@ -1404,7 +1412,7 @@ class open_addressing_ref_impl {
14041412
CallbackOp&& callback_op) const noexcept
14051413
{
14061414
auto probing_iter =
1407-
probing_scheme_.make_iterator<bucket_size>(group, key, storage_ref_.extent());
1415+
probing_scheme_.template make_iterator<bucket_size>(group, key, storage_ref_.extent());
14081416
auto const init_idx = *probing_iter;
14091417
bool empty = false;
14101418

@@ -1413,8 +1421,8 @@ class open_addressing_ref_impl {
14131421
auto const bucket_slots = this->storage_ref_[*probing_iter];
14141422

14151423
for (int32_t i = 0; i < bucket_size and !empty; ++i) {
1416-
switch (
1417-
this->predicate_.operator()<is_insert::NO>(key, this->extract_key(bucket_slots[i]))) {
1424+
switch (this->predicate_.template operator()<is_insert::NO>(
1425+
key, this->extract_key(bucket_slots[i]))) {
14181426
case detail::equal_result::EMPTY: {
14191427
empty = true;
14201428
continue;
@@ -1469,7 +1477,7 @@ class open_addressing_ref_impl {
14691477
SyncOp&& sync_op) const noexcept
14701478
{
14711479
auto probing_iter =
1472-
probing_scheme_.make_iterator<bucket_size>(group, key, storage_ref_.extent());
1480+
probing_scheme_.template make_iterator<bucket_size>(group, key, storage_ref_.extent());
14731481
auto const init_idx = *probing_iter;
14741482
bool empty = false;
14751483

@@ -1478,8 +1486,8 @@ class open_addressing_ref_impl {
14781486
auto const bucket_slots = this->storage_ref_[*probing_iter];
14791487

14801488
for (int32_t i = 0; i < bucket_size and !empty; ++i) {
1481-
switch (
1482-
this->predicate_.operator()<is_insert::NO>(key, this->extract_key(bucket_slots[i]))) {
1489+
switch (this->predicate_.template operator()<is_insert::NO>(
1490+
key, this->extract_key(bucket_slots[i]))) {
14831491
case detail::equal_result::EMPTY: {
14841492
empty = true;
14851493
continue;

include/cuco/detail/static_map/kernels.cuh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,8 @@ CUCO_KERNEL __launch_bounds__(BlockSize) void insert_or_apply_shmem(
191191

192192
// Shared map initialization
193193
__shared__ typename SharedMapRefType::value_type slots[bucket_extent.value()];
194-
auto storage = SharedMapRefType::storage_ref_type(bucket_extent, slots);
194+
using storage_ref_type = typename SharedMapRefType::storage_ref_type;
195+
auto storage = storage_ref_type(bucket_extent, slots);
195196
auto const num_buckets = storage.num_buckets();
196197

197198
using atomic_type = cuda::atomic<int32_t, cuda::thread_scope_block>;

include/cuco/detail/static_map/static_map.inl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,7 @@ std::pair<KeyOut, ValueOut>
703703
static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::retrieve_all(
704704
KeyOut keys_out, ValueOut values_out, cuda::stream_ref stream) const
705705
{
706-
auto const zipped_out_begin = thrust::make_zip_iterator(cuda::std::tuple{keys_out, values_out});
706+
auto const zipped_out_begin = thrust::make_zip_iterator(keys_out, values_out);
707707
auto const zipped_out_end = impl_->retrieve_all(zipped_out_begin, stream);
708708
auto const num = std::distance(zipped_out_begin, zipped_out_end);
709709

0 commit comments

Comments
 (0)