diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index af15e4a399..810d0d27b8 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -480,6 +480,7 @@ if(NOT BUILD_CPU_ONLY) src/neighbors/ivf_pq/detail/ivf_pq_build_extend_half_int64_t.cu src/neighbors/ivf_pq/detail/ivf_pq_build_extend_int8_t_int64_t.cu src/neighbors/ivf_pq/detail/ivf_pq_build_extend_uint8_t_int64_t.cu + src/neighbors/ivf_pq/detail/ivf_pq_build_precomputed_int64_t.cu src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_half_fp8_false.cu src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_half_fp8_true.cu src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_half_half.cu diff --git a/cpp/include/cuvs/neighbors/ivf_pq.hpp b/cpp/include/cuvs/neighbors/ivf_pq.hpp index 6836757cea..efd82a6f62 100644 --- a/cpp/include/cuvs/neighbors/ivf_pq.hpp +++ b/cpp/include/cuvs/neighbors/ivf_pq.hpp @@ -278,6 +278,72 @@ constexpr typename list_spec::list_extents list_spec:: template using list_data = ivf::list; +using pq_centers_extents = std::experimental:: + extents; + +template +class index_iface { + public: + virtual ~index_iface() = default; + + virtual cuvs::distance::DistanceType metric() const noexcept = 0; + virtual codebook_gen codebook_kind() const noexcept = 0; + virtual uint32_t dim() const noexcept = 0; + virtual uint32_t dim_ext() const noexcept = 0; + virtual uint32_t rot_dim() const noexcept = 0; + virtual uint32_t pq_bits() const noexcept = 0; + virtual uint32_t pq_dim() const noexcept = 0; + virtual uint32_t pq_len() const noexcept = 0; + virtual uint32_t pq_book_size() const noexcept = 0; + virtual uint32_t n_lists() const noexcept = 0; + virtual bool conservative_memory_allocation() const noexcept = 0; + virtual uint32_t get_list_size_in_bytes(uint32_t label) const noexcept = 0; + + virtual std::vector>>& lists() noexcept = 0; + virtual const std::vector>>& lists() const noexcept = 0; + + virtual raft::device_vector_view list_sizes() noexcept = 0; + virtual raft::device_vector_view list_sizes() + const noexcept = 0; + + virtual raft::device_vector_view data_ptrs() noexcept = 0; + virtual raft::device_vector_view data_ptrs() + const noexcept = 0; + + virtual raft::device_vector_view inds_ptrs() noexcept = 0; + virtual raft::device_vector_view inds_ptrs() + const noexcept = 0; + + virtual raft::host_vector_view accum_sorted_sizes() noexcept = 0; + virtual raft::host_vector_view accum_sorted_sizes() + const noexcept = 0; + + virtual raft::device_mdspan pq_centers() noexcept = 0; + virtual raft::device_mdspan pq_centers() + const noexcept = 0; + + virtual raft::device_matrix_view centers() noexcept = 0; + virtual raft::device_matrix_view centers() + const noexcept = 0; + + virtual raft::device_matrix_view centers_rot() noexcept = 0; + virtual raft::device_matrix_view centers_rot() + const noexcept = 0; + + virtual raft::device_matrix_view rotation_matrix() noexcept = 0; + virtual raft::device_matrix_view rotation_matrix() + const noexcept = 0; + + virtual raft::device_matrix_view rotation_matrix_int8( + const raft::resources& res) const = 0; + virtual raft::device_matrix_view rotation_matrix_half( + const raft::resources& res) const = 0; + virtual raft::device_matrix_view centers_int8( + const raft::resources& res) const = 0; + virtual raft::device_matrix_view centers_half( + const raft::resources& res) const = 0; +}; + /** * @defgroup ivf_pq_cpp_index IVF-PQ index * @{ @@ -328,17 +394,14 @@ using list_data = ivf::list; * */ template -struct index : cuvs::neighbors::index { +class index : public index_iface, cuvs::neighbors::index { + public: using index_params_type = ivf_pq::index_params; using search_params_type = ivf_pq::search_params; using index_type = IdxT; static_assert(!raft::is_narrowing_v, "IdxT must be able to represent all values of uint32_t"); - using pq_centers_extents = std::experimental:: - extents; - - public: index(const index&) = delete; index(index&&) = default; auto operator=(const index&) -> index& = delete; @@ -353,7 +416,20 @@ struct index : cuvs::neighbors::index { */ index(raft::resources const& handle); - /** Construct an empty index. It needs to be trained and then populated. */ + /** + * @brief Construct an index with specified parameters. + * + * This constructor creates an owning index with the given parameters. + * + * @param handle RAFT resources handle + * @param metric Distance metric for clustering + * @param codebook_kind How PQ codebooks are created + * @param n_lists Number of inverted lists (clusters) + * @param dim Dimensionality of the input data + * @param pq_bits Bit length of vector elements after PQ compression + * @param pq_dim Dimensionality after PQ compression (0 = auto-select) + * @param conservative_memory_allocation Memory allocation strategy + */ index(raft::resources const& handle, cuvs::distance::DistanceType metric, codebook_gen codebook_kind, @@ -363,14 +439,20 @@ struct index : cuvs::neighbors::index { uint32_t pq_dim = 0, bool conservative_memory_allocation = false); - /** Construct an empty index. It needs to be trained and then populated. */ + /** + * @brief Construct an index from index parameters. + * + * @param handle RAFT resources handle + * @param params Index parameters + * @param dim Dimensionality of the input data + */ index(raft::resources const& handle, const index_params& params, uint32_t dim); /** Total length of the index. */ IdxT size() const noexcept; /** Dimensionality of the input data. */ - uint32_t dim() const noexcept; + uint32_t dim() const noexcept override; /** * Dimensionality of the cluster centers: @@ -385,10 +467,10 @@ struct index : cuvs::neighbors::index { uint32_t rot_dim() const noexcept; /** The bit length of an encoded vector element after compression by PQ. */ - uint32_t pq_bits() const noexcept; + uint32_t pq_bits() const noexcept override; /** The dimensionality of an encoded vector after compression by PQ. */ - uint32_t pq_dim() const noexcept; + uint32_t pq_dim() const noexcept override; /** Dimensionality of a subspaces, i.e. the number of vector components mapped to a subspace */ uint32_t pq_len() const noexcept; @@ -397,10 +479,10 @@ struct index : cuvs::neighbors::index { uint32_t pq_book_size() const noexcept; /** Distance metric used for clustering. */ - cuvs::distance::DistanceType metric() const noexcept; + cuvs::distance::DistanceType metric() const noexcept override; /** How PQ codebooks are created. */ - codebook_gen codebook_kind() const noexcept; + codebook_gen codebook_kind() const noexcept override; /** Number of clusters/inverted lists (first level quantization). */ uint32_t n_lists() const noexcept; @@ -409,7 +491,7 @@ struct index : cuvs::neighbors::index { * Whether to use convervative memory allocation when extending the list (cluster) data * (see index_params.conservative_memory_allocation). */ - bool conservative_memory_allocation() const noexcept; + bool conservative_memory_allocation() const noexcept override; /** * PQ cluster centers @@ -417,30 +499,33 @@ struct index : cuvs::neighbors::index { * - codebook_gen::PER_SUBSPACE: [pq_dim , pq_len, pq_book_size] * - codebook_gen::PER_CLUSTER: [n_lists, pq_len, pq_book_size] */ - raft::device_mdspan pq_centers() noexcept; - raft::device_mdspan pq_centers() const noexcept; + raft::device_mdspan pq_centers() noexcept override; + raft::device_mdspan pq_centers() + const noexcept override; /** Lists' data and indices. */ - std::vector>>& lists() noexcept; - const std::vector>>& lists() const noexcept; + std::vector>>& lists() noexcept override; + const std::vector>>& lists() const noexcept override; /** Pointers to the inverted lists (clusters) data [n_lists]. */ - raft::device_vector_view data_ptrs() noexcept; + raft::device_vector_view data_ptrs() noexcept override; raft::device_vector_view data_ptrs() - const noexcept; + const noexcept override; /** Pointers to the inverted lists (clusters) indices [n_lists]. */ - raft::device_vector_view inds_ptrs() noexcept; - raft::device_vector_view inds_ptrs() const noexcept; + raft::device_vector_view inds_ptrs() noexcept override; + raft::device_vector_view inds_ptrs() + const noexcept override; /** The transform matrix (original space -> rotated padded space) [rot_dim, dim] */ - raft::device_matrix_view rotation_matrix() noexcept; - raft::device_matrix_view rotation_matrix() const noexcept; + raft::device_matrix_view rotation_matrix() noexcept override; + raft::device_matrix_view rotation_matrix() + const noexcept override; raft::device_matrix_view rotation_matrix_int8( - const raft::resources& res) const; + const raft::resources& res) const override; raft::device_matrix_view rotation_matrix_half( - const raft::resources& res) const; + const raft::resources& res) const override; /** * Accumulated list sizes, sorted in descending order [n_lists + 1]. @@ -451,25 +536,29 @@ struct index : cuvs::neighbors::index { * * This span is used during search to estimate the maximum size of the workspace. */ - raft::host_vector_view accum_sorted_sizes() noexcept; - raft::host_vector_view accum_sorted_sizes() const noexcept; + raft::host_vector_view accum_sorted_sizes() noexcept override; + raft::host_vector_view accum_sorted_sizes() + const noexcept override; /** Sizes of the lists [n_lists]. */ - raft::device_vector_view list_sizes() noexcept; - raft::device_vector_view list_sizes() const noexcept; + raft::device_vector_view list_sizes() noexcept override; + raft::device_vector_view list_sizes() + const noexcept override; /** Cluster centers corresponding to the lists in the original space [n_lists, dim_ext] */ - raft::device_matrix_view centers() noexcept; - raft::device_matrix_view centers() const noexcept; + raft::device_matrix_view centers() noexcept override; + raft::device_matrix_view centers() + const noexcept override; raft::device_matrix_view centers_int8( - const raft::resources& res) const; + const raft::resources& res) const override; raft::device_matrix_view centers_half( - const raft::resources& res) const; + const raft::resources& res) const override; /** Cluster centers corresponding to the lists in the rotated space [n_lists, rot_dim] */ - raft::device_matrix_view centers_rot() noexcept; - raft::device_matrix_view centers_rot() const noexcept; + raft::device_matrix_view centers_rot() noexcept override; + raft::device_matrix_view centers_rot() + const noexcept override; /** fetch size of a particular IVF list in bytes using the list extents. * Usage example: @@ -480,50 +569,31 @@ struct index : cuvs::neighbors::index { * // extend the IVF lists while building the index * index_params.add_data_on_build = true; * // create and fill the index from a [N, D] dataset - * auto index = cuvs::neighbors::ivf_pq::build(res, index_params, dataset, N, D); + * auto index = cuvs::neighbors::ivf_pq::build(res, index_params, dataset); * // Fetch the size of the fourth list * uint32_t size = index.get_list_size_in_bytes(3); * @endcode * * @param[in] label list ID */ - uint32_t get_list_size_in_bytes(uint32_t label); + uint32_t get_list_size_in_bytes(uint32_t label) const noexcept override; - private: - cuvs::distance::DistanceType metric_; - codebook_gen codebook_kind_; - uint32_t dim_; - uint32_t pq_bits_; - uint32_t pq_dim_; - bool conservative_memory_allocation_; - - // Primary data members - std::vector>> lists_; - raft::device_vector list_sizes_; - raft::device_mdarray pq_centers_; - raft::device_matrix centers_; - raft::device_matrix centers_rot_; - raft::device_matrix rotation_matrix_; - - // Lazy-initialized low-precision variants of index members - for low-precision coarse search. - // These are never serialized and not touched during build/extend. - mutable std::optional> centers_int8_; - mutable std::optional> centers_half_; - mutable std::optional> - rotation_matrix_int8_; - mutable std::optional> rotation_matrix_half_; - - // Computed members for accelerating search. - raft::device_vector data_ptrs_; - raft::device_vector inds_ptrs_; - raft::host_vector accum_sorted_sizes_; - - /** Throw an error if the index content is inconsistent. */ - void check_consistency(); - - pq_centers_extents make_pq_centers_extents(); + /** + * @brief Construct index from implementation pointer. + * + * This constructor is used internally by build/extend/deserialize functions. + * + * @param impl Implementation pointer (owning or view) + */ + explicit index(std::unique_ptr> impl); + + static pq_centers_extents make_pq_centers_extents( + uint32_t dim, uint32_t pq_dim, uint32_t pq_bits, codebook_gen codebook_kind, uint32_t n_lists); static uint32_t calculate_pq_dim(uint32_t dim); + + private: + std::unique_ptr> impl_; }; /** * @} @@ -996,6 +1066,129 @@ void build(raft::resources const& handle, const cuvs::neighbors::ivf_pq::index_params& index_params, raft::host_matrix_view dataset, cuvs::neighbors::ivf_pq::index* idx); + +/** + * @brief Build a view-type IVF-PQ index from device memory centroids and codebook. + * + * This function creates a non-owning index that references the provided device data directly. + * All parameters must be provided with correct extents. The caller is responsible for ensuring + * the lifetime of the input data exceeds the lifetime of the returned index. + * + * The index_params must be consistent with the provided matrices. Specifically: + * - index_params.codebook_kind determines the expected shape of pq_centers + * - index_params.metric will be stored in the index + * - index_params.conservative_memory_allocation will be stored in the index + * The function will verify consistency between index_params, dim, and the matrix extents. + * + * @tparam IdxT Type of indices (default: int64_t) + * + * @param[in] handle raft resources handle + * @param[in] index_params configure the index (metric, codebook_kind, etc.). Must be consistent + * with the provided matrices. + * @param[in] dim dimensionality of the input data + * @param[in] pq_centers PQ codebook on device memory with required extents: + * - codebook_gen::PER_SUBSPACE: [pq_dim, pq_len, pq_book_size] + * - codebook_gen::PER_CLUSTER: [n_lists, pq_len, pq_book_size] + * @param[in] centers Cluster centers in the original space [n_lists, dim_ext] + * where dim_ext = round_up(dim + 1, 8) + * @param[in] centers_rot Rotated cluster centers [n_lists, rot_dim] + * where rot_dim = pq_len * pq_dim + * @param[in] rotation_matrix Transform matrix (original space -> rotated padded space) [rot_dim, + * dim] + * + * @return A view-type ivf_pq index that references the provided data + */ +auto build(raft::resources const& handle, + const cuvs::neighbors::ivf_pq::index_params& index_params, + const uint32_t dim, + raft::device_mdspan, raft::row_major> pq_centers, + raft::device_matrix_view centers, + raft::device_matrix_view centers_rot, + raft::device_matrix_view rotation_matrix) + -> cuvs::neighbors::ivf_pq::index; + +/** + * @brief Build an IVF-PQ index from device memory centroids and codebook. + * + * This function creates a non-owning index that references the provided device data directly. + * All parameters must be provided with correct extents. The caller is responsible for ensuring + * the lifetime of the input data exceeds the lifetime of the returned index. + * + * The index_params must be consistent with the provided matrices. Specifically: + * - index_params.codebook_kind determines the expected shape of pq_centers + * - index_params.metric will be stored in the index + * - index_params.conservative_memory_allocation will be stored in the index + * The function will verify consistency between index_params, dim, and the matrix extents. + * + * @tparam IdxT Type of indices (default: int64_t) + * + * @param[in] handle raft resources handle + * @param[in] index_params configure the index (metric, codebook_kind, etc.). Must be consistent + * with the provided matrices. + * @param[in] dim dimensionality of the input data + * @param[in] pq_centers PQ codebook on device memory with required extents: + * - codebook_gen::PER_SUBSPACE: [pq_dim, pq_len, pq_book_size] + * - codebook_gen::PER_CLUSTER: [n_lists, pq_len, pq_book_size] + * @param[in] centers Cluster centers in the original space [n_lists, dim_ext] + * where dim_ext = round_up(dim + 1, 8) + * @param[in] centers_rot Rotated cluster centers [n_lists, rot_dim] + * where rot_dim = pq_len * pq_dim + * @param[in] rotation_matrix Transform matrix (original space -> rotated padded space) [rot_dim, + * dim] + * @param[out] idx pointer to ivf_pq::index + */ +void build(raft::resources const& handle, + const cuvs::neighbors::ivf_pq::index_params& index_params, + const uint32_t dim, + raft::device_mdspan, raft::row_major> pq_centers, + raft::device_matrix_view centers, + raft::device_matrix_view centers_rot, + raft::device_matrix_view rotation_matrix, + cuvs::neighbors::ivf_pq::index* idx); + +/** + * @brief Build an IVF-PQ index from host memory centroids and codebook (in-place). + * + * @param[in] handle raft resources handle + * @param[in] index_params configure the index building + * @param[in] dim dimensionality of the input data + * @param[in] pq_centers PQ codebook + * @param[in] centers Cluster centers + * @param[in] centers_rot Optional rotated cluster centers + * @param[in] rotation_matrix Optional rotation matrix + * @param[out] idx pointer to ivf_pq::index + */ +auto build( + raft::resources const& handle, + const cuvs::neighbors::ivf_pq::index_params& index_params, + const uint32_t dim, + raft::host_mdspan, raft::row_major> pq_centers, + raft::host_matrix_view centers, + std::optional> centers_rot, + std::optional> rotation_matrix) + -> cuvs::neighbors::ivf_pq::index; + +/** + * @brief Build an IVF-PQ index from host memory centroids and codebook (in-place). + * + * @param[in] handle raft resources handle + * @param[in] index_params configure the index building + * @param[in] dim dimensionality of the input data + * @param[in] pq_centers PQ codebook on host memory + * @param[in] centers Cluster centers on host memory + * @param[in] centers_rot Optional rotated cluster centers on host + * @param[in] rotation_matrix Optional rotation matrix on host + * @param[out] idx pointer to IVF-PQ index to be built + */ +void build( + raft::resources const& handle, + const cuvs::neighbors::ivf_pq::index_params& index_params, + const uint32_t dim, + raft::host_mdspan, raft::row_major> pq_centers, + raft::host_matrix_view centers, + std::optional> centers_rot, + std::optional> rotation_matrix, + cuvs::neighbors::ivf_pq::index* idx); /** * @} */ @@ -1117,6 +1310,7 @@ void extend(raft::resources const& handle, raft::device_matrix_view new_vectors, std::optional> new_indices, cuvs::neighbors::ivf_pq::index* idx); + /** * @brief Extend the index with the new data. * @@ -2847,35 +3041,50 @@ void make_rotation_matrix(raft::resources const& res, bool force_random_rotation); /** - * @brief Public helper API for externally modifying the index's IVF centroids. - * NB: The index must be reset before this. Use raft::neighbors::ivf_pq::extend to construct IVF - lists according to new centroids. + * @brief Pad cluster centers with their L2 norms for efficient GEMM operations. * - * Usage example: - * @code{.cpp} - * raft::resources res; - * // allocate the buffer for the input centers - * auto cluster_centers = raft::make_device_matrix(res, index.n_lists(), - index.dim()); - * ... prepare ivf centroids in cluster_centers ... - * // reset the index - * reset_index(res, &index); - * // recompute the state of the index - * cuvs::neighbors::ivf_pq::helpers::recompute_internal_state(res, index); - * // Write the IVF centroids - * cuvs::neighbors::ivf_pq::helpers::set_centers( - res, - &index, - cluster_centers); - * @endcode + * This function takes cluster centers and pads them with their L2 norms to create + * extended centers suitable for coarse search operations. The output has dimensions + * [n_centers, dim_ext] where dim_ext = round_up(dim + 1, 8). * * @param[in] res raft resource - * @param[inout] index pointer to IVF-PQ index - * @param[in] cluster_centers new cluster centers [index.n_lists(), index.dim()] + * @param[in] centers cluster centers [n_centers, dim] + * @param[out] padded_centers padded centers with norms [n_centers, dim_ext] */ -void set_centers(raft::resources const& res, - index* index, - raft::device_matrix_view cluster_centers); +void pad_centers_with_norms( + raft::resources const& res, + raft::device_matrix_view centers, + raft::device_matrix_view padded_centers); + +/** + * @brief Pad cluster centers with their L2 norms for efficient GEMM operations. + * + * This function takes cluster centers and pads them with their L2 norms to create + * extended centers suitable for coarse search operations. The output has dimensions + * [n_centers, dim_ext] where dim_ext = round_up(dim + 1, 8). + * + * @param[in] res raft resource + * @param[in] centers cluster centers [n_centers, dim] + * @param[out] padded_centers padded centers with norms [n_centers, dim_ext] + */ +void pad_centers_with_norms( + raft::resources const& res, + raft::host_matrix_view centers, + raft::device_matrix_view padded_centers); + +/** + * @brief Rotate padded centers with the rotation matrix. + * + * @param[in] res raft resource + * @param[in] padded_centers padded centers [n_centers, dim_ext] + * @param[in] rotation_matrix rotation matrix [rot_dim, dim] + * @param[out] rotated_centers rotated centers [n_centers, rot_dim] + */ +void rotate_padded_centers( + raft::resources const& res, + raft::device_matrix_view padded_centers, + raft::device_matrix_view rotation_matrix, + raft::device_matrix_view rotated_centers); /** * @brief Public helper API for fetching a trained index's IVF centroids @@ -2930,6 +3139,35 @@ void extract_centers(raft::resources const& res, */ void recompute_internal_state(const raft::resources& res, index* index); +/** + * @brief Generate a rotation matrix into user-provided buffer (standalone version). + * + * This standalone helper generates a rotation matrix without requiring an index object. + * Users can call this to prepare a rotation matrix before building from precomputed data. + * + * Usage example: + * @code{.cpp} + * raft::resources res; + * uint32_t dim = 128, pq_dim = 32; + * uint32_t rot_dim = pq_dim * ((dim + pq_dim - 1) / pq_dim); // rounded up + * + * // Allocate rotation matrix buffer [rot_dim, dim] + * auto rotation_matrix = raft::make_device_matrix(res, rot_dim, dim); + * + * // Generate the rotation matrix + * ivf_pq::helpers::make_rotation_matrix( + * res, rotation_matrix.view(), true); + * @endcode + * + * @param[in] res raft resource + * @param[out] rotation_matrix Output buffer [rot_dim, dim] for the rotation matrix + * @param[in] force_random_rotation If false and rot_dim == dim, creates identity matrix. + * If true or rot_dim != dim, creates random orthogonal matrix. + */ +void make_rotation_matrix( + raft::resources const& res, + raft::device_matrix_view rotation_matrix, + bool force_random_rotation); /** * @} */ diff --git a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_build_precomputed_inst.cuh b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_build_precomputed_inst.cuh new file mode 100644 index 0000000000..c6a96e4451 --- /dev/null +++ b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_build_precomputed_inst.cuh @@ -0,0 +1,67 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +#include "../ivf_pq_build.cuh" + +namespace cuvs::neighbors::ivf_pq { +#define CUVS_INST_IVF_PQ_BUILD_PRECOMPUTED(IdxT) \ + auto build( \ + raft::resources const& handle, \ + const cuvs::neighbors::ivf_pq::index_params& index_params, \ + const uint32_t dim, \ + raft::host_mdspan, raft::row_major> pq_centers, \ + raft::host_matrix_view centers, \ + std::optional> centers_rot, \ + std::optional> rotation_matrix) \ + -> cuvs::neighbors::ivf_pq::index \ + { \ + return detail::build( \ + handle, index_params, dim, pq_centers, centers, centers_rot, rotation_matrix); \ + } \ + void build( \ + raft::resources const& handle, \ + const cuvs::neighbors::ivf_pq::index_params& index_params, \ + const uint32_t dim, \ + raft::host_mdspan, raft::row_major> pq_centers, \ + raft::host_matrix_view centers, \ + std::optional> centers_rot, \ + std::optional> rotation_matrix, \ + cuvs::neighbors::ivf_pq::index* idx) \ + { \ + detail::build( \ + handle, index_params, dim, pq_centers, centers, centers_rot, rotation_matrix, idx); \ + } \ + auto build( \ + raft::resources const& handle, \ + const cuvs::neighbors::ivf_pq::index_params& index_params, \ + const uint32_t dim, \ + raft::device_mdspan, raft::row_major> pq_centers, \ + raft::device_matrix_view centers, \ + raft::device_matrix_view centers_rot, \ + raft::device_matrix_view rotation_matrix) \ + -> cuvs::neighbors::ivf_pq::index \ + { \ + return detail::build( \ + handle, index_params, dim, pq_centers, centers, centers_rot, rotation_matrix); \ + } \ + void build( \ + raft::resources const& handle, \ + const cuvs::neighbors::ivf_pq::index_params& index_params, \ + const uint32_t dim, \ + raft::device_mdspan, raft::row_major> pq_centers, \ + raft::device_matrix_view centers, \ + raft::device_matrix_view centers_rot, \ + raft::device_matrix_view rotation_matrix, \ + cuvs::neighbors::ivf_pq::index* idx) \ + { \ + detail::build( \ + handle, index_params, dim, pq_centers, centers, centers_rot, rotation_matrix, idx); \ + } + +} // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_build_precomputed_int64_t.cu b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_build_precomputed_int64_t.cu new file mode 100644 index 0000000000..85d65c9080 --- /dev/null +++ b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_build_precomputed_int64_t.cu @@ -0,0 +1,15 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include "ivf_pq_build_precomputed_inst.cuh" + +namespace cuvs::neighbors::ivf_pq { +CUVS_INST_IVF_PQ_BUILD_PRECOMPUTED(int64_t); + +#undef CUVS_INST_IVF_PQ_BUILD_PRECOMPUTED + +} // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh index 31f2a12989..d576290371 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh @@ -8,6 +8,7 @@ #include "../../core/nvtx.hpp" #include "../ivf_common.cuh" #include "../ivf_list.cuh" +#include "../ivf_pq_impl.hpp" #include "ivf_pq_codepacking.cuh" #include "ivf_pq_contiguous_list_data.cuh" #include "ivf_pq_process_and_fill_codes.cuh" @@ -243,56 +244,61 @@ auto calculate_offsets_and_indices(IdxT n_rows, return max_cluster_size; } -template -void set_centers(raft::resources const& handle, index* index, const float* cluster_centers) +template +void pad_centers_with_norms( + raft::resources const& res, + raft::mdspan, raft::row_major, accessor> centers, + raft::device_matrix_view padded_centers) { - auto stream = raft::resource::get_cuda_stream(handle); - auto* device_memory = raft::resource::get_workspace_resource(handle); + auto stream = raft::resource::get_cuda_stream(res); // Make sure to have trailing zeroes between dim and dim_ext; // We rely on this to enable padded tensor gemm kernels during coarse search. cuvs::spatial::knn::detail::utils::memzero( - index->centers().data_handle(), index->centers().size(), stream); + padded_centers.data_handle(), padded_centers.size(), stream); // combine cluster_centers and their norms - RAFT_CUDA_TRY(cudaMemcpy2DAsync(index->centers().data_handle(), - sizeof(float) * index->dim_ext(), - cluster_centers, - sizeof(float) * index->dim(), - sizeof(float) * index->dim(), - index->n_lists(), + RAFT_CUDA_TRY(cudaMemcpy2DAsync(padded_centers.data_handle(), + sizeof(float) * padded_centers.extent(1), + centers.data_handle(), + sizeof(float) * centers.extent(1), + sizeof(float) * centers.extent(1), + centers.extent(0), cudaMemcpyDefault, stream)); - rmm::device_uvector center_norms(index->n_lists(), stream, device_memory); + rmm::device_uvector center_norms(centers.extent(0), stream); raft::linalg::rowNorm( - center_norms.data(), cluster_centers, index->dim(), index->n_lists(), stream); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(index->centers().data_handle() + index->dim(), - sizeof(float) * index->dim_ext(), + center_norms.data(), centers.data_handle(), centers.extent(1), centers.extent(0), stream); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(padded_centers.data_handle() + centers.extent(1), + sizeof(float) * padded_centers.extent(1), center_norms.data(), sizeof(float), sizeof(float), - index->n_lists(), + padded_centers.extent(0), cudaMemcpyDefault, stream)); +} - // Rotate cluster_centers - float alpha = 1.0; - float beta = 0.0; - raft::linalg::gemm(handle, - true, - false, - index->rot_dim(), - index->n_lists(), - index->dim(), - &alpha, - index->rotation_matrix().data_handle(), - index->dim(), - cluster_centers, - index->dim(), - &beta, - index->centers_rot().data_handle(), - index->rot_dim(), - raft::resource::get_cuda_stream(handle)); +template +void set_centers(raft::resources const& handle, index* index, const float* cluster_centers) +{ + switch (utils::check_pointer_residency(cluster_centers)) { + case utils::pointer_residency::host_only: + cuvs::neighbors::ivf_pq::helpers::pad_centers_with_norms( + handle, + raft::make_host_matrix_view( + cluster_centers, index->n_lists(), index->dim()), + index->centers()); + break; + default: + cuvs::neighbors::ivf_pq::helpers::pad_centers_with_norms( + handle, + raft::make_device_matrix_view( + cluster_centers, index->n_lists(), index->dim()), + index->centers()); + } + cuvs::neighbors::ivf_pq::helpers::rotate_padded_centers( + handle, index->centers(), index->rotation_matrix(), index->centers_rot()); } template @@ -1040,14 +1046,15 @@ auto clone(const raft::resources& res, const index& source) -> index { auto stream = raft::resource::get_cuda_stream(res); - // Allocate the new index + // Allocate the new owning index index target(res, source.metric(), source.codebook_kind(), source.n_lists(), source.dim(), source.pq_bits(), - source.pq_dim()); + source.pq_dim(), + source.conservative_memory_allocation()); // raft::copy the independent parts raft::copy(target.list_sizes().data_handle(), @@ -1315,20 +1322,26 @@ auto build(raft::resources const& handle, RAFT_EXPECTS(n_rows > 0 && dim > 0, "empty dataset"); RAFT_EXPECTS(n_rows >= params.n_lists, "number of rows can't be less than n_lists"); - auto stream = raft::resource::get_cuda_stream(handle); + index idx(handle, + params.metric, + params.codebook_kind, + params.n_lists, + dim, + params.pq_bits, + params.pq_dim, + params.conservative_memory_allocation); - index index(handle, params, dim); - utils::memzero( - index.accum_sorted_sizes().data_handle(), index.accum_sorted_sizes().size(), stream); - utils::memzero(index.list_sizes().data_handle(), index.list_sizes().size(), stream); - utils::memzero(index.data_ptrs().data_handle(), index.data_ptrs().size(), stream); - utils::memzero(index.inds_ptrs().data_handle(), index.inds_ptrs().size(), stream); + auto stream = raft::resource::get_cuda_stream(handle); + utils::memzero(idx.accum_sorted_sizes().data_handle(), idx.accum_sorted_sizes().size(), stream); + utils::memzero(idx.list_sizes().data_handle(), idx.list_sizes().size(), stream); + utils::memzero(idx.data_ptrs().data_handle(), idx.data_ptrs().size(), stream); + utils::memzero(idx.inds_ptrs().data_handle(), idx.inds_ptrs().size(), stream); { raft::random::RngState random_state{137}; auto trainset_ratio = std::max( 1, - size_t(n_rows) / std::max(params.kmeans_trainset_fraction * n_rows, index.n_lists())); + size_t(n_rows) / std::max(params.kmeans_trainset_fraction * n_rows, idx.n_lists())); size_t n_rows_train = n_rows / trainset_ratio; rmm::device_async_resource_ref device_memory = raft::resource::get_workspace_resource(handle); @@ -1338,7 +1351,7 @@ auto build(raft::resources const& handle, constexpr size_t kTolerableRatio = 4; rmm::device_async_resource_ref big_memory_resource = raft::resource::get_large_workspace_resource(handle); - if (sizeof(float) * n_rows_train * index.dim() * kTolerableRatio < + if (sizeof(float) * n_rows_train * idx.dim() * kTolerableRatio < raft::resource::get_workspace_free_bytes(handle)) { big_memory_resource = device_memory; } @@ -1382,18 +1395,18 @@ auto build(raft::resources const& handle, // NB: here cluster_centers is used as if it is [n_clusters, data_dim] not [n_clusters, // dim_ext]! rmm::device_uvector cluster_centers_buf( - index.n_lists() * index.dim(), stream, device_memory); + idx.n_lists() * idx.dim(), stream, device_memory); auto cluster_centers = cluster_centers_buf.data(); // Train balanced hierarchical kmeans clustering auto trainset_const_view = raft::make_const_mdspan(trainset.view()); auto centers_view = raft::make_device_matrix_view( - cluster_centers, index.n_lists(), index.dim()); + cluster_centers, idx.n_lists(), idx.dim()); cuvs::cluster::kmeans::balanced_params kmeans_params; kmeans_params.n_iters = params.kmeans_n_iters; - kmeans_params.metric = static_cast((int)index.metric()); + kmeans_params.metric = static_cast((int)idx.metric()); - if (index.metric() == distance::DistanceType::CosineExpanded) { + if (idx.metric() == distance::DistanceType::CosineExpanded) { raft::linalg::row_normalize( handle, trainset_const_view, trainset.view()); } @@ -1402,8 +1415,8 @@ auto build(raft::resources const& handle, // Trainset labels are needed for training PQ codebooks rmm::device_uvector labels(n_rows_train, stream, big_memory_resource); auto centers_const_view = raft::make_device_matrix_view( - cluster_centers, index.n_lists(), index.dim()); - if (index.metric() == distance::DistanceType::CosineExpanded) { + cluster_centers, idx.n_lists(), idx.dim()); + if (idx.metric() == distance::DistanceType::CosineExpanded) { raft::linalg::row_normalize(handle, centers_const_view, centers_view); } auto labels_view = @@ -1412,15 +1425,15 @@ auto build(raft::resources const& handle, handle, kmeans_params, trainset_const_view, centers_const_view, labels_view); // Make rotation matrix - helpers::make_rotation_matrix(handle, &index, params.force_random_rotation); + helpers::make_rotation_matrix(handle, &idx, params.force_random_rotation); - helpers::set_centers(handle, &index, raft::make_const_mdspan(centers_view)); + set_centers(handle, &idx, cluster_centers); // Train PQ codebooks - switch (index.codebook_kind()) { + switch (idx.codebook_kind()) { case codebook_gen::PER_SUBSPACE: train_per_subset(handle, - index, + idx, n_rows_train, trainset.data_handle(), labels.data(), @@ -1429,7 +1442,7 @@ auto build(raft::resources const& handle, break; case codebook_gen::PER_CLUSTER: train_per_cluster(handle, - index, + idx, n_rows_train, trainset.data_handle(), labels.data(), @@ -1442,9 +1455,9 @@ auto build(raft::resources const& handle, // add the data if necessary if (params.add_data_on_build) { - detail::extend(handle, &index, dataset.data_handle(), nullptr, n_rows); + detail::extend(handle, &idx, dataset.data_handle(), nullptr, n_rows); } - return index; + return idx; } template @@ -1456,6 +1469,95 @@ void build(raft::resources const& handle, *index = build(handle, params, dataset); } +template +auto build(raft::resources const& handle, + const cuvs::neighbors::ivf_pq::index_params& index_params, + const uint32_t dim, + raft::device_mdspan, raft::row_major> pq_centers, + raft::device_matrix_view centers, + raft::device_matrix_view centers_rot, + raft::device_matrix_view rotation_matrix) + -> cuvs::neighbors::ivf_pq::index +{ + raft::common::nvtx::range fun_scope("ivf_pq::build(%u)", dim); + auto stream = raft::resource::get_cuda_stream(handle); + + auto pq_dim = index_params.pq_dim == 0 ? index::calculate_pq_dim(dim) : index_params.pq_dim; + auto expected_pq_centers_extents = index::make_pq_centers_extents( + dim, pq_dim, index_params.pq_bits, index_params.codebook_kind, index_params.n_lists); + RAFT_EXPECTS(pq_centers.extent(0) == expected_pq_centers_extents.extent(0) && + pq_centers.extent(1) == expected_pq_centers_extents.extent(1) && + pq_centers.extent(2) == expected_pq_centers_extents.extent(2), + "pq_centers must have extent [%u, %u, %u]. Got [%u, %u, %u]", + expected_pq_centers_extents.extent(0), + expected_pq_centers_extents.extent(1), + expected_pq_centers_extents.extent(2), + pq_centers.extent(0), + pq_centers.extent(1), + pq_centers.extent(2)); + + RAFT_EXPECTS( + centers.extent(0) == index_params.n_lists && + centers.extent(1) == raft::round_up_safe(dim + 1, 8u), + "centers must have extent [n_lists, round_up(dim + 1, 8)]. Expected [%u, %u], got [%u, %u]", + index_params.n_lists, + raft::round_up_safe(dim + 1, 8u), + centers.extent(0), + centers.extent(1)); + + auto pq_len = raft::div_rounding_up_unsafe(dim, pq_dim); + RAFT_EXPECTS(rotation_matrix.extent(0) == pq_len * pq_dim && rotation_matrix.extent(1) == dim, + "rotation_matrix must have extent [rot_dim, dim] = [%u, %u]. Got [%u, %u]", + pq_len * pq_dim, + dim, + rotation_matrix.extent(0), + rotation_matrix.extent(1)); + + RAFT_EXPECTS( + centers_rot.extent(0) == index_params.n_lists && centers_rot.extent(1) == pq_len * pq_dim, + "centers_rot must have extent [n_lists, pq_len * pq_dim]. Expected [%u, %u], got [%u, %u]", + index_params.n_lists, + pq_len * pq_dim, + centers_rot.extent(0), + centers_rot.extent(1)); + + auto impl = std::make_unique>(handle, + index_params.metric, + index_params.codebook_kind, + index_params.n_lists, + dim, + index_params.pq_bits, + pq_dim, + index_params.conservative_memory_allocation, + pq_centers, + centers, + centers_rot, + rotation_matrix); + + index view_index(std::move(impl)); + + utils::memzero( + view_index.accum_sorted_sizes().data_handle(), view_index.accum_sorted_sizes().size(), stream); + utils::memzero(view_index.list_sizes().data_handle(), view_index.list_sizes().size(), stream); + utils::memzero(view_index.data_ptrs().data_handle(), view_index.data_ptrs().size(), stream); + utils::memzero(view_index.inds_ptrs().data_handle(), view_index.inds_ptrs().size(), stream); + + return view_index; +} + +template +void build(raft::resources const& handle, + const cuvs::neighbors::ivf_pq::index_params& index_params, + const uint32_t dim, + raft::device_mdspan, raft::row_major> pq_centers, + raft::device_matrix_view centers, + raft::device_matrix_view centers_rot, + raft::device_matrix_view rotation_matrix, + index* idx) +{ + *idx = build(handle, index_params, dim, pq_centers, centers, centers_rot, rotation_matrix); +} + template auto extend( raft::resources const& handle, @@ -1480,6 +1582,7 @@ auto extend( n_rows); } +// In-place extend for base class pointer (clones, extends, moves back) template void extend( raft::resources const& handle, @@ -1504,6 +1607,120 @@ void extend( n_rows); } +template +auto build( + raft::resources const& handle, + const cuvs::neighbors::ivf_pq::index_params& index_params, + const uint32_t dim, + raft::host_mdspan, raft::row_major> pq_centers, + raft::host_matrix_view centers, + std::optional> centers_rot, + std::optional> rotation_matrix) + -> cuvs::neighbors::ivf_pq::index +{ + raft::common::nvtx::range fun_scope( + "ivf_pq::build_from_host(%u)", dim); + auto stream = raft::resource::get_cuda_stream(handle); + + auto pq_dim = index_params.pq_dim == 0 ? index::calculate_pq_dim(dim) : index_params.pq_dim; + + index owning_index(handle, + index_params.metric, + index_params.codebook_kind, + index_params.n_lists, + dim, + index_params.pq_bits, + pq_dim, + index_params.conservative_memory_allocation); + + utils::memzero(owning_index.accum_sorted_sizes().data_handle(), + owning_index.accum_sorted_sizes().size(), + stream); + utils::memzero(owning_index.list_sizes().data_handle(), owning_index.list_sizes().size(), stream); + utils::memzero(owning_index.data_ptrs().data_handle(), owning_index.data_ptrs().size(), stream); + utils::memzero(owning_index.inds_ptrs().data_handle(), owning_index.inds_ptrs().size(), stream); + + RAFT_EXPECTS( + (centers.extent(1) == dim || centers.extent(1) == raft::round_up_safe(dim + 1, 8u)) && + centers.extent(0) == owning_index.n_lists(), + "centers must have extent [n_lists, dim] or [n_lists, round_up(dim + 1, 8)]. " + "Got centers.extent(1)=%u, expected dim=%u or round_up(dim + 1, 8)=%u, and " + "centers.extent(0)=%u, expected n_lists=%u", + centers.extent(1), + dim, + raft::round_up_safe(dim + 1, 8u), + centers.extent(0), + owning_index.n_lists()); + + if (centers.extent(1) == owning_index.dim_ext()) { + raft::copy(owning_index.centers().data_handle(), + centers.data_handle(), + owning_index.centers().size(), + stream); + } else { + cuvs::neighbors::ivf_pq::helpers::pad_centers_with_norms( + handle, centers, owning_index.centers()); + } + + if (rotation_matrix.has_value()) { + RAFT_EXPECTS(rotation_matrix.value().extent(0) == owning_index.rot_dim() && + rotation_matrix.value().extent(1) == dim, + "rotation_matrix must have extent [rot_dim, dim] = [%u, %u]. Got [%u, %u]", + owning_index.rot_dim(), + dim, + rotation_matrix.value().extent(0), + rotation_matrix.value().extent(1)); + } else { + helpers::make_rotation_matrix(handle, &owning_index, index_params.force_random_rotation); + } + + if (centers_rot.has_value()) { + RAFT_EXPECTS(centers_rot.value().extent(0) == owning_index.n_lists() && + centers_rot.value().extent(1) == owning_index.rot_dim(), + "centers_rot must have extent [n_lists, rot_dim]. Expected [%u, %u], got [%u, %u]", + owning_index.n_lists(), + owning_index.rot_dim(), + centers_rot.value().extent(0), + centers_rot.value().extent(1)); + raft::copy(owning_index.centers_rot().data_handle(), + centers_rot.value().data_handle(), + centers_rot.value().size(), + stream); + } else { + cuvs::neighbors::ivf_pq::helpers::rotate_padded_centers( + handle, owning_index.centers(), owning_index.rotation_matrix(), owning_index.centers_rot()); + } + + RAFT_EXPECTS(pq_centers.extent(0) == owning_index.pq_centers().extent(0) && + pq_centers.extent(1) == owning_index.pq_centers().extent(1) && + pq_centers.extent(2) == owning_index.pq_centers().extent(2), + "pq_centers must have extent [%u, %u, %u]. Got [%u, %u, %u]", + owning_index.pq_centers().extent(0), + owning_index.pq_centers().extent(1), + owning_index.pq_centers().extent(2), + pq_centers.extent(0), + pq_centers.extent(1), + pq_centers.extent(2)); + raft::copy( + owning_index.pq_centers().data_handle(), pq_centers.data_handle(), pq_centers.size(), stream); + + return owning_index; +} + +template +void build( + raft::resources const& handle, + const cuvs::neighbors::ivf_pq::index_params& index_params, + const uint32_t dim, + raft::host_mdspan, raft::row_major> pq_centers, + raft::host_matrix_view centers, + std::optional> centers_rot, + std::optional> rotation_matrix, + index* idx) +{ + *idx = build(handle, index_params, dim, pq_centers, centers, centers_rot, rotation_matrix); +} + template inline void extract_centers(raft::resources const& res, const cuvs::neighbors::ivf_pq::index& index, diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_build_common.cu b/cpp/src/neighbors/ivf_pq/ivf_pq_build_common.cu index 983d0d8b95..7ef12993c8 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_build_common.cu +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_build_common.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -10,7 +10,8 @@ #include #include -namespace cuvs::neighbors::ivf_pq::helpers { +namespace cuvs::neighbors::ivf_pq { +namespace helpers { namespace codepacker { @@ -267,18 +268,6 @@ void make_rotation_matrix(raft::resources const& res, index->rotation_matrix().data_handle()); } -void set_centers(raft::resources const& handle, - index* index, - raft::device_matrix_view cluster_centers) -{ - RAFT_EXPECTS(cluster_centers.extent(0) == index->n_lists(), - "Number of rows in the new centers must be equal to the number of IVF lists"); - RAFT_EXPECTS(cluster_centers.extent(1) == index->dim(), - "Number of columns in the new cluster centers and index dim are different"); - RAFT_EXPECTS(index->size() == 0, "Index must be empty"); - detail::set_centers(handle, index, cluster_centers.data_handle()); -} - void extract_centers(raft::resources const& res, const cuvs::neighbors::ivf_pq::index& index, raft::device_matrix_view cluster_centers) @@ -298,4 +287,85 @@ void recompute_internal_state(const raft::resources& res, index* index) ivf::detail::recompute_internal_state(res, *index); } -} // namespace cuvs::neighbors::ivf_pq::helpers +void make_rotation_matrix( + raft::resources const& res, + raft::device_matrix_view rotation_matrix, + bool force_random_rotation) +{ + RAFT_EXPECTS(rotation_matrix.extent(0) > 0 && rotation_matrix.extent(1) > 0, + "rotation_matrix must have non-zero extents"); + + uint32_t rot_dim = rotation_matrix.extent(0); + uint32_t dim = rotation_matrix.extent(1); + + make_rotation_matrix(res, force_random_rotation, rot_dim, dim, rotation_matrix.data_handle()); +} + +void pad_centers_with_norms( + raft::resources const& res, + raft::device_matrix_view centers, + raft::device_matrix_view padded_centers) +{ + detail::pad_centers_with_norms(res, centers, padded_centers); +} + +void pad_centers_with_norms( + raft::resources const& res, + raft::host_matrix_view centers, + raft::device_matrix_view padded_centers) +{ + detail::pad_centers_with_norms(res, centers, padded_centers); +} + +void rotate_padded_centers( + raft::resources const& res, + raft::device_matrix_view padded_centers, + raft::device_matrix_view rotation_matrix, + raft::device_matrix_view rotated_centers) +{ + uint32_t n_lists = padded_centers.extent(0); + uint32_t centers_dim = padded_centers.extent(1); + uint32_t rot_dim = rotation_matrix.extent(0); + uint32_t dim = rotation_matrix.extent(1); + + RAFT_EXPECTS(rotated_centers.extent(0) == n_lists, + "centers_rot must have extent(0) == n_lists. Got centers_rot.extent(0) = %u, " + "expected %u", + rotated_centers.extent(0), + n_lists); + RAFT_EXPECTS(rotated_centers.extent(1) == rot_dim, + "centers_rot must have extent(1) == rot_dim. Got centers_rot.extent(1) = %u, " + "expected %u", + rotated_centers.extent(1), + rot_dim); + RAFT_EXPECTS(centers_dim >= dim, + "centers must have at least dim columns. Got centers.extent(1) = %u, " + "expected >= %u", + centers_dim, + dim); + + auto stream = raft::resource::get_cuda_stream(res); + + float alpha = 1.0f; + float beta = 0.0f; + + raft::linalg::gemm(res, + true, // transpose rotation_matrix + false, // don't transpose centers + rot_dim, + n_lists, + dim, + &alpha, + rotation_matrix.data_handle(), + dim, // lda (leading dim of rotation_matrix) + padded_centers.data_handle(), + centers_dim, // ldb (leading dim of centers, accounting for potential padding) + &beta, + rotated_centers.data_handle(), + rot_dim, // ldc (leading dim of output) + stream); +} + +} // namespace helpers + +} // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_serialize.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_serialize.cuh index 09d5cfce2b..b188f3c3cf 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_serialize.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_serialize.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -134,9 +134,9 @@ auto deserialize(raft::resources const& handle_, std::istream& is) -> index(pq_bits), static_cast(n_lists)); - auto index = cuvs::neighbors::ivf_pq::index( - handle_, metric, codebook_kind, n_lists, dim, pq_bits, pq_dim, cma); + index index(handle_, metric, codebook_kind, n_lists, dim, pq_bits, pq_dim, cma); + // Deserialize center/matrix data raft::deserialize_mdspan(handle_, is, index.pq_centers()); raft::deserialize_mdspan(handle_, is, index.centers()); raft::deserialize_mdspan(handle_, is, index.centers_rot()); diff --git a/cpp/src/neighbors/ivf_pq_impl.hpp b/cpp/src/neighbors/ivf_pq_impl.hpp new file mode 100644 index 0000000000..28618b27b7 --- /dev/null +++ b/cpp/src/neighbors/ivf_pq_impl.hpp @@ -0,0 +1,182 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +namespace cuvs::neighbors::ivf_pq { + +template +class index_impl : public index_iface { + public: + index_impl(raft::resources const& handle, + cuvs::distance::DistanceType metric, + codebook_gen codebook_kind, + uint32_t n_lists, + uint32_t dim, + uint32_t pq_bits, + uint32_t pq_dim, + bool conservative_memory_allocation); + + ~index_impl() = default; + index_impl(index_impl&&) = default; + index_impl& operator=(index_impl&&) = default; + index_impl(const index_impl&) = delete; + index_impl& operator=(const index_impl&) = delete; + + cuvs::distance::DistanceType metric() const noexcept override; + codebook_gen codebook_kind() const noexcept override; + uint32_t dim() const noexcept override; + uint32_t dim_ext() const noexcept override; + uint32_t rot_dim() const noexcept override; + uint32_t pq_bits() const noexcept override; + uint32_t pq_dim() const noexcept override; + uint32_t pq_len() const noexcept override; + uint32_t pq_book_size() const noexcept override; + uint32_t n_lists() const noexcept override; + bool conservative_memory_allocation() const noexcept override; + + std::vector>>& lists() noexcept override; + const std::vector>>& lists() const noexcept override; + + raft::device_vector_view list_sizes() noexcept override; + raft::device_vector_view list_sizes() + const noexcept override; + + raft::device_vector_view data_ptrs() noexcept override; + raft::device_vector_view data_ptrs() + const noexcept override; + + raft::device_vector_view inds_ptrs() noexcept override; + raft::device_vector_view inds_ptrs() + const noexcept override; + + raft::host_vector_view accum_sorted_sizes() noexcept override; + raft::host_vector_view accum_sorted_sizes() + const noexcept override; + + raft::device_matrix_view rotation_matrix_int8( + const raft::resources& res) const override; + raft::device_matrix_view rotation_matrix_half( + const raft::resources& res) const override; + raft::device_matrix_view centers_int8( + const raft::resources& res) const override; + raft::device_matrix_view centers_half( + const raft::resources& res) const override; + + uint32_t get_list_size_in_bytes(uint32_t label) const noexcept override; + + protected: + cuvs::distance::DistanceType metric_; + codebook_gen codebook_kind_; + uint32_t dim_; + uint32_t pq_bits_; + uint32_t pq_dim_; + bool conservative_memory_allocation_; + + std::vector>> lists_; + raft::device_vector list_sizes_; + raft::device_vector data_ptrs_; + raft::device_vector inds_ptrs_; + raft::host_vector accum_sorted_sizes_; + + mutable std::optional> centers_int8_; + mutable std::optional> centers_half_; + mutable std::optional> + rotation_matrix_int8_; + mutable std::optional> rotation_matrix_half_; + + void check_consistency(); +}; + +template +class owning_impl : public index_impl { + public: + owning_impl(raft::resources const& handle, + cuvs::distance::DistanceType metric, + codebook_gen codebook_kind, + uint32_t n_lists, + uint32_t dim, + uint32_t pq_bits, + uint32_t pq_dim, + bool conservative_memory_allocation); + + ~owning_impl() = default; + + owning_impl(owning_impl&&) = default; + owning_impl& operator=(owning_impl&&) = default; + owning_impl(const owning_impl&) = delete; + owning_impl& operator=(const owning_impl&) = delete; + + raft::device_mdspan pq_centers() noexcept override; + raft::device_mdspan pq_centers() + const noexcept override; + + raft::device_matrix_view centers() noexcept override; + raft::device_matrix_view centers() + const noexcept override; + + raft::device_matrix_view centers_rot() noexcept override; + raft::device_matrix_view centers_rot() + const noexcept override; + + raft::device_matrix_view rotation_matrix() noexcept override; + raft::device_matrix_view rotation_matrix() + const noexcept override; + + private: + raft::device_mdarray pq_centers_; + raft::device_matrix centers_; + raft::device_matrix centers_rot_; + raft::device_matrix rotation_matrix_; +}; + +template +class view_impl : public index_impl { + public: + view_impl(raft::resources const& handle, + cuvs::distance::DistanceType metric, + codebook_gen codebook_kind, + uint32_t n_lists, + uint32_t dim, + uint32_t pq_bits, + uint32_t pq_dim, + bool conservative_memory_allocation, + raft::device_mdspan pq_centers_view, + raft::device_matrix_view centers_view, + raft::device_matrix_view centers_rot_view, + raft::device_matrix_view rotation_matrix_view); + + ~view_impl() = default; + view_impl(view_impl&&) = default; + view_impl& operator=(view_impl&&) = default; + view_impl(const view_impl&) = delete; + view_impl& operator=(const view_impl&) = delete; + + raft::device_mdspan pq_centers() noexcept override; + raft::device_mdspan pq_centers() + const noexcept override; + + raft::device_matrix_view centers() noexcept override; + raft::device_matrix_view centers() + const noexcept override; + + raft::device_matrix_view centers_rot() noexcept override; + raft::device_matrix_view centers_rot() + const noexcept override; + + raft::device_matrix_view rotation_matrix() noexcept override; + raft::device_matrix_view rotation_matrix() + const noexcept override; + + private: + raft::device_mdspan pq_centers_view_; + raft::device_matrix_view centers_view_; + raft::device_matrix_view centers_rot_view_; + raft::device_matrix_view rotation_matrix_view_; +}; + +} // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/src/neighbors/ivf_pq_index.cu b/cpp/src/neighbors/ivf_pq_index.cu index 6c99f2c693..140e74f7ec 100644 --- a/cpp/src/neighbors/ivf_pq_index.cu +++ b/cpp/src/neighbors/ivf_pq_index.cu @@ -6,6 +6,7 @@ #include #include "detail/ann_utils.cuh" +#include "ivf_pq_impl.hpp" #include #include @@ -14,6 +15,343 @@ #include namespace cuvs::neighbors::ivf_pq { + +template +index_impl::index_impl(raft::resources const& handle, + cuvs::distance::DistanceType metric, + codebook_gen codebook_kind, + uint32_t n_lists, + uint32_t dim, + uint32_t pq_bits, + uint32_t pq_dim, + bool conservative_memory_allocation) + : metric_(metric), + codebook_kind_(codebook_kind), + dim_(dim), + pq_bits_(pq_bits), + pq_dim_(pq_dim == 0 ? index::calculate_pq_dim(dim) : pq_dim), + conservative_memory_allocation_(conservative_memory_allocation), + lists_(n_lists), + list_sizes_{raft::make_device_vector(handle, n_lists)}, + data_ptrs_{raft::make_device_vector(handle, n_lists)}, + inds_ptrs_{raft::make_device_vector(handle, n_lists)}, + accum_sorted_sizes_{raft::make_host_vector(n_lists + 1)} +{ + check_consistency(); + accum_sorted_sizes_(n_lists) = 0; +} + +template +cuvs::distance::DistanceType index_impl::metric() const noexcept +{ + return metric_; +} + +template +codebook_gen index_impl::codebook_kind() const noexcept +{ + return codebook_kind_; +} + +template +uint32_t index_impl::dim() const noexcept +{ + return dim_; +} + +template +uint32_t index_impl::dim_ext() const noexcept +{ + return raft::round_up_safe(dim_ + 1, 8u); +} + +template +uint32_t index_impl::rot_dim() const noexcept +{ + return pq_len() * pq_dim_; +} + +template +uint32_t index_impl::pq_bits() const noexcept +{ + return pq_bits_; +} + +template +uint32_t index_impl::pq_dim() const noexcept +{ + return pq_dim_; +} + +template +uint32_t index_impl::pq_len() const noexcept +{ + return raft::div_rounding_up_unsafe(dim_, pq_dim_); +} + +template +uint32_t index_impl::pq_book_size() const noexcept +{ + return 1 << pq_bits_; +} + +template +uint32_t index_impl::n_lists() const noexcept +{ + return lists_.size(); +} + +template +bool index_impl::conservative_memory_allocation() const noexcept +{ + return conservative_memory_allocation_; +} + +template +std::vector>>& index_impl::lists() noexcept +{ + return lists_; +} + +template +const std::vector>>& index_impl::lists() const noexcept +{ + return lists_; +} + +template +raft::device_vector_view +index_impl::list_sizes() noexcept +{ + return list_sizes_.view(); +} + +template +raft::device_vector_view index_impl::list_sizes() + const noexcept +{ + return list_sizes_.view(); +} + +template +raft::device_vector_view index_impl::data_ptrs() noexcept +{ + return data_ptrs_.view(); +} + +template +raft::device_vector_view +index_impl::data_ptrs() const noexcept +{ + return data_ptrs_.view(); +} + +template +raft::device_vector_view index_impl::inds_ptrs() noexcept +{ + return inds_ptrs_.view(); +} + +template +raft::device_vector_view index_impl::inds_ptrs() + const noexcept +{ + return raft::make_mdspan( + inds_ptrs_.data_handle(), inds_ptrs_.extents()); +} + +template +raft::host_vector_view +index_impl::accum_sorted_sizes() noexcept +{ + return accum_sorted_sizes_.view(); +} + +template +raft::host_vector_view index_impl::accum_sorted_sizes() + const noexcept +{ + return accum_sorted_sizes_.view(); +} + +template +owning_impl::owning_impl(raft::resources const& handle, + cuvs::distance::DistanceType metric, + codebook_gen codebook_kind, + uint32_t n_lists, + uint32_t dim, + uint32_t pq_bits, + uint32_t pq_dim, + bool conservative_memory_allocation) + : index_impl( + handle, metric, codebook_kind, n_lists, dim, pq_bits, pq_dim, conservative_memory_allocation), + pq_centers_{raft::make_device_mdarray( + handle, index::make_pq_centers_extents(dim, pq_dim, pq_bits, codebook_kind, n_lists))}, + centers_{ + raft::make_device_matrix(handle, n_lists, raft::round_up_safe(dim + 1, 8u))}, + centers_rot_{raft::make_device_matrix( + handle, n_lists, raft::div_rounding_up_unsafe(dim, pq_dim) * pq_dim)}, + rotation_matrix_{raft::make_device_matrix( + handle, raft::div_rounding_up_unsafe(dim, pq_dim) * pq_dim, dim)} +{ +} + +template +pq_centers_extents index::make_pq_centers_extents( + uint32_t dim, uint32_t pq_dim, uint32_t pq_bits, codebook_gen codebook_kind, uint32_t n_lists) +{ + uint32_t pq_len = raft::div_rounding_up_unsafe(dim, pq_dim); + uint32_t pq_book_size = 1u << pq_bits; + switch (codebook_kind) { + case codebook_gen::PER_SUBSPACE: + return raft::make_extents(pq_dim, pq_len, pq_book_size); + case codebook_gen::PER_CLUSTER: + return raft::make_extents(n_lists, pq_len, pq_book_size); + default: RAFT_FAIL("Unreachable code"); + } +} + +template +view_impl::view_impl( + raft::resources const& handle, + cuvs::distance::DistanceType metric, + codebook_gen codebook_kind, + uint32_t n_lists, + uint32_t dim, + uint32_t pq_bits, + uint32_t pq_dim, + bool conservative_memory_allocation, + raft::device_mdspan pq_centers_view, + raft::device_matrix_view centers_view, + raft::device_matrix_view centers_rot_view, + raft::device_matrix_view rotation_matrix_view) + : index_impl( + handle, metric, codebook_kind, n_lists, dim, pq_bits, pq_dim, conservative_memory_allocation), + pq_centers_view_(pq_centers_view), + centers_view_(centers_view), + centers_rot_view_(centers_rot_view), + rotation_matrix_view_(rotation_matrix_view) +{ +} + +template +raft::device_mdspan +owning_impl::pq_centers() noexcept +{ + return pq_centers_.view(); +} + +template +raft::device_mdspan +owning_impl::pq_centers() const noexcept +{ + return pq_centers_.view(); +} + +template +raft::device_matrix_view owning_impl::centers() noexcept +{ + return centers_.view(); +} + +template +raft::device_matrix_view owning_impl::centers() + const noexcept +{ + return centers_.view(); +} + +template +raft::device_matrix_view owning_impl::centers_rot() noexcept +{ + return centers_rot_.view(); +} + +template +raft::device_matrix_view owning_impl::centers_rot() + const noexcept +{ + return centers_rot_.view(); +} + +template +raft::device_matrix_view +owning_impl::rotation_matrix() noexcept +{ + return rotation_matrix_.view(); +} + +template +raft::device_matrix_view +owning_impl::rotation_matrix() const noexcept +{ + return rotation_matrix_.view(); +} + +template +raft::device_mdspan +view_impl::pq_centers() noexcept +{ + return raft::mdspan( + const_cast(pq_centers_view_.data_handle()), pq_centers_view_.extents()); +} + +template +raft::device_mdspan view_impl::pq_centers() + const noexcept +{ + return pq_centers_view_; +} + +template +raft::device_matrix_view view_impl::centers() noexcept +{ + return raft::make_device_matrix_view( + const_cast(centers_view_.data_handle()), + centers_view_.extent(0), + centers_view_.extent(1)); +} + +template +raft::device_matrix_view view_impl::centers() + const noexcept +{ + return centers_view_; +} + +template +raft::device_matrix_view view_impl::centers_rot() noexcept +{ + return raft::make_device_matrix_view( + const_cast(centers_rot_view_.data_handle()), + centers_rot_view_.extent(0), + centers_rot_view_.extent(1)); +} + +template +raft::device_matrix_view view_impl::centers_rot() + const noexcept +{ + return centers_rot_view_; +} + +template +raft::device_matrix_view +view_impl::rotation_matrix() noexcept +{ + return raft::make_device_matrix_view( + const_cast(rotation_matrix_view_.data_handle()), + rotation_matrix_view_.extent(0), + rotation_matrix_view_.extent(1)); +} + +template +raft::device_matrix_view view_impl::rotation_matrix() + const noexcept +{ + return rotation_matrix_view_; +} + index_params index_params::from_dataset(raft::matrix_extent dataset, cuvs::distance::DistanceType metric) { @@ -31,32 +369,21 @@ index_params index_params::from_dataset(raft::matrix_extent dataset, } template -index::index(raft::resources const& handle) - // this constructor is just for a temporary index, for use in the deserialization - // api. all the parameters here will get replaced with loaded values - that aren't - // necessarily known ahead of time before deserialization. - // TODO: do we even need a handle here - could just construct one? - : index(handle, - cuvs::distance::DistanceType::L2Expanded, - codebook_gen::PER_SUBSPACE, - 0, - 0, - 8, - 0, - true) +index::index(std::unique_ptr> impl) + : cuvs::neighbors::index(), impl_(std::move(impl)) { } template -index::index(raft::resources const& handle, const index_params& params, uint32_t dim) - : index(handle, - params.metric, - params.codebook_kind, - params.n_lists, - dim, - params.pq_bits, - params.pq_dim, - params.conservative_memory_allocation) +index::index(raft::resources const& handle) + : index(std::make_unique>(handle, + cuvs::distance::DistanceType::L2Expanded, + codebook_gen::PER_SUBSPACE, + 0, + 0, + 8, + 1, + true)) { } @@ -69,266 +396,247 @@ index::index(raft::resources const& handle, uint32_t pq_bits, uint32_t pq_dim, bool conservative_memory_allocation) - : cuvs::neighbors::index(), - metric_(metric), - codebook_kind_(codebook_kind), - dim_(dim), - pq_bits_(pq_bits), - pq_dim_(pq_dim == 0 ? calculate_pq_dim(dim) : pq_dim), - conservative_memory_allocation_(conservative_memory_allocation), - lists_{n_lists}, - list_sizes_{raft::make_device_vector(handle, n_lists)}, - pq_centers_{raft::make_device_mdarray(handle, make_pq_centers_extents())}, - centers_{raft::make_device_matrix(handle, n_lists, this->dim_ext())}, - centers_rot_{raft::make_device_matrix(handle, n_lists, this->rot_dim())}, - rotation_matrix_{ - raft::make_device_matrix(handle, this->rot_dim(), this->dim())}, - data_ptrs_{raft::make_device_vector(handle, n_lists)}, - inds_ptrs_{raft::make_device_vector(handle, n_lists)}, - accum_sorted_sizes_{raft::make_host_vector(n_lists + 1)} + : index( + std::make_unique>(handle, + metric, + codebook_kind, + n_lists, + dim, + pq_bits, + pq_dim == 0 ? index::calculate_pq_dim(dim) : pq_dim, + conservative_memory_allocation)) +{ +} + +template +index::index(raft::resources const& handle, const index_params& params, uint32_t dim) + : index(handle, + params.metric, + params.codebook_kind, + params.n_lists, + dim, + params.pq_bits, + params.pq_dim, + params.conservative_memory_allocation) { - check_consistency(); - accum_sorted_sizes_(n_lists) = 0; } +// Delegation methods - forward to impl accessor methods template IdxT index::size() const noexcept { - return accum_sorted_sizes_(n_lists()); + return accum_sorted_sizes()(n_lists()); } template uint32_t index::dim() const noexcept { - return dim_; + return impl_->dim(); } template uint32_t index::dim_ext() const noexcept { - return raft::round_up_safe(dim() + 1, 8u); + return impl_->dim_ext(); } template uint32_t index::rot_dim() const noexcept { - return pq_len() * pq_dim(); + return impl_->rot_dim(); } template uint32_t index::pq_bits() const noexcept { - return pq_bits_; + return impl_->pq_bits(); } template uint32_t index::pq_dim() const noexcept { - return pq_dim_; + return impl_->pq_dim(); } template uint32_t index::pq_len() const noexcept { - return raft::div_rounding_up_unsafe(dim(), pq_dim()); + return impl_->pq_len(); } template uint32_t index::pq_book_size() const noexcept { - return 1 << pq_bits(); + return impl_->pq_book_size(); } template cuvs::distance::DistanceType index::metric() const noexcept { - return metric_; + return impl_->metric(); } template codebook_gen index::codebook_kind() const noexcept { - return codebook_kind_; + return impl_->codebook_kind(); } template uint32_t index::n_lists() const noexcept { - return lists_.size(); + return impl_->n_lists(); } template bool index::conservative_memory_allocation() const noexcept { - return conservative_memory_allocation_; + return impl_->conservative_memory_allocation(); } template -raft::device_mdspan::pq_centers_extents, - raft::row_major> -index::pq_centers() noexcept +raft::device_mdspan index::pq_centers() noexcept { - return pq_centers_.view(); + return impl_->pq_centers(); } template -raft::device_mdspan::pq_centers_extents, - raft::row_major> -index::pq_centers() const noexcept +raft::device_mdspan index::pq_centers() + const noexcept { - return pq_centers_.view(); + return impl_->pq_centers(); } template -std::vector>>& index::lists() noexcept +raft::device_matrix_view index::centers() noexcept { - return lists_; + return impl_->centers(); } template -const std::vector>>& index::lists() const noexcept +raft::device_matrix_view index::centers() + const noexcept { - return lists_; + return impl_->centers(); } template -raft::device_vector_view index::data_ptrs() noexcept +raft::device_matrix_view index::centers_rot() noexcept { - return data_ptrs_.view(); + return impl_->centers_rot(); } template -raft::device_vector_view index::data_ptrs() +raft::device_matrix_view index::centers_rot() const noexcept { - return raft::make_mdspan( - data_ptrs_.data_handle(), data_ptrs_.extents()); + return impl_->centers_rot(); } template -raft::device_vector_view index::inds_ptrs() noexcept +raft::device_matrix_view index::rotation_matrix() noexcept { - return inds_ptrs_.view(); + return impl_->rotation_matrix(); } template -raft::device_vector_view index::inds_ptrs() +raft::device_matrix_view index::rotation_matrix() const noexcept { - return raft::make_mdspan( - inds_ptrs_.data_handle(), inds_ptrs_.extents()); + return impl_->rotation_matrix(); } template -raft::device_matrix_view index::rotation_matrix() noexcept +std::vector>>& index::lists() noexcept { - return rotation_matrix_.view(); + return impl_->lists(); } template -raft::device_matrix_view index::rotation_matrix() - const noexcept +const std::vector>>& index::lists() const noexcept { - return rotation_matrix_.view(); + return impl_->lists(); } template -raft::host_vector_view index::accum_sorted_sizes() noexcept +raft::device_vector_view index::data_ptrs() noexcept { - return accum_sorted_sizes_.view(); + return impl_->data_ptrs(); } template -raft::host_vector_view index::accum_sorted_sizes() +raft::device_vector_view index::data_ptrs() const noexcept { - return accum_sorted_sizes_.view(); + return impl_->data_ptrs(); } template -raft::device_vector_view index::list_sizes() noexcept +raft::device_vector_view index::inds_ptrs() noexcept { - return list_sizes_.view(); + return impl_->inds_ptrs(); } template -raft::device_vector_view index::list_sizes() +raft::device_vector_view index::inds_ptrs() const noexcept { - return list_sizes_.view(); + return impl_->inds_ptrs(); } template -raft::device_matrix_view index::centers() noexcept +raft::host_vector_view index::accum_sorted_sizes() noexcept { - return centers_.view(); + return impl_->accum_sorted_sizes(); } template -raft::device_matrix_view index::centers() +raft::host_vector_view index::accum_sorted_sizes() const noexcept { - return centers_.view(); + return impl_->accum_sorted_sizes(); } template -raft::device_matrix_view index::centers_rot() noexcept +raft::device_vector_view index::list_sizes() noexcept { - return centers_rot_.view(); + return impl_->list_sizes(); } template -raft::device_matrix_view index::centers_rot() +raft::device_vector_view index::list_sizes() const noexcept { - return centers_rot_.view(); + return impl_->list_sizes(); } +// centers() and centers_rot() are now pure virtual and implemented in derived classes + template -uint32_t index::get_list_size_in_bytes(uint32_t label) +uint32_t index::get_list_size_in_bytes(uint32_t label) const noexcept { - RAFT_EXPECTS(label < this->n_lists(), - "Expected label to be less than number of lists in the index"); - auto& list_data = this->lists()[label]->data; - return list_data.size(); + return impl_->get_list_size_in_bytes(label); } template -void index::check_consistency() +void index_impl::check_consistency() { - RAFT_EXPECTS(pq_bits() >= 4 && pq_bits() <= 8, + RAFT_EXPECTS(pq_bits_ >= 4 && pq_bits_ <= 8, "`pq_bits` must be within closed range [4,8], but got %u.", - pq_bits()); - RAFT_EXPECTS((pq_bits() * pq_dim()) % 8 == 0, + pq_bits_); + RAFT_EXPECTS((pq_bits_ * pq_dim_) % 8 == 0, "`pq_bits * pq_dim` must be a multiple of 8, but got %u * %u = %u.", - pq_bits(), - pq_dim(), - pq_bits() * pq_dim()); -} - -template -typename index::pq_centers_extents index::make_pq_centers_extents() -{ - switch (codebook_kind()) { - case codebook_gen::PER_SUBSPACE: - return raft::make_extents(pq_dim(), pq_len(), pq_book_size()); - case codebook_gen::PER_CLUSTER: - return raft::make_extents(n_lists(), pq_len(), pq_book_size()); - default: RAFT_FAIL("Unreachable code"); - } + pq_bits_, + pq_dim_, + pq_bits_ * pq_dim_); } template uint32_t index::calculate_pq_dim(uint32_t dim) { - // If the dimensionality is large enough, we can reduce it to improve performance if (dim >= 128) { dim /= 2; } - // Round it down to 32 to improve performance. auto r = raft::round_down_safe(dim, 32); if (r > 0) return r; - // If the dimensionality is really low, round it to the closest power-of-two r = 1; while ((r << 1) <= dim) { r = r << 1; @@ -337,31 +645,40 @@ uint32_t index::calculate_pq_dim(uint32_t dim) } template -raft::device_matrix_view index::rotation_matrix_int8( - const raft::resources& res) const +uint32_t index_impl::get_list_size_in_bytes(uint32_t label) const noexcept +{ + RAFT_EXPECTS(label < lists_.size(), + "Expected label to be less than number of lists in the index"); + auto& list_data = lists_[label]->data; + return list_data.size(); +} + +template +raft::device_matrix_view +index_impl::rotation_matrix_int8(const raft::resources& res) const { if (!rotation_matrix_int8_.has_value()) { rotation_matrix_int8_.emplace( - raft::make_device_mdarray(res, rotation_matrix().extents())); + raft::make_device_mdarray(res, this->rotation_matrix().extents())); raft::linalg::map(res, rotation_matrix_int8_->view(), cuvs::spatial::knn::detail::utils::mapping{}, - rotation_matrix()); + this->rotation_matrix()); } return rotation_matrix_int8_->view(); } template -raft::device_matrix_view index::centers_int8( +raft::device_matrix_view index_impl::centers_int8( const raft::resources& res) const { if (!centers_int8_.has_value()) { - uint32_t n_lists = this->n_lists(); + uint32_t n_lists = lists().size(); uint32_t dim = this->dim(); - uint32_t dim_ext = this->dim_ext(); + uint32_t dim_ext = raft::round_up_safe(dim + 1, 8u); uint32_t dim_ext_int8 = raft::round_up_safe(dim + 2, 16u); centers_int8_.emplace(raft::make_device_matrix(res, n_lists, dim_ext_int8)); - auto* inputs = centers().data_handle(); + auto* inputs = this->centers().data_handle(); /* NOTE: maximizing the range and the precision of int8_t GEMM int8_t has a very limited range [-128, 127], which is problematic when storing both vectors and @@ -397,48 +714,83 @@ raft::device_matrix_view index::c it is limited by the range we can cover (the squared norm must be within `m * 2` before normalization). */ - raft::linalg::map_offset( - res, centers_int8_->view(), [dim, dim_ext, dim_ext_int8, inputs] __device__(uint32_t ix) { - uint32_t col = ix % dim_ext_int8; - uint32_t row = ix / dim_ext_int8; - if (col < dim) { - return static_cast( - std::clamp(inputs[col + row * dim_ext] * 128.0f, -128.0f, 127.f)); - } - auto x = inputs[row * dim_ext + dim]; - auto c = 64.0f / static_cast(dim_ext_int8 - dim - 1); - auto y = std::clamp(x * c, -128.0f, 127.f); - auto z = std::clamp((y - std::round(y)) * 128.0f, -128.0f, 127.f); - if (col > dim) { return static_cast(std::round(y)); } - return static_cast(z); - }); + raft::linalg::map_offset(res, + this->centers_int8_->view(), + [dim, dim_ext, dim_ext_int8, inputs] __device__(uint32_t ix) { + uint32_t col = ix % dim_ext_int8; + uint32_t row = ix / dim_ext_int8; + if (col < dim) { + return static_cast(std::clamp( + inputs[col + row * dim_ext] * 128.0f, -128.0f, 127.f)); + } + auto x = inputs[row * dim_ext + dim]; + auto c = 64.0f / static_cast(dim_ext_int8 - dim - 1); + auto y = std::clamp(x * c, -128.0f, 127.f); + auto z = std::clamp((y - std::round(y)) * 128.0f, -128.0f, 127.f); + if (col > dim) { return static_cast(std::round(y)); } + return static_cast(z); + }); } return centers_int8_->view(); } template -raft::device_matrix_view index::rotation_matrix_half( - const raft::resources& res) const +raft::device_matrix_view +index_impl::rotation_matrix_half(const raft::resources& res) const { if (!rotation_matrix_half_.has_value()) { rotation_matrix_half_.emplace( - raft::make_device_mdarray(res, rotation_matrix().extents())); - raft::linalg::map(res, rotation_matrix_half_->view(), raft::cast_op{}, rotation_matrix()); + raft::make_device_mdarray(res, this->rotation_matrix().extents())); + raft::linalg::map( + res, rotation_matrix_half_->view(), raft::cast_op{}, this->rotation_matrix()); } return rotation_matrix_half_->view(); } template -raft::device_matrix_view index::centers_half( +raft::device_matrix_view index_impl::centers_half( const raft::resources& res) const { if (!centers_half_.has_value()) { - centers_half_.emplace(raft::make_device_mdarray(res, centers().extents())); - raft::linalg::map(res, centers_half_->view(), raft::cast_op{}, centers()); + centers_half_.emplace( + raft::make_device_mdarray(res, this->centers().extents())); + raft::linalg::map(res, centers_half_->view(), raft::cast_op{}, this->centers()); } return centers_half_->view(); } +template +raft::device_matrix_view index::rotation_matrix_int8( + const raft::resources& res) const +{ + return impl_->rotation_matrix_int8(res); +} + +template +raft::device_matrix_view index::centers_int8( + const raft::resources& res) const +{ + return impl_->centers_int8(res); +} + +template +raft::device_matrix_view index::rotation_matrix_half( + const raft::resources& res) const +{ + return impl_->rotation_matrix_half(res); +} + +template +raft::device_matrix_view index::centers_half( + const raft::resources& res) const +{ + return impl_->centers_half(res); +} + +template class index_iface; +template class index_impl; +template struct owning_impl; +template struct view_impl; template struct index; } // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/tests/neighbors/ann_ivf_pq.cuh b/cpp/tests/neighbors/ann_ivf_pq.cuh index 4660c5d0d3..fd4469072a 100644 --- a/cpp/tests/neighbors/ann_ivf_pq.cuh +++ b/cpp/tests/neighbors/ann_ivf_pq.cuh @@ -266,6 +266,54 @@ class ivf_pq_test : public ::testing::TestWithParam { return index; } + void build_precomputed() + { + auto ipams = ps.index_params; + ipams.add_data_on_build = false; + auto database_view = + raft::make_device_matrix_view(database.data(), ps.num_db_vecs, ps.dim); + const auto& base_index = cuvs::neighbors::ivf_pq::build(handle_, ipams, database_view); + + auto view_index = cuvs::neighbors::ivf_pq::build(handle_, + ipams, + base_index.dim(), + base_index.pq_centers(), + base_index.centers(), + base_index.centers_rot(), + base_index.rotation_matrix()); + + ASSERT_EQ(base_index.pq_centers().data_handle(), view_index.pq_centers().data_handle()); + ASSERT_EQ(base_index.centers().data_handle(), view_index.centers().data_handle()); + ASSERT_EQ(base_index.centers_rot().data_handle(), view_index.centers_rot().data_handle()); + ASSERT_EQ(base_index.rotation_matrix().data_handle(), + view_index.rotation_matrix().data_handle()); + + ASSERT_EQ(base_index.pq_centers().extents(), view_index.pq_centers().extents()); + ASSERT_EQ(base_index.centers().extents(), view_index.centers().extents()); + ASSERT_EQ(base_index.centers_rot().extents(), view_index.centers_rot().extents()); + ASSERT_EQ(base_index.rotation_matrix().extents(), view_index.rotation_matrix().extents()); + + auto db_indices = raft::make_device_vector(handle_, ps.num_db_vecs); + raft::linalg::map_offset(handle_, db_indices.view(), raft::identity_op{}); + + auto vecs_view = + raft::make_device_matrix_view(database.data(), ps.num_db_vecs, ps.dim); + auto inds_view = + raft::make_device_vector_view(db_indices.data_handle(), ps.num_db_vecs); + + cuvs::neighbors::ivf_pq::extend(handle_, vecs_view, inds_view, &view_index); + cuvs::neighbors::ivf_pq::extend(handle_, + vecs_view, + inds_view, + const_cast*>(&base_index)); + + // Verify that both indices have identical list sizes after extension + ASSERT_TRUE(cuvs::devArrMatch(base_index.list_sizes().data_handle(), + view_index.list_sizes().data_handle(), + base_index.n_lists(), + cuvs::Compare{})); + } + void check_reconstruction(const index& index, double compression_ratio, uint32_t label, @@ -1094,6 +1142,9 @@ inline auto special_cases() -> test_cases_t this->run([this]() { return this->build_serialize(); }); \ } +#define TEST_BUILD_PRECOMPUTED(type) \ + TEST_P(type, build_precomputed) /* NOLINT */ { this->build_precomputed(); } + #define INSTANTIATE(type, vals) \ INSTANTIATE_TEST_SUITE_P(IvfPq, type, ::testing::ValuesIn(vals)); /* NOLINT */ diff --git a/cpp/tests/neighbors/ann_ivf_pq/test_float_int64_t.cu b/cpp/tests/neighbors/ann_ivf_pq/test_float_int64_t.cu index ac7108460d..fd5a8d4842 100644 --- a/cpp/tests/neighbors/ann_ivf_pq/test_float_int64_t.cu +++ b/cpp/tests/neighbors/ann_ivf_pq/test_float_int64_t.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -14,6 +14,7 @@ TEST_BUILD_HOST_INPUT_SEARCH(f32_f32_i64) TEST_BUILD_HOST_INPUT_OVERLAP_SEARCH(f32_f32_i64) TEST_BUILD_EXTEND_SEARCH(f32_f32_i64) TEST_BUILD_SERIALIZE_SEARCH(f32_f32_i64) +TEST_BUILD_PRECOMPUTED(f32_f32_i64) INSTANTIATE(f32_f32_i64, defaults() + small_dims() + big_dims_moderate_lut() + enum_variety_l2() + enum_variety_l2sqrt() + enum_variety_ip() + enum_variety_cosine());