2323#include < cuvs/neighbors/ivf_flat.hpp>
2424
2525#include " ../detail/ann_utils.cuh"
26- #include < cuvs/core/byte_array .hpp>
26+ #include < cuvs/core/byte_arithmetic_ptr .hpp>
2727#include < cuvs/distance/distance.hpp>
2828#include < raft/core/logger.hpp>
2929#include < raft/core/operators.hpp>
@@ -90,10 +90,12 @@ __device__ inline void copy_vectorized(T* out, const T* in, uint32_t n)
9090 }
9191}
9292
93- // Specialization for byte_array -> uint8_t* (for int8_t normalization)
94- __device__ inline void copy_vectorized (uint8_t * out, const cuvs::detail::byte_array in, uint32_t n)
93+ // Specialization for byte_arithmetic_ptr -> uint8_t* (for int8_t normalization)
94+ __device__ inline void copy_vectorized (uint8_t * out,
95+ const cuvs::detail::byte_arithmetic_ptr& in,
96+ uint32_t n)
9597{
96- // For byte_array , copy element by element with normalization to uint8_t
98+ // For byte_arithmetic_ptr , copy element by element with normalization to uint8_t
9799 for (int i = threadIdx .x ; i < n; i += blockDim .x ) {
98100 out[i] = static_cast <uint8_t >(in[i]);
99101 }
@@ -236,7 +238,8 @@ struct is_inner_prod_dist : std::false_type {};
236238template <int Veclen, typename DataT, typename AccT>
237239struct is_inner_prod_dist <inner_prod_dist<Veclen, DataT, AccT>> : std::true_type {};
238240
239- // This handles uint8_t 8, 16 Veclens (also handles int8_t via byte_array with int32 accumulator)
241+ // This handles uint8_t 8, 16 Veclens (also handles int8_t via byte_arithmetic_ptr with int32
242+ // accumulator)
240243template <int kUnroll , typename Lambda, int uint8_veclen, bool ComputeNorm>
241244struct loadAndComputeDist <kUnroll , Lambda, uint8_veclen, uint8_t , int32_t , ComputeNorm> {
242245 Lambda compute_dist;
@@ -250,7 +253,7 @@ struct loadAndComputeDist<kUnroll, Lambda, uint8_veclen, uint8_t, int32_t, Compu
250253 {
251254 }
252255
253- __device__ __forceinline__ void runLoadShmemCompute (const cuvs::detail::byte_array & data,
256+ __device__ __forceinline__ void runLoadShmemCompute (const cuvs::detail::byte_arithmetic_ptr & data,
254257 const uint8_t * query_shared,
255258 int loadIndex,
256259 int shmemIndex)
@@ -298,10 +301,11 @@ struct loadAndComputeDist<kUnroll, Lambda, uint8_veclen, uint8_t, int32_t, Compu
298301 }
299302 }
300303
301- __device__ __forceinline__ void runLoadShflAndCompute (cuvs::detail::byte_array data,
302- const cuvs::detail::byte_array& query,
303- int baseLoadIndex,
304- const int lane_id)
304+ __device__ __forceinline__ void runLoadShflAndCompute (
305+ cuvs::detail::byte_arithmetic_ptr data,
306+ const cuvs::detail::byte_arithmetic_ptr& query,
307+ int baseLoadIndex,
308+ const int lane_id)
305309 {
306310 constexpr int veclen_int = uint8_veclen / 4 ; // converting uint8_t veclens to int
307311 const bool is_signed = data.is_signed ;
@@ -358,8 +362,8 @@ struct loadAndComputeDist<kUnroll, Lambda, uint8_veclen, uint8_t, int32_t, Compu
358362 }
359363
360364 __device__ __forceinline__ void runLoadShflAndComputeRemainder (
361- cuvs::detail::byte_array data,
362- const cuvs::detail::byte_array & query,
365+ cuvs::detail::byte_arithmetic_ptr data,
366+ const cuvs::detail::byte_arithmetic_ptr & query,
363367 const int lane_id,
364368 const int dim,
365369 const int dimBlocks)
@@ -412,7 +416,7 @@ struct loadAndComputeDist<kUnroll, Lambda, uint8_veclen, uint8_t, int32_t, Compu
412416};
413417
414418// Keep this specialized uint8 Veclen = 4, because compiler is generating suboptimal code while
415- // using above common template of int2/int4 (also handles int8_t via byte_array with int32
419+ // using above common template of int2/int4 (also handles int8_t via byte_arithmetic_ptr with int32
416420// accumulator)
417421template <int kUnroll , typename Lambda, bool ComputeNorm>
418422struct loadAndComputeDist <kUnroll , Lambda, 4 , uint8_t , int32_t , ComputeNorm> {
@@ -427,7 +431,7 @@ struct loadAndComputeDist<kUnroll, Lambda, 4, uint8_t, int32_t, ComputeNorm> {
427431 {
428432 }
429433
430- __device__ __forceinline__ void runLoadShmemCompute (const cuvs::detail::byte_array & data,
434+ __device__ __forceinline__ void runLoadShmemCompute (const cuvs::detail::byte_arithmetic_ptr & data,
431435 const uint8_t * query_shared,
432436 int loadIndex,
433437 int shmemIndex)
@@ -463,10 +467,11 @@ struct loadAndComputeDist<kUnroll, Lambda, 4, uint8_t, int32_t, ComputeNorm> {
463467 }
464468 }
465469
466- __device__ __forceinline__ void runLoadShflAndCompute (cuvs::detail::byte_array data,
467- const cuvs::detail::byte_array& query,
468- int baseLoadIndex,
469- const int lane_id)
470+ __device__ __forceinline__ void runLoadShflAndCompute (
471+ cuvs::detail::byte_arithmetic_ptr data,
472+ const cuvs::detail::byte_arithmetic_ptr& query,
473+ int baseLoadIndex,
474+ const int lane_id)
470475 {
471476 const bool is_signed = data.is_signed ;
472477 constexpr int32_t offset_reg = 0x80808080 ; // 128 in each byte position
@@ -516,8 +521,8 @@ struct loadAndComputeDist<kUnroll, Lambda, 4, uint8_t, int32_t, ComputeNorm> {
516521 }
517522
518523 __device__ __forceinline__ void runLoadShflAndComputeRemainder (
519- cuvs::detail::byte_array data,
520- const cuvs::detail::byte_array & query,
524+ cuvs::detail::byte_arithmetic_ptr data,
525+ const cuvs::detail::byte_arithmetic_ptr & query,
521526 const int lane_id,
522527 const int dim,
523528 const int dimBlocks)
@@ -573,7 +578,7 @@ struct loadAndComputeDist<kUnroll, Lambda, 2, uint8_t, int32_t, ComputeNorm> {
573578 {
574579 }
575580
576- __device__ __forceinline__ void runLoadShmemCompute (const cuvs::detail::byte_array & data,
581+ __device__ __forceinline__ void runLoadShmemCompute (const cuvs::detail::byte_arithmetic_ptr & data,
577582 const uint8_t * query_shared,
578583 int loadIndex,
579584 int shmemIndex)
@@ -609,10 +614,11 @@ struct loadAndComputeDist<kUnroll, Lambda, 2, uint8_t, int32_t, ComputeNorm> {
609614 }
610615 }
611616
612- __device__ __forceinline__ void runLoadShflAndCompute (cuvs::detail::byte_array data,
613- const cuvs::detail::byte_array& query,
614- int baseLoadIndex,
615- const int lane_id)
617+ __device__ __forceinline__ void runLoadShflAndCompute (
618+ cuvs::detail::byte_arithmetic_ptr data,
619+ const cuvs::detail::byte_arithmetic_ptr& query,
620+ int baseLoadIndex,
621+ const int lane_id)
616622 {
617623 const bool is_signed = data.is_signed ;
618624 constexpr int32_t offset_reg = 0x8080 ; // 128 in each of 2 byte positions
@@ -658,8 +664,8 @@ struct loadAndComputeDist<kUnroll, Lambda, 2, uint8_t, int32_t, ComputeNorm> {
658664 }
659665
660666 __device__ __forceinline__ void runLoadShflAndComputeRemainder (
661- cuvs::detail::byte_array data,
662- const cuvs::detail::byte_array & query,
667+ cuvs::detail::byte_arithmetic_ptr data,
668+ const cuvs::detail::byte_arithmetic_ptr & query,
663669 const int lane_id,
664670 const int dim,
665671 const int dimBlocks)
@@ -715,7 +721,7 @@ struct loadAndComputeDist<kUnroll, Lambda, 1, uint8_t, int32_t, ComputeNorm> {
715721 {
716722 }
717723
718- __device__ __forceinline__ void runLoadShmemCompute (const cuvs::detail::byte_array & data,
724+ __device__ __forceinline__ void runLoadShmemCompute (const cuvs::detail::byte_arithmetic_ptr & data,
719725 const uint8_t * query_shared,
720726 int loadIndex,
721727 int shmemIndex)
@@ -748,10 +754,11 @@ struct loadAndComputeDist<kUnroll, Lambda, 1, uint8_t, int32_t, ComputeNorm> {
748754 }
749755 }
750756
751- __device__ __forceinline__ void runLoadShflAndCompute (cuvs::detail::byte_array data,
752- const cuvs::detail::byte_array& query,
753- int baseLoadIndex,
754- const int lane_id)
757+ __device__ __forceinline__ void runLoadShflAndCompute (
758+ cuvs::detail::byte_arithmetic_ptr data,
759+ const cuvs::detail::byte_arithmetic_ptr& query,
760+ int baseLoadIndex,
761+ const int lane_id)
755762 {
756763 const bool is_signed = data.is_signed ;
757764 constexpr int32_t offset_byte = 128 ;
@@ -789,8 +796,8 @@ struct loadAndComputeDist<kUnroll, Lambda, 1, uint8_t, int32_t, ComputeNorm> {
789796 }
790797
791798 __device__ __forceinline__ void runLoadShflAndComputeRemainder (
792- cuvs::detail::byte_array data,
793- const cuvs::detail::byte_array & query,
799+ cuvs::detail::byte_arithmetic_ptr data,
800+ const cuvs::detail::byte_arithmetic_ptr & query,
794801 const int lane_id,
795802 const int dim,
796803 const int dimBlocks)
@@ -883,10 +890,10 @@ template <
883890 typename Lambda,
884891 typename PostLambda,
885892 typename DataT = std::conditional_t <std::is_same_v<T, int8_t > || std::is_same_v<T, uint8_t >,
886- cuvs::detail::byte_array ,
893+ cuvs::detail::byte_arithmetic_ptr ,
887894 const T*>,
888895 typename ListDataT = std::conditional_t <std::is_same_v<T, int8_t > || std::is_same_v<T, uint8_t >,
889- cuvs::detail::byte_array *,
896+ cuvs::detail::byte_arithmetic_ptr *,
890897 const T* const *>>
891898RAFT_KERNEL __launch_bounds__ (kThreadsPerBlock )
892899 interleaved_scan_kernel(Lambda compute_dist,
@@ -1060,18 +1067,21 @@ uint32_t configure_launch_x(uint32_t numQueries, uint32_t n_probes, int32_t sMem
10601067}
10611068
10621069/* *
1063- * Functor to convert uint8_t pointers to byte_arrays .
1070+ * Functor to convert uint8_t pointers to byte_arithmetic_ptrs .
10641071 * Using a functor instead of a lambda ensures the same type is used across
10651072 * all template instantiations, avoiding ~40MB of duplicate host code.
10661073 */
1067- struct byte_array_converter {
1074+ struct byte_arithmetic_ptr_converter {
10681075 bool is_signed;
10691076
1070- __host__ __device__ explicit byte_array_converter (bool is_signed_) : is_signed(is_signed_) {}
1077+ __host__ __device__ explicit byte_arithmetic_ptr_converter (bool is_signed_)
1078+ : is_signed(is_signed_)
1079+ {
1080+ }
10711081
1072- __device__ cuvs::detail::byte_array operator ()(const uint8_t * ptr) const
1082+ __device__ cuvs::detail::byte_arithmetic_ptr operator ()(const uint8_t * ptr) const
10731083 {
1074- return cuvs::detail::byte_array (const_cast <uint8_t *>(ptr), is_signed);
1084+ return cuvs::detail::byte_arithmetic_ptr (const_cast <uint8_t *>(ptr), is_signed);
10751085 }
10761086};
10771087
@@ -1137,19 +1147,21 @@ void launch_kernel(Lambda lambda,
11371147 return ;
11381148 }
11391149
1140- // For int8_t/uint8_t, pre-convert data pointers to byte_arrays (only needs to be done once)
1141- std::optional<rmm::device_uvector<cuvs::detail::byte_array>> byte_array_list_data_ptrs;
1150+ // For int8_t/uint8_t, pre-convert data pointers to byte_arithmetic_ptrs (only needs to be done
1151+ // once)
1152+ std::optional<rmm::device_uvector<cuvs::detail::byte_arithmetic_ptr>>
1153+ byte_arithmetic_ptr_list_data_ptrs;
11421154 if constexpr (std::is_same_v<T, int8_t > || std::is_same_v<T, uint8_t >) {
11431155 constexpr bool is_signed = std::is_same_v<T, int8_t >;
1144- byte_array_list_data_ptrs .emplace (index.data_ptrs ().size (), stream);
1156+ byte_arithmetic_ptr_list_data_ptrs .emplace (index.data_ptrs ().size (), stream);
11451157 // Cast to uint8_t* and use functor to ensure identical thrust::transform instantiation
11461158 const uint8_t * const * data_ptrs_uint8 =
11471159 reinterpret_cast <const uint8_t * const *>(index.data_ptrs ().data_handle ());
11481160 thrust::transform (rmm::exec_policy (stream),
11491161 data_ptrs_uint8,
11501162 data_ptrs_uint8 + index.data_ptrs ().size (),
1151- byte_array_list_data_ptrs ->begin (),
1152- byte_array_converter (is_signed));
1163+ byte_arithmetic_ptr_list_data_ptrs ->begin (),
1164+ byte_arithmetic_ptr_converter (is_signed));
11531165 }
11541166
11551167 for (uint32_t query_offset = 0 ; query_offset < num_queries; query_offset += kMaxGridY ) {
@@ -1165,26 +1177,28 @@ void launch_kernel(Lambda lambda,
11651177 n_probes,
11661178 smem_size);
11671179 if constexpr (std::is_same_v<T, int8_t > || std::is_same_v<T, uint8_t >) {
1168- // For both int8_t and uint8_t, wrap query batch in byte_array (is_signed determines
1180+ // For both int8_t and uint8_t, wrap query batch in byte_arithmetic_ptr (is_signed determines
11691181 // normalization)
11701182 constexpr bool is_signed = std::is_same_v<T, int8_t >;
1171- auto byte_array_queries = cuvs::detail::byte_array (const_cast <T*>(queries), is_signed);
1172- kKernel <<<grid_dim, block_dim, smem_size, stream>>> (lambda,
1173- post_process,
1174- query_smem_elems,
1175- byte_array_queries,
1176- coarse_index,
1177- byte_array_list_data_ptrs->data (),
1178- index.list_sizes ().data_handle (),
1179- queries_offset + query_offset,
1180- n_probes,
1181- k,
1182- max_samples,
1183- chunk_indices,
1184- index.dim (),
1185- sample_filter,
1186- neighbors,
1187- distances);
1183+ auto byte_arithmetic_ptr_queries =
1184+ cuvs::detail::byte_arithmetic_ptr (const_cast <T*>(queries), is_signed);
1185+ kKernel <<<grid_dim, block_dim, smem_size, stream>>> (
1186+ lambda,
1187+ post_process,
1188+ query_smem_elems,
1189+ byte_arithmetic_ptr_queries,
1190+ coarse_index,
1191+ byte_arithmetic_ptr_list_data_ptrs->data (),
1192+ index.list_sizes ().data_handle (),
1193+ queries_offset + query_offset,
1194+ n_probes,
1195+ k,
1196+ max_samples,
1197+ chunk_indices,
1198+ index.dim (),
1199+ sample_filter,
1200+ neighbors,
1201+ distances);
11881202 } else {
11891203 // For other types (float, etc.), use raw pointers directly
11901204 kKernel <<<grid_dim, block_dim, smem_size, stream>>> (lambda,
0 commit comments