@@ -22,47 +22,47 @@ mod reduce_sum_of_x_as_u32_y_as_u32 {
2222 let mut n = lhs. len ( ) ;
2323 let mut a = lhs. as_ptr ( ) ;
2424 let mut b = rhs. as_ptr ( ) ;
25- let lo = _mm512_set1_epi16 ( 0x00ff_u16 as i16 ) ;
26- let hi = _mm512_set1_epi16 ( 0xff00_u16 as i16 ) ;
25+ let lo = _mm512_set1_epi16 ( 0x00ff_i16 ) ;
2726 let mut _0 = _mm512_setzero_si512 ( ) ;
2827 let mut _1 = _mm512_setzero_si512 ( ) ;
2928 let mut _2 = _mm512_setzero_si512 ( ) ;
3029 let mut _3 = _mm512_setzero_si512 ( ) ;
3130 while n >= 64 {
3231 let x = unsafe { _mm512_loadu_epi8 ( a. cast ( ) ) } ;
3332 let y = unsafe { _mm512_loadu_epi8 ( b. cast ( ) ) } ;
34- let x_l = _mm512_and_si512 ( x, lo) ;
35- let x_h = _mm512_and_si512 ( _mm512_srli_epi16 ( _mm512_and_si512 ( x, hi ) , 8 ) , lo ) ;
36- let y_l = _mm512_and_si512 ( y, lo) ;
37- let y_h = _mm512_and_si512 ( _mm512_srli_epi16 ( _mm512_and_si512 ( y, hi ) , 8 ) , lo ) ;
38- let l = _mm512_mullo_epi16 ( x_l , y_l ) ;
39- let h = _mm512_mullo_epi16 ( x_h , y_h ) ;
33+ let x_0 = _mm512_and_si512 ( x, lo) ;
34+ let x_1 = _mm512_srli_epi16 ( x, 8 ) ;
35+ let y_0 = _mm512_and_si512 ( y, lo) ;
36+ let y_1 = _mm512_srli_epi16 ( y, 8 ) ;
37+ let z_0 = _mm512_mullo_epi16 ( x_0 , y_0 ) ;
38+ let z_1 = _mm512_mullo_epi16 ( x_1 , y_1 ) ;
4039 a = unsafe { a. add ( 64 ) } ;
4140 b = unsafe { b. add ( 64 ) } ;
4241 n -= 64 ;
43- _0 = _mm512_add_epi32 ( _0, _mm512_cvtepu16_epi32 ( _mm512_extracti32x8_epi32 ( l , 0 ) ) ) ;
44- _1 = _mm512_add_epi32 ( _1, _mm512_cvtepu16_epi32 ( _mm512_extracti32x8_epi32 ( l , 1 ) ) ) ;
45- _2 = _mm512_add_epi32 ( _2, _mm512_cvtepu16_epi32 ( _mm512_extracti32x8_epi32 ( h , 0 ) ) ) ;
46- _3 = _mm512_add_epi32 ( _3, _mm512_cvtepu16_epi32 ( _mm512_extracti32x8_epi32 ( h , 1 ) ) ) ;
42+ _0 = _mm512_add_epi32 ( _0, _mm512_cvtepu16_epi32 ( _mm512_extracti32x8_epi32 ( z_0 , 0 ) ) ) ;
43+ _1 = _mm512_add_epi32 ( _1, _mm512_cvtepu16_epi32 ( _mm512_extracti32x8_epi32 ( z_0 , 1 ) ) ) ;
44+ _2 = _mm512_add_epi32 ( _2, _mm512_cvtepu16_epi32 ( _mm512_extracti32x8_epi32 ( z_1 , 0 ) ) ) ;
45+ _3 = _mm512_add_epi32 ( _3, _mm512_cvtepu16_epi32 ( _mm512_extracti32x8_epi32 ( z_1 , 1 ) ) ) ;
4746 }
4847 if n > 0 {
4948 let mask = _bzhi_u64 ( 0xffffffffffffffff , n as u32 ) ;
5049 let x = unsafe { _mm512_maskz_loadu_epi8 ( mask, a. cast ( ) ) } ;
5150 let y = unsafe { _mm512_maskz_loadu_epi8 ( mask, b. cast ( ) ) } ;
52- let x_l = _mm512_and_si512 ( x, lo) ;
53- let x_h = _mm512_and_si512 ( _mm512_srli_epi16 ( _mm512_and_si512 ( x, hi) , 8 ) , lo) ;
54- let y_l = _mm512_and_si512 ( y, lo) ;
55- let y_h = _mm512_and_si512 ( _mm512_srli_epi16 ( _mm512_and_si512 ( y, hi) , 8 ) , lo) ;
56- let l = _mm512_mullo_epi16 ( x_l, y_l) ;
57- let h = _mm512_mullo_epi16 ( x_h, y_h) ;
58- _0 = _mm512_add_epi32 ( _0, _mm512_cvtepu16_epi32 ( _mm512_extracti32x8_epi32 ( l, 0 ) ) ) ;
59- _1 = _mm512_add_epi32 ( _1, _mm512_cvtepu16_epi32 ( _mm512_extracti32x8_epi32 ( l, 1 ) ) ) ;
60- _2 = _mm512_add_epi32 ( _2, _mm512_cvtepu16_epi32 ( _mm512_extracti32x8_epi32 ( h, 0 ) ) ) ;
61- _3 = _mm512_add_epi32 ( _3, _mm512_cvtepu16_epi32 ( _mm512_extracti32x8_epi32 ( h, 1 ) ) ) ;
62- }
63- let _5 = _mm512_add_epi32 ( _0, _1) ;
64- let _6 = _mm512_add_epi32 ( _2, _3) ;
65- _mm512_reduce_add_epi32 ( _mm512_add_epi32 ( _5, _6) ) as u32
51+ let x_0 = _mm512_and_si512 ( x, lo) ;
52+ let x_1 = _mm512_srli_epi16 ( x, 8 ) ;
53+ let y_0 = _mm512_and_si512 ( y, lo) ;
54+ let y_1 = _mm512_srli_epi16 ( y, 8 ) ;
55+ let z_0 = _mm512_mullo_epi16 ( x_0, y_0) ;
56+ let z_1 = _mm512_mullo_epi16 ( x_1, y_1) ;
57+ _0 = _mm512_add_epi32 ( _0, _mm512_cvtepu16_epi32 ( _mm512_extracti32x8_epi32 ( z_0, 0 ) ) ) ;
58+ _1 = _mm512_add_epi32 ( _1, _mm512_cvtepu16_epi32 ( _mm512_extracti32x8_epi32 ( z_0, 1 ) ) ) ;
59+ _2 = _mm512_add_epi32 ( _2, _mm512_cvtepu16_epi32 ( _mm512_extracti32x8_epi32 ( z_1, 0 ) ) ) ;
60+ _3 = _mm512_add_epi32 ( _3, _mm512_cvtepu16_epi32 ( _mm512_extracti32x8_epi32 ( z_1, 1 ) ) ) ;
61+ }
62+ let r_0 = _mm512_add_epi32 ( _0, _2) ;
63+ let r_1 = _mm512_add_epi32 ( _1, _3) ;
64+ let r_2 = _mm512_add_epi32 ( r_0, r_1) ;
65+ _mm512_reduce_add_epi32 ( r_2) as u32
6666 }
6767
6868 #[ cfg( all( target_arch = "x86_64" , test, not( miri) ) ) ]
@@ -101,32 +101,31 @@ mod reduce_sum_of_x_as_u32_y_as_u32 {
101101 let mut n = lhs. len ( ) ;
102102 let mut a = lhs. as_ptr ( ) ;
103103 let mut b = rhs. as_ptr ( ) ;
104- let lo = _mm256_set1_epi16 ( 0x00ff_u16 as i16 ) ;
105- let hi = _mm256_set1_epi16 ( 0xff00_u16 as i16 ) ;
104+ let lo = _mm256_set1_epi16 ( 0x00ff_i16 ) ;
106105 let mut _0 = _mm256_setzero_si256 ( ) ;
107106 let mut _1 = _mm256_setzero_si256 ( ) ;
108107 let mut _2 = _mm256_setzero_si256 ( ) ;
109108 let mut _3 = _mm256_setzero_si256 ( ) ;
110109 while n >= 32 {
111110 let x = unsafe { _mm256_loadu_si256 ( a. cast ( ) ) } ;
112111 let y = unsafe { _mm256_loadu_si256 ( b. cast ( ) ) } ;
113- let x_l = _mm256_and_si256 ( x, lo) ;
114- let x_h = _mm256_and_si256 ( _mm256_srli_epi16 ( _mm256_and_si256 ( x, hi ) , 8 ) , lo ) ;
115- let y_l = _mm256_and_si256 ( y, lo) ;
116- let y_h = _mm256_and_si256 ( _mm256_srli_epi16 ( _mm256_and_si256 ( y, hi ) , 8 ) , lo ) ;
117- let l = _mm256_mullo_epi16 ( x_l , y_l ) ;
118- let h = _mm256_mullo_epi16 ( x_h , y_h ) ;
112+ let x_0 = _mm256_and_si256 ( x, lo) ;
113+ let x_1 = _mm256_srli_epi16 ( x, 8 ) ;
114+ let y_0 = _mm256_and_si256 ( y, lo) ;
115+ let y_1 = _mm256_srli_epi16 ( y, 8 ) ;
116+ let z_0 = _mm256_mullo_epi16 ( x_0 , y_0 ) ;
117+ let z_1 = _mm256_mullo_epi16 ( x_1 , y_1 ) ;
119118 a = unsafe { a. add ( 32 ) } ;
120119 b = unsafe { b. add ( 32 ) } ;
121120 n -= 32 ;
122- _0 = _mm256_add_epi32 ( _0, _mm256_cvtepu16_epi32 ( _mm256_extracti128_si256 ( l , 0 ) ) ) ;
123- _1 = _mm256_add_epi32 ( _1, _mm256_cvtepu16_epi32 ( _mm256_extracti128_si256 ( l , 1 ) ) ) ;
124- _2 = _mm256_add_epi32 ( _2, _mm256_cvtepu16_epi32 ( _mm256_extracti128_si256 ( h , 0 ) ) ) ;
125- _3 = _mm256_add_epi32 ( _3, _mm256_cvtepu16_epi32 ( _mm256_extracti128_si256 ( h , 1 ) ) ) ;
121+ _0 = _mm256_add_epi32 ( _0, _mm256_cvtepu16_epi32 ( _mm256_extracti128_si256 ( z_0 , 0 ) ) ) ;
122+ _1 = _mm256_add_epi32 ( _1, _mm256_cvtepu16_epi32 ( _mm256_extracti128_si256 ( z_0 , 1 ) ) ) ;
123+ _2 = _mm256_add_epi32 ( _2, _mm256_cvtepu16_epi32 ( _mm256_extracti128_si256 ( z_1 , 0 ) ) ) ;
124+ _3 = _mm256_add_epi32 ( _3, _mm256_cvtepu16_epi32 ( _mm256_extracti128_si256 ( z_1 , 1 ) ) ) ;
126125 }
127126 let mut sum = emulate_mm256_reduce_add_epi32 ( _mm256_add_epi32 (
128- _mm256_add_epi32 ( _0, _1 ) ,
129- _mm256_add_epi32 ( _2 , _3) ,
127+ _mm256_add_epi32 ( _0, _2 ) ,
128+ _mm256_add_epi32 ( _1 , _3) ,
130129 ) ) as u32 ;
131130 // this hint is used to disable loop unrolling
132131 while std:: hint:: black_box ( n) > 0 {
@@ -176,28 +175,27 @@ mod reduce_sum_of_x_as_u32_y_as_u32 {
176175 let mut n = lhs. len ( ) ;
177176 let mut a = lhs. as_ptr ( ) ;
178177 let mut b = rhs. as_ptr ( ) ;
179- let lo = _mm_set1_epi16 ( 0x00ff_u16 as i16 ) ;
180- let hi = _mm_set1_epi16 ( 0xff00_u16 as i16 ) ;
178+ let lo = _mm_set1_epi16 ( 0x00ff_i16 ) ;
181179 let mut _0 = _mm_setzero_si128 ( ) ;
182180 let mut _1 = _mm_setzero_si128 ( ) ;
183181 let mut _2 = _mm_setzero_si128 ( ) ;
184182 let mut _3 = _mm_setzero_si128 ( ) ;
185183 while n >= 16 {
186184 let x = unsafe { _mm_loadu_si128 ( a. cast ( ) ) } ;
187185 let y = unsafe { _mm_loadu_si128 ( b. cast ( ) ) } ;
188- let x_l = _mm_and_si128 ( x, lo) ;
189- let x_h = _mm_and_si128 ( _mm_srli_epi16 ( _mm_and_si128 ( x, hi ) , 8 ) , lo ) ;
190- let y_l = _mm_and_si128 ( y, lo) ;
191- let y_h = _mm_and_si128 ( _mm_srli_epi16 ( _mm_and_si128 ( y, hi ) , 8 ) , lo ) ;
192- let l = _mm_mullo_epi16 ( x_l , y_l ) ;
193- let h = _mm_mullo_epi16 ( x_h , y_h ) ;
186+ let x_0 = _mm_and_si128 ( x, lo) ;
187+ let x_1 = _mm_srli_epi16 ( x, 8 ) ;
188+ let y_0 = _mm_and_si128 ( y, lo) ;
189+ let y_1 = _mm_srli_epi16 ( y, 8 ) ;
190+ let z_0 = _mm_mullo_epi16 ( x_0 , y_0 ) ;
191+ let z_1 = _mm_mullo_epi16 ( x_1 , y_1 ) ;
194192 a = unsafe { a. add ( 16 ) } ;
195193 b = unsafe { b. add ( 16 ) } ;
196194 n -= 16 ;
197- _0 = _mm_add_epi32 ( _0, _mm_cvtepu16_epi32 ( l ) ) ;
198- _1 = _mm_add_epi32 ( _1, _mm_cvtepu16_epi32 ( _mm_unpackhi_epi64 ( l , l ) ) ) ;
199- _2 = _mm_add_epi32 ( _2, _mm_cvtepu16_epi32 ( h ) ) ;
200- _3 = _mm_add_epi32 ( _3, _mm_cvtepu16_epi32 ( _mm_unpackhi_epi64 ( h , h ) ) ) ;
195+ _0 = _mm_add_epi32 ( _0, _mm_cvtepu16_epi32 ( z_0 ) ) ;
196+ _1 = _mm_add_epi32 ( _1, _mm_cvtepu16_epi32 ( _mm_unpackhi_epi64 ( z_0 , z_0 ) ) ) ;
197+ _2 = _mm_add_epi32 ( _2, _mm_cvtepu16_epi32 ( z_1 ) ) ;
198+ _3 = _mm_add_epi32 ( _3, _mm_cvtepu16_epi32 ( _mm_unpackhi_epi64 ( z_1 , z_1 ) ) ) ;
201199 }
202200 let mut sum = emulate_mm_reduce_add_epi32 ( _mm_add_epi32 (
203201 _mm_add_epi32 ( _0, _1) ,
@@ -256,23 +254,23 @@ mod reduce_sum_of_x_as_u32_y_as_u32 {
256254 let mut _2 = vdupq_n_u32 ( 0 ) ;
257255 let mut _3 = vdupq_n_u32 ( 0 ) ;
258256 while n >= 16 {
259- let x = vreinterpretq_u16_u8 ( unsafe { vld1q_u8 ( a. cast ( ) ) } ) ;
260- let y = vreinterpretq_u16_u8 ( unsafe { vld1q_u8 ( b. cast ( ) ) } ) ;
261- let x_l = vandq_u16 ( x, lo) ;
262- let x_h = vshrq_n_u16 ( x, 8 ) ;
263- let y_l = vandq_u16 ( y, lo) ;
264- let y_h = vshrq_n_u16 ( y, 8 ) ;
265- let l = vmulq_u16 ( x_l , y_l ) ;
266- let h = vmulq_u16 ( x_h , y_h ) ;
257+ let x = unsafe { vld1q_u16 ( a. cast ( ) ) } ;
258+ let y = unsafe { vld1q_u16 ( b. cast ( ) ) } ;
259+ let x_0 = vandq_u16 ( x, lo) ;
260+ let x_1 = vshrq_n_u16 ( x, 8 ) ;
261+ let y_0 = vandq_u16 ( y, lo) ;
262+ let y_1 = vshrq_n_u16 ( y, 8 ) ;
263+ let z_0 = vmulq_u16 ( x_0 , y_0 ) ;
264+ let z_1 = vmulq_u16 ( x_1 , y_1 ) ;
267265 a = unsafe { a. add ( 16 ) } ;
268266 b = unsafe { b. add ( 16 ) } ;
269267 n -= 16 ;
270- _0 = vaddq_u32 ( _0, vmovl_u16 ( vget_low_u16 ( l ) ) ) ;
271- _1 = vaddq_u32 ( _1, vmovl_u16 ( vget_high_u16 ( l ) ) ) ;
272- _2 = vaddq_u32 ( _2, vmovl_u16 ( vget_low_u16 ( h ) ) ) ;
273- _3 = vaddq_u32 ( _3, vmovl_u16 ( vget_high_u16 ( h ) ) ) ;
268+ _0 = vaddq_u32 ( _0, vmovl_u16 ( vget_low_u16 ( z_0 ) ) ) ;
269+ _1 = vaddq_u32 ( _1, vmovl_u16 ( vget_high_u16 ( z_0 ) ) ) ;
270+ _2 = vaddq_u32 ( _2, vmovl_u16 ( vget_low_u16 ( z_1 ) ) ) ;
271+ _3 = vaddq_u32 ( _3, vmovl_u16 ( vget_high_u16 ( z_1 ) ) ) ;
274272 }
275- let mut sum = vaddvq_u32 ( vaddq_u32 ( vaddq_u32 ( _0, _1 ) , vaddq_u32 ( _2 , _3) ) ) ;
273+ let mut sum = vaddvq_u32 ( vaddq_u32 ( vaddq_u32 ( _0, _2 ) , vaddq_u32 ( _1 , _3) ) ) ;
276274 // this hint is used to disable loop unrolling
277275 while std:: hint:: black_box ( n) > 0 {
278276 let x = unsafe { a. read ( ) } ;
0 commit comments