Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ if(NOT BUILD_CPU_ONLY)
src/neighbors/iface/iface_pq_int8_t_int64_t.cu
src/neighbors/iface/iface_pq_uint8_t_int64_t.cu
src/neighbors/detail/cagra/topk_for_cagra/topk.cu
src/neighbors/detail/vpq_dataset_subspaces.cu
src/neighbors/dynamic_batching.cu
src/neighbors/cagra_index_wrapper.cu
src/neighbors/composite/index.cu
Expand Down
9 changes: 3 additions & 6 deletions cpp/src/neighbors/detail/vamana/vamana_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
#pragma once

#include "../../../sparse/neighbors/cross_component_nn.cuh"
#include "../../detail/vpq_dataset.cuh"
#include "../../detail/ann_utils.cuh"
#include "../../detail/vpq_dataset_subspaces.hpp"
#include "greedy_search.cuh"
#include "robust_prune.cuh"
#include "vamana_structs.cuh"
#include <cuvs/neighbors/vamana.hpp>

#include <raft/cluster/kmeans.cuh>
#include <raft/cluster/kmeans_types.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/error.hpp>
Expand All @@ -26,10 +25,8 @@
#include <raft/matrix/copy.cuh>
#include <raft/matrix/init.cuh>
#include <raft/matrix/slice.cuh>
#include <raft/random/make_blobs.cuh>

#include <thrust/device_vector.h>
#include <thrust/sequence.h>
#include <thrust/unique.h>

#include <cuvs/distance/distance.hpp>
Expand Down Expand Up @@ -561,7 +558,7 @@ auto quantize_all_vectors(raft::resources const& res,
auto vq_codebook = raft::make_device_matrix<float, uint32_t, raft::row_major>(res, 1, dim);
raft::matrix::fill<float>(res, vq_codebook.view(), 0.0);

auto codes = cuvs::neighbors::detail::process_and_fill_codes_subspaces<float, int64_t>(
auto codes = cuvs::neighbors::detail::process_and_fill_codes_subspaces(
res, ps, residuals, raft::make_const_mdspan(vq_codebook.view()), pq_codebook);
return codes;
}
Expand Down
1 change: 1 addition & 0 deletions cpp/src/neighbors/detail/vamana/vamana_codebooks.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <raft/core/error.hpp>

#include <fstream>
#include <vector>

namespace cuvs::neighbors::vamana::detail {
Expand Down
32 changes: 32 additions & 0 deletions cpp/src/neighbors/detail/vpq_dataset_subspaces.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/

#include "vpq_dataset.cuh"

namespace cuvs::neighbors::detail {

#define PROCESS_AND_FILL_CODES_SUBSPACES_IMPL(MathT, IdxT, DatasetT) \
auto process_and_fill_codes_subspaces( \
const raft::resources& res, \
const vpq_params& params, \
DatasetT dataset, \
raft::device_matrix_view<const MathT, uint32_t, raft::row_major> vq_centers, \
raft::device_matrix_view<const MathT, uint32_t, raft::row_major> pq_centers) \
-> raft::device_matrix<uint8_t, IdxT, raft::row_major> \
{ \
return process_and_fill_codes_subspaces<MathT, IdxT, DatasetT>( \
res, params, dataset, vq_centers, pq_centers); \
}

#define COMMA ,

PROCESS_AND_FILL_CODES_SUBSPACES_IMPL(
float, int64_t, raft::device_matrix_view<const float COMMA int64_t COMMA raft::row_major>)
PROCESS_AND_FILL_CODES_SUBSPACES_IMPL(
double, int64_t, raft::device_matrix_view<const double COMMA int64_t COMMA raft::row_major>)

#undef PROCESS_AND_FILL_CODES_SUBSPACES_IMPL

} // namespace cuvs::neighbors::detail
30 changes: 30 additions & 0 deletions cpp/src/neighbors/detail/vpq_dataset_subspaces.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <cuvs/neighbors/common.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/resources.hpp>

namespace cuvs::neighbors::detail {

auto process_and_fill_codes_subspaces(
const raft::resources& res,
const vpq_params& params,
raft::device_matrix_view<const float, int64_t, raft::row_major> dataset,
raft::device_matrix_view<const float, uint32_t, raft::row_major> vq_centers,
raft::device_matrix_view<const float, uint32_t, raft::row_major> pq_centers)
-> raft::device_matrix<uint8_t, int64_t, raft::row_major>;

auto process_and_fill_codes_subspaces(
const raft::resources& res,
const vpq_params& params,
raft::device_matrix_view<const double, int64_t, raft::row_major> dataset,
raft::device_matrix_view<const double, uint32_t, raft::row_major> vq_centers,
raft::device_matrix_view<const double, uint32_t, raft::row_major> pq_centers)
-> raft::device_matrix<uint8_t, int64_t, raft::row_major>;

} // namespace cuvs::neighbors::detail
102 changes: 2 additions & 100 deletions cpp/src/neighbors/scann/detail/scann_quantize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/

#include "../../detail/vpq_dataset.cuh"
#include "../../detail/vpq_dataset_subspaces.hpp"
#include <chrono>
#include <cmath>
#include <cuvs/neighbors/common.hpp>
Expand All @@ -18,105 +19,6 @@ namespace cuvs::neighbors::experimental::scann::detail {
/** Fix the internal indexing type to avoid integer underflows/overflows */
using ix_t = int64_t;

template <uint32_t BlockSize,
uint32_t PqBits,
typename DataT,
typename MathT,
typename IdxT,
typename LabelT>
__launch_bounds__(BlockSize) RAFT_KERNEL process_and_fill_codes_subspaces_kernel(
raft::device_matrix_view<uint8_t, IdxT, raft::row_major> out_codes,
raft::device_matrix_view<const DataT, IdxT, raft::row_major> dataset,
raft::device_matrix_view<const MathT, uint32_t, raft::row_major> vq_centers,
raft::device_vector_view<const LabelT, IdxT, raft::row_major> vq_labels,
raft::device_matrix_view<const MathT, uint32_t, raft::row_major> pq_centers)
{
constexpr uint32_t kSubWarpSize = std::min<uint32_t>(raft::WarpSize, 1u << PqBits);
using subwarp_align = raft::Pow2<kSubWarpSize>;
const IdxT row_ix = subwarp_align::div(IdxT{threadIdx.x} + IdxT{BlockSize} * IdxT{blockIdx.x});
if (row_ix >= out_codes.extent(0)) { return; }

const uint32_t pq_dim = raft::div_rounding_up_unsafe(vq_centers.extent(1), pq_centers.extent(1));

const uint32_t lane_id = raft::Pow2<kSubWarpSize>::mod(threadIdx.x);
const LabelT vq_label = vq_labels(row_ix);

// write label
auto* out_label_ptr = reinterpret_cast<LabelT*>(&out_codes(row_ix, 0));
if (lane_id == 0) { *out_label_ptr = vq_label; }

auto* out_codes_ptr = reinterpret_cast<uint8_t*>(out_label_ptr + 1);
cuvs::neighbors::ivf_pq::detail::bitfield_view_t<PqBits> code_view{out_codes_ptr};
for (uint32_t j = 0; j < pq_dim; j++) {
// find PQ label
int subspace_offset = j * pq_centers.extent(1) * (1 << PqBits);
auto pq_subspace_view = raft::make_device_matrix_view(
pq_centers.data_handle() + subspace_offset, (uint32_t)(1 << PqBits), pq_centers.extent(1));
uint8_t code = cuvs::neighbors::detail::compute_code<kSubWarpSize>(
dataset, vq_centers, pq_subspace_view, row_ix, j, vq_label);
// TODO: this writes in global memory one byte per warp, which is very slow.
// It's better to keep the codes in the shared memory or registers and dump them at once.
if (lane_id == 0) { code_view[j] = code; }
}
}

template <typename MathT, typename IdxT, typename DatasetT>
auto process_and_fill_codes_subspaces(
const raft::resources& res,
const vpq_params& params,
const DatasetT& dataset,
raft::device_matrix_view<const MathT, uint32_t, raft::row_major> vq_centers,
raft::device_matrix_view<const MathT, uint32_t, raft::row_major> pq_centers)
-> raft::device_matrix<uint8_t, IdxT, raft::row_major>
{
using data_t = typename DatasetT::value_type;
using cdataset_t = vpq_dataset<MathT, IdxT>;
using label_t = uint32_t;

const ix_t n_rows = dataset.extent(0);
const ix_t dim = dataset.extent(1);
const ix_t pq_dim = params.pq_dim;
const ix_t pq_bits = params.pq_bits;
const ix_t pq_n_centers = ix_t{1} << pq_bits;
// NB: codes must be aligned at least to sizeof(label_t) to be able to read labels.
const ix_t codes_rowlen =
sizeof(label_t) * (1 + raft::div_rounding_up_safe<ix_t>(pq_dim * pq_bits, 8 * sizeof(label_t)));

auto codes = raft::make_device_matrix<uint8_t, IdxT, raft::row_major>(res, n_rows, codes_rowlen);

auto stream = raft::resource::get_cuda_stream(res);

// TODO: with scaling workspace we could choose the batch size dynamically
constexpr ix_t kBlockSize = 256;
const ix_t threads_per_vec = std::min<ix_t>(raft::WarpSize, pq_n_centers);
dim3 threads(kBlockSize, 1, 1);

auto kernel = [](uint32_t pq_bits) {
switch (pq_bits) {
case 4:
return process_and_fill_codes_subspaces_kernel<kBlockSize, 4, data_t, MathT, IdxT, label_t>;
case 8:
return process_and_fill_codes_subspaces_kernel<kBlockSize, 8, data_t, MathT, IdxT, label_t>;
default: RAFT_FAIL("Invalid pq_bits (%u), the value must be 4 or 8", pq_bits);
}
}(pq_bits);

auto labels = cuvs::neighbors::detail::predict_vq<label_t>(res, dataset, vq_centers);

dim3 blocks(raft::div_rounding_up_safe<ix_t>(n_rows, kBlockSize / threads_per_vec), 1, 1);

kernel<<<blocks, threads, 0, stream>>>(
raft::make_device_matrix_view<uint8_t, IdxT>(codes.data_handle(), n_rows, codes_rowlen),
dataset,
vq_centers,
raft::make_const_mdspan(labels.view()),
pq_centers);

RAFT_CUDA_TRY(cudaPeekAtLastError());

return codes;
}

template <typename T>
auto create_pq_codebook(raft::resources const& res,
raft::device_matrix_view<const T, int64_t> residuals,
Expand Down Expand Up @@ -203,7 +105,7 @@ auto quantize_residuals(raft::resources const& res,
vq_codebook.size() * sizeof(T),
raft::resource::get_cuda_stream(res)));

auto codes = process_and_fill_codes_subspaces<T, IdxT>(
auto codes = cuvs::neighbors::detail::process_and_fill_codes_subspaces(
res, ps, residuals, raft::make_const_mdspan(vq_codebook.view()), pq_codebook);

return codes;
Expand Down
Loading