Skip to content

Commit ef65e5a

Browse files
committed
rename byte_array to byte_arithmetic_ptr
1 parent 591f9d6 commit ef65e5a

File tree

2 files changed

+98
-75
lines changed

2 files changed

+98
-75
lines changed

cpp/include/cuvs/core/byte_array.hpp renamed to cpp/include/cuvs/core/byte_arithmetic_ptr.hpp

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,23 @@
2020

2121
namespace cuvs::detail {
2222

23-
struct byte_array {
23+
struct byte_arithmetic_ptr {
2424
void* data = nullptr;
2525
bool is_signed = false;
2626

27-
__host__ __device__ byte_array(void* ptr, bool signed_flag) : data(ptr), is_signed(signed_flag) {}
27+
__host__ __device__ byte_arithmetic_ptr(void* ptr, bool signed_flag)
28+
: data(ptr), is_signed(signed_flag)
29+
{
30+
}
2831

2932
// Proxy that references an element in the array
3033
struct byte {
31-
byte_array* parent = nullptr;
32-
int64_t idx = -1;
33-
uint8_t value = 0; // used for detached proxies
34+
byte_arithmetic_ptr* parent = nullptr;
35+
int64_t idx = -1;
36+
uint8_t value = 0; // used for detached proxies
3437

3538
// Constructor for live proxy
36-
__host__ __device__ byte(byte_array& p, int64_t i) : parent(&p), idx(i) {}
39+
__host__ __device__ byte(byte_arithmetic_ptr& p, int64_t i) : parent(&p), idx(i) {}
3740

3841
// Copy constructor: detached copy stores the current value
3942
__host__ __device__ byte(const byte& other)
@@ -105,16 +108,22 @@ struct byte_array {
105108
__host__ __device__ byte operator*() { return byte(*this, 0); }
106109

107110
// Pointer arithmetic
108-
__host__ __device__ byte_array operator+(int64_t offset) const
111+
__host__ __device__ byte_arithmetic_ptr operator+(int64_t offset) const
109112
{
110113
if (is_signed)
111-
return byte_array(static_cast<int8_t*>(data) + offset, true);
114+
return byte_arithmetic_ptr(static_cast<int8_t*>(data) + offset, true);
112115
else
113-
return byte_array(static_cast<uint8_t*>(data) + offset, false);
116+
return byte_arithmetic_ptr(static_cast<uint8_t*>(data) + offset, false);
114117
}
115118

116-
__host__ __device__ bool operator==(const byte_array& other) const { return data == other.data; }
117-
__host__ __device__ bool operator!=(const byte_array& other) const { return !(*this == other); }
119+
__host__ __device__ bool operator==(const byte_arithmetic_ptr& other) const
120+
{
121+
return data == other.data;
122+
}
123+
__host__ __device__ bool operator!=(const byte_arithmetic_ptr& other) const
124+
{
125+
return !(*this == other);
126+
}
118127
};
119128

120129
} // namespace cuvs::detail

cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh

Lines changed: 78 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
#include <cuvs/neighbors/ivf_flat.hpp>
2424

2525
#include "../detail/ann_utils.cuh"
26-
#include <cuvs/core/byte_array.hpp>
26+
#include <cuvs/core/byte_arithmetic_ptr.hpp>
2727
#include <cuvs/distance/distance.hpp>
2828
#include <raft/core/logger.hpp>
2929
#include <raft/core/operators.hpp>
@@ -90,10 +90,12 @@ __device__ inline void copy_vectorized(T* out, const T* in, uint32_t n)
9090
}
9191
}
9292

93-
// Specialization for byte_array -> uint8_t* (for int8_t normalization)
94-
__device__ inline void copy_vectorized(uint8_t* out, const cuvs::detail::byte_array in, uint32_t n)
93+
// Specialization for byte_arithmetic_ptr -> uint8_t* (for int8_t normalization)
94+
__device__ inline void copy_vectorized(uint8_t* out,
95+
const cuvs::detail::byte_arithmetic_ptr& in,
96+
uint32_t n)
9597
{
96-
// For byte_array, copy element by element with normalization to uint8_t
98+
// For byte_arithmetic_ptr, copy element by element with normalization to uint8_t
9799
for (int i = threadIdx.x; i < n; i += blockDim.x) {
98100
out[i] = static_cast<uint8_t>(in[i]);
99101
}
@@ -236,7 +238,8 @@ struct is_inner_prod_dist : std::false_type {};
236238
template <int Veclen, typename DataT, typename AccT>
237239
struct is_inner_prod_dist<inner_prod_dist<Veclen, DataT, AccT>> : std::true_type {};
238240

239-
// This handles uint8_t 8, 16 Veclens (also handles int8_t via byte_array with int32 accumulator)
241+
// This handles uint8_t 8, 16 Veclens (also handles int8_t via byte_arithmetic_ptr with int32
242+
// accumulator)
240243
template <int kUnroll, typename Lambda, int uint8_veclen, bool ComputeNorm>
241244
struct loadAndComputeDist<kUnroll, Lambda, uint8_veclen, uint8_t, int32_t, ComputeNorm> {
242245
Lambda compute_dist;
@@ -250,7 +253,7 @@ struct loadAndComputeDist<kUnroll, Lambda, uint8_veclen, uint8_t, int32_t, Compu
250253
{
251254
}
252255

253-
__device__ __forceinline__ void runLoadShmemCompute(const cuvs::detail::byte_array& data,
256+
__device__ __forceinline__ void runLoadShmemCompute(const cuvs::detail::byte_arithmetic_ptr& data,
254257
const uint8_t* query_shared,
255258
int loadIndex,
256259
int shmemIndex)
@@ -298,10 +301,11 @@ struct loadAndComputeDist<kUnroll, Lambda, uint8_veclen, uint8_t, int32_t, Compu
298301
}
299302
}
300303

301-
__device__ __forceinline__ void runLoadShflAndCompute(cuvs::detail::byte_array data,
302-
const cuvs::detail::byte_array& query,
303-
int baseLoadIndex,
304-
const int lane_id)
304+
__device__ __forceinline__ void runLoadShflAndCompute(
305+
cuvs::detail::byte_arithmetic_ptr data,
306+
const cuvs::detail::byte_arithmetic_ptr& query,
307+
int baseLoadIndex,
308+
const int lane_id)
305309
{
306310
constexpr int veclen_int = uint8_veclen / 4; // converting uint8_t veclens to int
307311
const bool is_signed = data.is_signed;
@@ -358,8 +362,8 @@ struct loadAndComputeDist<kUnroll, Lambda, uint8_veclen, uint8_t, int32_t, Compu
358362
}
359363

360364
__device__ __forceinline__ void runLoadShflAndComputeRemainder(
361-
cuvs::detail::byte_array data,
362-
const cuvs::detail::byte_array& query,
365+
cuvs::detail::byte_arithmetic_ptr data,
366+
const cuvs::detail::byte_arithmetic_ptr& query,
363367
const int lane_id,
364368
const int dim,
365369
const int dimBlocks)
@@ -412,7 +416,7 @@ struct loadAndComputeDist<kUnroll, Lambda, uint8_veclen, uint8_t, int32_t, Compu
412416
};
413417

414418
// Keep this specialized uint8 Veclen = 4, because compiler is generating suboptimal code while
415-
// using above common template of int2/int4 (also handles int8_t via byte_array with int32
419+
// using above common template of int2/int4 (also handles int8_t via byte_arithmetic_ptr with int32
416420
// accumulator)
417421
template <int kUnroll, typename Lambda, bool ComputeNorm>
418422
struct loadAndComputeDist<kUnroll, Lambda, 4, uint8_t, int32_t, ComputeNorm> {
@@ -427,7 +431,7 @@ struct loadAndComputeDist<kUnroll, Lambda, 4, uint8_t, int32_t, ComputeNorm> {
427431
{
428432
}
429433

430-
__device__ __forceinline__ void runLoadShmemCompute(const cuvs::detail::byte_array& data,
434+
__device__ __forceinline__ void runLoadShmemCompute(const cuvs::detail::byte_arithmetic_ptr& data,
431435
const uint8_t* query_shared,
432436
int loadIndex,
433437
int shmemIndex)
@@ -463,10 +467,11 @@ struct loadAndComputeDist<kUnroll, Lambda, 4, uint8_t, int32_t, ComputeNorm> {
463467
}
464468
}
465469

466-
__device__ __forceinline__ void runLoadShflAndCompute(cuvs::detail::byte_array data,
467-
const cuvs::detail::byte_array& query,
468-
int baseLoadIndex,
469-
const int lane_id)
470+
__device__ __forceinline__ void runLoadShflAndCompute(
471+
cuvs::detail::byte_arithmetic_ptr data,
472+
const cuvs::detail::byte_arithmetic_ptr& query,
473+
int baseLoadIndex,
474+
const int lane_id)
470475
{
471476
const bool is_signed = data.is_signed;
472477
constexpr int32_t offset_reg = 0x80808080; // 128 in each byte position
@@ -516,8 +521,8 @@ struct loadAndComputeDist<kUnroll, Lambda, 4, uint8_t, int32_t, ComputeNorm> {
516521
}
517522

518523
__device__ __forceinline__ void runLoadShflAndComputeRemainder(
519-
cuvs::detail::byte_array data,
520-
const cuvs::detail::byte_array& query,
524+
cuvs::detail::byte_arithmetic_ptr data,
525+
const cuvs::detail::byte_arithmetic_ptr& query,
521526
const int lane_id,
522527
const int dim,
523528
const int dimBlocks)
@@ -573,7 +578,7 @@ struct loadAndComputeDist<kUnroll, Lambda, 2, uint8_t, int32_t, ComputeNorm> {
573578
{
574579
}
575580

576-
__device__ __forceinline__ void runLoadShmemCompute(const cuvs::detail::byte_array& data,
581+
__device__ __forceinline__ void runLoadShmemCompute(const cuvs::detail::byte_arithmetic_ptr& data,
577582
const uint8_t* query_shared,
578583
int loadIndex,
579584
int shmemIndex)
@@ -609,10 +614,11 @@ struct loadAndComputeDist<kUnroll, Lambda, 2, uint8_t, int32_t, ComputeNorm> {
609614
}
610615
}
611616

612-
__device__ __forceinline__ void runLoadShflAndCompute(cuvs::detail::byte_array data,
613-
const cuvs::detail::byte_array& query,
614-
int baseLoadIndex,
615-
const int lane_id)
617+
__device__ __forceinline__ void runLoadShflAndCompute(
618+
cuvs::detail::byte_arithmetic_ptr data,
619+
const cuvs::detail::byte_arithmetic_ptr& query,
620+
int baseLoadIndex,
621+
const int lane_id)
616622
{
617623
const bool is_signed = data.is_signed;
618624
constexpr int32_t offset_reg = 0x8080; // 128 in each of 2 byte positions
@@ -658,8 +664,8 @@ struct loadAndComputeDist<kUnroll, Lambda, 2, uint8_t, int32_t, ComputeNorm> {
658664
}
659665

660666
__device__ __forceinline__ void runLoadShflAndComputeRemainder(
661-
cuvs::detail::byte_array data,
662-
const cuvs::detail::byte_array& query,
667+
cuvs::detail::byte_arithmetic_ptr data,
668+
const cuvs::detail::byte_arithmetic_ptr& query,
663669
const int lane_id,
664670
const int dim,
665671
const int dimBlocks)
@@ -715,7 +721,7 @@ struct loadAndComputeDist<kUnroll, Lambda, 1, uint8_t, int32_t, ComputeNorm> {
715721
{
716722
}
717723

718-
__device__ __forceinline__ void runLoadShmemCompute(const cuvs::detail::byte_array& data,
724+
__device__ __forceinline__ void runLoadShmemCompute(const cuvs::detail::byte_arithmetic_ptr& data,
719725
const uint8_t* query_shared,
720726
int loadIndex,
721727
int shmemIndex)
@@ -748,10 +754,11 @@ struct loadAndComputeDist<kUnroll, Lambda, 1, uint8_t, int32_t, ComputeNorm> {
748754
}
749755
}
750756

751-
__device__ __forceinline__ void runLoadShflAndCompute(cuvs::detail::byte_array data,
752-
const cuvs::detail::byte_array& query,
753-
int baseLoadIndex,
754-
const int lane_id)
757+
__device__ __forceinline__ void runLoadShflAndCompute(
758+
cuvs::detail::byte_arithmetic_ptr data,
759+
const cuvs::detail::byte_arithmetic_ptr& query,
760+
int baseLoadIndex,
761+
const int lane_id)
755762
{
756763
const bool is_signed = data.is_signed;
757764
constexpr int32_t offset_byte = 128;
@@ -789,8 +796,8 @@ struct loadAndComputeDist<kUnroll, Lambda, 1, uint8_t, int32_t, ComputeNorm> {
789796
}
790797

791798
__device__ __forceinline__ void runLoadShflAndComputeRemainder(
792-
cuvs::detail::byte_array data,
793-
const cuvs::detail::byte_array& query,
799+
cuvs::detail::byte_arithmetic_ptr data,
800+
const cuvs::detail::byte_arithmetic_ptr& query,
794801
const int lane_id,
795802
const int dim,
796803
const int dimBlocks)
@@ -883,10 +890,10 @@ template <
883890
typename Lambda,
884891
typename PostLambda,
885892
typename DataT = std::conditional_t<std::is_same_v<T, int8_t> || std::is_same_v<T, uint8_t>,
886-
cuvs::detail::byte_array,
893+
cuvs::detail::byte_arithmetic_ptr,
887894
const T*>,
888895
typename ListDataT = std::conditional_t<std::is_same_v<T, int8_t> || std::is_same_v<T, uint8_t>,
889-
cuvs::detail::byte_array*,
896+
cuvs::detail::byte_arithmetic_ptr*,
890897
const T* const*>>
891898
RAFT_KERNEL __launch_bounds__(kThreadsPerBlock)
892899
interleaved_scan_kernel(Lambda compute_dist,
@@ -1060,18 +1067,21 @@ uint32_t configure_launch_x(uint32_t numQueries, uint32_t n_probes, int32_t sMem
10601067
}
10611068

10621069
/**
1063-
* Functor to convert uint8_t pointers to byte_arrays.
1070+
* Functor to convert uint8_t pointers to byte_arithmetic_ptrs.
10641071
* Using a functor instead of a lambda ensures the same type is used across
10651072
* all template instantiations, avoiding ~40MB of duplicate host code.
10661073
*/
1067-
struct byte_array_converter {
1074+
struct byte_arithmetic_ptr_converter {
10681075
bool is_signed;
10691076

1070-
__host__ __device__ explicit byte_array_converter(bool is_signed_) : is_signed(is_signed_) {}
1077+
__host__ __device__ explicit byte_arithmetic_ptr_converter(bool is_signed_)
1078+
: is_signed(is_signed_)
1079+
{
1080+
}
10711081

1072-
__device__ cuvs::detail::byte_array operator()(const uint8_t* ptr) const
1082+
__device__ cuvs::detail::byte_arithmetic_ptr operator()(const uint8_t* ptr) const
10731083
{
1074-
return cuvs::detail::byte_array(const_cast<uint8_t*>(ptr), is_signed);
1084+
return cuvs::detail::byte_arithmetic_ptr(const_cast<uint8_t*>(ptr), is_signed);
10751085
}
10761086
};
10771087

@@ -1137,19 +1147,21 @@ void launch_kernel(Lambda lambda,
11371147
return;
11381148
}
11391149

1140-
// For int8_t/uint8_t, pre-convert data pointers to byte_arrays (only needs to be done once)
1141-
std::optional<rmm::device_uvector<cuvs::detail::byte_array>> byte_array_list_data_ptrs;
1150+
// For int8_t/uint8_t, pre-convert data pointers to byte_arithmetic_ptrs (only needs to be done
1151+
// once)
1152+
std::optional<rmm::device_uvector<cuvs::detail::byte_arithmetic_ptr>>
1153+
byte_arithmetic_ptr_list_data_ptrs;
11421154
if constexpr (std::is_same_v<T, int8_t> || std::is_same_v<T, uint8_t>) {
11431155
constexpr bool is_signed = std::is_same_v<T, int8_t>;
1144-
byte_array_list_data_ptrs.emplace(index.data_ptrs().size(), stream);
1156+
byte_arithmetic_ptr_list_data_ptrs.emplace(index.data_ptrs().size(), stream);
11451157
// Cast to uint8_t* and use functor to ensure identical thrust::transform instantiation
11461158
const uint8_t* const* data_ptrs_uint8 =
11471159
reinterpret_cast<const uint8_t* const*>(index.data_ptrs().data_handle());
11481160
thrust::transform(rmm::exec_policy(stream),
11491161
data_ptrs_uint8,
11501162
data_ptrs_uint8 + index.data_ptrs().size(),
1151-
byte_array_list_data_ptrs->begin(),
1152-
byte_array_converter(is_signed));
1163+
byte_arithmetic_ptr_list_data_ptrs->begin(),
1164+
byte_arithmetic_ptr_converter(is_signed));
11531165
}
11541166

11551167
for (uint32_t query_offset = 0; query_offset < num_queries; query_offset += kMaxGridY) {
@@ -1165,26 +1177,28 @@ void launch_kernel(Lambda lambda,
11651177
n_probes,
11661178
smem_size);
11671179
if constexpr (std::is_same_v<T, int8_t> || std::is_same_v<T, uint8_t>) {
1168-
// For both int8_t and uint8_t, wrap query batch in byte_array (is_signed determines
1180+
// For both int8_t and uint8_t, wrap query batch in byte_arithmetic_ptr (is_signed determines
11691181
// normalization)
11701182
constexpr bool is_signed = std::is_same_v<T, int8_t>;
1171-
auto byte_array_queries = cuvs::detail::byte_array(const_cast<T*>(queries), is_signed);
1172-
kKernel<<<grid_dim, block_dim, smem_size, stream>>>(lambda,
1173-
post_process,
1174-
query_smem_elems,
1175-
byte_array_queries,
1176-
coarse_index,
1177-
byte_array_list_data_ptrs->data(),
1178-
index.list_sizes().data_handle(),
1179-
queries_offset + query_offset,
1180-
n_probes,
1181-
k,
1182-
max_samples,
1183-
chunk_indices,
1184-
index.dim(),
1185-
sample_filter,
1186-
neighbors,
1187-
distances);
1183+
auto byte_arithmetic_ptr_queries =
1184+
cuvs::detail::byte_arithmetic_ptr(const_cast<T*>(queries), is_signed);
1185+
kKernel<<<grid_dim, block_dim, smem_size, stream>>>(
1186+
lambda,
1187+
post_process,
1188+
query_smem_elems,
1189+
byte_arithmetic_ptr_queries,
1190+
coarse_index,
1191+
byte_arithmetic_ptr_list_data_ptrs->data(),
1192+
index.list_sizes().data_handle(),
1193+
queries_offset + query_offset,
1194+
n_probes,
1195+
k,
1196+
max_samples,
1197+
chunk_indices,
1198+
index.dim(),
1199+
sample_filter,
1200+
neighbors,
1201+
distances);
11881202
} else {
11891203
// For other types (float, etc.), use raw pointers directly
11901204
kKernel<<<grid_dim, block_dim, smem_size, stream>>>(lambda,

0 commit comments

Comments
 (0)