Skip to content

Commit b2c1df2

Browse files
committed
add avx512 adler32 implementation
1 parent 21bfecc commit b2c1df2

File tree

6 files changed

+235
-4
lines changed

6 files changed

+235
-4
lines changed

.github/workflows/checks.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,10 @@ jobs:
485485
run: "cargo +nightly miri nextest run -j4 -p zlib-rs --target ${{ matrix.target }} --features=vpclmulqdq crc32::"
486486
env:
487487
RUSTFLAGS: "-Ctarget-feature=+vpclmulqdq,+avx512f"
488+
- name: Test avx512 adler32 implementation
489+
run: "cargo +nightly miri nextest run -j4 -p zlib-rs --target ${{ matrix.target }} --features=avx512 adler32::"
490+
env:
491+
RUSTFLAGS: "-Ctarget-feature=+avx2,+bmi2,+bmi1,+avx512f,+avx512bw"
488492
- name: Test allocator with miri
489493
run: "cargo +nightly miri nextest run -j4 -p zlib-rs --target ${{ matrix.target }} allocate::"
490494
- name: Test gz logic with miri

Cargo.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

zlib-rs/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ __internal-fuzz = ["arbitrary"]
2424
__internal-fuzz-disable-checksum = [] # disable checksum validation on inflate
2525
__internal-test = ["quickcheck"]
2626
ZLIB_DEBUG = []
27-
vpclmulqdq = [] # use avx512 to speed up crc32. Only stable from 1.89.0 onwards
27+
vpclmulqdq = [] # use avx512 to speed up crc32. Only stable from 1.89.0 onwards.
28+
avx512 = ["vpclmulqdq"] # use avx512 to speed up crc32 and adler32. Only stable from 1.89.0 onwards.
2829

2930

3031
[dependencies]

zlib-rs/src/adler32.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
#[cfg(target_arch = "x86_64")]
22
mod avx2;
3+
#[cfg(feature = "avx512")]
4+
#[cfg(target_arch = "x86_64")]
5+
mod avx512;
36
mod generic;
47
#[cfg(target_arch = "aarch64")]
58
mod neon;
69
#[cfg(any(target_arch = "wasm32", target_arch = "wasm64"))]
710
mod wasm;
811

912
pub fn adler32(start_checksum: u32, data: &[u8]) -> u32 {
13+
#[cfg(feature = "avx512")]
14+
#[cfg(target_arch = "x86_64")]
15+
if cfg!(all(target_feature = "avx512f", target_feature = "avx512bw")) {
16+
return unsafe { avx512::adler32_avx512(start_checksum, data) };
17+
}
18+
1019
#[cfg(target_arch = "x86_64")]
1120
if crate::cpu_features::is_enabled_avx2_and_bmi2() {
1221
return avx2::adler32_avx2(start_checksum, data);

zlib-rs/src/adler32/avx2.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ const DOT3V: __m256i = __m256i_literal([
2727
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
2828
]);
2929

30-
const ZERO: __m256i = __m256i_literal([0; 32]);
30+
const ZERO: __m256i = __m256i_literal([0u8; 32]);
3131

3232
/// 32 bit horizontal sum, adapted from Agner Fog's vector library.
3333
#[target_feature(enable = "avx2")]

zlib-rs/src/adler32/avx512.rs

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
use core::arch::x86_64::{
2+
__m512i, _mm256_add_epi32, _mm256_castsi256_si128, _mm256_extracti128_si256, _mm512_add_epi32,
3+
_mm512_castsi512_si256, _mm512_extracti64x4_epi64, _mm512_madd_epi16, _mm512_maddubs_epi16,
4+
_mm512_permutexvar_epi32, _mm512_sad_epu8, _mm512_set1_epi16, _mm512_setr_epi32,
5+
_mm512_slli_epi32, _mm512_zextsi128_si512, _mm_add_epi32, _mm_cvtsi128_si32, _mm_cvtsi32_si128,
6+
_mm_shuffle_epi32, _mm_unpackhi_epi64,
7+
};
8+
9+
use crate::adler32::{BASE, NMAX};
10+
11+
const fn __m512i_literal(bytes: [u8; 64]) -> __m512i {
12+
// SAFETY: any valid [u8; 64] represents a valid __m512i
13+
unsafe { core::mem::transmute(bytes) }
14+
}
15+
16+
const DOT2V: __m512i = __m512i_literal([
17+
64, 63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41,
18+
40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17,
19+
16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1,
20+
]);
21+
22+
const ZERO: __m512i = __m512i_literal([0u8; 64]);
23+
24+
#[target_feature(enable = "avx512f")]
25+
#[target_feature(enable = "avx512bw")]
26+
pub fn adler32_avx512(adler: u32, src: &[u8]) -> u32 {
27+
assert!(cfg!(target_feature = "avx512f"));
28+
assert!(cfg!(target_feature = "avx512bw"));
29+
// SAFETY: the assertion above ensures this code is not executed unless the CPU has avx512.
30+
unsafe { adler32_avx512_help(adler, src) }
31+
}
32+
33+
#[target_feature(enable = "avx512f")]
34+
#[target_feature(enable = "avx512bw")]
35+
unsafe fn adler32_avx512_help(adler: u32, src: &[u8]) -> u32 {
36+
if src.is_empty() {
37+
return adler;
38+
}
39+
40+
// SAFETY: [u8; 64] safely transmutes into __m512i.
41+
let (before, middle, after) = unsafe { src.align_to::<__m512i>() };
42+
43+
let adler = if !before.is_empty() {
44+
super::avx2::adler32_avx2(adler, before)
45+
} else {
46+
adler
47+
};
48+
49+
let mut adler1 = (adler >> 16) & 0xffff;
50+
let mut adler0 = adler & 0xffff;
51+
52+
// Use largest step possible (without causing overflow).
53+
for chunk in middle.chunks(NMAX as usize / 64) {
54+
(adler0, adler1) = unsafe { helper_64_bytes(adler0, adler1, chunk) };
55+
}
56+
57+
if after.is_empty() {
58+
adler0 | (adler1 << 16)
59+
} else {
60+
super::avx2::adler32_avx2(adler0 | (adler1 << 16), after)
61+
}
62+
}
63+
64+
unsafe fn helper_64_bytes(mut adler0: u32, mut adler1: u32, src: &[__m512i]) -> (u32, u32) {
65+
unsafe {
66+
let mut vs1 = _mm512_zextsi128_si512(_mm_cvtsi32_si128(adler0 as i32));
67+
let mut vs2 = _mm512_zextsi128_si512(_mm_cvtsi32_si128(adler1 as i32));
68+
69+
let mut vs1_0 = vs1;
70+
let mut vs3 = ZERO;
71+
72+
let dot3v = _mm512_set1_epi16(1);
73+
74+
for vbuf in src.iter().copied() {
75+
let vs1_sad = _mm512_sad_epu8(vbuf, ZERO);
76+
let v_short_sum2 = _mm512_maddubs_epi16(vbuf, DOT2V);
77+
vs1 = _mm512_add_epi32(vs1_sad, vs1);
78+
vs3 = _mm512_add_epi32(vs3, vs1_0);
79+
let vsum2 = _mm512_madd_epi16(v_short_sum2, dot3v);
80+
vs2 = _mm512_add_epi32(vsum2, vs2);
81+
vs1_0 = vs1;
82+
}
83+
84+
vs3 = _mm512_slli_epi32(vs3, 6);
85+
vs2 = _mm512_add_epi32(vs2, vs3);
86+
87+
adler0 = partial_hsum(vs1) % BASE;
88+
adler1 = _mm512_reduce_add_epu32(vs2) % BASE;
89+
90+
(adler0, adler1)
91+
}
92+
}
93+
94+
#[inline(always)]
95+
unsafe fn _mm512_reduce_add_epu32(x: __m512i) -> u32 {
96+
unsafe {
97+
let a = _mm512_extracti64x4_epi64(x, 1);
98+
let b = _mm512_extracti64x4_epi64(x, 0);
99+
100+
let a_plus_b = _mm256_add_epi32(a, b);
101+
let c = _mm256_extracti128_si256(a_plus_b, 1);
102+
let d = _mm256_extracti128_si256(a_plus_b, 0);
103+
let c_plus_d = _mm_add_epi32(c, d);
104+
105+
let sum1 = _mm_unpackhi_epi64(c_plus_d, c_plus_d);
106+
let sum2 = _mm_add_epi32(sum1, c_plus_d);
107+
let sum3 = _mm_shuffle_epi32(sum2, 0x01);
108+
let sum4 = _mm_add_epi32(sum2, sum3);
109+
110+
_mm_cvtsi128_si32(sum4) as u32
111+
}
112+
}
113+
114+
#[inline(always)]
115+
unsafe fn partial_hsum(x: __m512i) -> u32 {
116+
unsafe {
117+
// Permutation vector to extract every other integer.
118+
let perm_vec: __m512i =
119+
_mm512_setr_epi32(0, 2, 4, 6, 8, 10, 12, 14, 1, 1, 1, 1, 1, 1, 1, 1);
120+
121+
let non_zero = _mm512_permutexvar_epi32(perm_vec, x);
122+
123+
// From here, it's a simple 256 bit wide reduction sum.
124+
let non_zero_avx = _mm512_castsi512_si256(non_zero);
125+
126+
// See Agner Fog's vectorclass for a decent reference. Essentially, phadd is
127+
// pretty slow, much slower than the longer instruction sequence below.
128+
let sum1 = _mm_add_epi32(
129+
_mm256_extracti128_si256(non_zero_avx, 1),
130+
_mm256_castsi256_si128(non_zero_avx),
131+
);
132+
let sum2 = _mm_add_epi32(sum1, _mm_unpackhi_epi64(sum1, sum1));
133+
let sum3 = _mm_add_epi32(sum2, _mm_shuffle_epi32(sum2, 1));
134+
135+
_mm_cvtsi128_si32(sum3) as u32
136+
}
137+
}
138+
139+
#[cfg(test)]
140+
#[cfg(target_feature = "avx512f")]
141+
#[cfg(target_feature = "avx512bw")]
142+
mod test {
143+
use super::*;
144+
use core::arch::x86_64::__m256i;
145+
146+
#[test]
147+
fn empty_input() {
148+
let avx512 = unsafe { adler32_avx512(0, &[]) };
149+
let rust = crate::adler32::generic::adler32_rust(0, &[]);
150+
151+
assert_eq!(rust, avx512);
152+
}
153+
154+
#[test]
155+
fn zero_chunks() {
156+
let input = &[
157+
1u8, 39, 76, 148, 0, 58, 0, 14, 255, 59, 1, 229, 1, 83, 5, 84, 207, 152, 188,
158+
];
159+
let avx512 = unsafe { adler32_avx512(0, input) };
160+
let rust = crate::adler32::generic::adler32_rust(0, input);
161+
162+
assert_eq!(rust, avx512);
163+
}
164+
165+
#[test]
166+
fn one_chunk() {
167+
let input: [u8; 85] = core::array::from_fn(|i| i as u8);
168+
let avx512 = unsafe { adler32_avx512(0, &input) };
169+
let rust = crate::adler32::generic::adler32_rust(0, &input);
170+
171+
assert_eq!(rust, avx512);
172+
}
173+
174+
quickcheck::quickcheck! {
175+
fn adler32_avx512_is_adler32_rust(v: Vec<u8>, start: u32) -> bool {
176+
let avx512 = unsafe { adler32_avx512(start, &v) };
177+
let rust = crate::adler32::generic::adler32_rust(start, &v);
178+
179+
rust == avx512
180+
}
181+
}
182+
183+
const INPUT: [u8; 128] = {
184+
let mut array = [0; 128];
185+
let mut i = 0;
186+
while i < array.len() {
187+
array[i] = i as u8;
188+
i += 1;
189+
}
190+
191+
array
192+
};
193+
194+
#[test]
195+
fn start_alignment() {
196+
// SIMD algorithm is sensitive to alignment;
197+
for i in 0..16 {
198+
for start in [crate::ADLER32_INITIAL_VALUE as u32, 42] {
199+
let avx512 = unsafe { adler32_avx512(start, &INPUT[i..]) };
200+
let rust = crate::adler32::generic::adler32_rust(start, &INPUT[i..]);
201+
202+
assert_eq!(avx512, rust, "offset = {i}, start = {start}");
203+
}
204+
}
205+
}
206+
207+
#[test]
208+
#[cfg_attr(miri, ignore)]
209+
fn large_input() {
210+
const DEFAULT: &[u8] = include_bytes!("../deflate/test-data/paper-100k.pdf");
211+
212+
let avx512 = unsafe { adler32_avx512(42, DEFAULT) };
213+
let rust = crate::adler32::generic::adler32_rust(42, DEFAULT);
214+
215+
assert_eq!(avx512, rust);
216+
}
217+
}

0 commit comments

Comments
 (0)