diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index ababb22548..c7e167c7ea 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -438,6 +438,7 @@ if(NOT BUILD_CPU_ONLY) src/neighbors/iface/iface_pq_int8_t_int64_t.cu src/neighbors/iface/iface_pq_uint8_t_int64_t.cu src/neighbors/detail/cagra/topk_for_cagra/topk.cu + src/neighbors/detail/vpq_dataset_subspaces.cu src/neighbors/dynamic_batching.cu src/neighbors/cagra_index_wrapper.cu src/neighbors/composite/index.cu diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index caf08d770c..990d753b82 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -6,14 +6,13 @@ #pragma once #include "../../../sparse/neighbors/cross_component_nn.cuh" -#include "../../detail/vpq_dataset.cuh" +#include "../../detail/ann_utils.cuh" +#include "../../detail/vpq_dataset_subspaces.hpp" #include "greedy_search.cuh" #include "robust_prune.cuh" #include "vamana_structs.cuh" #include -#include -#include #include #include #include @@ -26,10 +25,8 @@ #include #include #include -#include #include -#include #include #include @@ -561,7 +558,7 @@ auto quantize_all_vectors(raft::resources const& res, auto vq_codebook = raft::make_device_matrix(res, 1, dim); raft::matrix::fill(res, vq_codebook.view(), 0.0); - auto codes = cuvs::neighbors::detail::process_and_fill_codes_subspaces( + auto codes = cuvs::neighbors::detail::process_and_fill_codes_subspaces( res, ps, residuals, raft::make_const_mdspan(vq_codebook.view()), pq_codebook); return codes; } diff --git a/cpp/src/neighbors/detail/vamana/vamana_codebooks.cuh b/cpp/src/neighbors/detail/vamana/vamana_codebooks.cuh index 9a2a459c43..1cf1fe0275 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_codebooks.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_codebooks.cuh @@ -9,6 +9,7 @@ #include +#include #include namespace cuvs::neighbors::vamana::detail { diff --git a/cpp/src/neighbors/detail/vpq_dataset_subspaces.cu b/cpp/src/neighbors/detail/vpq_dataset_subspaces.cu new file mode 100644 index 0000000000..5cff64332d --- /dev/null +++ b/cpp/src/neighbors/detail/vpq_dataset_subspaces.cu @@ -0,0 +1,32 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "vpq_dataset.cuh" + +namespace cuvs::neighbors::detail { + +#define PROCESS_AND_FILL_CODES_SUBSPACES_IMPL(MathT, IdxT, DatasetT) \ + auto process_and_fill_codes_subspaces( \ + const raft::resources& res, \ + const vpq_params& params, \ + DatasetT dataset, \ + raft::device_matrix_view vq_centers, \ + raft::device_matrix_view pq_centers) \ + -> raft::device_matrix \ + { \ + return process_and_fill_codes_subspaces( \ + res, params, dataset, vq_centers, pq_centers); \ + } + +#define COMMA , + +PROCESS_AND_FILL_CODES_SUBSPACES_IMPL( + float, int64_t, raft::device_matrix_view) +PROCESS_AND_FILL_CODES_SUBSPACES_IMPL( + double, int64_t, raft::device_matrix_view) + +#undef PROCESS_AND_FILL_CODES_SUBSPACES_IMPL + +} // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/detail/vpq_dataset_subspaces.hpp b/cpp/src/neighbors/detail/vpq_dataset_subspaces.hpp new file mode 100644 index 0000000000..2a8e5f2421 --- /dev/null +++ b/cpp/src/neighbors/detail/vpq_dataset_subspaces.hpp @@ -0,0 +1,30 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include + +namespace cuvs::neighbors::detail { + +auto process_and_fill_codes_subspaces( + const raft::resources& res, + const vpq_params& params, + raft::device_matrix_view dataset, + raft::device_matrix_view vq_centers, + raft::device_matrix_view pq_centers) + -> raft::device_matrix; + +auto process_and_fill_codes_subspaces( + const raft::resources& res, + const vpq_params& params, + raft::device_matrix_view dataset, + raft::device_matrix_view vq_centers, + raft::device_matrix_view pq_centers) + -> raft::device_matrix; + +} // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/scann/detail/scann_quantize.cuh b/cpp/src/neighbors/scann/detail/scann_quantize.cuh index 69c7ca08e8..7ad5684d7d 100644 --- a/cpp/src/neighbors/scann/detail/scann_quantize.cuh +++ b/cpp/src/neighbors/scann/detail/scann_quantize.cuh @@ -4,6 +4,7 @@ */ #include "../../detail/vpq_dataset.cuh" +#include "../../detail/vpq_dataset_subspaces.hpp" #include #include #include @@ -18,105 +19,6 @@ namespace cuvs::neighbors::experimental::scann::detail { /** Fix the internal indexing type to avoid integer underflows/overflows */ using ix_t = int64_t; -template -__launch_bounds__(BlockSize) RAFT_KERNEL process_and_fill_codes_subspaces_kernel( - raft::device_matrix_view out_codes, - raft::device_matrix_view dataset, - raft::device_matrix_view vq_centers, - raft::device_vector_view vq_labels, - raft::device_matrix_view pq_centers) -{ - constexpr uint32_t kSubWarpSize = std::min(raft::WarpSize, 1u << PqBits); - using subwarp_align = raft::Pow2; - const IdxT row_ix = subwarp_align::div(IdxT{threadIdx.x} + IdxT{BlockSize} * IdxT{blockIdx.x}); - if (row_ix >= out_codes.extent(0)) { return; } - - const uint32_t pq_dim = raft::div_rounding_up_unsafe(vq_centers.extent(1), pq_centers.extent(1)); - - const uint32_t lane_id = raft::Pow2::mod(threadIdx.x); - const LabelT vq_label = vq_labels(row_ix); - - // write label - auto* out_label_ptr = reinterpret_cast(&out_codes(row_ix, 0)); - if (lane_id == 0) { *out_label_ptr = vq_label; } - - auto* out_codes_ptr = reinterpret_cast(out_label_ptr + 1); - cuvs::neighbors::ivf_pq::detail::bitfield_view_t code_view{out_codes_ptr}; - for (uint32_t j = 0; j < pq_dim; j++) { - // find PQ label - int subspace_offset = j * pq_centers.extent(1) * (1 << PqBits); - auto pq_subspace_view = raft::make_device_matrix_view( - pq_centers.data_handle() + subspace_offset, (uint32_t)(1 << PqBits), pq_centers.extent(1)); - uint8_t code = cuvs::neighbors::detail::compute_code( - dataset, vq_centers, pq_subspace_view, row_ix, j, vq_label); - // TODO: this writes in global memory one byte per warp, which is very slow. - // It's better to keep the codes in the shared memory or registers and dump them at once. - if (lane_id == 0) { code_view[j] = code; } - } -} - -template -auto process_and_fill_codes_subspaces( - const raft::resources& res, - const vpq_params& params, - const DatasetT& dataset, - raft::device_matrix_view vq_centers, - raft::device_matrix_view pq_centers) - -> raft::device_matrix -{ - using data_t = typename DatasetT::value_type; - using cdataset_t = vpq_dataset; - using label_t = uint32_t; - - const ix_t n_rows = dataset.extent(0); - const ix_t dim = dataset.extent(1); - const ix_t pq_dim = params.pq_dim; - const ix_t pq_bits = params.pq_bits; - const ix_t pq_n_centers = ix_t{1} << pq_bits; - // NB: codes must be aligned at least to sizeof(label_t) to be able to read labels. - const ix_t codes_rowlen = - sizeof(label_t) * (1 + raft::div_rounding_up_safe(pq_dim * pq_bits, 8 * sizeof(label_t))); - - auto codes = raft::make_device_matrix(res, n_rows, codes_rowlen); - - auto stream = raft::resource::get_cuda_stream(res); - - // TODO: with scaling workspace we could choose the batch size dynamically - constexpr ix_t kBlockSize = 256; - const ix_t threads_per_vec = std::min(raft::WarpSize, pq_n_centers); - dim3 threads(kBlockSize, 1, 1); - - auto kernel = [](uint32_t pq_bits) { - switch (pq_bits) { - case 4: - return process_and_fill_codes_subspaces_kernel; - case 8: - return process_and_fill_codes_subspaces_kernel; - default: RAFT_FAIL("Invalid pq_bits (%u), the value must be 4 or 8", pq_bits); - } - }(pq_bits); - - auto labels = cuvs::neighbors::detail::predict_vq(res, dataset, vq_centers); - - dim3 blocks(raft::div_rounding_up_safe(n_rows, kBlockSize / threads_per_vec), 1, 1); - - kernel<<>>( - raft::make_device_matrix_view(codes.data_handle(), n_rows, codes_rowlen), - dataset, - vq_centers, - raft::make_const_mdspan(labels.view()), - pq_centers); - - RAFT_CUDA_TRY(cudaPeekAtLastError()); - - return codes; -} - template auto create_pq_codebook(raft::resources const& res, raft::device_matrix_view residuals, @@ -203,7 +105,7 @@ auto quantize_residuals(raft::resources const& res, vq_codebook.size() * sizeof(T), raft::resource::get_cuda_stream(res))); - auto codes = process_and_fill_codes_subspaces( + auto codes = cuvs::neighbors::detail::process_and_fill_codes_subspaces( res, ps, residuals, raft::make_const_mdspan(vq_codebook.view()), pq_codebook); return codes;