Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 8 additions & 33 deletions c/include/cuvs/neighbors/cagra.h
Original file line number Diff line number Diff line change
Expand Up @@ -418,33 +418,6 @@ cuvsError_t cuvsCagraIndexGetDataset(cuvsCagraIndex_t index, DLManagedTensor* da
*/
cuvsError_t cuvsCagraIndexGetGraph(cuvsCagraIndex_t index, DLManagedTensor* graph);

/**
* @}
*/

/**
* @defgroup cagra_c_merge_params C API for CUDA ANN Graph-based nearest neighbor search
* @{
*/

/**
* @brief Supplemental parameters to merge CAGRA index
*
*/

struct cuvsCagraMergeParams {
cuvsCagraIndexParams_t output_index_params;
cuvsMergeStrategy strategy;
};

typedef struct cuvsCagraMergeParams* cuvsCagraMergeParams_t;

/** Allocate CAGRA merge params with default values */
cuvsError_t cuvsCagraMergeParamsCreate(cuvsCagraMergeParams_t* params);

/** De-allocate CAGRA merge params */
cuvsError_t cuvsCagraMergeParamsDestroy(cuvsCagraMergeParams_t params);

/**
* @}
*/
Expand Down Expand Up @@ -718,7 +691,7 @@ cuvsError_t cuvsCagraIndexFromArgs(cuvsResources_t res,
*
* All input indices must have been built with the same data type (`index.dtype`) and
* have the same dimensionality (`index.dims`). The merged index uses the output
* parameters specified in `cuvsCagraMergeParams`.
* parameters specified in `cuvsCagraIndexParams`.
*
* Input indices must have:
* - `index.dtype.code` and `index.dtype.bits` matching across all indices.
Expand All @@ -745,29 +718,31 @@ cuvsError_t cuvsCagraIndexFromArgs(cuvsResources_t res,
*
* // Assume index1 and index2 have been built using cuvsCagraBuild
*
* cuvsCagraMergeParams_t merge_params;
* cuvsError_t params_create_status = cuvsCagraMergeParamsCreate(&merge_params);
* cuvsCagraIndexParams_t merge_params;
* cuvsError_t params_create_status = cuvsCagraIndexParamsCreate(&merge_params);
*
* cuvsError_t merge_status = cuvsCagraMerge(res, merge_params, (cuvsCagraIndex_t[]){index1,
* index2}, 2, merged_index);
*
* // Use merged_index for search operations
*
* cuvsError_t params_destroy_status = cuvsCagraMergeParamsDestroy(merge_params);
* cuvsError_t params_destroy_status = cuvsCagraIndexParamsDestroy(merge_params);
* cuvsError_t res_destroy_status = cuvsResourcesDestroy(res);
* @endcode
*
* @param[in] res cuvsResources_t opaque C handle
* @param[in] params cuvsCagraMergeParams_t parameters controlling merge behavior
* @param[in] params cuvsCagraIndexParams_t parameters controlling merge behavior
* @param[in] indices Array of input cuvsCagraIndex_t handles to merge
* @param[in] num_indices Number of input indices
* @param[in] filter Filter that can be used to filter out vectors from the merged index
* @param[out] output_index Output handle that will store the merged index.
* Must be initialized using `cuvsCagraIndexCreate` before use.
*/
cuvsError_t cuvsCagraMerge(cuvsResources_t res,
cuvsCagraMergeParams_t params,
cuvsCagraIndexParams_t params,
cuvsCagraIndex_t* indices,
size_t num_indices,
cuvsFilter filter,
cuvsCagraIndex_t output_index);

/**
Expand Down
74 changes: 34 additions & 40 deletions c/src/neighbors/cagra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,23 +285,23 @@ void* _deserialize(cuvsResources_t res, const char* filename)

template <typename T>
void* _merge(cuvsResources_t res,
cuvsCagraMergeParams params,
cuvsCagraIndexParams params,
cuvsCagraIndex_t* indices,
size_t num_indices)
size_t num_indices,
cuvsFilter filter)
{
auto res_ptr = reinterpret_cast<raft::resources*>(res);
cuvs::neighbors::cagra::merge_params merge_params_cpp;
auto& out_idx_params = *params.output_index_params;
cuvs::neighbors::cagra::index_params params_cpp;

merge_params_cpp.output_index_params.metric =
static_cast<cuvs::distance::DistanceType>((int)out_idx_params.metric);
merge_params_cpp.output_index_params.intermediate_graph_degree =
out_idx_params.intermediate_graph_degree;
merge_params_cpp.output_index_params.graph_degree = out_idx_params.graph_degree;
params_cpp.metric =
static_cast<cuvs::distance::DistanceType>((int)params.metric);
params_cpp.intermediate_graph_degree =
params.intermediate_graph_degree;
params_cpp.graph_degree = params.graph_degree;

int64_t total_size = 0;
int64_t dim = 0;
if (out_idx_params.build_algo == cuvsCagraGraphBuildAlgo::IVF_PQ) {
if (params.build_algo == cuvsCagraGraphBuildAlgo::IVF_PQ) {
auto first_idx_ptr =
reinterpret_cast<cuvs::neighbors::cagra::index<T, uint32_t>*>(indices[0]->addr);
dim = first_idx_ptr->dim();
Expand All @@ -312,9 +312,9 @@ void* _merge(cuvsResources_t res,
}
}

_set_graph_build_params(merge_params_cpp.output_index_params.graph_build_params,
out_idx_params,
out_idx_params.build_algo,
_set_graph_build_params(params_cpp.graph_build_params,
params,
params.build_algo,
total_size,
dim);

Expand All @@ -325,10 +325,21 @@ void* _merge(cuvsResources_t res,
index_ptrs.push_back(idx_ptr);
}

auto merged_index = new cuvs::neighbors::cagra::index<T, uint32_t>(
cuvs::neighbors::cagra::merge(*res_ptr, merge_params_cpp, index_ptrs));

return merged_index;
if (filter.type == NO_FILTER) {
return new cuvs::neighbors::cagra::index<T, uint32_t>(
cuvs::neighbors::cagra::merge(*res_ptr, params_cpp, index_ptrs));
} else if (filter.type == BITSET) {
using filter_mdspan_type = raft::device_vector_view<std::uint32_t, int64_t, raft::row_major>;
auto removed_indices_tensor = reinterpret_cast<DLManagedTensor*>(filter.addr);
auto removed_indices = cuvs::core::from_dlpack<filter_mdspan_type>(removed_indices_tensor);
cuvs::core::bitset_view<std::uint32_t, int64_t> removed_indices_bitset(
removed_indices, total_size);
auto bitset_filter_obj = cuvs::neighbors::filtering::bitset_filter(removed_indices_bitset);
return new cuvs::neighbors::cagra::index<T, uint32_t>(
cuvs::neighbors::cagra::merge(*res_ptr, params_cpp, index_ptrs, bitset_filter_obj));
} else {
RAFT_FAIL("Unsupported filter type: BITMAP");
}
}

template <typename T, typename IdxT>
Expand Down Expand Up @@ -601,9 +612,10 @@ extern "C" cuvsError_t cuvsCagraSearch(cuvsResources_t res,
}

extern "C" cuvsError_t cuvsCagraMerge(cuvsResources_t res,
cuvsCagraMergeParams_t params,
cuvsCagraIndexParams_t params,
cuvsCagraIndex_t* indices,
size_t num_indices,
cuvsFilter filter,
cuvsCagraIndex_t output_index)
{
return cuvs::core::translate_exceptions([=] {
Expand All @@ -621,16 +633,16 @@ extern "C" cuvsError_t cuvsCagraMerge(cuvsResources_t res,
// Dispatch based on data type
if (dtype.code == kDLFloat && dtype.bits == 32) {
output_index->addr =
reinterpret_cast<uintptr_t>(_merge<float>(res, *params, indices, num_indices));
reinterpret_cast<uintptr_t>(_merge<float>(res, *params, indices, num_indices, filter));
} else if (dtype.code == kDLFloat && dtype.bits == 16) {
output_index->addr =
reinterpret_cast<uintptr_t>(_merge<half>(res, *params, indices, num_indices));
reinterpret_cast<uintptr_t>(_merge<half>(res, *params, indices, num_indices, filter));
} else if (dtype.code == kDLInt && dtype.bits == 8) {
output_index->addr =
reinterpret_cast<uintptr_t>(_merge<int8_t>(res, *params, indices, num_indices));
reinterpret_cast<uintptr_t>(_merge<int8_t>(res, *params, indices, num_indices, filter));
} else if (dtype.code == kDLUInt && dtype.bits == 8) {
output_index->addr =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah- I was trying to figure out how you were going to support allowing a user to pass in the "index" to populate in the C layer while only have the option to return a newly created index in the C++ layer. Not opposed to this solution at all. Just good that we're not doing a copy!

reinterpret_cast<uintptr_t>(_merge<uint8_t>(res, *params, indices, num_indices));
reinterpret_cast<uintptr_t>(_merge<uint8_t>(res, *params, indices, num_indices, filter));
} else {
RAFT_FAIL("Unsupported index data type: code=%d, bits=%d", dtype.code, dtype.bits);
}
Expand Down Expand Up @@ -708,24 +720,6 @@ extern "C" cuvsError_t cuvsCagraSearchParamsDestroy(cuvsCagraSearchParams_t para
return cuvs::core::translate_exceptions([=] { delete params; });
}

extern "C" cuvsError_t cuvsCagraMergeParamsCreate(cuvsCagraMergeParams_t* params)
{
return cuvs::core::translate_exceptions([=] {
cuvsCagraIndexParams_t idx_params;
cuvsCagraIndexParamsCreate(&idx_params);
*params = new cuvsCagraMergeParams{.output_index_params = idx_params,
.strategy = MERGE_STRATEGY_PHYSICAL};
});
}

extern "C" cuvsError_t cuvsCagraMergeParamsDestroy(cuvsCagraMergeParams_t params)
{
return cuvs::core::translate_exceptions([=] {
cuvsCagraIndexParamsDestroy(params->output_index_params);
delete params;
});
}

extern "C" cuvsError_t cuvsCagraDeserialize(cuvsResources_t res,
const char* filename,
cuvsCagraIndex_t index)
Expand Down
12 changes: 5 additions & 7 deletions c/tests/neighbors/ann_cagra_c.cu
Original file line number Diff line number Diff line change
Expand Up @@ -509,13 +509,15 @@ TEST(CagraC, BuildMergeSearch)
ASSERT_EQ(cuvsCagraBuild(res, build_params, &main_dataset_tensor, index_main), CUVS_SUCCESS);
ASSERT_EQ(cuvsCagraBuild(res, build_params, &additional_dataset_tensor, index_add), CUVS_SUCCESS);

cuvsCagraMergeParams_t merge_params;
cuvsCagraMergeParamsCreate(&merge_params);
cuvsCagraIndex_t index_merged;
cuvsCagraIndexCreate(&index_merged);

cuvsFilter filter;
filter.type = NO_FILTER;
filter.addr = 0;

cuvsCagraIndex_t index_array[2] = {index_main, index_add};
ASSERT_EQ(cuvsCagraMerge(res, merge_params, index_array, 2, index_merged), CUVS_SUCCESS);
ASSERT_EQ(cuvsCagraMerge(res, build_params, index_array, 2, filter, index_merged), CUVS_SUCCESS);

int64_t merged_dim = -1;
ASSERT_EQ(cuvsCagraIndexGetDims(index_merged, &merged_dim), CUVS_SUCCESS);
Expand Down Expand Up @@ -547,9 +549,6 @@ TEST(CagraC, BuildMergeSearch)
cuvsCagraSearchParamsCreate(&search_params);
(*search_params).itopk_size = 1;

cuvsFilter filter;
filter.type = NO_FILTER;
filter.addr = 0;
ASSERT_EQ(cuvsCagraSearch(res,
search_params,
index_merged,
Expand All @@ -569,7 +568,6 @@ TEST(CagraC, BuildMergeSearch)
EXPECT_NEAR(distance_host, 0.0f, 1e-6);

cuvsCagraSearchParamsDestroy(search_params);
cuvsCagraMergeParamsDestroy(merge_params);
cuvsCagraIndexParamsDestroy(build_params);
cuvsCagraIndexDestroy(index_merged);
cuvsCagraIndexDestroy(index_add);
Expand Down
Loading