3434#include < cuda/std/tuple>
3535#include < cuda/std/type_traits>
3636#include < cuda/stream_ref>
37+ #include < cuda/utility>
3738#include < thrust/iterator/constant_iterator.h>
3839
3940#include < cooperative_groups.h>
@@ -142,15 +143,14 @@ class bloom_filter_impl {
142143 template <class HashValue , class BlockIndex >
143144 __device__ void add_impl (HashValue const & hash_value, BlockIndex block_index)
144145 {
145- #pragma unroll words_per_block
146- for (uint32_t i = 0 ; i < words_per_block; ++i) {
147- auto const word = policy_.word_pattern (hash_value, i);
146+ cuda::static_for<words_per_block>([&](auto i) {
147+ auto const word = policy_.word_pattern (hash_value, i ());
148148 if (word != 0 ) {
149149 auto atom_word = cuda::atomic_ref<word_type, thread_scope>{
150- *(words_ + (block_index * words_per_block + i))};
150+ *(words_ + (block_index * words_per_block + i () ))};
151151 atom_word.fetch_or (word, cuda::memory_order_relaxed);
152152 }
153- }
153+ });
154154 }
155155
156156 template <class CG , class ProbeKey >
@@ -205,9 +205,11 @@ class bloom_filter_impl {
205205 block_index = policy_.block_index (hash_value, num_blocks_);
206206 }
207207
208- for (uint32_t j = 0 ; (j < num_threads) and (i + j < num_keys); ++j) {
209- this ->add_impl (group, group.shfl (hash_value, j), group.shfl (block_index, j));
210- }
208+ cuda::static_for<num_threads>([&](auto j) {
209+ if ((j () < num_threads) and (i + j () < num_keys)) {
210+ this ->add_impl (group, group.shfl (hash_value, j ()), group.shfl (block_index, j ()));
211+ }
212+ });
211213 }
212214 } else { // subdivide given CG into multiple optimal CGs
213215 typename policy_type::hash_result_type hash_value;
@@ -225,10 +227,13 @@ class bloom_filter_impl {
225227 block_index = policy_.block_index (hash_value, num_blocks_);
226228 }
227229
228- for (uint32_t j = 0 ; (j < worker_num_threads) and (i + worker_offset + j < num_keys); ++j) {
229- this ->add_impl (
230- worker_group, worker_group.shfl (hash_value, j), worker_group.shfl (block_index, j));
231- }
230+ cuda::static_for<worker_num_threads>([&](auto j) {
231+ if ((j () < worker_num_threads) and (i + worker_offset + j () < num_keys)) {
232+ this ->add_impl (worker_group,
233+ worker_group.shfl (hash_value, j ()),
234+ worker_group.shfl (block_index, j ()));
235+ }
236+ });
232237 }
233238 }
234239 }
@@ -245,12 +250,13 @@ class bloom_filter_impl {
245250 *(words_ + (block_index * words_per_block + rank))};
246251 atom_word.fetch_or (policy_.word_pattern (hash_value, rank), cuda::memory_order_relaxed);
247252 } else {
248- #pragma unroll
249- for (auto i = rank; i < words_per_block; i += num_threads) {
250- auto atom_word = cuda::atomic_ref<word_type, thread_scope>{
251- *(words_ + (block_index * words_per_block + i))};
252- atom_word.fetch_or (policy_.word_pattern (hash_value, i), cuda::memory_order_relaxed);
253- }
253+ cuda::static_for<words_per_block>([&](auto i) {
254+ if (i () >= rank && (i () - rank) % num_threads == 0 ) {
255+ auto atom_word = cuda::atomic_ref<word_type, thread_scope>{
256+ *(words_ + (block_index * words_per_block + i ()))};
257+ atom_word.fetch_or (policy_.word_pattern (hash_value, i ()), cuda::memory_order_relaxed);
258+ }
259+ });
254260 }
255261 }
256262
@@ -330,11 +336,12 @@ class bloom_filter_impl {
330336 auto const stored_pattern = this ->vec_load_words <words_per_block>(
331337 policy_.block_index (hash_value, num_blocks_) * words_per_block);
332338
333- #pragma unroll words_per_block
334- for (uint32_t i = 0 ; i < words_per_block; ++i) {
335- auto const expected_pattern = policy_.word_pattern (hash_value, i);
336- if ((stored_pattern[i] & expected_pattern) != expected_pattern) { return false ; }
337- }
339+ bool result = true ;
340+ cuda::static_for<words_per_block>([&](auto i) {
341+ auto const expected_pattern = policy_.word_pattern (hash_value, i ());
342+ if ((stored_pattern[i ()] & expected_pattern) != expected_pattern) { result = false ; }
343+ });
344+ if (!result) { return false ; }
338345
339346 return true ;
340347 }
@@ -354,17 +361,17 @@ class bloom_filter_impl {
354361 auto const hash_value = policy_.hash (key);
355362 bool success = true ;
356363
357- # pragma unroll
358- for ( uint32_t i = rank; i < optimal_num_threads; i += num_threads ) {
359- auto const thread_offset = i * words_per_thread;
360- auto const stored_pattern = this ->vec_load_words <words_per_thread>(
361- policy_.block_index (hash_value, num_blocks_) * words_per_block + thread_offset);
362- # pragma unroll words_per_thread
363- for ( uint32_t j = 0 ; j < words_per_thread; ++j) {
364- auto const expected_pattern = policy_. word_pattern (hash_value, thread_offset + j);
365- if ((stored_pattern[j] & expected_pattern) != expected_pattern) { success = false ; }
364+ cuda::static_for<optimal_num_threads>([&]( auto i) {
365+ if ( i () > = rank && ( i () - rank) % num_threads == 0 ) {
366+ auto const thread_offset = i () * words_per_thread;
367+ auto const stored_pattern = this ->vec_load_words <words_per_thread>(
368+ policy_.block_index (hash_value, num_blocks_) * words_per_block + thread_offset);
369+ cuda::static_for< words_per_thread>([&]( auto j) {
370+ auto const expected_pattern = policy_. word_pattern (hash_value, thread_offset + j ());
371+ if ((stored_pattern[ j ()] & expected_pattern) != expected_pattern) { success = false ; }
372+ });
366373 }
367- }
374+ });
368375
369376 return group.all (success);
370377 }
0 commit comments