Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
0e9f61c
covnert max_topk to a runtime parameter
seunghwak Nov 4, 2025
444b946
convert max_candidates to a runtime parameter
seunghwak Nov 5, 2025
54a1805
convert max_elements to a runtime parameter
seunghwak Nov 5, 2025
6fc87a2
Merge branch 'main' of https://github.com/rapidsai/cuvs into enh_sear…
seunghwak Nov 5, 2025
25a0b8d
Merge branch 'main' of https://github.com/rapidsai/cuvs into enh_sear…
seunghwak Nov 6, 2025
a13112f
Merge branch 'main' of https://github.com/rapidsai/cuvs into enh_sear…
seunghwak Nov 8, 2025
aa5c116
tighter bound on array size
seunghwak Nov 11, 2025
178459a
undo most of the changes in bitonic.hpp except for the less unrolling…
seunghwak Nov 12, 2025
6aae61d
remove __inline__ from topk_by_radix (to lower register pressure)
seunghwak Nov 12, 2025
8aaf11b
branch outside topk_by_bitonic_sort
seunghwak Nov 12, 2025
0735129
use shared memory when I need to create large stack arrays for bitoni…
seunghwak Nov 12, 2025
3c9ee5f
branch before calling topky_by_bitonic_sort_and_merge
seunghwak Nov 13, 2025
68b191b
undo changes in topk_cta_11_core (branch in the caller site)
seunghwak Nov 13, 2025
89d9698
fix build error
seunghwak Nov 13, 2025
a9bfab7
update max_itopk setting in single-CTA radix sort based search to mat…
seunghwak Nov 14, 2025
a0be265
create non-template wrapper functions to prevent high register pressu…
seunghwak Nov 14, 2025
52141c3
remove unnecessary include statements
seunghwak Nov 15, 2025
15bdc84
use smem to reduce register pressure
seunghwak Nov 17, 2025
c41d548
Merge branch 'main' of https://github.com/rapidsai/cuvs into enh_sear…
seunghwak Nov 17, 2025
2a8eac0
Merge branch 'main' of https://github.com/rapidsai/cuvs into enh_sear…
seunghwak Nov 17, 2025
5a117a5
fix build error after pulling new updates
seunghwak Nov 17, 2025
4cd6f8f
undo using shared memory to store key, value pairs
seunghwak Nov 18, 2025
c8a9673
delete dead code
seunghwak Nov 18, 2025
77d657c
fix an error
seunghwak Nov 18, 2025
efcf489
tweak register pressure
seunghwak Nov 18, 2025
55b80f4
Merge branch 'main' of https://github.com/rapidsai/cuvs into enh_sear…
seunghwak Nov 18, 2025
bde5ac5
copyright year
seunghwak Nov 18, 2025
49c1c64
Merge branch 'main' of https://github.com/rapidsai/cuvs into enh_sear…
seunghwak Nov 19, 2025
901bad6
undo conditional unrolling in bitonic.hpp (this significantly slows d…
seunghwak Nov 23, 2025
4f7ef48
final performance tweak
seunghwak Nov 24, 2025
a596c9f
Merge branch 'main' of https://github.com/rapidsai/cuvs into enh_sear…
seunghwak Nov 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/src/neighbors/detail/cagra/bitonic.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
Expand Down
1 change: 0 additions & 1 deletion cpp/src/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include "factory.cuh"
#include "sample_filter_utils.cuh"
#include "search_plan.cuh"
#include "search_single_cta_inst.cuh"

#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdspan.hpp>
Expand Down
92 changes: 73 additions & 19 deletions cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,14 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort(float* distances, // [num
INDEX_T* indices, // [num_elements]
const uint32_t num_elements)
{
const unsigned warp_id = threadIdx.x / 32;
const unsigned warp_id = threadIdx.x / raft::warp_size();
if (warp_id > 0) { return; }
const unsigned lane_id = threadIdx.x % 32;
constexpr unsigned N = (MAX_ELEMENTS + 31) / 32;
const unsigned lane_id = threadIdx.x % raft::warp_size();
constexpr unsigned N = (MAX_ELEMENTS + (raft::warp_size() - 1)) / raft::warp_size();
float key[N];
INDEX_T val[N];
for (unsigned i = 0; i < N; i++) {
unsigned j = lane_id + (32 * i);
unsigned j = lane_id + (raft::warp_size() * i);
if (j < num_elements) {
key[i] = distances[j];
val[i] = indices[j];
Expand All @@ -142,13 +142,34 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort(float* distances, // [num
}
}

RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_wrapper_64(
float* distances, // [num_elements]
uint32_t* indices, // [num_elements]
const uint32_t num_elements)
{
topk_by_bitonic_sort<64, uint32_t>(distances, indices, num_elements);
}

RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_wrapper_128(
float* distances, // [num_elements]
uint32_t* indices, // [num_elements]
const uint32_t num_elements)
{
topk_by_bitonic_sort<128, uint32_t>(distances, indices, num_elements);
}

RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_wrapper_256(
float* distances, // [num_elements]
uint32_t* indices, // [num_elements]
const uint32_t num_elements)
{
topk_by_bitonic_sort<256, uint32_t>(distances, indices, num_elements);
}

//
// multiple CTAs per single query
//
template <std::uint32_t MAX_ELEMENTS,
class DATASET_DESCRIPTOR_T,
class SourceIndexT,
class SAMPLE_FILTER_T>
template <class DATASET_DESCRIPTOR_T, class SourceIndexT, class SAMPLE_FILTER_T>
RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel(
typename DATASET_DESCRIPTOR_T::INDEX_T* const
result_indices_ptr, // [num_queries, num_cta_per_query, itopk_size]
Expand All @@ -157,6 +178,7 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel(
const DATASET_DESCRIPTOR_T* dataset_desc,
const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, // [num_queries, dataset_dim]
const typename DATASET_DESCRIPTOR_T::INDEX_T* const knn_graph, // [dataset_size, graph_degree]
const uint32_t max_elements,
const uint32_t graph_degree,
const SourceIndexT* source_indices_ptr, // [num_queries, search_width]
const unsigned num_distilation,
Expand Down Expand Up @@ -211,7 +233,7 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel(
// |<--- result_buffer_size_32 --->|
const auto result_buffer_size = itopk_size + graph_degree;
const auto result_buffer_size_32 = raft::round_up_safe<uint32_t>(result_buffer_size, 32);
assert(result_buffer_size_32 <= MAX_ELEMENTS);
assert(result_buffer_size_32 <= max_elements);

// Set smem working buffer for the distance calculation
dataset_desc = dataset_desc->setup_workspace(smem, queries_ptr, query_id);
Expand Down Expand Up @@ -268,8 +290,33 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel(
_CLK_START();
if (threadIdx.x < 32) {
// [1st warp] Topk with bitonic sort
topk_by_bitonic_sort<MAX_ELEMENTS, INDEX_T>(
result_distances_buffer, result_indices_buffer, result_buffer_size_32);
if constexpr (std::is_same_v<INDEX_T, uint32_t>) {
// use a non-template wrapper function to avoid pre-inlining the topk_by_bitonic_sort
// function (vs post-inlining, this impacts register pressure)
if (max_elements <= 64) {
topk_by_bitonic_sort_wrapper_64(
result_distances_buffer, result_indices_buffer, result_buffer_size_32);
} else if (max_elements <= 128) {
topk_by_bitonic_sort_wrapper_128(
result_distances_buffer, result_indices_buffer, result_buffer_size_32);
} else {
assert(max_elements <= 256);
topk_by_bitonic_sort_wrapper_256(
result_distances_buffer, result_indices_buffer, result_buffer_size_32);
}
} else {
if (max_elements <= 64) {
topk_by_bitonic_sort<64, INDEX_T>(
result_distances_buffer, result_indices_buffer, result_buffer_size_32);
} else if (max_elements <= 128) {
topk_by_bitonic_sort<128, INDEX_T>(
result_distances_buffer, result_indices_buffer, result_buffer_size_32);
} else {
assert(max_elements <= 256);
topk_by_bitonic_sort<256, INDEX_T>(
result_distances_buffer, result_indices_buffer, result_buffer_size_32);
}
}
}
__syncthreads();
_CLK_REC(clk_topk);
Expand Down Expand Up @@ -487,17 +534,12 @@ struct search_kernel_config {
// Search kernel function type. Note that the actual values for the template value
// parameters do not matter, because they are not part of the function signature. The
// second to fourth value parameters will be selected by the choose_* functions below.
using kernel_t =
decltype(&search_kernel<128, DATASET_DESCRIPTOR_T, SourceIndexT, SAMPLE_FILTER_T>);
using kernel_t = decltype(&search_kernel<DATASET_DESCRIPTOR_T, SourceIndexT, SAMPLE_FILTER_T>);

static auto choose_buffer_size(unsigned result_buffer_size, unsigned block_size) -> kernel_t
{
if (result_buffer_size <= 64) {
return search_kernel<64, DATASET_DESCRIPTOR_T, SourceIndexT, SAMPLE_FILTER_T>;
} else if (result_buffer_size <= 128) {
return search_kernel<128, DATASET_DESCRIPTOR_T, SourceIndexT, SAMPLE_FILTER_T>;
} else if (result_buffer_size <= 256) {
return search_kernel<256, DATASET_DESCRIPTOR_T, SourceIndexT, SAMPLE_FILTER_T>;
if (result_buffer_size <= 256) {
return search_kernel<DATASET_DESCRIPTOR_T, SourceIndexT, SAMPLE_FILTER_T>;
}
THROW("Result buffer size %u larger than max buffer size %u", result_buffer_size, 256);
}
Expand Down Expand Up @@ -536,6 +578,17 @@ void select_and_run(const dataset_descriptor_host<DataT, IndexT, DistanceT>& dat
SourceIndexT,
SampleFilterT>::choose_buffer_size(result_buffer_size, block_size);

uint32_t max_elements{};
if (result_buffer_size <= 64) {
max_elements = 64;
} else if (result_buffer_size <= 128) {
max_elements = 128;
} else if (result_buffer_size <= 256) {
max_elements = 256;
} else {
THROW("Result buffer size %u larger than max buffer size %u", result_buffer_size, 256);
}

RAFT_CUDA_TRY(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// Initialize hash table
Expand All @@ -560,6 +613,7 @@ void select_and_run(const dataset_descriptor_host<DataT, IndexT, DistanceT>& dat
dataset_desc.dev_ptr(stream),
queries_ptr,
graph.data_handle(),
max_elements,
graph.extent(1),
source_indices_ptr,
ps.num_random_samplings,
Expand Down
1 change: 0 additions & 1 deletion cpp/src/neighbors/detail/cagra/search_plan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include "compute_distance-ext.cuh"
#include <cuvs/neighbors/common.hpp>
#include <raft/core/resource/cuda_stream.hpp>
// #include "search_single_cta_inst.cuh"
// #include "topk_for_cagra/topk.h"

#include <raft/core/device_mdspan.hpp>
Expand Down
31 changes: 11 additions & 20 deletions cpp/src/neighbors/detail/cagra/search_single_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ struct search
//
// Determine the thread block size
//
constexpr unsigned min_block_size = 64; // 32 or 64
constexpr unsigned min_block_size = 64;
constexpr unsigned min_block_size_radix = 256;
constexpr unsigned max_block_size = 1024;
//
Expand All @@ -129,13 +129,17 @@ struct search
sizeof(std::uint32_t) * topk_ws_size + sizeof(std::uint32_t);

std::uint32_t additional_smem_size = 0;
if (num_itopk_candidates > 256) {
// Tentatively calculate the required share memory size when radix
// sort based topk is used, assuming the block size is the maximum.
if (num_itopk_candidates > 256) { // radix sort
// Tentatively calculate the required shared memory size when radix sort based topk is used,
// assuming the block size is the maximum.
if (itopk_size <= 256) {
additional_smem_size += topk_by_radix_sort<256, INDEX_T>::smem_size * sizeof(std::uint32_t);
constexpr unsigned MAX_ITOPK = 256;
additional_smem_size +=
topk_by_radix_sort<INDEX_T>::smem_size(MAX_ITOPK) * sizeof(std::uint32_t);
} else {
additional_smem_size += topk_by_radix_sort<512, INDEX_T>::smem_size * sizeof(std::uint32_t);
constexpr unsigned MAX_ITOPK = 512;
additional_smem_size +=
topk_by_radix_sort<INDEX_T>::smem_size(MAX_ITOPK) * sizeof(std::uint32_t);
}
}

Expand All @@ -152,7 +156,7 @@ struct search
if (block_size == 0) {
block_size = min_block_size;

if (num_itopk_candidates > 256) {
if (num_itopk_candidates > 256) { // radix sort
// radix-based topk is used.
block_size = min_block_size_radix;

Expand Down Expand Up @@ -190,19 +194,6 @@ struct search
max_block_size);
thread_block_size = block_size;

if (num_itopk_candidates <= 256) {
RAFT_LOG_DEBUG("# bitonic-sort based topk routine is used");
} else {
RAFT_LOG_DEBUG("# radix-sort based topk routine is used");
smem_size = base_smem_size;
if (itopk_size <= 256) {
constexpr unsigned MAX_ITOPK = 256;
smem_size += topk_by_radix_sort<MAX_ITOPK, INDEX_T>::smem_size * sizeof(std::uint32_t);
} else {
constexpr unsigned MAX_ITOPK = 512;
smem_size += topk_by_radix_sort<MAX_ITOPK, INDEX_T>::smem_size * sizeof(std::uint32_t);
}
}
RAFT_LOG_DEBUG("# smem_size: %u", smem_size);
hashmap_size = 0;
if (small_hash_bitlen == 0 && !this->persistent) {
Expand Down
Loading
Loading