Skip to content

Commit 81de1b9

Browse files
committed
feat: add simd::halfbyte
Signed-off-by: usamoi <[email protected]>
1 parent 408ed05 commit 81de1b9

File tree

6 files changed

+474
-71
lines changed

6 files changed

+474
-71
lines changed

crates/index/src/accessor.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ impl Accessor2<u8, u8, [f32; 4], [f32; 4]> for DimensionDistanceAccessor<Rabitq8
409409
#[inline(always)]
410410
fn push(&mut self, target: &[u8], input: &[u8]) {
411411
self.0 += target.len() as u32;
412-
self.1 += simd::u8::reduce_sum_of_x_as_u32_y_as_u32(target, input);
412+
self.1 += simd::byte::reduce_sum_of_x_as_u32_y_as_u32(target, input);
413413
}
414414

415415
#[inline(always)]
@@ -432,7 +432,7 @@ impl Accessor2<u8, u8, [f32; 4], [f32; 4]> for DimensionDistanceAccessor<Rabitq8
432432
#[inline(always)]
433433
fn push(&mut self, target: &[u8], input: &[u8]) {
434434
self.0 += target.len() as u32;
435-
self.1 += simd::u8::reduce_sum_of_x_as_u32_y_as_u32(target, input);
435+
self.1 += simd::byte::reduce_sum_of_x_as_u32_y_as_u32(target, input);
436436
}
437437

438438
#[inline(always)]

crates/rabitq/src/bit.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,9 @@ pub mod binary {
176176
pub(crate) fn preprocess_with_distance(vector: &[f32], dis_v_2: f32) -> BinaryLut {
177177
let (k, b, qvector) = simd::quantize::quantize(vector, ((1 << BITS) - 1) as f32);
178178
let qvector_sum = if vector.len() <= (65535_usize / ((1 << BITS) - 1)) {
179-
simd::u8::reduce_sum_of_x_as_u16(&qvector) as f32
179+
simd::byte::reduce_sum_of_x_as_u16(&qvector) as f32
180180
} else {
181-
simd::u8::reduce_sum_of_x_as_u32(&qvector) as f32
181+
simd::byte::reduce_sum_of_x_as_u32(&qvector) as f32
182182
};
183183
(
184184
BinaryLutMetadata {

crates/rabitq/src/byte.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ pub mod binary {
4444
}
4545

4646
pub fn accumulate(x: &[u8], y: &[u8]) -> u32 {
47-
simd::u8::reduce_sum_of_x_as_u32_y_as_u32(x, y)
47+
simd::byte::reduce_sum_of_x_as_u32_y_as_u32(x, y)
4848
}
4949

5050
pub fn half_process_dot(

crates/simd/src/u8.rs renamed to crates/simd/src/byte.rs

Lines changed: 63 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -22,47 +22,47 @@ mod reduce_sum_of_x_as_u32_y_as_u32 {
2222
let mut n = lhs.len();
2323
let mut a = lhs.as_ptr();
2424
let mut b = rhs.as_ptr();
25-
let lo = _mm512_set1_epi16(0x00ff_u16 as i16);
26-
let hi = _mm512_set1_epi16(0xff00_u16 as i16);
25+
let lo = _mm512_set1_epi16(0x00ff_i16);
2726
let mut _0 = _mm512_setzero_si512();
2827
let mut _1 = _mm512_setzero_si512();
2928
let mut _2 = _mm512_setzero_si512();
3029
let mut _3 = _mm512_setzero_si512();
3130
while n >= 64 {
3231
let x = unsafe { _mm512_loadu_epi8(a.cast()) };
3332
let y = unsafe { _mm512_loadu_epi8(b.cast()) };
34-
let x_l = _mm512_and_si512(x, lo);
35-
let x_h = _mm512_and_si512(_mm512_srli_epi16(_mm512_and_si512(x, hi), 8), lo);
36-
let y_l = _mm512_and_si512(y, lo);
37-
let y_h = _mm512_and_si512(_mm512_srli_epi16(_mm512_and_si512(y, hi), 8), lo);
38-
let l = _mm512_mullo_epi16(x_l, y_l);
39-
let h = _mm512_mullo_epi16(x_h, y_h);
33+
let x_0 = _mm512_and_si512(x, lo);
34+
let x_1 = _mm512_srli_epi16(x, 8);
35+
let y_0 = _mm512_and_si512(y, lo);
36+
let y_1 = _mm512_srli_epi16(y, 8);
37+
let z_0 = _mm512_mullo_epi16(x_0, y_0);
38+
let z_1 = _mm512_mullo_epi16(x_1, y_1);
4039
a = unsafe { a.add(64) };
4140
b = unsafe { b.add(64) };
4241
n -= 64;
43-
_0 = _mm512_add_epi32(_0, _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(l, 0)));
44-
_1 = _mm512_add_epi32(_1, _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(l, 1)));
45-
_2 = _mm512_add_epi32(_2, _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(h, 0)));
46-
_3 = _mm512_add_epi32(_3, _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(h, 1)));
42+
_0 = _mm512_add_epi32(_0, _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(z_0, 0)));
43+
_1 = _mm512_add_epi32(_1, _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(z_0, 1)));
44+
_2 = _mm512_add_epi32(_2, _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(z_1, 0)));
45+
_3 = _mm512_add_epi32(_3, _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(z_1, 1)));
4746
}
4847
if n > 0 {
4948
let mask = _bzhi_u64(0xffffffffffffffff, n as u32);
5049
let x = unsafe { _mm512_maskz_loadu_epi8(mask, a.cast()) };
5150
let y = unsafe { _mm512_maskz_loadu_epi8(mask, b.cast()) };
52-
let x_l = _mm512_and_si512(x, lo);
53-
let x_h = _mm512_and_si512(_mm512_srli_epi16(_mm512_and_si512(x, hi), 8), lo);
54-
let y_l = _mm512_and_si512(y, lo);
55-
let y_h = _mm512_and_si512(_mm512_srli_epi16(_mm512_and_si512(y, hi), 8), lo);
56-
let l = _mm512_mullo_epi16(x_l, y_l);
57-
let h = _mm512_mullo_epi16(x_h, y_h);
58-
_0 = _mm512_add_epi32(_0, _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(l, 0)));
59-
_1 = _mm512_add_epi32(_1, _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(l, 1)));
60-
_2 = _mm512_add_epi32(_2, _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(h, 0)));
61-
_3 = _mm512_add_epi32(_3, _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(h, 1)));
62-
}
63-
let _5 = _mm512_add_epi32(_0, _1);
64-
let _6 = _mm512_add_epi32(_2, _3);
65-
_mm512_reduce_add_epi32(_mm512_add_epi32(_5, _6)) as u32
51+
let x_0 = _mm512_and_si512(x, lo);
52+
let x_1 = _mm512_srli_epi16(x, 8);
53+
let y_0 = _mm512_and_si512(y, lo);
54+
let y_1 = _mm512_srli_epi16(y, 8);
55+
let z_0 = _mm512_mullo_epi16(x_0, y_0);
56+
let z_1 = _mm512_mullo_epi16(x_1, y_1);
57+
_0 = _mm512_add_epi32(_0, _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(z_0, 0)));
58+
_1 = _mm512_add_epi32(_1, _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(z_0, 1)));
59+
_2 = _mm512_add_epi32(_2, _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(z_1, 0)));
60+
_3 = _mm512_add_epi32(_3, _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(z_1, 1)));
61+
}
62+
let r_0 = _mm512_add_epi32(_0, _2);
63+
let r_1 = _mm512_add_epi32(_1, _3);
64+
let r_2 = _mm512_add_epi32(r_0, r_1);
65+
_mm512_reduce_add_epi32(r_2) as u32
6666
}
6767

6868
#[cfg(all(target_arch = "x86_64", test, not(miri)))]
@@ -101,32 +101,31 @@ mod reduce_sum_of_x_as_u32_y_as_u32 {
101101
let mut n = lhs.len();
102102
let mut a = lhs.as_ptr();
103103
let mut b = rhs.as_ptr();
104-
let lo = _mm256_set1_epi16(0x00ff_u16 as i16);
105-
let hi = _mm256_set1_epi16(0xff00_u16 as i16);
104+
let lo = _mm256_set1_epi16(0x00ff_i16);
106105
let mut _0 = _mm256_setzero_si256();
107106
let mut _1 = _mm256_setzero_si256();
108107
let mut _2 = _mm256_setzero_si256();
109108
let mut _3 = _mm256_setzero_si256();
110109
while n >= 32 {
111110
let x = unsafe { _mm256_loadu_si256(a.cast()) };
112111
let y = unsafe { _mm256_loadu_si256(b.cast()) };
113-
let x_l = _mm256_and_si256(x, lo);
114-
let x_h = _mm256_and_si256(_mm256_srli_epi16(_mm256_and_si256(x, hi), 8), lo);
115-
let y_l = _mm256_and_si256(y, lo);
116-
let y_h = _mm256_and_si256(_mm256_srli_epi16(_mm256_and_si256(y, hi), 8), lo);
117-
let l = _mm256_mullo_epi16(x_l, y_l);
118-
let h = _mm256_mullo_epi16(x_h, y_h);
112+
let x_0 = _mm256_and_si256(x, lo);
113+
let x_1 = _mm256_srli_epi16(x, 8);
114+
let y_0 = _mm256_and_si256(y, lo);
115+
let y_1 = _mm256_srli_epi16(y, 8);
116+
let z_0 = _mm256_mullo_epi16(x_0, y_0);
117+
let z_1 = _mm256_mullo_epi16(x_1, y_1);
119118
a = unsafe { a.add(32) };
120119
b = unsafe { b.add(32) };
121120
n -= 32;
122-
_0 = _mm256_add_epi32(_0, _mm256_cvtepu16_epi32(_mm256_extracti128_si256(l, 0)));
123-
_1 = _mm256_add_epi32(_1, _mm256_cvtepu16_epi32(_mm256_extracti128_si256(l, 1)));
124-
_2 = _mm256_add_epi32(_2, _mm256_cvtepu16_epi32(_mm256_extracti128_si256(h, 0)));
125-
_3 = _mm256_add_epi32(_3, _mm256_cvtepu16_epi32(_mm256_extracti128_si256(h, 1)));
121+
_0 = _mm256_add_epi32(_0, _mm256_cvtepu16_epi32(_mm256_extracti128_si256(z_0, 0)));
122+
_1 = _mm256_add_epi32(_1, _mm256_cvtepu16_epi32(_mm256_extracti128_si256(z_0, 1)));
123+
_2 = _mm256_add_epi32(_2, _mm256_cvtepu16_epi32(_mm256_extracti128_si256(z_1, 0)));
124+
_3 = _mm256_add_epi32(_3, _mm256_cvtepu16_epi32(_mm256_extracti128_si256(z_1, 1)));
126125
}
127126
let mut sum = emulate_mm256_reduce_add_epi32(_mm256_add_epi32(
128-
_mm256_add_epi32(_0, _1),
129-
_mm256_add_epi32(_2, _3),
127+
_mm256_add_epi32(_0, _2),
128+
_mm256_add_epi32(_1, _3),
130129
)) as u32;
131130
// this hint is used to disable loop unrolling
132131
while std::hint::black_box(n) > 0 {
@@ -176,28 +175,27 @@ mod reduce_sum_of_x_as_u32_y_as_u32 {
176175
let mut n = lhs.len();
177176
let mut a = lhs.as_ptr();
178177
let mut b = rhs.as_ptr();
179-
let lo = _mm_set1_epi16(0x00ff_u16 as i16);
180-
let hi = _mm_set1_epi16(0xff00_u16 as i16);
178+
let lo = _mm_set1_epi16(0x00ff_i16);
181179
let mut _0 = _mm_setzero_si128();
182180
let mut _1 = _mm_setzero_si128();
183181
let mut _2 = _mm_setzero_si128();
184182
let mut _3 = _mm_setzero_si128();
185183
while n >= 16 {
186184
let x = unsafe { _mm_loadu_si128(a.cast()) };
187185
let y = unsafe { _mm_loadu_si128(b.cast()) };
188-
let x_l = _mm_and_si128(x, lo);
189-
let x_h = _mm_and_si128(_mm_srli_epi16(_mm_and_si128(x, hi), 8), lo);
190-
let y_l = _mm_and_si128(y, lo);
191-
let y_h = _mm_and_si128(_mm_srli_epi16(_mm_and_si128(y, hi), 8), lo);
192-
let l = _mm_mullo_epi16(x_l, y_l);
193-
let h = _mm_mullo_epi16(x_h, y_h);
186+
let x_0 = _mm_and_si128(x, lo);
187+
let x_1 = _mm_srli_epi16(x, 8);
188+
let y_0 = _mm_and_si128(y, lo);
189+
let y_1 = _mm_srli_epi16(y, 8);
190+
let z_0 = _mm_mullo_epi16(x_0, y_0);
191+
let z_1 = _mm_mullo_epi16(x_1, y_1);
194192
a = unsafe { a.add(16) };
195193
b = unsafe { b.add(16) };
196194
n -= 16;
197-
_0 = _mm_add_epi32(_0, _mm_cvtepu16_epi32(l));
198-
_1 = _mm_add_epi32(_1, _mm_cvtepu16_epi32(_mm_unpackhi_epi64(l, l)));
199-
_2 = _mm_add_epi32(_2, _mm_cvtepu16_epi32(h));
200-
_3 = _mm_add_epi32(_3, _mm_cvtepu16_epi32(_mm_unpackhi_epi64(h, h)));
195+
_0 = _mm_add_epi32(_0, _mm_cvtepu16_epi32(z_0));
196+
_1 = _mm_add_epi32(_1, _mm_cvtepu16_epi32(_mm_unpackhi_epi64(z_0, z_0)));
197+
_2 = _mm_add_epi32(_2, _mm_cvtepu16_epi32(z_1));
198+
_3 = _mm_add_epi32(_3, _mm_cvtepu16_epi32(_mm_unpackhi_epi64(z_1, z_1)));
201199
}
202200
let mut sum = emulate_mm_reduce_add_epi32(_mm_add_epi32(
203201
_mm_add_epi32(_0, _1),
@@ -256,23 +254,23 @@ mod reduce_sum_of_x_as_u32_y_as_u32 {
256254
let mut _2 = vdupq_n_u32(0);
257255
let mut _3 = vdupq_n_u32(0);
258256
while n >= 16 {
259-
let x = vreinterpretq_u16_u8(unsafe { vld1q_u8(a.cast()) });
260-
let y = vreinterpretq_u16_u8(unsafe { vld1q_u8(b.cast()) });
261-
let x_l = vandq_u16(x, lo);
262-
let x_h = vshrq_n_u16(x, 8);
263-
let y_l = vandq_u16(y, lo);
264-
let y_h = vshrq_n_u16(y, 8);
265-
let l = vmulq_u16(x_l, y_l);
266-
let h = vmulq_u16(x_h, y_h);
257+
let x = unsafe { vld1q_u16(a.cast()) };
258+
let y = unsafe { vld1q_u16(b.cast()) };
259+
let x_0 = vandq_u16(x, lo);
260+
let x_1 = vshrq_n_u16(x, 8);
261+
let y_0 = vandq_u16(y, lo);
262+
let y_1 = vshrq_n_u16(y, 8);
263+
let z_0 = vmulq_u16(x_0, y_0);
264+
let z_1 = vmulq_u16(x_1, y_1);
267265
a = unsafe { a.add(16) };
268266
b = unsafe { b.add(16) };
269267
n -= 16;
270-
_0 = vaddq_u32(_0, vmovl_u16(vget_low_u16(l)));
271-
_1 = vaddq_u32(_1, vmovl_u16(vget_high_u16(l)));
272-
_2 = vaddq_u32(_2, vmovl_u16(vget_low_u16(h)));
273-
_3 = vaddq_u32(_3, vmovl_u16(vget_high_u16(h)));
268+
_0 = vaddq_u32(_0, vmovl_u16(vget_low_u16(z_0)));
269+
_1 = vaddq_u32(_1, vmovl_u16(vget_high_u16(z_0)));
270+
_2 = vaddq_u32(_2, vmovl_u16(vget_low_u16(z_1)));
271+
_3 = vaddq_u32(_3, vmovl_u16(vget_high_u16(z_1)));
274272
}
275-
let mut sum = vaddvq_u32(vaddq_u32(vaddq_u32(_0, _1), vaddq_u32(_2, _3)));
273+
let mut sum = vaddvq_u32(vaddq_u32(vaddq_u32(_0, _2), vaddq_u32(_1, _3)));
276274
// this hint is used to disable loop unrolling
277275
while std::hint::black_box(n) > 0 {
278276
let x = unsafe { a.read() };

0 commit comments

Comments
 (0)