Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 53 additions & 160 deletions c/parallel/src/radix_sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

#include <cub/detail/choose_offset.cuh>
#include <cub/detail/launcher/cuda_driver.cuh>
#include <cub/detail/ptx-json-parser.cuh>
#include <cub/device/device_radix_sort.cuh>

#include <format>
Expand All @@ -31,92 +30,6 @@ static_assert(std::is_same_v<cub::detail::choose_offset_t<OffsetT>, OffsetT>, "O

namespace radix_sort
{
using namespace cub::detail::radix_sort_runtime_policies;

struct radix_sort_runtime_tuning_policy
{
RuntimeRadixSortHistogramAgentPolicy histogram;
RuntimeRadixSortExclusiveSumAgentPolicy exclusive_sum;
RuntimeRadixSortOnesweepAgentPolicy onesweep;
cub::detail::RuntimeScanAgentPolicy scan;
cub::detail::RuntimeRadixSortDownsweepAgentPolicy downsweep;
cub::detail::RuntimeRadixSortDownsweepAgentPolicy alt_downsweep;
RuntimeRadixSortUpsweepAgentPolicy upsweep;
RuntimeRadixSortUpsweepAgentPolicy alt_upsweep;
cub::detail::RuntimeRadixSortDownsweepAgentPolicy single_tile;
bool is_onesweep;

auto Histogram() const
{
return histogram;
}

auto ExclusiveSum() const
{
return exclusive_sum;
}

auto Onesweep() const
{
return onesweep;
}

auto Scan() const
{
return scan;
}

auto Downsweep() const
{
return downsweep;
}

auto AltDownsweep() const
{
return alt_downsweep;
}

auto Upsweep() const
{
return upsweep;
}

auto AltUpsweep() const
{
return alt_upsweep;
}

auto SingleTile() const
{
return single_tile;
}

bool IsOnesweep() const
{
return is_onesweep;
}

template <typename PolicyT>
CUB_RUNTIME_FUNCTION static constexpr int RadixBits(PolicyT policy)
{
return policy.RadixBits();
}

template <typename PolicyT>
CUB_RUNTIME_FUNCTION static constexpr int BlockThreads(PolicyT policy)
{
return policy.BlockThreads();
}

using MaxPolicy = radix_sort_runtime_tuning_policy;

template <typename F>
cudaError_t Invoke(int, F& op)
{
return op.template Invoke<radix_sort_runtime_tuning_policy>(*this);
}
};

std::string get_single_tile_kernel_name(
std::string_view chained_policy_t,
cccl_sort_order_t sort_order,
Expand Down Expand Up @@ -290,12 +203,10 @@ CUresult cccl_device_radix_sort_build_ex(
{
const char* name = "test";

const int cc = cc_major * 10 + cc_minor;
const auto key_cpp = cccl_type_enum_to_name(input_keys_it.value_type.type);
const auto value_cpp =
input_values_it.type == cccl_iterator_kind_t::CCCL_POINTER && input_values_it.state == nullptr
? "cub::NullType"
: cccl_type_enum_to_name(input_values_it.value_type.type);
const auto keys_only =
input_values_it.type == cccl_iterator_kind_t::CCCL_POINTER && input_values_it.state == nullptr;
const auto value_cpp = keys_only ? "cub::NullType" : cccl_type_enum_to_name(input_values_it.value_type.type);
const std::string op_src =
(decomposer.name == nullptr || (decomposer.name != nullptr && decomposer.name[0] == '\0'))
? "using op_wrapper = cub::detail::identity_decomposer_t;"
Expand All @@ -305,8 +216,32 @@ CUresult cccl_device_radix_sort_build_ex(
std::string offset_t;
check(cccl_type_name_from_nvrtc<OffsetT>(&offset_t));

const auto policy_hub_expr =
std::format("cub::detail::radix_sort::policy_hub<{}, {}, {}>", key_cpp, value_cpp, offset_t);
// TODO(bgruber): generalize this somewhere
const auto key_type = [&] {
switch (input_keys_it.value_type.type)
{
case CCCL_FLOAT32:
return cub::detail::type_t::float32;
case CCCL_FLOAT64:
return cub::detail::type_t::float64;
default:
return cub::detail::type_t::other;
}
}();

const auto cub_arch_policies = cub::detail::radix_sort::arch_policies{
static_cast<int>(input_keys_it.value_type.size),
// FIXME(bgruber): input_values_it.value_type.size is 4 when it represents cub::NullType, which is very odd
keys_only ? 0 : static_cast<int>(input_values_it.value_type.size),
int{sizeof(OffsetT)},
key_type};

// TODO(bgruber): drop this if tuning policies become formattable
std::stringstream cub_arch_policies_str;
cub_arch_policies_str << cub_arch_policies(cuda::to_arch_id(cuda::compute_capability{cc_major, cc_minor}));

auto policy_hub_expr =
std::format("cub::detail::radix_sort::arch_policies_from_types<{}, {}, {}>", key_cpp, value_cpp, offset_t);

const std::string final_src = std::format(
R"XXX(
Expand All @@ -321,21 +256,18 @@ struct __align__({3}) values_storage_t {{
char data[{2}];
}};
{4}
using {5} = {6}::MaxPolicy;

#include <cub/detail/ptx-json/json.cuh>
__device__ consteval auto& policy_generator() {{
return ptx_json::id<ptx_json::string("device_radix_sort_policy")>()
= cub::detail::radix_sort::RadixSortPolicyWrapper<{5}::ActivePolicy>::EncodedPolicy();
}}
using device_radix_sort_policy = {5};
using namespace cub;
using namespace cub::detail::radix_sort;
static_assert(device_radix_sort_policy()(::cuda::arch_id{{CUB_PTX_ARCH / 10}}) == {6}, "Host generated and JIT compiled policy mismatch");
)XXX",
input_keys_it.value_type.size, // 0
input_keys_it.value_type.alignment, // 1
input_values_it.value_type.size, // 2
input_values_it.value_type.alignment, // 3
op_src, // 4
chained_policy_t, // 5
policy_hub_expr); // 6
policy_hub_expr, // 5
cub_arch_policies_str.view()); // 6

#if false // CCCL_DEBUGGING_SWITCH
fflush(stderr);
Expand Down Expand Up @@ -380,7 +312,6 @@ __device__ consteval auto& policy_generator() {{
"-rdc=true",
"-dlto",
"-DCUB_DISABLE_CDP",
"-DCUB_ENABLE_POLICY_PTX_JSON",
"-std=c++20"};

cccl::detail::extend_args_with_build_config(args, config);
Expand Down Expand Up @@ -434,43 +365,13 @@ __device__ consteval auto& policy_generator() {{
&build_ptr->exclusive_sum_kernel, build_ptr->library, exclusive_sum_kernel_lowered_name.c_str()));
check(cuLibraryGetKernel(&build_ptr->onesweep_kernel, build_ptr->library, onesweep_kernel_lowered_name.c_str()));

nlohmann::json runtime_policy =
cub::detail::ptx_json::parse("device_radix_sort_policy", {result.data.get(), result.size});

using namespace cub::detail::radix_sort_runtime_policies;
using cub::detail::RuntimeScanAgentPolicy;
auto single_tile_policy =
cub::detail::RuntimeRadixSortDownsweepAgentPolicy::from_json(runtime_policy, "SingleTilePolicy");
auto onesweep_policy = RuntimeRadixSortOnesweepAgentPolicy::from_json(runtime_policy, "OnesweepPolicy");
auto upsweep_policy = RuntimeRadixSortUpsweepAgentPolicy::from_json(runtime_policy, "UpsweepPolicy");
auto alt_upsweep_policy = RuntimeRadixSortUpsweepAgentPolicy::from_json(runtime_policy, "AltUpsweepPolicy");
auto downsweep_policy =
cub::detail::RuntimeRadixSortDownsweepAgentPolicy::from_json(runtime_policy, "DownsweepPolicy");
auto alt_downsweep_policy =
cub::detail::RuntimeRadixSortDownsweepAgentPolicy::from_json(runtime_policy, "AltDownsweepPolicy");
auto histogram_policy = RuntimeRadixSortHistogramAgentPolicy::from_json(runtime_policy, "HistogramPolicy");
auto exclusive_sum_policy =
RuntimeRadixSortExclusiveSumAgentPolicy::from_json(runtime_policy, "ExclusiveSumPolicy");
auto scan_policy = RuntimeScanAgentPolicy::from_json(runtime_policy, "ScanPolicy");
auto is_onesweep = runtime_policy["Onesweep"].get<bool>();

build_ptr->cc = cc;
build_ptr->cc = cc_major * 10 + cc_minor;
build_ptr->cubin = (void*) result.data.release();
build_ptr->cubin_size = result.size;
build_ptr->key_type = input_keys_it.value_type;
build_ptr->value_type = input_values_it.value_type;
build_ptr->order = sort_order;
build_ptr->runtime_policy = new radix_sort::radix_sort_runtime_tuning_policy{
histogram_policy,
exclusive_sum_policy,
onesweep_policy,
scan_policy,
downsweep_policy,
alt_downsweep_policy,
upsweep_policy,
alt_upsweep_policy,
single_tile_policy,
is_onesweep};
build_ptr->runtime_policy = new cub::detail::radix_sort::arch_policies{cub_arch_policies};
}
catch (const std::exception& exc)
{
Expand Down Expand Up @@ -529,29 +430,20 @@ CUresult cccl_device_radix_sort_impl(
cub::DoubleBuffer<indirect_arg_t> d_values_buffer(
*static_cast<indirect_arg_t**>(&val_arg_in), *static_cast<indirect_arg_t**>(&val_arg_out));

auto exec_status = cub::DispatchRadixSort<
Order,
indirect_arg_t,
indirect_arg_t,
OffsetT,
indirect_arg_t,
radix_sort::radix_sort_runtime_tuning_policy,
radix_sort::radix_sort_kernel_source,
cub::detail::CudaDriverLauncherFactory>::
Dispatch(
d_temp_storage,
*temp_storage_bytes,
d_keys_buffer,
d_values_buffer,
num_items,
begin_bit,
end_bit,
is_overwrite_okay,
stream,
decomposer,
{build},
cub::detail::CudaDriverLauncherFactory{cu_device, build.cc},
*reinterpret_cast<radix_sort::radix_sort_runtime_tuning_policy*>(build.runtime_policy));
auto exec_status = cub::detail::radix_sort::dispatch<Order>(
d_temp_storage,
*temp_storage_bytes,
d_keys_buffer,
d_values_buffer,
num_items,
begin_bit,
end_bit,
is_overwrite_okay,
stream,
decomposer,
*static_cast<cub::detail::radix_sort::arch_policies*>(build.runtime_policy),
radix_sort::radix_sort_kernel_source{build},
cub::detail::CudaDriverLauncherFactory{cu_device, build.cc});

*selector = d_keys_buffer.selector;
error = static_cast<CUresult>(exec_status);
Expand Down Expand Up @@ -649,8 +541,9 @@ CUresult cccl_device_radix_sort_cleanup(cccl_device_radix_sort_build_result_t* b
return CUDA_ERROR_INVALID_VALUE;
}

using namespace cub::detail::radix_sort;
std::unique_ptr<char[]> cubin(reinterpret_cast<char*>(build_ptr->cubin));
std::unique_ptr<char[]> runtime_policy(reinterpret_cast<char*>(build_ptr->runtime_policy));
std::unique_ptr<arch_policies> policy(static_cast<arch_policies*>(build_ptr->runtime_policy));
check(cuLibraryUnload(build_ptr->library));
}
catch (const std::exception& exc)
Expand Down
21 changes: 19 additions & 2 deletions cub/cub/agent/agent_radix_sort_histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,38 @@
#include <cuda/std/__algorithm/max.h>
#include <cuda/std/__algorithm/min.h>
#include <cuda/std/__functional/operations.h>
#include <cuda/std/__type_traits/is_void.h>

CUB_NAMESPACE_BEGIN

//! @param ComputeT If void, use NOMINAL_4B_NUM_PARTS directly for NUM_PARTS. Otherwise, perform scaling.
template <int BlockThreads, int ItemsPerThread, int NOMINAL_4B_NUM_PARTS, typename ComputeT, int RadixBits>
struct AgentRadixSortHistogramPolicy
{
static constexpr int BLOCK_THREADS = BlockThreads;
static constexpr int ITEMS_PER_THREAD = ItemsPerThread;

// need to discard sizeof(ComputeType) in case it's void
template <typename ComputeType = ComputeT>
_CCCL_API static constexpr int num_parts_helper()
{
if constexpr (::cuda::std::is_void_v<ComputeT>)
{
return NOMINAL_4B_NUM_PARTS;
}
else
{
return ::cuda::std::max(1, NOMINAL_4B_NUM_PARTS * 4 / ::cuda::std::max(int{sizeof(ComputeType)}, 4));
}
}

/** NUM_PARTS is the number of private histograms (parts) each histogram is split
* into. Each warp lane is assigned to a specific part based on the lane
* ID. However, lanes with the same ID in different warp use the same private
* histogram. This arrangement helps reduce the degree of conflicts in atomic
* operations. */
static constexpr int NUM_PARTS =
::cuda::std::max(1, NOMINAL_4B_NUM_PARTS * 4 / ::cuda::std::max(int{sizeof(ComputeT)}, 4));
static constexpr int NUM_PARTS = num_parts_helper<ComputeT>();

static constexpr int RADIX_BITS = RadixBits;
};

Expand Down
1 change: 1 addition & 0 deletions cub/cub/block/block_load.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,7 @@ inline ::std::ostream& operator<<(::std::ostream& os, BlockLoadAlgorithm algo)
#endif // !_CCCL_COMPILER(NVRTC)

//! @rst

//! The BlockLoad class provides :ref:`collective <collective-primitives>` data movement methods for loading a linear
//! segment of items from memory into a :ref:`blocked arrangement <flexible-data-arrangement>` across a CUDA thread
//! block.
Expand Down
Loading