Skip to content

Commit 6c27435

Browse files
committed
add avx512 adler32 implementation
1 parent d645ed7 commit 6c27435

File tree

6 files changed

+159
-19
lines changed

6 files changed

+159
-19
lines changed

.github/workflows/checks.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,10 @@ jobs:
484484
run: "cargo +nightly miri nextest run -j4 -p zlib-rs --target ${{ matrix.target }} --features=vpclmulqdq crc32::"
485485
env:
486486
RUSTFLAGS: "-Ctarget-feature=+vpclmulqdq,+avx512f"
487+
- name: Test avx512 adler32 implementation
488+
run: "cargo +nightly miri nextest run -j4 -p zlib-rs --target ${{ matrix.target }} --features=avx512 adler32::"
489+
env:
490+
RUSTFLAGS: "-Ctarget-feature=+avx2,+bmi2,+bmi1,+avx512f,+avx512bw"
487491
- name: Test allocator with miri
488492
run: "cargo +nightly miri nextest run -j4 -p zlib-rs --target ${{ matrix.target }} allocate::"
489493
- name: Test gz logic with miri

Cargo.lock

Lines changed: 0 additions & 16 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

zlib-rs/Cargo.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ __internal-fuzz = ["arbitrary"]
2424
__internal-fuzz-disable-checksum = [] # disable checksum validation on inflate
2525
__internal-test = ["quickcheck"]
2626
ZLIB_DEBUG = []
27-
vpclmulqdq = [] # use avx512 to speed up crc32. Only stable from 1.89.0 onwards
27+
vpclmulqdq = [] # use avx512 to speed up crc32. Only stable from 1.89.0 onwards.
28+
avx512 = ["vpclmulqdq"] # use avx512 to speed up crc32 and adler32. Only stable from 1.89.0 onwards.
2829

2930

3031
[dependencies]
@@ -33,5 +34,5 @@ quickcheck = { workspace = true, optional = true }
3334

3435
[dev-dependencies]
3536
crc32fast = "1.3.2"
36-
memoffset = "0.9.1"
37+
# memoffset = "0.9.1"
3738
quickcheck.workspace = true

zlib-rs/src/adler32.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
#[cfg(target_arch = "x86_64")]
22
mod avx2;
3+
#[cfg(feature = "avx512")]
4+
#[cfg(target_arch = "x86_64")]
5+
mod avx512;
36
mod generic;
47
#[cfg(target_arch = "aarch64")]
58
mod neon;
69
#[cfg(any(target_arch = "wasm32", target_arch = "wasm64"))]
710
mod wasm;
811

912
pub fn adler32(start_checksum: u32, data: &[u8]) -> u32 {
13+
#[cfg(feature = "avx512")]
14+
#[cfg(target_arch = "x86_64")]
15+
if cfg!(all(target_feature = "avx512f", target_feature = "avx512bw")) {
16+
return unsafe { avx512::adler32_avx512(start_checksum, data) };
17+
}
18+
1019
#[cfg(target_arch = "x86_64")]
1120
if crate::cpu_features::is_enabled_avx2_and_bmi2() {
1221
return avx2::adler32_avx2(start_checksum, data);

zlib-rs/src/adler32/avx2.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ const DOT3V: __m256i = __m256i_literal([
2727
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
2828
]);
2929

30-
const ZERO: __m256i = __m256i_literal([0; 32]);
30+
const ZERO: __m256i = __m256i_literal([0u8; 32]);
3131

3232
/// 32 bit horizontal sum, adapted from Agner Fog's vector library.
3333
#[target_feature(enable = "avx2")]

zlib-rs/src/adler32/avx512.rs

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
use core::arch::x86_64::{
2+
__m512i, _mm256_add_epi32, _mm256_castsi256_si128, _mm256_extracti128_si256, _mm512_add_epi32,
3+
_mm512_castsi512_si256, _mm512_extracti64x4_epi64, _mm512_madd_epi16, _mm512_maddubs_epi16,
4+
_mm512_permutexvar_epi32, _mm512_sad_epu8, _mm512_set1_epi16, _mm512_setr_epi32,
5+
_mm512_slli_epi32, _mm512_zextsi128_si512, _mm_add_epi32, _mm_cvtsi128_si32, _mm_cvtsi32_si128,
6+
_mm_shuffle_epi32, _mm_unpackhi_epi64,
7+
};
8+
9+
use crate::adler32::{BASE, NMAX};
10+
11+
const fn __m512i_literal(bytes: [u8; 64]) -> __m512i {
12+
// SAFETY: any valid [u8; 64] represents a valid __m512i
13+
unsafe { core::mem::transmute(bytes) }
14+
}
15+
16+
const DOT2V: __m512i = __m512i_literal([
17+
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
18+
27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
19+
51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
20+
]);
21+
22+
const ZERO: __m512i = __m512i_literal([0u8; 64]);
23+
24+
#[target_feature(enable = "avx512f")]
25+
#[target_feature(enable = "avx512bw")]
26+
pub fn adler32_avx512(adler: u32, src: &[u8]) -> u32 {
27+
assert!(cfg!(target_feature = "avx512f"));
28+
assert!(cfg!(target_feature = "avx512bw"));
29+
// SAFETY: the assertion above ensures this code is not executed unless the CPU has AVX2.
30+
unsafe { adler32_avx512_help(adler, src) }
31+
}
32+
33+
#[target_feature(enable = "avx512f")]
34+
#[target_feature(enable = "avx512bw")]
35+
unsafe fn adler32_avx512_help(adler: u32, src: &[u8]) -> u32 {
36+
if src.is_empty() {
37+
return adler;
38+
}
39+
40+
// SAFETY: [u8; 32] safely transmutes into __m256i.
41+
let (before, middle, after) = unsafe { src.align_to::<__m512i>() };
42+
43+
let adler = if !before.is_empty() {
44+
super::avx2::adler32_avx2(adler, src)
45+
} else {
46+
adler
47+
};
48+
49+
let mut adler1 = (adler >> 16) & 0xffff;
50+
let mut adler0 = adler & 0xffff;
51+
52+
// use largest step possible (without causing overflow)
53+
for chunk in middle.chunks(NMAX as usize / 64) {
54+
(adler0, adler1) = unsafe { helper_64_bytes(adler0, adler1, chunk) };
55+
}
56+
57+
if after.is_empty() {
58+
adler0 | (adler1 << 16)
59+
} else {
60+
super::avx2::adler32_avx2(adler0 | (adler1 << 16), src)
61+
}
62+
}
63+
64+
#[target_feature(enable = "avx2")]
65+
unsafe fn helper_64_bytes(mut adler0: u32, mut adler1: u32, src: &[__m512i]) -> (u32, u32) {
66+
unsafe {
67+
let mut vs1 = _mm512_zextsi128_si512(_mm_cvtsi32_si128(adler0 as i32));
68+
let mut vs2 = _mm512_zextsi128_si512(_mm_cvtsi32_si128(adler1 as i32));
69+
70+
let mut vs1_0 = vs1;
71+
let mut vs3 = ZERO;
72+
73+
let dot3v = _mm512_set1_epi16(1);
74+
75+
for vbuf in src.iter().copied() {
76+
let vs1_sad = _mm512_sad_epu8(vbuf, ZERO);
77+
let v_short_sum2 = _mm512_maddubs_epi16(vbuf, DOT2V);
78+
vs1 = _mm512_add_epi32(vs1_sad, vs1);
79+
vs3 = _mm512_add_epi32(vs3, vs1_0);
80+
let vsum2 = _mm512_madd_epi16(v_short_sum2, dot3v);
81+
vs2 = _mm512_add_epi32(vsum2, vs2);
82+
vs1_0 = vs1;
83+
}
84+
85+
/* Defer the multiplication with 32 to outside of the loop */
86+
vs3 = _mm512_slli_epi32(vs3, 6);
87+
vs2 = _mm512_add_epi32(vs2, vs3);
88+
89+
adler0 = partial_hsum(vs1) % BASE;
90+
adler1 = _mm512_reduce_add_epu32(vs2) % BASE;
91+
92+
(adler0, adler1)
93+
}
94+
}
95+
96+
#[inline(always)]
97+
unsafe fn _mm512_reduce_add_epu32(x: __m512i) -> u32 {
98+
unsafe {
99+
let a = _mm512_extracti64x4_epi64(x, 1);
100+
let b = _mm512_extracti64x4_epi64(x, 0);
101+
102+
let a_plus_b = _mm256_add_epi32(a, b);
103+
let c = _mm256_extracti128_si256(a_plus_b, 1);
104+
let d = _mm256_extracti128_si256(a_plus_b, 0);
105+
let c_plus_d = _mm_add_epi32(c, d);
106+
107+
let sum1 = _mm_unpackhi_epi64(c_plus_d, c_plus_d);
108+
let sum2 = _mm_add_epi32(sum1, c_plus_d);
109+
let sum3 = _mm_shuffle_epi32(sum2, 0x01);
110+
let sum4 = _mm_add_epi32(sum2, sum3);
111+
112+
_mm_cvtsi128_si32(sum4) as u32
113+
}
114+
}
115+
116+
#[inline(always)]
117+
unsafe fn partial_hsum(x: __m512i) -> u32 {
118+
unsafe {
119+
// We need a permutation vector to extract every other integer. The
120+
// rest are going to be zeros. Marking this const so the compiler stands
121+
// a better chance of keeping this resident in a register through entire
122+
// loop execution. We certainly have enough zmm registers (32) */
123+
let perm_vec: __m512i =
124+
_mm512_setr_epi32(0, 2, 4, 6, 8, 10, 12, 14, 1, 1, 1, 1, 1, 1, 1, 1);
125+
126+
let non_zero = _mm512_permutexvar_epi32(perm_vec, x);
127+
128+
/* From here, it's a simple 256 bit wide reduction sum */
129+
let non_zero_avx = _mm512_castsi512_si256(non_zero);
130+
131+
/* See Agner Fog's vectorclass for a decent reference. Essentially, phadd is
132+
* pretty slow, much slower than the longer instruction sequence below */
133+
let sum1 = _mm_add_epi32(
134+
_mm256_extracti128_si256(non_zero_avx, 1),
135+
_mm256_castsi256_si128(non_zero_avx),
136+
);
137+
let sum2 = _mm_add_epi32(sum1, _mm_unpackhi_epi64(sum1, sum1));
138+
let sum3 = _mm_add_epi32(sum2, _mm_shuffle_epi32(sum2, 1));
139+
140+
_mm_cvtsi128_si32(sum3) as u32
141+
}
142+
}

0 commit comments

Comments
 (0)