1515 */
1616
1717#include < cuco/detail/bitwise_compare.cuh>
18- #include < cuco/detail/static_map/helpers.cuh>
1918#include < cuco/detail/static_map/kernels.cuh>
2019#include < cuco/detail/utility/cuda.hpp>
2120#include < cuco/detail/utils.hpp>
21+ #include < cuco/extent.cuh>
2222#include < cuco/operator.hpp>
2323#include < cuco/static_map_ref.cuh>
24+ #include < cuco/storage.cuh>
2425
2526#include < cuda/std/tuple>
2627#include < cuda/stream_ref>
@@ -352,7 +353,7 @@ void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Stora
352353{
353354 auto constexpr has_init = false ;
354355 auto const init = this ->empty_value_sentinel (); // use empty_sentinel as unused init value
355- detail::static_map_ns:: dispatch_insert_or_apply<has_init, cg_size, Allocator>(
356+ this -> dispatch_insert_or_apply <has_init, cg_size, Allocator>(
356357 first, last, init, op, ref (op::insert_or_apply), stream);
357358}
358359
@@ -370,7 +371,7 @@ void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Stora
370371 InputIt first, InputIt last, Init init, Op op, cuda::stream_ref stream) noexcept
371372{
372373 auto constexpr has_init = true ;
373- detail::static_map_ns:: dispatch_insert_or_apply<has_init, cg_size, Allocator>(
374+ this -> dispatch_insert_or_apply <has_init, cg_size, Allocator>(
374375 first, last, init, op, ref (op::insert_or_apply), stream);
375376}
376377
@@ -960,4 +961,92 @@ auto static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Stora
960961 cuda_thread_scope<Scope>{},
961962 impl_->storage_ref ()};
962963}
964+
965+ template <class Key ,
966+ class T ,
967+ class Extent ,
968+ cuda::thread_scope Scope,
969+ class KeyEqual ,
970+ class ProbingScheme ,
971+ class Allocator ,
972+ class Storage >
973+ template <bool HasInit,
974+ int32_t CGSize,
975+ typename AllocatorType,
976+ typename InputIt,
977+ typename InitType,
978+ typename OpType,
979+ typename RefType>
980+ void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::
981+ dispatch_insert_or_apply (
982+ InputIt first, InputIt last, InitType init, OpType op, RefType ref, cuda::stream_ref stream)
983+ {
984+ auto const num = cuco::detail::distance (first, last);
985+ if (num == 0 ) { return ; }
986+
987+ int32_t const default_grid_size = cuco::detail::grid_size (num, CGSize);
988+
989+ if constexpr (CGSize == 1 ) {
990+ using shmem_size_type = int32_t ;
991+
992+ int32_t constexpr shmem_block_size = 1024 ;
993+ shmem_size_type constexpr cardinality_threshold = shmem_block_size;
994+ shmem_size_type constexpr shared_map_num_elements = cardinality_threshold + shmem_block_size;
995+ float constexpr load_factor = 0.7 ;
996+ shmem_size_type constexpr shared_map_size =
997+ static_cast <shmem_size_type>((1.0 / load_factor) * shared_map_num_elements);
998+
999+ using extent_type = cuco::extent<shmem_size_type, shared_map_size>;
1000+ using shared_map_type = cuco::static_map<typename RefType::key_type,
1001+ typename RefType::mapped_type,
1002+ extent_type,
1003+ cuda::thread_scope_block,
1004+ typename RefType::key_equal,
1005+ typename RefType::probing_scheme_type,
1006+ AllocatorType,
1007+ cuco::storage<1 >>;
1008+
1009+ using shared_map_ref_type = typename shared_map_type::template ref_type<>;
1010+ auto constexpr bucket_extent =
1011+ cuco::make_valid_extent<typename shared_map_ref_type::probing_scheme_type,
1012+ typename shared_map_ref_type::storage_ref_type>(extent_type{});
1013+
1014+ auto insert_or_apply_shmem_fn_ptr =
1015+ cuco::detail::static_map_ns::insert_or_apply_shmem<HasInit,
1016+ CGSize,
1017+ shmem_block_size,
1018+ shared_map_ref_type,
1019+ InputIt,
1020+ InitType,
1021+ OpType,
1022+ RefType>;
1023+
1024+ int32_t const max_op_grid_size =
1025+ cuco::detail::max_occupancy_grid_size (shmem_block_size, insert_or_apply_shmem_fn_ptr);
1026+
1027+ int32_t const shmem_default_grid_size =
1028+ cuco::detail::grid_size (num, CGSize, cuco::detail::default_stride (), shmem_block_size);
1029+
1030+ auto const shmem_grid_size = std::min (shmem_default_grid_size, max_op_grid_size);
1031+ auto const num_elements_per_thread = num / (shmem_grid_size * shmem_block_size);
1032+
1033+ // use shared_memory only if each thread has atleast 3 elements to process
1034+ if (num_elements_per_thread > 2 ) {
1035+ cuco::detail::static_map_ns::
1036+ insert_or_apply_shmem<HasInit, CGSize, shmem_block_size, shared_map_ref_type>
1037+ <<<shmem_grid_size, shmem_block_size, 0 , stream.get ()>>>(
1038+ first, num, init, op, ref, bucket_extent);
1039+ } else {
1040+ cuco::detail::static_map_ns::
1041+ insert_or_apply<HasInit, CGSize, cuco::detail::default_block_size ()>
1042+ <<<default_grid_size, cuco::detail::default_block_size (), 0 , stream.get ()>>>(
1043+ first, num, init, op, ref);
1044+ }
1045+ } else {
1046+ cuco::detail::static_map_ns::
1047+ insert_or_apply<HasInit, CGSize, cuco::detail::default_block_size ()>
1048+ <<<default_grid_size, cuco::detail::default_block_size (), 0 , stream.get ()>>>(
1049+ first, num, init, op, ref);
1050+ }
1051+ }
9631052} // namespace cuco
0 commit comments