Skip to content

Commit bc0d725

Browse files
committed
Fix a circular inclusion
1 parent b48eac3 commit bc0d725

File tree

4 files changed

+124
-130
lines changed

4 files changed

+124
-130
lines changed

cmake/header_testing.cmake

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,13 @@ function(cuco_add_header_test label definitions)
4141

4242
# List of headers that have known issues or are not meant to be included directly
4343
set(excluded_headers
44-
# Headers with circular dependencies that are not meant to be included directly
45-
cuco/detail/static_map/helpers.cuh
44+
# Add any headers that should be excluded from testing here
45+
# Example: cuco/internal_header.cuh
4646
)
4747

4848
# Remove excluded headers
4949
if(excluded_headers)
5050
list(REMOVE_ITEM headers ${excluded_headers})
51-
list(LENGTH headers headers_count_after_exclusion)
52-
message(STATUS "After exclusion: ${headers_count_after_exclusion} headers remaining")
5351
endif()
5452

5553
# Only test with CUDA compiler since cuco is device-only

include/cuco/detail/static_map/helpers.cuh

Lines changed: 0 additions & 123 deletions
This file was deleted.

include/cuco/detail/static_map/static_map.inl

Lines changed: 92 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
*/
1616

1717
#include <cuco/detail/bitwise_compare.cuh>
18-
#include <cuco/detail/static_map/helpers.cuh>
1918
#include <cuco/detail/static_map/kernels.cuh>
2019
#include <cuco/detail/utility/cuda.hpp>
2120
#include <cuco/detail/utils.hpp>
21+
#include <cuco/extent.cuh>
2222
#include <cuco/operator.hpp>
2323
#include <cuco/static_map_ref.cuh>
24+
#include <cuco/storage.cuh>
2425

2526
#include <cuda/std/tuple>
2627
#include <cuda/stream_ref>
@@ -352,7 +353,7 @@ void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Stora
352353
{
353354
auto constexpr has_init = false;
354355
auto const init = this->empty_value_sentinel(); // use empty_sentinel as unused init value
355-
detail::static_map_ns::dispatch_insert_or_apply<has_init, cg_size, Allocator>(
356+
this->dispatch_insert_or_apply<has_init, cg_size, Allocator>(
356357
first, last, init, op, ref(op::insert_or_apply), stream);
357358
}
358359

@@ -370,7 +371,7 @@ void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Stora
370371
InputIt first, InputIt last, Init init, Op op, cuda::stream_ref stream) noexcept
371372
{
372373
auto constexpr has_init = true;
373-
detail::static_map_ns::dispatch_insert_or_apply<has_init, cg_size, Allocator>(
374+
this->dispatch_insert_or_apply<has_init, cg_size, Allocator>(
374375
first, last, init, op, ref(op::insert_or_apply), stream);
375376
}
376377

@@ -960,4 +961,92 @@ auto static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Stora
960961
cuda_thread_scope<Scope>{},
961962
impl_->storage_ref()};
962963
}
964+
965+
template <class Key,
966+
class T,
967+
class Extent,
968+
cuda::thread_scope Scope,
969+
class KeyEqual,
970+
class ProbingScheme,
971+
class Allocator,
972+
class Storage>
973+
template <bool HasInit,
974+
int32_t CGSize,
975+
typename AllocatorType,
976+
typename InputIt,
977+
typename InitType,
978+
typename OpType,
979+
typename RefType>
980+
void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::
981+
dispatch_insert_or_apply(
982+
InputIt first, InputIt last, InitType init, OpType op, RefType ref, cuda::stream_ref stream)
983+
{
984+
auto const num = cuco::detail::distance(first, last);
985+
if (num == 0) { return; }
986+
987+
int32_t const default_grid_size = cuco::detail::grid_size(num, CGSize);
988+
989+
if constexpr (CGSize == 1) {
990+
using shmem_size_type = int32_t;
991+
992+
int32_t constexpr shmem_block_size = 1024;
993+
shmem_size_type constexpr cardinality_threshold = shmem_block_size;
994+
shmem_size_type constexpr shared_map_num_elements = cardinality_threshold + shmem_block_size;
995+
float constexpr load_factor = 0.7;
996+
shmem_size_type constexpr shared_map_size =
997+
static_cast<shmem_size_type>((1.0 / load_factor) * shared_map_num_elements);
998+
999+
using extent_type = cuco::extent<shmem_size_type, shared_map_size>;
1000+
using shared_map_type = cuco::static_map<typename RefType::key_type,
1001+
typename RefType::mapped_type,
1002+
extent_type,
1003+
cuda::thread_scope_block,
1004+
typename RefType::key_equal,
1005+
typename RefType::probing_scheme_type,
1006+
AllocatorType,
1007+
cuco::storage<1>>;
1008+
1009+
using shared_map_ref_type = typename shared_map_type::template ref_type<>;
1010+
auto constexpr bucket_extent =
1011+
cuco::make_valid_extent<typename shared_map_ref_type::probing_scheme_type,
1012+
typename shared_map_ref_type::storage_ref_type>(extent_type{});
1013+
1014+
auto insert_or_apply_shmem_fn_ptr =
1015+
cuco::detail::static_map_ns::insert_or_apply_shmem<HasInit,
1016+
CGSize,
1017+
shmem_block_size,
1018+
shared_map_ref_type,
1019+
InputIt,
1020+
InitType,
1021+
OpType,
1022+
RefType>;
1023+
1024+
int32_t const max_op_grid_size =
1025+
cuco::detail::max_occupancy_grid_size(shmem_block_size, insert_or_apply_shmem_fn_ptr);
1026+
1027+
int32_t const shmem_default_grid_size =
1028+
cuco::detail::grid_size(num, CGSize, cuco::detail::default_stride(), shmem_block_size);
1029+
1030+
auto const shmem_grid_size = std::min(shmem_default_grid_size, max_op_grid_size);
1031+
auto const num_elements_per_thread = num / (shmem_grid_size * shmem_block_size);
1032+
1033+
// use shared_memory only if each thread has atleast 3 elements to process
1034+
if (num_elements_per_thread > 2) {
1035+
cuco::detail::static_map_ns::
1036+
insert_or_apply_shmem<HasInit, CGSize, shmem_block_size, shared_map_ref_type>
1037+
<<<shmem_grid_size, shmem_block_size, 0, stream.get()>>>(
1038+
first, num, init, op, ref, bucket_extent);
1039+
} else {
1040+
cuco::detail::static_map_ns::
1041+
insert_or_apply<HasInit, CGSize, cuco::detail::default_block_size()>
1042+
<<<default_grid_size, cuco::detail::default_block_size(), 0, stream.get()>>>(
1043+
first, num, init, op, ref);
1044+
}
1045+
} else {
1046+
cuco::detail::static_map_ns::
1047+
insert_or_apply<HasInit, CGSize, cuco::detail::default_block_size()>
1048+
<<<default_grid_size, cuco::detail::default_block_size(), 0, stream.get()>>>(
1049+
first, num, init, op, ref);
1050+
}
1051+
}
9631052
} // namespace cuco

include/cuco/static_map.cuh

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,6 +1215,36 @@ class static_map {
12151215
[[nodiscard]] auto ref(Operators... ops) const noexcept;
12161216

12171217
private:
1218+
/**
1219+
* @brief Dispatches to shared memory map kernel if `num_elements_per_thread > 2`, else
1220+
* fallbacks to global memory map kernel.
1221+
*
1222+
* @tparam HasInit Boolean to dispatch based on init parameter
1223+
* @tparam CGSize Number of threads in each CG
1224+
* @tparam Allocator Allocator type used to created shared_memory map
1225+
* @tparam InputIt Device accessible input iterator whose `value_type` is
1226+
* convertible to the `value_type` of the data structure
1227+
* @tparam Init Type of init value convertible to payload type
1228+
* @tparam Op Callable type used to peform `apply` operation.
1229+
* @tparam Ref Type of non-owning device ref allowing access to storage
1230+
*
1231+
* @param first Beginning of the sequence of input elements
1232+
* @param last End of the sequence of input elements
1233+
* @param init The init value of the `op`
1234+
* @param op Callable object to perform apply operation.
1235+
* @param ref Non-owning container device ref used to access the slot storage
1236+
* @param stream CUDA stream used for insert_or_apply operation
1237+
*/
1238+
template <bool HasInit,
1239+
int32_t CGSize,
1240+
typename AllocatorType,
1241+
typename InputIt,
1242+
typename InitType,
1243+
typename OpType,
1244+
typename RefType>
1245+
void dispatch_insert_or_apply(
1246+
InputIt first, InputIt last, InitType init, OpType op, RefType ref, cuda::stream_ref stream);
1247+
12181248
std::unique_ptr<impl_type> impl_; ///< Static map implementation
12191249
mapped_type empty_value_sentinel_; ///< Sentinel value that indicates an empty payload
12201250
};

0 commit comments

Comments
 (0)