|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include <stdint.h> |
| 4 | +#include <immintrin.h> |
| 5 | + |
| 6 | +#include "popcount.h" |
| 7 | + |
| 8 | +using namespace cryptanalysislib::popcount; |
| 9 | + |
| 10 | +// TODO tests and benches |
| 11 | + |
| 12 | +/// Source: https://github.com/WojciechMula/toys/blob/master/simd-pdep-pext/pdep_avx2.cpp |
| 13 | +template <const uint32_t MAX_MASK_BITS, |
| 14 | + const bool EARLY_EXIT> |
| 15 | +void avx2_pdep_u32_reference(const uint32_t* data_arr, |
| 16 | + const uint32_t* mask_arr, |
| 17 | + uint32_t* out_arr, |
| 18 | + const size_t n) { |
| 19 | + static_assert(MAX_MASK_BITS > 0); |
| 20 | + static_assert(MAX_MASK_BITS <= 32); |
| 21 | + |
| 22 | + const __m256i one = _mm256_set1_epi32(1); |
| 23 | + const __m256i zero = _mm256_set1_epi32(0); |
| 24 | + for (size_t i=0; i < n; i += 8) { |
| 25 | + __m256i data = _mm256_loadu_si256((const __m256i*)(&data_arr[i])); |
| 26 | + __m256i mask = _mm256_loadu_si256((const __m256i*)(&mask_arr[i])); |
| 27 | + __m256i out = _mm256_set1_epi32(0); |
| 28 | + |
| 29 | + __m256i bit = one; |
| 30 | + |
| 31 | + /* for m = 0 .. 31 loop |
| 32 | + if mask[m] == 1 then |
| 33 | + out[m] = data[k] |
| 34 | + k := k + 1 |
| 35 | + fi |
| 36 | + end |
| 37 | + */ |
| 38 | + for (int j=0; j < MAX_MASK_BITS; j++) { |
| 39 | + // 1. isolate the first non-zoro bit set of mask (at m) |
| 40 | + // mask = [0101_1001_1100_0000|0000_1110_1100_1000|...] |
| 41 | + const __m256i m0 = _mm256_sub_epi32(mask, one); // m0 = [0101_1001_1011_1111|0000_1110_1100_0111|...] |
| 42 | + const __m256i m1 = _mm256_and_si256(mask, m0); // m1 = [0101_1001_1000_0000|0000_1110_1100_0000|...] |
| 43 | + const __m256i m2 = _mm256_xor_si256(mask, m1); // m2 = [0000_0000_0100_0000|0000_0000_0000_1000|...] |
| 44 | + |
| 45 | + // 2. isolate k-th bit from data data = [1100_0000_1111_1110|0000_0000_1000_0000|...] |
| 46 | + // bit = [0000_0000_0001_0000]0000_0000_0001_0000|...] |
| 47 | + const __m256i d0 = _mm256_and_si256(data, bit); // d0 = [0000_0000_0001_0000|0000_0000_0000_0000|...] |
| 48 | + // ^ ^ |
| 49 | + // 4. fill word with *negation* of data bit |
| 50 | + const __m256i d1 = _mm256_cmpeq_epi32(d0, zero);// d1 = [0000_0000_0000_0000|1111_1111_1111_1111|...] |
| 51 | + |
| 52 | + // 5. keep the mask bit, iff data[k] == 1 |
| 53 | + const __m256i m3 = _mm256_andnot_si256(d1, m2); // m3 = [0000_0000_0100_0000|0000_0000_0000_0000|...] |
| 54 | + |
| 55 | + // 6. update the out |
| 56 | + out = _mm256_or_si256(out, m3); |
| 57 | + mask = m1; |
| 58 | + |
| 59 | + // 7. the next bit in data to check |
| 60 | + bit = _mm256_add_epi32(bit, bit); |
| 61 | + |
| 62 | + // 8. all are zeros? |
| 63 | + if (EARLY_EXIT && _mm256_testc_si256(zero, mask)) { |
| 64 | + break; |
| 65 | + } |
| 66 | + } |
| 67 | + |
| 68 | + _mm256_storeu_si256((__m256i*)(&out_arr[i]), out); |
| 69 | + } |
| 70 | +} |
| 71 | + |
| 72 | +/// Source: https://github.com/WojciechMula/toys/blob/master/simd-pdep-pext/pdep_avx512.cpp |
| 73 | +template <const uint32_t MAX_MASK_BITS, |
| 74 | + const bool EARLY_EXIT> |
| 75 | +void avx512_pdep_u32_reference(const uint32_t* data_arr, |
| 76 | + const uint32_t* mask_arr, |
| 77 | + uint32_t* out_arr, |
| 78 | + const size_t n) { |
| 79 | + static_assert(MAX_MASK_BITS > 0); |
| 80 | + static_assert(MAX_MASK_BITS <= 32); |
| 81 | + |
| 82 | + const __m512i one = _mm512_set1_epi32(1); |
| 83 | + const __m512i zero = _mm512_set1_epi32(0); |
| 84 | + for (size_t i=0; i < n; i += 16) { |
| 85 | + __m512i data = _mm512_loadu_si512((const __m512i*)(&data_arr[i])); |
| 86 | + __m512i mask = _mm512_loadu_si512((const __m512i*)(&mask_arr[i])); |
| 87 | + __m512i out = _mm512_set1_epi32(0); |
| 88 | + |
| 89 | + __m512i bit = one; |
| 90 | + |
| 91 | + /* for m = 0 .. 31 loop |
| 92 | + if mask[m] == 1 then |
| 93 | + out[m] = data[k] |
| 94 | + k := k + 1 -- invariant: k is never greater than m |
| 95 | + fi |
| 96 | + end |
| 97 | + */ |
| 98 | + for (int j=0; j < MAX_MASK_BITS; j++) { |
| 99 | + // 1. isolate the first non-zoro bit set of mask (at m) |
| 100 | + |
| 101 | + // mask = [0101_1001_1100_0000|0000_1110_1100_1000|...] |
| 102 | + const __m512i m0 = _mm512_sub_epi32(mask, one); // m0 = [0101_1001_1011_1111|0000_1110_1100_0111|...] |
| 103 | + const __m512i m1 = _mm512_and_si512(mask, m0); // m1 = [0101_1001_1000_0000|0000_1110_1100_0000|...] |
| 104 | + const __m512i m2 = _mm512_xor_si512(mask, m1); // m2 = [0000_0000_0100_0000|0000_0000_0000_1000|...] |
| 105 | + // the above and & xor should be fused to a single ternarylogic instruction |
| 106 | + |
| 107 | + |
| 108 | + // 2. isolate k-th bit from data data = [1100_0000_1111_1110|0000_0000_1000_0000|...] |
| 109 | + // bit = [0000_0000_0001_0000]0000_0000_0001_0000|...] |
| 110 | + const __m512i d0 = _mm512_and_si512(data, bit); // d0 = [0000_0000_0001_0000|0000_0000_0000_0000|...] |
| 111 | + // ^ ^ |
| 112 | + // 4. move k-th bit to n-th position, possible since k <= n |
| 113 | + const __m512i d1 = _mm512_add_epi32(d0, m0); // d1 = [0101_1001_1100_1111|0000_1110_1100_0111|...] |
| 114 | + const __m512i d2 = _mm512_and_si512(d1, m2); // d2 = [0000_0000_0100_0000|0000_0000_0000_0000|...] |
| 115 | + |
| 116 | + // 6. update the out |
| 117 | + out = _mm512_or_si512(out, d2); |
| 118 | + // the above and & or should be fused to a single ternarylogic instruction |
| 119 | + mask = m1; |
| 120 | + |
| 121 | + // 7. the next bit in data to check |
| 122 | + bit = _mm512_add_epi32(bit, bit); |
| 123 | + |
| 124 | + // 8. all are zeros? |
| 125 | + if (EARLY_EXIT && (_mm512_cmpeq_epi32_mask(zero, mask) == 0xffff)) { |
| 126 | + break; |
| 127 | + } |
| 128 | + } |
| 129 | + |
| 130 | + _mm512_storeu_si512((__m512i*)(&out_arr[i]), out); |
| 131 | + } |
| 132 | +} |
| 133 | + |
| 134 | +unsigned int pdep32_emu(unsigned int v, unsigned int m) { |
| 135 | + unsigned int ret = 0, pc = popcount(m); |
| 136 | + switch (pc) { |
| 137 | + case 0: |
| 138 | + ret = 0; |
| 139 | + break; |
| 140 | + case 1: |
| 141 | + ret = (v & 1) << _tzcnt_u32(m); |
| 142 | + break; |
| 143 | + case 2: |
| 144 | + ret = (((v << (32 - pc)) & 0x80000000) >> _lzcnt_u32(m)) | ((v & 1) << _tzcnt_u32(m)); |
| 145 | + break; |
| 146 | + case 3: |
| 147 | + case 4: |
| 148 | + case 5: |
| 149 | + case 6: |
| 150 | + case 7: |
| 151 | + case 8: |
| 152 | + case 9: |
| 153 | + case 10: |
| 154 | + case 11: |
| 155 | + case 12: |
| 156 | + case 13: { |
| 157 | + unsigned int lsb = 0, msb = 0; |
| 158 | + unsigned int v1 = v << (32 - pc); |
| 159 | + for (unsigned int i = 0; i < pc / 2 ; i++) { |
| 160 | + const unsigned int tz = _tzcnt_u32(m); |
| 161 | + const unsigned int lz = _lzcnt_u32(m); |
| 162 | + m &= ~((0x80000000 >> lz) | (1 << tz)); |
| 163 | + msb = (v1 & 0x80000000) >> lz; |
| 164 | + lsb = (v & 1) << tz; |
| 165 | + ret |= (msb | lsb); |
| 166 | + v >>= 1; |
| 167 | + v1 <<= 1; |
| 168 | + } |
| 169 | + ret |= ((pc & 1) & v) << _tzcnt_u32(m); |
| 170 | + break; |
| 171 | + } |
| 172 | + default: { |
| 173 | + __m128i mtwo = _mm_set1_epi64x((~0ULL) - 1); |
| 174 | + __m128i mm = _mm_cvtsi32_si128(~m); |
| 175 | + __m128i bit0 = _mm_clmulepi64_si128(mm, mtwo, 0); |
| 176 | + mm = _mm_and_si128(mm, bit0); |
| 177 | + __m128i bit1 = _mm_clmulepi64_si128(mm, mtwo, 0); |
| 178 | + mm = _mm_and_si128(mm, bit1); |
| 179 | + __m128i bit2 = _mm_clmulepi64_si128(mm, mtwo, 0); |
| 180 | + mm = _mm_and_si128(mm, bit2); |
| 181 | + __m128i bit3 = _mm_clmulepi64_si128(mm, mtwo, 0); |
| 182 | + mm = _mm_and_si128(mm, bit3); |
| 183 | + __m128i bit4 = _mm_sub_epi64(_mm_setzero_si128(), mm); |
| 184 | + bit4 = _mm_add_epi64(bit4, bit4); |
| 185 | + __m128i a = _mm_cvtsi32_si128(_bzhi_u32(v, pc)); |
| 186 | + |
| 187 | + bit4 = _mm_srli_epi64(bit4, 16); |
| 188 | + a = _mm_add_epi64(_mm_andnot_si128(bit4, a),_mm_slli_epi64(_mm_and_si128(bit4, a), 16)); |
| 189 | + bit3 = _mm_srli_epi64(bit3, 8); |
| 190 | + a = _mm_add_epi64(_mm_andnot_si128(bit3, a),_mm_slli_epi64(_mm_and_si128(bit3, a), 8)); |
| 191 | + bit2 = _mm_srli_epi64(bit2, 4); |
| 192 | + a = _mm_add_epi64(_mm_andnot_si128(bit2, a),_mm_slli_epi64(_mm_and_si128(bit2, a), 4)); |
| 193 | + bit1 = _mm_srli_epi64(bit1, 2); |
| 194 | + a = _mm_add_epi64(_mm_andnot_si128(bit1, a),_mm_slli_epi64(_mm_and_si128(bit1, a), 2)); |
| 195 | + bit0 = _mm_srli_epi64(bit0, 1); |
| 196 | + a = _mm_add_epi64(_mm_andnot_si128(bit0, a),_mm_slli_epi64(_mm_and_si128(bit0, a), 1)); |
| 197 | + ret = _mm_cvtsi128_si32(a); |
| 198 | + } |
| 199 | + break; |
| 200 | + } |
| 201 | + return ret; |
| 202 | +}; |
0 commit comments