diff --git a/ci/test_python.sh b/ci/test_python.sh index bf1cfe43a4..c17bce79fb 100755 --- a/ci/test_python.sh +++ b/ci/test_python.sh @@ -48,6 +48,7 @@ pytest \ --cov=cuvs \ --cov-report=xml:"${RAPIDS_COVERAGE_DIR}/cuvs-coverage.xml" \ --cov-report=term \ + -s -v \ tests rapids-logger "pytest cuvs-bench" diff --git a/ci/test_wheel_cuvs.sh b/ci/test_wheel_cuvs.sh index 427dc69adc..a953d9ad4f 100755 --- a/ci/test_wheel_cuvs.sh +++ b/ci/test_wheel_cuvs.sh @@ -18,4 +18,4 @@ rapids-pip-retry install \ "${LIBCUVS_WHEELHOUSE}"/libcuvs*.whl \ "$(echo "${CUVS_WHEELHOUSE}"/cuvs*.whl)[test]" -python -m pytest ./python/cuvs/cuvs/tests +python -m pytest ./python/cuvs/cuvs/tests -s -v diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index af15e4a399..f2ae2152b4 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -462,12 +462,10 @@ if(NOT BUILD_CPU_ONLY) src/neighbors/ivf_flat/ivf_flat_search_uint8_t_int64_t.cu src/neighbors/ivf_flat/ivf_flat_interleaved_scan_float_int64_t.cu src/neighbors/ivf_flat/ivf_flat_interleaved_scan_half_int64_t.cu - src/neighbors/ivf_flat/ivf_flat_interleaved_scan_int8_t_int64_t.cu - src/neighbors/ivf_flat/ivf_flat_interleaved_scan_uint8_t_int64_t.cu + src/neighbors/ivf_flat/ivf_flat_interleaved_scan_int8_uint8_t_int64_t.cu src/neighbors/ivf_flat/ivf_flat_interleaved_scan_float_int64_t_bitset.cu src/neighbors/ivf_flat/ivf_flat_interleaved_scan_half_int64_t_bitset.cu - src/neighbors/ivf_flat/ivf_flat_interleaved_scan_int8_t_int64_t_bitset.cu - src/neighbors/ivf_flat/ivf_flat_interleaved_scan_uint8_t_int64_t_bitset.cu + src/neighbors/ivf_flat/ivf_flat_interleaved_scan_int8_uint8_t_int64_t_bitset.cu src/neighbors/ivf_flat/ivf_flat_serialize_float_int64_t.cu src/neighbors/ivf_flat/ivf_flat_serialize_half_int64_t.cu src/neighbors/ivf_flat/ivf_flat_serialize_int8_t_int64_t.cu diff --git a/cpp/include/cuvs/core/byte_arithmetic_ptr.hpp b/cpp/include/cuvs/core/byte_arithmetic_ptr.hpp new file mode 100644 index 0000000000..b88ee61c0f --- /dev/null +++ b/cpp/include/cuvs/core/byte_arithmetic_ptr.hpp @@ -0,0 +1,131 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace cuvs::detail { + +struct byte_arithmetic_ptr { + void* data = nullptr; + bool is_signed = false; + + __host__ __device__ byte_arithmetic_ptr(void* ptr, bool signed_flag) + : data(ptr), is_signed(signed_flag) + { + } + + // Proxy that references an element in the array + struct byte { + byte_arithmetic_ptr* parent = nullptr; + int64_t idx = -1; + uint8_t value = 0; // used for detached proxies + + // Constructor for live proxy + __host__ __device__ byte(byte_arithmetic_ptr& p, int64_t i) : parent(&p), idx(i) {} + + // Copy constructor: detached copy stores the current value + __host__ __device__ byte(const byte& other) + : parent(nullptr), idx(-1), value(static_cast(other)) + { + } + + // Copy assignment: detached copy stores value + __host__ __device__ byte& operator=(const byte& other) + { + parent = nullptr; + idx = -1; + value = static_cast(other); + return *this; + } + + // Deleted move operations + __host__ __device__ byte(byte&& other) = delete; + __host__ __device__ byte& operator=(byte&& other) = delete; + + // Conversion to uint8_t + __host__ __device__ operator uint8_t() const + { + if (parent) { + if (parent->is_signed) { + int8_t val = reinterpret_cast(parent->data)[idx]; + return static_cast(static_cast(val) + 128); + } else { + return reinterpret_cast(parent->data)[idx]; + } + } else { + return value; // return local value if detached + } + } + + // Assignment from uint8_t + __host__ __device__ byte& operator=(uint8_t normalized_value) + { + if (parent) { + if (parent->is_signed) { + reinterpret_cast(parent->data)[idx] = + static_cast(static_cast(normalized_value) - 128); + } else { + reinterpret_cast(parent->data)[idx] = normalized_value; + } + } else { + value = normalized_value; // store in local value if detached + } + return *this; + } + }; + + // Non-const index access: returns live proxy + __host__ __device__ byte operator[](int64_t idx) { return byte(*this, idx); } + + // Const index access: returns immediate value + __host__ __device__ uint8_t operator[](int64_t idx) const + { + if (is_signed) { + int8_t val = reinterpret_cast(data)[idx]; + return static_cast(static_cast(val) + 128); + } else { + return reinterpret_cast(data)[idx]; + } + } + + // Dereference (like *ptr) + __host__ __device__ uint8_t operator*() const { return (*this)[0]; } + __host__ __device__ byte operator*() { return byte(*this, 0); } + + // Pointer arithmetic + __host__ __device__ byte_arithmetic_ptr operator+(int64_t offset) const + { + if (is_signed) + return byte_arithmetic_ptr(static_cast(data) + offset, true); + else + return byte_arithmetic_ptr(static_cast(data) + offset, false); + } + + __host__ __device__ bool operator==(const byte_arithmetic_ptr& other) const + { + return data == other.data; + } + __host__ __device__ bool operator!=(const byte_arithmetic_ptr& other) const + { + return !(*this == other); + } + + __host__ __device__ uint8_t* get_uint8_ptr() const { return reinterpret_cast(data); } +}; + +} // namespace cuvs::detail diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh index 4c0bb3644a..0dea2a8a0c 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh @@ -7,13 +7,16 @@ #include "../ivf_common.cuh" #include "../sample_filter.cuh" +#include #include #include #include "../detail/ann_utils.cuh" +#include #include #include #include +#include #include #include // RAFT_CUDA_TRY #include @@ -22,6 +25,7 @@ #include #include +#include namespace cuvs::neighbors::ivf_flat::detail { @@ -75,6 +79,56 @@ __device__ inline void copy_vectorized(T* out, const T* in, uint32_t n) } } +// Specialization for byte_arithmetic_ptr -> uint8_t* (for int8_t normalization) +__device__ inline void copy_vectorized(uint8_t* out, + const cuvs::detail::byte_arithmetic_ptr& in, + uint32_t n) +{ + // For byte_arithmetic_ptr, copy element by element with normalization to uint8_t + for (int i = threadIdx.x; i < n; i += blockDim.x) { + out[i] = in.get_uint8_ptr()[i]; + } +} + +/** + * Pack 4 uint8_t values into a single uint32_t for vectorized operations. + */ +__device__ __forceinline__ uint32_t pack_bytes_uint32(uint8_t b0, + uint8_t b1, + uint8_t b2, + uint8_t b3) +{ + return (static_cast(b0) << 0) | (static_cast(b1) << 8) | + (static_cast(b2) << 16) | (static_cast(b3) << 24); +} + +/** + * Normalize int8_t bytes (stored as raw uint8_t) to actual uint8_t values by adding 128 to each + * byte. This is only needed for Euclidean distance with signed data, as Euclidean distance is + * translation-invariant. + */ +__device__ __forceinline__ uint32_t normalize_int8_packed(uint32_t packed) +{ + // Unpack 4 bytes, reinterpret as signed, add 128 to each, repack + uint8_t b0 = static_cast(static_cast(static_cast(packed >> 0)) + 128); + uint8_t b1 = static_cast(static_cast(static_cast(packed >> 8)) + 128); + uint8_t b2 = static_cast(static_cast(static_cast(packed >> 16)) + 128); + uint8_t b3 = static_cast(static_cast(static_cast(packed >> 24)) + 128); + return pack_bytes_uint32(b0, b1, b2, b3); +} + +/** + * Normalize int8_t bytes packed in uint16_t (2 bytes). + */ +__device__ __forceinline__ uint32_t normalize_int8_packed_u16(uint32_t packed) +{ + // Unpack 2 bytes, reinterpret as signed, add 128 to each, repack (result fits in uint16_t but + // returned as uint32_t) + uint8_t b0 = static_cast(static_cast(static_cast(packed >> 0)) + 128); + uint8_t b1 = static_cast(static_cast(static_cast(packed >> 8)) + 128); + return (static_cast(b0) << 0) | (static_cast(b1) << 8); +} + /** * @brief Load a part of a vector from the index and from query, compute the (part of the) distance * between them, and aggregate it using the provided Lambda; one structure per thread, per query, @@ -205,11 +259,13 @@ struct loadAndComputeDist(data) + loadIndex + j * kIndexGroupSize * veclen_int); + reinterpret_cast(data_ptr) + loadIndex + j * kIndexGroupSize * veclen_int); uint32_t queryRegs[veclen_int]; raft::lds(queryRegs, reinterpret_cast(query_shared + shmemIndex) + j * veclen_int); #pragma unroll for (int k = 0; k < veclen_int; k++) { + // Normalize int8_t bytes to uint8_t for Euclidean distance (translation-invariant) + if (is_signed) { + encV[k] = normalize_int8_packed(encV[k]); + queryRegs[k] = normalize_int8_packed(queryRegs[k]); + } compute_dist(dist, queryRegs[k], encV[k]); if constexpr (ComputeNorm) { norm_query = raft::dp4a(queryRegs[k], queryRegs[k], norm_query); @@ -231,28 +292,37 @@ struct loadAndComputeDist(query + baseLoadIndex)[lane_id] : 0; + (lane_id < 8) ? reinterpret_cast(query_ptr + baseLoadIndex)[lane_id] : 0; constexpr int stride = kUnroll * uint8_veclen; + const uint8_t* data_ptr = data.get_uint8_ptr(); #pragma unroll - for (int i = 0; i < raft::WarpSize / stride; ++i, data += stride * kIndexGroupSize) { + for (int i = 0; i < raft::WarpSize / stride; ++i, data_ptr += stride * kIndexGroupSize) { #pragma unroll for (int j = 0; j < kUnroll; ++j) { uint32_t encV[veclen_int]; - raft::ldg( - encV, - reinterpret_cast(data) + (lane_id + j * kIndexGroupSize) * veclen_int); + raft::ldg(encV, + reinterpret_cast(data_ptr) + + (lane_id + j * kIndexGroupSize) * veclen_int); const int d = (i * kUnroll + j) * veclen_int; #pragma unroll for (int k = 0; k < veclen_int; ++k) { uint32_t q = raft::shfl(queryReg, d + k, raft::WarpSize); + // Normalize int8_t bytes to uint8_t for Euclidean distance (translation-invariant) + if (is_signed) { + encV[k] = normalize_int8_packed(encV[k]); + q = normalize_int8_packed(q); + } compute_dist(dist, q, encV[k]); if constexpr (ComputeNorm) { norm_query = raft::dp4a(q, q, norm_query); @@ -261,24 +331,37 @@ struct loadAndComputeDist(data_ptr), is_signed); } - __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, - const uint8_t* query, - const int lane_id, - const int dim, - const int dimBlocks) + __device__ __forceinline__ void runLoadShflAndComputeRemainder( + cuvs::detail::byte_arithmetic_ptr& data, + const cuvs::detail::byte_arithmetic_ptr& query, + const int lane_id, + const int dim, + const int dimBlocks) { + const uint8_t* query_ptr = query.get_uint8_ptr(); + const bool is_signed = data.is_signed; constexpr int veclen_int = uint8_veclen / 4; const int loadDim = dimBlocks + lane_id * 4; // Here 4 is for 1 - int - uint32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; + uint32_t queryReg = + loadDim < dim ? reinterpret_cast(query_ptr + loadDim)[0] : 0; + + const uint8_t* data_ptr = data.get_uint8_ptr(); for (int d = 0; d < dim - dimBlocks; - d += uint8_veclen, data += kIndexGroupSize * uint8_veclen) { + d += uint8_veclen, data_ptr += kIndexGroupSize * uint8_veclen) { uint32_t enc[veclen_int]; - raft::ldg(enc, reinterpret_cast(data) + lane_id * veclen_int); + raft::ldg(enc, reinterpret_cast(data_ptr) + lane_id * veclen_int); #pragma unroll for (int k = 0; k < veclen_int; k++) { uint32_t q = raft::shfl(queryReg, (d / 4) + k, raft::WarpSize); + // Normalize int8_t bytes to uint8_t for Euclidean distance (translation-invariant) + if (is_signed) { + enc[k] = normalize_int8_packed(enc[k]); + q = normalize_int8_packed(q); + } compute_dist(dist, q, enc[k]); if constexpr (ComputeNorm) { norm_query = raft::dp4a(q, q, norm_query); @@ -286,6 +369,8 @@ struct loadAndComputeDist(data_ptr), is_signed); } }; @@ -304,15 +389,22 @@ struct loadAndComputeDist { { } - __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, + __device__ __forceinline__ void runLoadShmemCompute(const cuvs::detail::byte_arithmetic_ptr& data, const uint8_t* query_shared, int loadIndex, int shmemIndex) { + const uint8_t* data_ptr = data.get_uint8_ptr(); + const bool is_signed = data.is_signed; #pragma unroll for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; + uint32_t encV = reinterpret_cast(data_ptr)[loadIndex + j * kIndexGroupSize]; uint32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; + // Normalize int8_t bytes to uint8_t for Euclidean distance (translation-invariant) + if (is_signed) { + encV = normalize_int8_packed(encV); + queryRegs = normalize_int8_packed(queryRegs); + } compute_dist(dist, queryRegs, encV); if constexpr (ComputeNorm) { norm_query = raft::dp4a(queryRegs, queryRegs, norm_query); @@ -320,22 +412,31 @@ struct loadAndComputeDist { } } } - __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, - const uint8_t* query, - int baseLoadIndex, - const int lane_id) + __device__ __forceinline__ void runLoadShflAndCompute( + cuvs::detail::byte_arithmetic_ptr& data, + const cuvs::detail::byte_arithmetic_ptr& query, + int baseLoadIndex, + const int lane_id) { + const uint8_t* query_ptr = query.get_uint8_ptr(); + const bool is_signed = data.is_signed; uint32_t queryReg = - (lane_id < 8) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; + (lane_id < 8) ? reinterpret_cast(query_ptr + baseLoadIndex)[lane_id] : 0; constexpr int veclen = 4; constexpr int stride = kUnroll * veclen; + const uint8_t* data_ptr = data.get_uint8_ptr(); #pragma unroll - for (int i = 0; i < raft::WarpSize / stride; ++i, data += stride * kIndexGroupSize) { + for (int i = 0; i < raft::WarpSize / stride; ++i, data_ptr += stride * kIndexGroupSize) { #pragma unroll for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; + uint32_t encV = reinterpret_cast(data_ptr)[lane_id + j * kIndexGroupSize]; uint32_t q = raft::shfl(queryReg, i * kUnroll + j, raft::WarpSize); + // Normalize int8_t bytes to uint8_t for Euclidean distance (translation-invariant) + if (is_signed) { + encV = normalize_int8_packed(encV); + q = normalize_int8_packed(q); + } compute_dist(dist, q, encV); if constexpr (ComputeNorm) { norm_query = raft::dp4a(q, q, norm_query); @@ -343,26 +444,40 @@ struct loadAndComputeDist { } } } + // Update the byte_arithmetic_ptr by the total offset + data = cuvs::detail::byte_arithmetic_ptr(const_cast(data_ptr), is_signed); } - __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, - const uint8_t* query, - const int lane_id, - const int dim, - const int dimBlocks) + __device__ __forceinline__ void runLoadShflAndComputeRemainder( + cuvs::detail::byte_arithmetic_ptr& data, + const cuvs::detail::byte_arithmetic_ptr& query, + const int lane_id, + const int dim, + const int dimBlocks) { - constexpr int veclen = 4; - const int loadDim = dimBlocks + lane_id; - uint32_t queryReg = loadDim < dim ? reinterpret_cast(query)[loadDim] : 0; - for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { - uint32_t enc = reinterpret_cast(data)[lane_id]; + constexpr int veclen = 4; + const int loadDim = dimBlocks + lane_id; + const uint8_t* query_ptr = query.get_uint8_ptr(); + const bool is_signed = data.is_signed; + uint32_t queryReg = loadDim < dim ? reinterpret_cast(query_ptr)[loadDim] : 0; + + const uint8_t* data_ptr = data.get_uint8_ptr(); + for (int d = 0; d < dim - dimBlocks; d += veclen, data_ptr += kIndexGroupSize * veclen) { + uint32_t enc = reinterpret_cast(data_ptr)[lane_id]; uint32_t q = raft::shfl(queryReg, d / veclen, raft::WarpSize); + // Normalize int8_t bytes to uint8_t for Euclidean distance (translation-invariant) + if (is_signed) { + enc = normalize_int8_packed(enc); + q = normalize_int8_packed(q); + } compute_dist(dist, q, enc); if constexpr (ComputeNorm) { norm_query = raft::dp4a(q, q, norm_query); norm_data = raft::dp4a(enc, enc, norm_data); } } + // Update the byte_arithmetic_ptr by the total offset + data = cuvs::detail::byte_arithmetic_ptr(const_cast(data_ptr), is_signed); } }; @@ -379,15 +494,22 @@ struct loadAndComputeDist { { } - __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, + __device__ __forceinline__ void runLoadShmemCompute(const cuvs::detail::byte_arithmetic_ptr& data, const uint8_t* query_shared, int loadIndex, int shmemIndex) { + const uint8_t* data_ptr = data.get_uint8_ptr(); + const bool is_signed = data.is_signed; #pragma unroll for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; + uint32_t encV = reinterpret_cast(data_ptr)[loadIndex + j * kIndexGroupSize]; uint32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; + // Normalize int8_t bytes to uint8_t for Euclidean distance (translation-invariant) + if (is_signed) { + encV = normalize_int8_packed_u16(encV); + queryRegs = normalize_int8_packed_u16(queryRegs); + } compute_dist(dist, queryRegs, encV); if constexpr (ComputeNorm) { norm_query = raft::dp4a(queryRegs, queryRegs, norm_query); @@ -396,22 +518,31 @@ struct loadAndComputeDist { } } - __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, - const uint8_t* query, - int baseLoadIndex, - const int lane_id) + __device__ __forceinline__ void runLoadShflAndCompute( + cuvs::detail::byte_arithmetic_ptr& data, + const cuvs::detail::byte_arithmetic_ptr& query, + int baseLoadIndex, + const int lane_id) { + const uint8_t* query_ptr = query.get_uint8_ptr(); + const bool is_signed = data.is_signed; uint32_t queryReg = - (lane_id < 16) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; + (lane_id < 16) ? reinterpret_cast(query_ptr + baseLoadIndex)[lane_id] : 0; constexpr int veclen = 2; constexpr int stride = kUnroll * veclen; + const uint8_t* data_ptr = data.get_uint8_ptr(); #pragma unroll - for (int i = 0; i < raft::WarpSize / stride; ++i, data += stride * kIndexGroupSize) { + for (int i = 0; i < raft::WarpSize / stride; ++i, data_ptr += stride * kIndexGroupSize) { #pragma unroll for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; + uint32_t encV = reinterpret_cast(data_ptr)[lane_id + j * kIndexGroupSize]; uint32_t q = raft::shfl(queryReg, i * kUnroll + j, raft::WarpSize); + // Normalize int8_t bytes to uint8_t for Euclidean distance (translation-invariant) + if (is_signed) { + encV = normalize_int8_packed_u16(encV); + q = normalize_int8_packed_u16(q); + } compute_dist(dist, q, encV); if constexpr (ComputeNorm) { norm_query = raft::dp4a(q, q, norm_query); @@ -419,26 +550,41 @@ struct loadAndComputeDist { } } } + // Update the byte_arithmetic_ptr by the total offset + data = cuvs::detail::byte_arithmetic_ptr(const_cast(data_ptr), is_signed); } - __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, - const uint8_t* query, - const int lane_id, - const int dim, - const int dimBlocks) + __device__ __forceinline__ void runLoadShflAndComputeRemainder( + cuvs::detail::byte_arithmetic_ptr& data, + const cuvs::detail::byte_arithmetic_ptr& query, + const int lane_id, + const int dim, + const int dimBlocks) { - constexpr int veclen = 2; - int loadDim = dimBlocks + lane_id * veclen; - uint32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; - for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { - uint32_t enc = reinterpret_cast(data)[lane_id]; + constexpr int veclen = 2; + int loadDim = dimBlocks + lane_id * veclen; + const uint8_t* query_ptr = query.get_uint8_ptr(); + const bool is_signed = data.is_signed; + uint32_t queryReg = + loadDim < dim ? reinterpret_cast(query_ptr + loadDim)[0] : 0; + + const uint8_t* data_ptr = data.get_uint8_ptr(); + for (int d = 0; d < dim - dimBlocks; d += veclen, data_ptr += kIndexGroupSize * veclen) { + uint32_t enc = reinterpret_cast(data_ptr)[lane_id]; uint32_t q = raft::shfl(queryReg, d / veclen, raft::WarpSize); + // Normalize int8_t bytes to uint8_t for Euclidean distance (translation-invariant) + if (is_signed) { + enc = normalize_int8_packed_u16(enc); + q = normalize_int8_packed_u16(q); + } compute_dist(dist, q, enc); if constexpr (ComputeNorm) { norm_query = raft::dp4a(q, q, norm_query); norm_data = raft::dp4a(enc, enc, norm_data); } } + // Update the byte_arithmetic_ptr by the total offset + data = cuvs::detail::byte_arithmetic_ptr(const_cast(data_ptr), is_signed); } }; @@ -455,15 +601,21 @@ struct loadAndComputeDist { { } - __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, + __device__ __forceinline__ void runLoadShmemCompute(const cuvs::detail::byte_arithmetic_ptr& data, const uint8_t* query_shared, int loadIndex, int shmemIndex) { + const bool is_signed = data.is_signed; #pragma unroll for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = data[loadIndex + j * kIndexGroupSize]; + uint32_t encV = static_cast(data[loadIndex + j * kIndexGroupSize]); uint32_t queryRegs = query_shared[shmemIndex + j]; + // Normalize int8_t bytes to uint8_t for Euclidean distance (translation-invariant) + if (is_signed) { + queryRegs = + static_cast(static_cast(static_cast(queryRegs)) + 128); + } compute_dist(dist, queryRegs, encV); if constexpr (ComputeNorm) { norm_query += queryRegs * queryRegs; @@ -472,21 +624,23 @@ struct loadAndComputeDist { } } - __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, - const uint8_t* query, - int baseLoadIndex, - const int lane_id) + __device__ __forceinline__ void runLoadShflAndCompute( + cuvs::detail::byte_arithmetic_ptr& data, + const cuvs::detail::byte_arithmetic_ptr& query, + int baseLoadIndex, + const int lane_id) { - uint32_t queryReg = query[baseLoadIndex + lane_id]; + uint32_t queryReg = static_cast(query[baseLoadIndex + lane_id]); constexpr int veclen = 1; constexpr int stride = kUnroll * veclen; #pragma unroll - for (int i = 0; i < raft::WarpSize / stride; ++i, data += stride * kIndexGroupSize) { + for (int i = 0; i < raft::WarpSize / stride; ++i, data = data + stride * kIndexGroupSize) { #pragma unroll for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = data[lane_id + j * kIndexGroupSize]; + uint32_t encV = static_cast(data[lane_id + j * kIndexGroupSize]); uint32_t q = raft::shfl(queryReg, i * kUnroll + j, raft::WarpSize); + // Normalize int8_t bytes to uint8_t for Euclidean distance (translation-invariant) compute_dist(dist, q, encV); if constexpr (ComputeNorm) { norm_query += q * q; @@ -496,17 +650,19 @@ struct loadAndComputeDist { } } - __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, - const uint8_t* query, - const int lane_id, - const int dim, - const int dimBlocks) + __device__ __forceinline__ void runLoadShflAndComputeRemainder( + cuvs::detail::byte_arithmetic_ptr& data, + const cuvs::detail::byte_arithmetic_ptr& query, + const int lane_id, + const int dim, + const int dimBlocks) { constexpr int veclen = 1; int loadDim = dimBlocks + lane_id; - uint32_t queryReg = loadDim < dim ? query[loadDim] : 0; - for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { - uint32_t enc = data[lane_id]; + uint32_t queryReg = loadDim < dim ? static_cast(query[loadDim]) : 0; + + for (int d = 0; d < dim - dimBlocks; d += veclen, data = data + kIndexGroupSize * veclen) { + uint32_t enc = static_cast(data[lane_id]); uint32_t q = raft::shfl(queryReg, d, raft::WarpSize); compute_dist(dist, q, enc); if constexpr (ComputeNorm) { @@ -796,6 +952,22 @@ using block_sort_t = typename flat_block_sort::typ * @param[out] neighbors * @param[out] distances */ + +// Forward declaration for euclidean_dist +template +struct euclidean_dist; + +// Type trait to detect euclidean_dist +template +struct is_euclidean_dist : std::false_type {}; + +template +struct is_euclidean_dist> : std::true_type {}; + +template +constexpr bool byte_arithmetic_dispatch = + std::is_same_v || (std::is_same_v && is_euclidean_dist::value); + template + typename DataT = std::conditional_t, + cuvs::detail::byte_arithmetic_ptr, + const T*>, + typename ListDataT = std::conditional_t, + cuvs::detail::byte_arithmetic_ptr*, + const T* const*>> RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) interleaved_scan_kernel(Lambda compute_dist, - PostLambda post_process, + cuvs::distance::DistanceType metric, const uint32_t query_smem_elems, - const T* query, + DataT query, const uint32_t* coarse_index, - const T* const* list_data_ptrs, + ListDataT list_data_ptrs, const uint32_t* list_sizes, const uint32_t queries_offset, const uint32_t n_probes, @@ -834,7 +1011,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) // Make the query input and output point to this block's shared query { const int query_id = blockIdx.y; - query += query_id * dim; + query = query + query_id * dim; if constexpr (kManageLocalTopK) { neighbors += query_id * k * gridDim.x + blockIdx.x * k; distances += query_id * k * gridDim.x + blockIdx.x * k; @@ -888,7 +1065,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) AccT norm_query = 0; AccT norm_dataset = 0; // This is where this warp begins reading data (start position of an interleaved group) - const T* data = list_data_ptrs[list_id] + (group_id * kIndexGroupSize) * dim; + auto data = list_data_ptrs[list_id] + (group_id * kIndexGroupSize) * dim; // This is the vector a given lane/thread handles const uint32_t vec_id = group_id * raft::WarpSize + lane_id; @@ -900,7 +1077,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) loadAndComputeDist lc( dist, compute_dist, norm_query, norm_dataset); for (int pos = 0; pos < shm_assisted_dim; - pos += raft::WarpSize, data += kIndexGroupSize * raft::WarpSize) { + pos += raft::WarpSize, data = data + kIndexGroupSize * raft::WarpSize) { lc.runLoadShmemCompute(data, query_shared, lane_id, pos); } @@ -917,7 +1094,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) loadAndComputeDist<1, decltype(compute_dist), Veclen, T, AccT, ComputeNorm> lc( dist, compute_dist, norm_query, norm_dataset); for (int pos = full_warps_along_dim; pos < dim; - pos += Veclen, data += kIndexGroupSize * Veclen) { + pos += Veclen, data = data + kIndexGroupSize * Veclen) { lc.runLoadShmemCompute(data, query_shared, lane_id, pos); } } @@ -954,7 +1131,18 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) if constexpr (kManageLocalTopK) { __syncthreads(); queue.done(interleaved_scan_kernel_smem); - queue.store(distances, neighbors, post_process); + // Apply post-processing based on metric (runtime dispatch for acceptable one-time cost) + if (metric == cuvs::distance::DistanceType::L2SqrtExpanded || + metric == cuvs::distance::DistanceType::L2SqrtUnexpanded) { + queue.store(distances, neighbors, raft::sqrt_op{}); + } else if (metric == cuvs::distance::DistanceType::CosineExpanded) { + queue.store( + distances, + neighbors, + raft::compose_op(raft::add_const_op{1.0f}, raft::mul_const_op{-1.0f})); + } else { + queue.store(distances, neighbors, raft::identity_op{}); + } } } @@ -977,6 +1165,25 @@ uint32_t configure_launch_x(uint32_t numQueries, uint32_t n_probes, int32_t sMem return min_grid_x > n_probes ? n_probes : static_cast(min_grid_x); } +/** + * Functor to convert uint8_t pointers to byte_arithmetic_ptrs. + * Using a functor instead of a lambda ensures the same type is used across + * all template instantiations, avoiding ~40MB of duplicate host code. + */ +struct byte_arithmetic_ptr_converter { + bool is_signed; + + __host__ __device__ explicit byte_arithmetic_ptr_converter(bool is_signed_) + : is_signed(is_signed_) + { + } + + __device__ cuvs::detail::byte_arithmetic_ptr operator()(const uint8_t* ptr) const + { + return cuvs::detail::byte_arithmetic_ptr(const_cast(ptr), is_signed); + } +}; + template + typename Lambda> void launch_kernel(Lambda lambda, - PostLambda post_process, + cuvs::distance::DistanceType metric, const index& index, const T* queries, const uint32_t* coarse_index, @@ -1006,16 +1212,21 @@ void launch_kernel(Lambda lambda, { RAFT_EXPECTS(Veclen == index.veclen(), "Configured Veclen does not match the index interleaving pattern."); - constexpr auto kKernel = interleaved_scan_kernel; + // Only dispatch int8_t to uint8_t for euclidean_dist (to unify kernels) + // For inner_prod_dist, keep separate instantiations + using InstantiateT = std::conditional_t, uint8_t, T>; + // Also dispatch AccT to uint32_t for euclidean_dist with byte types (to match loadAndComputeDist + // specializations) + using InstantiateAccT = std::conditional_t, uint32_t, AccT>; + constexpr auto kKernel = interleaved_scan_kernel; const int max_query_smem = 16384; int query_smem_elems = std::min(max_query_smem / sizeof(T), raft::Pow2::roundUp(index.dim())); @@ -1037,6 +1248,23 @@ void launch_kernel(Lambda lambda, return; } + // For int8_t/uint8_t with euclidean_dist, pre-convert data pointers to byte_arithmetic_ptrs + // (only needs to be done once) + std::optional> + byte_arithmetic_ptr_list_data_ptrs; + if constexpr (byte_arithmetic_dispatch) { + constexpr bool is_signed = std::is_same_v; + byte_arithmetic_ptr_list_data_ptrs.emplace(index.data_ptrs().size(), stream); + // Cast to uint8_t* and use functor to ensure identical thrust::transform instantiation + const uint8_t* const* data_ptrs_uint8 = + reinterpret_cast(index.data_ptrs().data_handle()); + thrust::transform(rmm::exec_policy(stream), + data_ptrs_uint8, + data_ptrs_uint8 + index.data_ptrs().size(), + byte_arithmetic_ptr_list_data_ptrs->begin(), + byte_arithmetic_ptr_converter(is_signed)); + } + for (uint32_t query_offset = 0; query_offset < num_queries; query_offset += kMaxGridY) { uint32_t grid_dim_y = std::min(kMaxGridY, num_queries - query_offset); dim3 grid_dim(grid_dim_x, grid_dim_y, 1); @@ -1049,22 +1277,47 @@ void launch_kernel(Lambda lambda, block_dim.x, n_probes, smem_size); - kKernel<<>>(lambda, - post_process, - query_smem_elems, - queries, - coarse_index, - index.data_ptrs().data_handle(), - index.list_sizes().data_handle(), - queries_offset + query_offset, - n_probes, - k, - max_samples, - chunk_indices, - index.dim(), - sample_filter, - neighbors, - distances); + if constexpr (byte_arithmetic_dispatch) { + // For int8_t/uint8_t with euclidean_dist, wrap in byte_arithmetic_ptr with normalization + constexpr bool is_signed = std::is_same_v; + auto byte_arithmetic_ptr_queries = + cuvs::detail::byte_arithmetic_ptr(const_cast(queries), is_signed); + kKernel<<>>( + lambda, + metric, + query_smem_elems, + byte_arithmetic_ptr_queries, + coarse_index, + byte_arithmetic_ptr_list_data_ptrs->data(), + index.list_sizes().data_handle(), + queries_offset + query_offset, + n_probes, + k, + max_samples, + chunk_indices, + index.dim(), + sample_filter, + neighbors, + distances); + } else { + // For inner_prod_dist with int8_t/uint8_t, or other types (float, etc.), use raw pointers + kKernel<<>>(lambda, + metric, + query_smem_elems, + queries, + coarse_index, + index.data_ptrs().data_handle(), + index.list_sizes().data_handle(), + queries_offset + query_offset, + n_probes, + k, + max_samples, + chunk_indices, + index.dim(), + sample_filter, + neighbors, + distances); + } queries += grid_dim_y * index.dim(); if constexpr (Capacity > 0) { neighbors += grid_dim_y * grid_dim_x * k; @@ -1086,13 +1339,14 @@ struct euclidean_dist { } }; +// Specialization for uint8_t (handles both uint8_t and normalized int8_t via byte_arithmetic_ptr) template struct euclidean_dist { __device__ __forceinline__ void operator()(uint32_t& acc, uint32_t x, uint32_t y) { if constexpr (Veclen > 1) { - const auto diff = __vabsdiffu4(x, y); - acc = raft::dp4a(diff, diff, acc); + const uint32_t diff_u32 = __vabsdiffu4(x, y); + acc = raft::dp4a(diff_u32, diff_u32, acc); } else { const auto diff = __usad(x, y, 0u); acc += diff * diff; @@ -1100,24 +1354,6 @@ struct euclidean_dist { } }; -template -struct euclidean_dist { - __device__ __forceinline__ void operator()(int32_t& acc, int32_t x, int32_t y) - { - if constexpr (Veclen > 1) { - // Note that we enforce here that the unsigned version of dp4a is used, because the difference - // between two int8 numbers can be greater than 127 and therefore represented as a negative - // number in int8. Casting from int8 to int32 would yield incorrect results, while casting - // from uint8 to uint32 is correct. - const auto diff = __vabsdiffs4(x, y); - acc = raft::dp4a(diff, diff, static_cast(acc)); - } else { - const auto diff = x - y; - acc += diff * diff; - } - } -}; - template struct inner_prod_dist { __device__ __forceinline__ void operator()(AccT& acc, AccT x, AccT y) @@ -1141,58 +1377,29 @@ template void launch_with_fixed_consts(cuvs::distance::DistanceType metric, Args&&... args) { - switch (metric) { - case cuvs::distance::DistanceType::L2Expanded: - case cuvs::distance::DistanceType::L2Unexpanded: - return launch_kernel, - raft::identity_op>({}, {}, std::forward(args)...); - case cuvs::distance::DistanceType::L2SqrtExpanded: - case cuvs::distance::DistanceType::L2SqrtUnexpanded: - return launch_kernel, - raft::sqrt_op>({}, {}, std::forward(args)...); - case cuvs::distance::DistanceType::InnerProduct: - return launch_kernel, - raft::identity_op>({}, {}, std::forward(args)...); - case cuvs::distance::DistanceType::CosineExpanded: - // NB: "Ascending" is reversed because the post-processing step is done after that sort - return launch_kernel>( - {}, - raft::compose_op(raft::add_const_op{1.0f}, raft::mul_const_op{-1.0f}), - std::forward(args)...); // NB: update the description of `knn::ivf_flat::build` when - // adding here a new metric. - default: RAFT_FAIL("The chosen distance metric is not supported (%d)", int(metric)); + // Dispatch int8_t to uint8_t for the metric lambda types to match kernel instantiation + using MetricT = std::conditional_t, uint8_t, T>; + // Also dispatch AccT to uint32_t for byte types with Euclidean (to match the kernel's + // InstantiateAccT) + using MetricAccT = + std::conditional_t || std::is_same_v, uint32_t, AccT>; + + if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2Unexpanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded || + metric == cuvs::distance::DistanceType::L2SqrtUnexpanded) { + // For Euclidean distances, use MetricT and MetricAccT for the distance functor + return launch_kernel( + euclidean_dist{}, metric, std::forward(args)...); + } else if (metric == cuvs::distance::DistanceType::InnerProduct) { + return launch_kernel( + inner_prod_dist{}, metric, std::forward(args)...); + } else if (metric == cuvs::distance::DistanceType::CosineExpanded) { + // NB: "Ascending" is reversed because the post-processing step is done after that sort + return launch_kernel( + inner_prod_dist{}, metric, std::forward(args)...); + } else { + RAFT_FAIL("The chosen distance metric is not supported (%d)", int(metric)); } } @@ -1313,6 +1520,7 @@ void ivfflat_interleaved_scan(const index& index, auto filter_adapter = cuvs::neighbors::filtering::ivf_to_sample_filter( index.inds_ptrs().data_handle(), sample_filter); + select_interleaved_scan_kernel::run(capacity, index.veclen(), select_min, diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_int8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_int8_t_int64_t.cu deleted file mode 100644 index 418f246576..0000000000 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_int8_t_int64_t.cu +++ /dev/null @@ -1,12 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ - -#include "ivf_flat_interleaved_scan_explicit_inst.cuh" - -namespace cuvs::neighbors::ivf_flat::detail { - -CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(int8_t, int64_t, filtering::none_sample_filter); - -} // namespace cuvs::neighbors::ivf_flat::detail diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_int8_t_int64_t_bitset.cu b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_int8_t_int64_t_bitset.cu deleted file mode 100644 index acf4a05ed3..0000000000 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_int8_t_int64_t_bitset.cu +++ /dev/null @@ -1,14 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ - -#include "ivf_flat_interleaved_scan_explicit_inst.cuh" - -namespace cuvs::neighbors::ivf_flat::detail { - -CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(int8_t, - int64_t, - filtering::bitset_filter); - -} // namespace cuvs::neighbors::ivf_flat::detail diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_int8_uint8_t_int64_t.cu similarity index 66% rename from cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_uint8_t_int64_t.cu rename to cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_int8_uint8_t_int64_t.cu index abc0f352aa..af15d43652 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_int8_uint8_t_int64_t.cu @@ -7,6 +7,8 @@ namespace cuvs::neighbors::ivf_flat::detail { +// Combined instantiations for int8_t and uint8_t (same kernel due to byte_array unification) +CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(int8_t, int64_t, filtering::none_sample_filter); CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(uint8_t, int64_t, filtering::none_sample_filter); } // namespace cuvs::neighbors::ivf_flat::detail diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_uint8_t_int64_t_bitset.cu b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_int8_uint8_t_int64_t_bitset.cu similarity index 62% rename from cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_uint8_t_int64_t_bitset.cu rename to cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_int8_uint8_t_int64_t_bitset.cu index 22a93c1509..ebf80a3a4f 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_uint8_t_int64_t_bitset.cu +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_int8_uint8_t_int64_t_bitset.cu @@ -7,6 +7,10 @@ namespace cuvs::neighbors::ivf_flat::detail { +// Combined instantiations for int8_t and uint8_t (same kernel due to byte_array unification) +CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(int8_t, + int64_t, + filtering::bitset_filter); CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(uint8_t, int64_t, filtering::bitset_filter); diff --git a/python/cuvs/cuvs/tests/ann_utils.py b/python/cuvs/cuvs/tests/ann_utils.py index 9da93cea92..c0becd8cc2 100644 --- a/python/cuvs/cuvs/tests/ann_utils.py +++ b/python/cuvs/cuvs/tests/ann_utils.py @@ -127,4 +127,5 @@ def run_filtered_search_test( recall = calc_recall(mapped_actual_indices, skl_idx) + print(f"Recall: {recall}") assert recall > 0.7