Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 1 addition & 100 deletions cpp/src/neighbors/scann/detail/scann_quantize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,105 +18,6 @@ namespace cuvs::neighbors::experimental::scann::detail {
/** Fix the internal indexing type to avoid integer underflows/overflows */
using ix_t = int64_t;

template <uint32_t BlockSize,
uint32_t PqBits,
typename DataT,
typename MathT,
typename IdxT,
typename LabelT>
__launch_bounds__(BlockSize) RAFT_KERNEL process_and_fill_codes_subspaces_kernel(
raft::device_matrix_view<uint8_t, IdxT, raft::row_major> out_codes,
raft::device_matrix_view<const DataT, IdxT, raft::row_major> dataset,
raft::device_matrix_view<const MathT, uint32_t, raft::row_major> vq_centers,
raft::device_vector_view<const LabelT, IdxT, raft::row_major> vq_labels,
raft::device_matrix_view<const MathT, uint32_t, raft::row_major> pq_centers)
{
constexpr uint32_t kSubWarpSize = std::min<uint32_t>(raft::WarpSize, 1u << PqBits);
using subwarp_align = raft::Pow2<kSubWarpSize>;
const IdxT row_ix = subwarp_align::div(IdxT{threadIdx.x} + IdxT{BlockSize} * IdxT{blockIdx.x});
if (row_ix >= out_codes.extent(0)) { return; }

const uint32_t pq_dim = raft::div_rounding_up_unsafe(vq_centers.extent(1), pq_centers.extent(1));

const uint32_t lane_id = raft::Pow2<kSubWarpSize>::mod(threadIdx.x);
const LabelT vq_label = vq_labels(row_ix);

// write label
auto* out_label_ptr = reinterpret_cast<LabelT*>(&out_codes(row_ix, 0));
if (lane_id == 0) { *out_label_ptr = vq_label; }

auto* out_codes_ptr = reinterpret_cast<uint8_t*>(out_label_ptr + 1);
cuvs::neighbors::ivf_pq::detail::bitfield_view_t<PqBits> code_view{out_codes_ptr};
for (uint32_t j = 0; j < pq_dim; j++) {
// find PQ label
int subspace_offset = j * pq_centers.extent(1) * (1 << PqBits);
auto pq_subspace_view = raft::make_device_matrix_view(
pq_centers.data_handle() + subspace_offset, (uint32_t)(1 << PqBits), pq_centers.extent(1));
uint8_t code = cuvs::neighbors::detail::compute_code<kSubWarpSize>(
dataset, vq_centers, pq_subspace_view, row_ix, j, vq_label);
// TODO: this writes in global memory one byte per warp, which is very slow.
// It's better to keep the codes in the shared memory or registers and dump them at once.
if (lane_id == 0) { code_view[j] = code; }
}
}

template <typename MathT, typename IdxT, typename DatasetT>
auto process_and_fill_codes_subspaces(
const raft::resources& res,
const vpq_params& params,
const DatasetT& dataset,
raft::device_matrix_view<const MathT, uint32_t, raft::row_major> vq_centers,
raft::device_matrix_view<const MathT, uint32_t, raft::row_major> pq_centers)
-> raft::device_matrix<uint8_t, IdxT, raft::row_major>
{
using data_t = typename DatasetT::value_type;
using cdataset_t = vpq_dataset<MathT, IdxT>;
using label_t = uint32_t;

const ix_t n_rows = dataset.extent(0);
const ix_t dim = dataset.extent(1);
const ix_t pq_dim = params.pq_dim;
const ix_t pq_bits = params.pq_bits;
const ix_t pq_n_centers = ix_t{1} << pq_bits;
// NB: codes must be aligned at least to sizeof(label_t) to be able to read labels.
const ix_t codes_rowlen =
sizeof(label_t) * (1 + raft::div_rounding_up_safe<ix_t>(pq_dim * pq_bits, 8 * sizeof(label_t)));

auto codes = raft::make_device_matrix<uint8_t, IdxT, raft::row_major>(res, n_rows, codes_rowlen);

auto stream = raft::resource::get_cuda_stream(res);

// TODO: with scaling workspace we could choose the batch size dynamically
constexpr ix_t kBlockSize = 256;
const ix_t threads_per_vec = std::min<ix_t>(raft::WarpSize, pq_n_centers);
dim3 threads(kBlockSize, 1, 1);

auto kernel = [](uint32_t pq_bits) {
switch (pq_bits) {
case 4:
return process_and_fill_codes_subspaces_kernel<kBlockSize, 4, data_t, MathT, IdxT, label_t>;
case 8:
return process_and_fill_codes_subspaces_kernel<kBlockSize, 8, data_t, MathT, IdxT, label_t>;
default: RAFT_FAIL("Invalid pq_bits (%u), the value must be 4 or 8", pq_bits);
}
}(pq_bits);

auto labels = cuvs::neighbors::detail::predict_vq<label_t>(res, dataset, vq_centers);

dim3 blocks(raft::div_rounding_up_safe<ix_t>(n_rows, kBlockSize / threads_per_vec), 1, 1);

kernel<<<blocks, threads, 0, stream>>>(
raft::make_device_matrix_view<uint8_t, IdxT>(codes.data_handle(), n_rows, codes_rowlen),
dataset,
vq_centers,
raft::make_const_mdspan(labels.view()),
pq_centers);

RAFT_CUDA_TRY(cudaPeekAtLastError());

return codes;
}

template <typename T>
auto create_pq_codebook(raft::resources const& res,
raft::device_matrix_view<const T, int64_t> residuals,
Expand Down Expand Up @@ -203,7 +104,7 @@ auto quantize_residuals(raft::resources const& res,
vq_codebook.size() * sizeof(T),
raft::resource::get_cuda_stream(res)));

auto codes = process_and_fill_codes_subspaces<T, IdxT>(
auto codes = cuvs::neighbors::detail::process_and_fill_codes_subspaces<T, IdxT>(
res, ps, residuals, raft::make_const_mdspan(vq_codebook.view()), pq_codebook);

return codes;
Expand Down
Loading