Skip to content

Commit eb9319b

Browse files
authored
Add host find_if APIs for all hash tables (#638)
This migrates the existing OA `find` kernel to `find_if_n` and adds `find_if(_async)` APIs for all hash tables.
1 parent d39d59a commit eb9319b

File tree

14 files changed

+588
-31
lines changed

14 files changed

+588
-31
lines changed

include/cuco/detail/open_addressing/kernels.cuh

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -329,19 +329,33 @@ struct find_buffer<Container, cuda::std::void_t<typename Container::mapped_type>
329329
* @tparam CGSize Number of threads in each CG
330330
* @tparam BlockSize The size of the thread block
331331
* @tparam InputIt Device accessible input iterator
332+
* @tparam StencilIt Device accessible random access iterator whose value_type is
333+
* convertible to Predicate's argument type
334+
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool`
335+
* and argument type is convertible from `std::iterator_traits<StencilIt>::value_type`
332336
* @tparam OutputIt Device accessible output iterator
333337
* @tparam Ref Type of non-owning device ref allowing access to storage
334338
*
335339
* @param first Beginning of the sequence of keys
336340
* @param n Number of keys to query
341+
* @param stencil Beginning of the stencil sequence
342+
* @param pred Predicate to test on every element in the range `[stencil, stencil + n)`
337343
* @param output_begin Beginning of the sequence of matched payloads retrieved for each key
338344
* @param ref Non-owning container device ref used to access the slot storage
339345
*/
340-
template <int32_t CGSize, int32_t BlockSize, typename InputIt, typename OutputIt, typename Ref>
341-
CUCO_KERNEL __launch_bounds__(BlockSize) void find(InputIt first,
342-
cuco::detail::index_type n,
343-
OutputIt output_begin,
344-
Ref ref)
346+
template <int32_t CGSize,
347+
int32_t BlockSize,
348+
typename InputIt,
349+
typename StencilIt,
350+
typename Predicate,
351+
typename OutputIt,
352+
typename Ref>
353+
CUCO_KERNEL __launch_bounds__(BlockSize) void find_if_n(InputIt first,
354+
cuco::detail::index_type n,
355+
StencilIt stencil,
356+
Predicate pred,
357+
OutputIt output_begin,
358+
Ref ref)
345359
{
346360
namespace cg = cooperative_groups;
347361

@@ -382,7 +396,7 @@ CUCO_KERNEL __launch_bounds__(BlockSize) void find(InputIt first,
382396
* synchronizing before writing back to global, we no longer rely on L1, preventing the
383397
* increase in sector stores from L2 to global and improving performance.
384398
*/
385-
output_buffer[thread_idx] = output(found);
399+
output_buffer[thread_idx] = pred(*(stencil + idx)) ? output(found) : sentinel;
386400
}
387401
block.sync();
388402
if (idx < n) { *(output_begin + idx) = output_buffer[thread_idx]; }
@@ -392,7 +406,9 @@ CUCO_KERNEL __launch_bounds__(BlockSize) void find(InputIt first,
392406
typename std::iterator_traits<InputIt>::value_type const& key = *(first + idx);
393407
auto const found = ref.find(tile, key);
394408

395-
if (tile.thread_rank() == 0) { *(output_begin + idx) = output(found); }
409+
if (tile.thread_rank() == 0) {
410+
*(output_begin + idx) = pred(*(stencil + idx)) ? output(found) : sentinel;
411+
}
396412
}
397413
}
398414
idx += loop_stride;

include/cuco/detail/open_addressing/open_addressing_impl.cuh

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -565,15 +565,59 @@ class open_addressing_impl {
565565
OutputIt output_begin,
566566
Ref container_ref,
567567
cuda::stream_ref stream) const noexcept
568+
{
569+
auto const always_true = thrust::constant_iterator<bool>{true};
570+
571+
this->find_if_async(
572+
first, last, always_true, thrust::identity{}, output_begin, container_ref, stream);
573+
}
574+
575+
/**
576+
* @brief For all keys in the range `[first, last)`, asynchronously finds
577+
* a match with its key equivalent to the query key.
578+
*
579+
* @note If `pred( *(stencil + i) )` is true, stores the payload of the
580+
* matched key or the `empty_value_sentienl` to `(output_begin + i)`. If `pred( *(stencil + i) )`
581+
* is false, stores `empty_value_sentienl` to `(output_begin + i)`.
582+
*
583+
* @tparam InputIt Device accessible input iterator
584+
* @tparam StencilIt Device accessible random access iterator whose value_type is
585+
* convertible to Predicate's argument type
586+
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and
587+
* argument type is convertible from <tt>std::iterator_traits<StencilIt>::value_type</tt>
588+
* @tparam OutputIt Device accessible output iterator
589+
* @tparam Ref Type of non-owning device container ref allowing access to storage
590+
*
591+
* @param first Beginning of the sequence of keys
592+
* @param last End of the sequence of keys
593+
* @param stencil Beginning of the stencil sequence
594+
* @param pred Predicate to test on every element in the range `[stencil, stencil +
595+
* std::distance(first, last))`
596+
* @param output_begin Beginning of the sequence of matches retrieved for each key
597+
* @param container_ref Non-owning device container ref used to access the slot storage
598+
* @param stream Stream used for executing the kernels
599+
*/
600+
template <typename InputIt,
601+
typename StencilIt,
602+
typename Predicate,
603+
typename OutputIt,
604+
typename Ref>
605+
void find_if_async(InputIt first,
606+
InputIt last,
607+
StencilIt stencil,
608+
Predicate pred,
609+
OutputIt output_begin,
610+
Ref container_ref,
611+
cuda::stream_ref stream) const noexcept
568612
{
569613
auto const num_keys = cuco::detail::distance(first, last);
570614
if (num_keys == 0) { return; }
571615

572616
auto const grid_size = cuco::detail::grid_size(num_keys, cg_size);
573617

574-
detail::find<cg_size, cuco::detail::default_block_size()>
618+
detail::find_if_n<cg_size, cuco::detail::default_block_size()>
575619
<<<grid_size, cuco::detail::default_block_size(), 0, stream.get()>>>(
576-
first, num_keys, output_begin, container_ref);
620+
first, num_keys, stencil, pred, output_begin, container_ref);
577621
}
578622

579623
/**

include/cuco/detail/static_map/static_map.inl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,47 @@ void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Stora
491491
impl_->find_async(first, last, output_begin, ref(op::find), stream);
492492
}
493493

494+
template <class Key,
495+
class T,
496+
class Extent,
497+
cuda::thread_scope Scope,
498+
class KeyEqual,
499+
class ProbingScheme,
500+
class Allocator,
501+
class Storage>
502+
template <typename InputIt, typename StencilIt, typename Predicate, typename OutputIt>
503+
void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::find_if(
504+
InputIt first,
505+
InputIt last,
506+
StencilIt stencil,
507+
Predicate pred,
508+
OutputIt output_begin,
509+
cuda::stream_ref stream) const
510+
{
511+
this->find_if_async(first, last, stencil, pred, output_begin, stream);
512+
stream.wait();
513+
}
514+
515+
template <class Key,
516+
class T,
517+
class Extent,
518+
cuda::thread_scope Scope,
519+
class KeyEqual,
520+
class ProbingScheme,
521+
class Allocator,
522+
class Storage>
523+
template <typename InputIt, typename StencilIt, typename Predicate, typename OutputIt>
524+
void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::find_if_async(
525+
InputIt first,
526+
InputIt last,
527+
StencilIt stencil,
528+
Predicate pred,
529+
OutputIt output_begin,
530+
cuda::stream_ref stream) const
531+
{
532+
impl_->find_if_async(first, last, stencil, pred, output_begin, ref(op::find), stream);
533+
}
534+
494535
template <class Key,
495536
class T,
496537
class Extent,

include/cuco/detail/static_multimap/static_multimap.inl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,47 @@ void static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator,
315315
impl_->find_async(first, last, output_begin, ref(op::find), stream);
316316
}
317317

318+
template <class Key,
319+
class T,
320+
class Extent,
321+
cuda::thread_scope Scope,
322+
class KeyEqual,
323+
class ProbingScheme,
324+
class Allocator,
325+
class Storage>
326+
template <typename InputIt, typename StencilIt, typename Predicate, typename OutputIt>
327+
void static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::find_if(
328+
InputIt first,
329+
InputIt last,
330+
StencilIt stencil,
331+
Predicate pred,
332+
OutputIt output_begin,
333+
cuda::stream_ref stream) const
334+
{
335+
this->find_if_async(first, last, stencil, pred, output_begin, stream);
336+
stream.wait();
337+
}
338+
339+
template <class Key,
340+
class T,
341+
class Extent,
342+
cuda::thread_scope Scope,
343+
class KeyEqual,
344+
class ProbingScheme,
345+
class Allocator,
346+
class Storage>
347+
template <typename InputIt, typename StencilIt, typename Predicate, typename OutputIt>
348+
void static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::
349+
find_if_async(InputIt first,
350+
InputIt last,
351+
StencilIt stencil,
352+
Predicate pred,
353+
OutputIt output_begin,
354+
cuda::stream_ref stream) const
355+
{
356+
impl_->find_if_async(first, last, stencil, pred, output_begin, ref(op::find), stream);
357+
}
358+
318359
template <class Key,
319360
class T,
320361
class Extent,

include/cuco/detail/static_multiset/static_multiset.inl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,45 @@ void static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Sto
277277
impl_->find_async(first, last, output_begin, ref(op::find), stream);
278278
}
279279

280+
template <class Key,
281+
class Extent,
282+
cuda::thread_scope Scope,
283+
class KeyEqual,
284+
class ProbingScheme,
285+
class Allocator,
286+
class Storage>
287+
template <typename InputIt, typename StencilIt, typename Predicate, typename OutputIt>
288+
void static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::find_if(
289+
InputIt first,
290+
InputIt last,
291+
StencilIt stencil,
292+
Predicate pred,
293+
OutputIt output_begin,
294+
cuda::stream_ref stream) const
295+
{
296+
this->find_if_async(first, last, stencil, pred, output_begin, stream);
297+
stream.wait();
298+
}
299+
300+
template <class Key,
301+
class Extent,
302+
cuda::thread_scope Scope,
303+
class KeyEqual,
304+
class ProbingScheme,
305+
class Allocator,
306+
class Storage>
307+
template <typename InputIt, typename StencilIt, typename Predicate, typename OutputIt>
308+
void static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::
309+
find_if_async(InputIt first,
310+
InputIt last,
311+
StencilIt stencil,
312+
Predicate pred,
313+
OutputIt output_begin,
314+
cuda::stream_ref stream) const
315+
{
316+
impl_->find_if_async(first, last, stencil, pred, output_begin, ref(op::find), stream);
317+
}
318+
280319
template <class Key,
281320
class Extent,
282321
cuda::thread_scope Scope,

include/cuco/detail/static_set/static_set.inl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,45 @@ void static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>
338338
impl_->find_async(first, last, output_begin, ref(op::find), stream);
339339
}
340340

341+
template <class Key,
342+
class Extent,
343+
cuda::thread_scope Scope,
344+
class KeyEqual,
345+
class ProbingScheme,
346+
class Allocator,
347+
class Storage>
348+
template <typename InputIt, typename StencilIt, typename Predicate, typename OutputIt>
349+
void static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::find_if(
350+
InputIt first,
351+
InputIt last,
352+
StencilIt stencil,
353+
Predicate pred,
354+
OutputIt output_begin,
355+
cuda::stream_ref stream) const
356+
{
357+
this->find_if_async(first, last, stencil, pred, output_begin, stream);
358+
stream.wait();
359+
}
360+
361+
template <class Key,
362+
class Extent,
363+
cuda::thread_scope Scope,
364+
class KeyEqual,
365+
class ProbingScheme,
366+
class Allocator,
367+
class Storage>
368+
template <typename InputIt, typename StencilIt, typename Predicate, typename OutputIt>
369+
void static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::find_if_async(
370+
InputIt first,
371+
InputIt last,
372+
StencilIt stencil,
373+
Predicate pred,
374+
OutputIt output_begin,
375+
cuda::stream_ref stream) const
376+
{
377+
impl_->find_if_async(first, last, stencil, pred, output_begin, ref(op::find), stream);
378+
}
379+
341380
template <class Key,
342381
class Extent,
343382
cuda::thread_scope Scope,

include/cuco/static_map.cuh

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,70 @@ class static_map {
767767
OutputIt output_begin,
768768
cuda::stream_ref stream = {}) const;
769769

770+
/**
771+
* @brief For all keys in the range `[first, last)`, finds a match with its key equivalent to the
772+
* query key.
773+
*
774+
* @note If `pred( *(stencil + i) )` is true, stores the payload of the
775+
* matched key or the `empty_value_sentienl` to `(output_begin + i)`. If `pred( *(stencil + i) )`
776+
* is false, always stores the `empty_value_sentienl` to `(output_begin + i)`.
777+
* @note This function synchronizes the given stream. For asynchronous execution use
778+
* `find_if_async`.
779+
*
780+
* @tparam InputIt Device accessible input iterator
781+
* @tparam StencilIt Device accessible random access iterator whose `value_type` is convertible to
782+
* Predicate's argument type
783+
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and
784+
* argument type is convertible from <tt>std::iterator_traits<StencilIt>::value_type</tt>
785+
* @tparam OutputIt Device accessible output iterator
786+
*
787+
* @param first Beginning of the sequence of keys
788+
* @param last End of the sequence of keys
789+
* @param stencil Beginning of the stencil sequence
790+
* @param pred Predicate to test on every element in the range `[stencil, stencil +
791+
* std::distance(first, last))`
792+
* @param output_begin Beginning of the sequence of matches retrieved for each key
793+
* @param stream Stream used for executing the kernels
794+
*/
795+
template <typename InputIt, typename StencilIt, typename Predicate, typename OutputIt>
796+
void find_if(InputIt first,
797+
InputIt last,
798+
StencilIt stencil,
799+
Predicate pred,
800+
OutputIt output_begin,
801+
cuda::stream_ref stream = {}) const;
802+
803+
/**
804+
* @brief For all keys in the range `[first, last)`, asynchronously finds
805+
* a match with its key equivalent to the query key.
806+
*
807+
* @note If `pred( *(stencil + i) )` is true, stores the payload of the
808+
* matched key or the `empty_value_sentienl` to `(output_begin + i)`. If `pred( *(stencil + i) )`
809+
* is false, always stores the `empty_value_sentienl` to `(output_begin + i)`.
810+
*
811+
* @tparam InputIt Device accessible input iterator
812+
* @tparam StencilIt Device accessible random access iterator whose `value_type` is convertible to
813+
* Predicate's argument type
814+
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and
815+
* argument type is convertible from <tt>std::iterator_traits<StencilIt>::value_type</tt>
816+
* @tparam OutputIt Device accessible output iterator
817+
*
818+
* @param first Beginning of the sequence of keys
819+
* @param last End of the sequence of keys
820+
* @param stencil Beginning of the stencil sequence
821+
* @param pred Predicate to test on every element in the range `[stencil, stencil +
822+
* std::distance(first, last))`
823+
* @param output_begin Beginning of the sequence of matches retrieved for each key
824+
* @param stream Stream used for executing the kernels
825+
*/
826+
template <typename InputIt, typename StencilIt, typename Predicate, typename OutputIt>
827+
void find_if_async(InputIt first,
828+
InputIt last,
829+
StencilIt stencil,
830+
Predicate pred,
831+
OutputIt output_begin,
832+
cuda::stream_ref stream = {}) const;
833+
770834
/**
771835
* @brief Applies the given function object `callback_op` to the copy of every filled slot in the
772836
* container

0 commit comments

Comments
 (0)