diff --git a/encodings/fastlanes/Cargo.toml b/encodings/fastlanes/Cargo.toml index 5c99b1e2157..ac30b2032b4 100644 --- a/encodings/fastlanes/Cargo.toml +++ b/encodings/fastlanes/Cargo.toml @@ -55,3 +55,8 @@ required-features = ["_test-harness"] name = "compute_between" harness = false required-features = ["_test-harness"] + +[[bench]] +name = "bit_transpose" +harness = false +required-features = ["_test-harness"] diff --git a/encodings/fastlanes/benches/bit_transpose.rs b/encodings/fastlanes/benches/bit_transpose.rs new file mode 100644 index 00000000000..4bc9027cc28 --- /dev/null +++ b/encodings/fastlanes/benches/bit_transpose.rs @@ -0,0 +1,312 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +#![allow(clippy::unwrap_used)] + +use divan::Bencher; +use vortex_fastlanes::bit_transpose::scalar::transpose_bits_scalar; +use vortex_fastlanes::bit_transpose::scalar::untranspose_bits_scalar; + +fn main() { + divan::main(); +} + +/// Generate deterministic test data. +#[allow(clippy::cast_possible_truncation)] +fn generate_test_data(seed: usize) -> [u8; 128] { + let mut data = [0u8; 128]; + for (i, byte) in data.iter_mut().enumerate() { + *byte = seed.wrapping_mul(17).wrapping_add(i).wrapping_mul(31) as u8; + } + data +} + +const BATCH_SIZE: usize = 1000; + +// ============================================================================ +// Transpose: single array +// ============================================================================ + +#[divan::bench] +fn transpose_scalar(bencher: Bencher) { + let input = generate_test_data(42); + + bencher + .with_inputs(|| (&input, [0u8; 128])) + .bench_refs(|(input, output)| { + transpose_bits_scalar(input, output); + }); +} + +// ============================================================================ +// Transpose: throughput (1000 arrays) +// ============================================================================ + +#[divan::bench] +fn transpose_scalar_throughput(bencher: Bencher) { + let inputs: Vec<[u8; 128]> = (0..BATCH_SIZE).map(generate_test_data).collect(); + + bencher + .with_inputs(|| (&inputs, vec![[0u8; 128]; BATCH_SIZE])) + .bench_refs(|(inputs, outputs)| { + for (input, output) in inputs.iter().zip(outputs.iter_mut()) { + transpose_bits_scalar(input, output); + } + }); +} + +// ============================================================================ +// Untranspose: single array +// ============================================================================ + +#[divan::bench] +fn untranspose_scalar(bencher: Bencher) { + let input = generate_test_data(42); + + bencher + .with_inputs(|| (&input, [0u8; 128])) + .bench_refs(|(input, output)| { + untranspose_bits_scalar(input, output); + }); +} + +// ============================================================================ +// Untranspose: throughput (1000 arrays) +// ============================================================================ + +#[divan::bench] +fn untranspose_scalar_throughput(bencher: Bencher) { + let inputs: Vec<[u8; 128]> = (0..BATCH_SIZE).map(generate_test_data).collect(); + + bencher + .with_inputs(|| (&inputs, vec![[0u8; 128]; BATCH_SIZE])) + .bench_refs(|(inputs, outputs)| { + for (input, output) in inputs.iter().zip(outputs.iter_mut()) { + untranspose_bits_scalar(input, output); + } + }); +} + +// ============================================================================ +// x86_64 benchmarks +// ============================================================================ + +#[cfg(target_arch = "x86_64")] +mod x86 { + use divan::Bencher; + use vortex_fastlanes::bit_transpose::x86::has_bmi2; + use vortex_fastlanes::bit_transpose::x86::has_vbmi; + use vortex_fastlanes::bit_transpose::x86::transpose_bits_bmi2; + use vortex_fastlanes::bit_transpose::x86::transpose_bits_vbmi; + use vortex_fastlanes::bit_transpose::x86::untranspose_bits_bmi2; + use vortex_fastlanes::bit_transpose::x86::untranspose_bits_vbmi; + + use super::BATCH_SIZE; + use super::generate_test_data; + + // --- Transpose: single array --- + + #[divan::bench] + fn transpose_bmi2(bencher: Bencher) { + if !has_bmi2() { + return; + } + + let input = generate_test_data(42); + + bencher + .with_inputs(|| (&input, [0u8; 128])) + .bench_refs(|(input, output)| { + unsafe { transpose_bits_bmi2(input, output) }; + }); + } + + #[divan::bench] + fn transpose_vbmi(bencher: Bencher) { + if !has_vbmi() { + return; + } + + let input = generate_test_data(42); + + bencher + .with_inputs(|| (&input, [0u8; 128])) + .bench_refs(|(input, output)| { + unsafe { transpose_bits_vbmi(input, output) }; + }); + } + + // --- Untranspose: single array --- + + #[divan::bench] + fn untranspose_bmi2(bencher: Bencher) { + if !has_bmi2() { + return; + } + + let input = generate_test_data(42); + + bencher + .with_inputs(|| (&input, [0u8; 128])) + .bench_refs(|(input, output)| { + unsafe { untranspose_bits_bmi2(input, output) }; + }); + } + + #[divan::bench] + fn untranspose_vbmi(bencher: Bencher) { + if !has_vbmi() { + return; + } + + let input = generate_test_data(42); + + bencher + .with_inputs(|| (&input, [0u8; 128])) + .bench_refs(|(input, output)| { + unsafe { untranspose_bits_vbmi(input, output) }; + }); + } + + // --- Transpose: throughput (1000 arrays) --- + + #[divan::bench] + fn transpose_bmi2_throughput(bencher: Bencher) { + if !has_bmi2() { + return; + } + + let inputs: Vec<[u8; 128]> = (0..BATCH_SIZE).map(generate_test_data).collect(); + + bencher + .with_inputs(|| (&inputs, vec![[0u8; 128]; BATCH_SIZE])) + .bench_refs(|(inputs, outputs)| { + for (input, output) in inputs.iter().zip(outputs.iter_mut()) { + unsafe { transpose_bits_bmi2(input, output) }; + } + }); + } + + #[divan::bench] + fn transpose_vbmi_throughput(bencher: Bencher) { + if !has_vbmi() { + return; + } + + let inputs: Vec<[u8; 128]> = (0..BATCH_SIZE).map(generate_test_data).collect(); + + bencher + .with_inputs(|| (&inputs, vec![[0u8; 128]; BATCH_SIZE])) + .bench_refs(|(inputs, outputs)| { + for (input, output) in inputs.iter().zip(outputs.iter_mut()) { + unsafe { transpose_bits_vbmi(input, output) }; + } + }); + } + + // --- Untranspose: throughput (1000 arrays) --- + + #[divan::bench] + fn untranspose_bmi2_throughput(bencher: Bencher) { + if !has_bmi2() { + return; + } + + let inputs: Vec<[u8; 128]> = (0..BATCH_SIZE).map(generate_test_data).collect(); + + bencher + .with_inputs(|| (&inputs, vec![[0u8; 128]; BATCH_SIZE])) + .bench_refs(|(inputs, outputs)| { + for (input, output) in inputs.iter().zip(outputs.iter_mut()) { + unsafe { untranspose_bits_bmi2(input, output) }; + } + }); + } + + #[divan::bench] + fn untranspose_vbmi_throughput(bencher: Bencher) { + if !has_vbmi() { + return; + } + + let inputs: Vec<[u8; 128]> = (0..BATCH_SIZE).map(generate_test_data).collect(); + + bencher + .with_inputs(|| (&inputs, vec![[0u8; 128]; BATCH_SIZE])) + .bench_refs(|(inputs, outputs)| { + for (input, output) in inputs.iter().zip(outputs.iter_mut()) { + unsafe { untranspose_bits_vbmi(input, output) }; + } + }); + } +} + +// ============================================================================ +// aarch64 benchmarks +// ============================================================================ + +#[cfg(target_arch = "aarch64")] +mod aarch64 { + use vortex_fastlanes::bit_transpose::aarch64::transpose_bits_neon; + use vortex_fastlanes::bit_transpose::aarch64::untranspose_bits_neon; + + use super::BATCH_SIZE; + use super::Bencher; + use super::generate_test_data; + + // --- Transpose: single array --- + + #[divan::bench] + fn transpose_neon(bencher: Bencher) { + let input = generate_test_data(42); + + bencher + .with_inputs(|| (&input, [0u8; 128])) + .bench_refs(|(input, output)| { + unsafe { transpose_bits_neon(input, output) }; + }); + } + + // --- Untranspose: single array --- + + #[divan::bench] + fn untranspose_neon(bencher: Bencher) { + let input = generate_test_data(42); + + bencher + .with_inputs(|| (&input, [0u8; 128])) + .bench_refs(|(input, output)| { + unsafe { untranspose_bits_neon(input, output) }; + }); + } + + // --- Transpose: throughput (1000 arrays) --- + + #[divan::bench] + fn transpose_neon_throughput(bencher: Bencher) { + let inputs: Vec<[u8; 128]> = (0..BATCH_SIZE).map(generate_test_data).collect(); + + bencher + .with_inputs(|| (&inputs, vec![[0u8; 128]; BATCH_SIZE])) + .bench_refs(|(inputs, outputs)| { + for (input, output) in inputs.iter().zip(outputs.iter_mut()) { + unsafe { transpose_bits_neon(input, output) }; + } + }); + } + + // --- Untranspose: throughput (1000 arrays) --- + + #[divan::bench] + fn untranspose_neon_throughput(bencher: Bencher) { + let inputs: Vec<[u8; 128]> = (0..BATCH_SIZE).map(generate_test_data).collect(); + + bencher + .with_inputs(|| (&inputs, vec![[0u8; 128]; BATCH_SIZE])) + .bench_refs(|(inputs, outputs)| { + for (input, output) in inputs.iter().zip(outputs.iter_mut()) { + unsafe { untranspose_bits_neon(input, output) }; + } + }); + } +} diff --git a/encodings/fastlanes/public-api.lock b/encodings/fastlanes/public-api.lock index 28fad23e177..8b6cd3acb17 100644 --- a/encodings/fastlanes/public-api.lock +++ b/encodings/fastlanes/public-api.lock @@ -1,5 +1,11 @@ pub mod vortex_fastlanes +pub mod vortex_fastlanes::bit_transpose + +pub fn vortex_fastlanes::bit_transpose::transpose_bits(input: &[u8; 128], output: &mut [u8; 128]) + +pub fn vortex_fastlanes::bit_transpose::untranspose_bits(input: &[u8; 128], output: &mut [u8; 128]) + pub mod vortex_fastlanes::bitpack_compress pub fn vortex_fastlanes::bitpack_compress::bit_width_histogram(array: &vortex_array::arrays::primitive::array::PrimitiveArray) -> vortex_error::VortexResult> diff --git a/encodings/fastlanes/src/bit_transpose/aarch64.rs b/encodings/fastlanes/src/bit_transpose/aarch64.rs new file mode 100644 index 00000000000..87fca0b5ba4 --- /dev/null +++ b/encodings/fastlanes/src/bit_transpose/aarch64.rs @@ -0,0 +1,300 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +#![cfg(target_arch = "aarch64")] + +use core::arch::aarch64::uint64x2_t; +use core::arch::aarch64::vandq_u64; +use core::arch::aarch64::vdupq_n_u64; +use core::arch::aarch64::veorq_u64; +use core::arch::aarch64::vgetq_lane_u64; +use core::arch::aarch64::vld1q_u8; +use core::arch::aarch64::vld1q_u8_x4; +use core::arch::aarch64::vorrq_u8; +use core::arch::aarch64::vqtbl4q_u8; +use core::arch::aarch64::vreinterpretq_u8_u64; +use core::arch::aarch64::vreinterpretq_u64_u8; +use core::arch::aarch64::vshlq_n_u64; +use core::arch::aarch64::vshrq_n_u64; +use core::arch::aarch64::vst1q_u8; + +use crate::bit_transpose::BASE_PATTERN_FIRST; +use crate::bit_transpose::BASE_PATTERN_SECOND; +use crate::bit_transpose::TRANSPOSE_2X2; +use crate::bit_transpose::TRANSPOSE_4X4; +use crate::bit_transpose::TRANSPOSE_8X8; + +/// Gather indices for the first half from input[0..64]. +/// Each group needs 4 bytes at stride 16 (the low half of the stride pattern). +/// Layout: [`g0_from_lo(4` bytes), pad(4 bytes), `g1_from_lo(4` bytes), pad(4 bytes), ...] +/// Two groups per 16-byte NEON register. +static GATHER_FIRST_LO: [[u8; 16]; 4] = [ + // Groups 0,1 from BASE_PATTERN_FIRST: bases 0, 8 + [ + 0, 16, 32, 48, 0xFF, 0xFF, 0xFF, 0xFF, 8, 24, 40, 56, 0xFF, 0xFF, 0xFF, 0xFF, + ], + // Groups 2,3: bases 4, 12 + [ + 4, 20, 36, 52, 0xFF, 0xFF, 0xFF, 0xFF, 12, 28, 44, 60, 0xFF, 0xFF, 0xFF, 0xFF, + ], + // Groups 4,5: bases 2, 10 + [ + 2, 18, 34, 50, 0xFF, 0xFF, 0xFF, 0xFF, 10, 26, 42, 58, 0xFF, 0xFF, 0xFF, 0xFF, + ], + // Groups 6,7: bases 6, 14 + [ + 6, 22, 38, 54, 0xFF, 0xFF, 0xFF, 0xFF, 14, 30, 46, 62, 0xFF, 0xFF, 0xFF, 0xFF, + ], +]; + +/// Gather indices for the first half from input[64..128]. +/// These fill in bytes 4-7 of each u64 (the high half of the stride pattern). +static GATHER_FIRST_HI: [[u8; 16]; 4] = [ + // Groups 0,1: bases 0, 8 (offset by -64 since table starts at input[64]) + [ + 0xFF, 0xFF, 0xFF, 0xFF, 0, 16, 32, 48, 0xFF, 0xFF, 0xFF, 0xFF, 8, 24, 40, 56, + ], + // Groups 2,3: bases 4, 12 + [ + 0xFF, 0xFF, 0xFF, 0xFF, 4, 20, 36, 52, 0xFF, 0xFF, 0xFF, 0xFF, 12, 28, 44, 60, + ], + // Groups 4,5: bases 2, 10 + [ + 0xFF, 0xFF, 0xFF, 0xFF, 2, 18, 34, 50, 0xFF, 0xFF, 0xFF, 0xFF, 10, 26, 42, 58, + ], + // Groups 6,7: bases 6, 14 + [ + 0xFF, 0xFF, 0xFF, 0xFF, 6, 22, 38, 54, 0xFF, 0xFF, 0xFF, 0xFF, 14, 30, 46, 62, + ], +]; + +/// Gather indices for the second half from input[0..64]. +/// Uses `BASE_PATTERN_SECOND`: bases [1, 9, 5, 13, 3, 11, 7, 15] +static GATHER_SECOND_LO: [[u8; 16]; 4] = [ + [ + 1, 17, 33, 49, 0xFF, 0xFF, 0xFF, 0xFF, 9, 25, 41, 57, 0xFF, 0xFF, 0xFF, 0xFF, + ], + [ + 5, 21, 37, 53, 0xFF, 0xFF, 0xFF, 0xFF, 13, 29, 45, 61, 0xFF, 0xFF, 0xFF, 0xFF, + ], + [ + 3, 19, 35, 51, 0xFF, 0xFF, 0xFF, 0xFF, 11, 27, 43, 59, 0xFF, 0xFF, 0xFF, 0xFF, + ], + [ + 7, 23, 39, 55, 0xFF, 0xFF, 0xFF, 0xFF, 15, 31, 47, 63, 0xFF, 0xFF, 0xFF, 0xFF, + ], +]; + +/// Gather indices for the second half from input[64..128]. +static GATHER_SECOND_HI: [[u8; 16]; 4] = [ + [ + 0xFF, 0xFF, 0xFF, 0xFF, 1, 17, 33, 49, 0xFF, 0xFF, 0xFF, 0xFF, 9, 25, 41, 57, + ], + [ + 0xFF, 0xFF, 0xFF, 0xFF, 5, 21, 37, 53, 0xFF, 0xFF, 0xFF, 0xFF, 13, 29, 45, 61, + ], + [ + 0xFF, 0xFF, 0xFF, 0xFF, 3, 19, 35, 51, 0xFF, 0xFF, 0xFF, 0xFF, 11, 27, 43, 59, + ], + [ + 0xFF, 0xFF, 0xFF, 0xFF, 7, 23, 39, 55, 0xFF, 0xFF, 0xFF, 0xFF, 15, 31, 47, 63, + ], +]; + +/// 8x8 byte transpose (scatter) permutation split into 4 × 16-byte chunks for NEON TBL. +/// Input layout: [g0b0..g0b7, g1b0..g1b7, ..., g7b0..g7b7] (64 bytes, group-major) +/// Output layout: [g0b0,g1b0,..,g7b0, g0b1,g1b1,..,g7b1, ...] (64 bytes, row-major) +/// Same permutation as x86 `SCATTER_8X8`, split for 16-byte NEON registers. +static SCATTER_8X8_NEON: [[u8; 16]; 4] = [ + [0, 8, 16, 24, 32, 40, 48, 56, 1, 9, 17, 25, 33, 41, 49, 57], + [2, 10, 18, 26, 34, 42, 50, 58, 3, 11, 19, 27, 35, 43, 51, 59], + [4, 12, 20, 28, 36, 44, 52, 60, 5, 13, 21, 29, 37, 45, 53, 61], + [6, 14, 22, 30, 38, 46, 54, 62, 7, 15, 23, 31, 39, 47, 55, 63], +]; + +/// Perform 8x8 bit transpose on two u64s packed in a `uint64x2_t`. +#[allow(unsafe_op_in_unsafe_fn)] +#[inline] +unsafe fn bit_transpose_8x8_neon(mut v: uint64x2_t) -> uint64x2_t { + let mask1 = vdupq_n_u64(TRANSPOSE_2X2); + let t = vandq_u64(veorq_u64(v, vshrq_n_u64::<7>(v)), mask1); + v = veorq_u64(veorq_u64(v, t), vshlq_n_u64::<7>(t)); + + let mask2 = vdupq_n_u64(TRANSPOSE_4X4); + let t = vandq_u64(veorq_u64(v, vshrq_n_u64::<14>(v)), mask2); + v = veorq_u64(veorq_u64(v, t), vshlq_n_u64::<14>(t)); + + let mask3 = vdupq_n_u64(TRANSPOSE_8X8); + let t = vandq_u64(veorq_u64(v, vshrq_n_u64::<28>(v)), mask3); + veorq_u64(veorq_u64(v, t), vshlq_n_u64::<28>(t)) +} + +/// Transpose 1024 bits using ARM NEON with TBL-based vectorized gather and scatter. +/// +/// Uses `vqtbl4q_u8` to gather bytes from the 128-byte input in parallel, +/// avoiding scalar byte-by-byte loads. Then uses `vqtbl4q_u8` again to perform +/// the 8x8 byte transpose for scatter. This is the NEON analog of x86 VBMI's +/// `vpermb`/`vpermi2b` byte permutation instructions. +/// +/// # Safety +/// Requires `AArch64` with NEON (always available on `AArch64`). +#[allow(unsafe_op_in_unsafe_fn)] +#[inline(never)] +pub unsafe fn transpose_bits_neon(input: &[u8; 128], output: &mut [u8; 128]) { + // Load all 128 input bytes into two uint8x16x4_t tables (64 bytes each) + let tbl_lo = vld1q_u8_x4(input.as_ptr()); + let tbl_hi = vld1q_u8_x4(input.as_ptr().add(64)); + + // Load scatter permutation indices (4 × 16 bytes) + let scatter0 = vld1q_u8(SCATTER_8X8_NEON[0].as_ptr()); + let scatter1 = vld1q_u8(SCATTER_8X8_NEON[1].as_ptr()); + let scatter2 = vld1q_u8(SCATTER_8X8_NEON[2].as_ptr()); + let scatter3 = vld1q_u8(SCATTER_8X8_NEON[3].as_ptr()); + + // Process first 64 output bytes (8 groups from BASE_PATTERN_FIRST) + // Gather and bit-transpose all 4 pairs, then scatter the full 64 bytes + let mut buf = [0u8; 64]; + for (i, (gather_lo, gather_high)) in [ + (GATHER_FIRST_LO, GATHER_FIRST_HI), + (GATHER_SECOND_LO, GATHER_SECOND_HI), + ] + .iter() + .enumerate() + { + for pair in 0..4 { + let idx_lo = vld1q_u8(gather_lo[pair].as_ptr()); + let idx_hi = vld1q_u8(gather_high[pair].as_ptr()); + + let from_lo = vqtbl4q_u8(tbl_lo, idx_lo); + let from_hi = vqtbl4q_u8(tbl_hi, idx_hi); + let gathered = vorrq_u8(from_lo, from_hi); + + let v = bit_transpose_8x8_neon(vreinterpretq_u64_u8(gathered)); + vst1q_u8(buf.as_mut_ptr().add(pair * 16), vreinterpretq_u8_u64(v)); + } + + // Load the 64-byte result as a TBL table and apply 8x8 byte transpose + let result_tbl = vld1q_u8_x4(buf.as_ptr()); + vst1q_u8( + output.as_mut_ptr().add(i * 64), + vqtbl4q_u8(result_tbl, scatter0), + ); + vst1q_u8( + output.as_mut_ptr().add(i * 64 + 16), + vqtbl4q_u8(result_tbl, scatter1), + ); + vst1q_u8( + output.as_mut_ptr().add(i * 64 + 32), + vqtbl4q_u8(result_tbl, scatter2), + ); + vst1q_u8( + output.as_mut_ptr().add(i * 64 + 48), + vqtbl4q_u8(result_tbl, scatter3), + ); + } +} + +/// Untranspose 1024 bits using ARM NEON with TBL-based vectorized operations. +/// +/// # Safety +/// Requires `AArch64` with NEON (always available on `AArch64`). +#[allow(unsafe_op_in_unsafe_fn)] +#[inline(never)] +pub unsafe fn untranspose_bits_neon(input: &[u8; 128], output: &mut [u8; 128]) { + // Load scatter indices (SCATTER_8X8 is self-inverse, so same table un-scatters) + let scatter0 = vld1q_u8(SCATTER_8X8_NEON[0].as_ptr()); + let scatter1 = vld1q_u8(SCATTER_8X8_NEON[1].as_ptr()); + let scatter2 = vld1q_u8(SCATTER_8X8_NEON[2].as_ptr()); + let scatter3 = vld1q_u8(SCATTER_8X8_NEON[3].as_ptr()); + + // Each iteration un-scatters the 64-byte input block to group-major order + let mut buf = [0u8; 64]; + for (i, base_pattern) in [BASE_PATTERN_FIRST, BASE_PATTERN_SECOND].iter().enumerate() { + let in_tbl = vld1q_u8_x4(input.as_ptr().add(i * 64)); + vst1q_u8(buf.as_mut_ptr(), vqtbl4q_u8(in_tbl, scatter0)); + vst1q_u8(buf.as_mut_ptr().add(16), vqtbl4q_u8(in_tbl, scatter1)); + vst1q_u8(buf.as_mut_ptr().add(32), vqtbl4q_u8(in_tbl, scatter2)); + vst1q_u8(buf.as_mut_ptr().add(48), vqtbl4q_u8(in_tbl, scatter3)); + + // Bit-transpose each pair and scatter to stride-16 output + for pair in 0..4 { + let base_group_0 = pair * 2; + let base_group_1 = pair * 2 + 1; + + let gathered = vld1q_u8(buf.as_ptr().add(pair * 16)); + let v = bit_transpose_8x8_neon(vreinterpretq_u64_u8(gathered)); + + let result_0 = vgetq_lane_u64::<0>(v); + let result_1 = vgetq_lane_u64::<1>(v); + + let out_base_0 = base_pattern[base_group_0]; + let out_base_1 = base_pattern[base_group_1]; + for i in 0..8 { + output[out_base_0 + i * 16] = (result_0 >> (i * 8)) as u8; + output[out_base_1 + i * 16] = (result_1 >> (i * 8)) as u8; + } + } + } +} + +#[cfg(test)] +mod tests { + use crate::bit_transpose::aarch64::transpose_bits_neon; + use crate::bit_transpose::aarch64::untranspose_bits_neon; + use crate::bit_transpose::generate_test_data; + use crate::bit_transpose::transpose_bits_baseline; + use crate::bit_transpose::untranspose_bits_baseline; + + #[test] + fn test_neon_matches_baseline() { + for seed in [0, 42, 123, 255] { + let input = generate_test_data(seed); + let mut baseline_out = [0u8; 128]; + let mut tbl_out = [0u8; 128]; + + transpose_bits_baseline(&input, &mut baseline_out); + unsafe { transpose_bits_neon(&input, &mut tbl_out) }; + + assert_eq!( + baseline_out, tbl_out, + "NEON TBL transpose doesn't match baseline for seed {seed}" + ); + } + } + + #[test] + fn test_neon_roundtrip() { + for seed in [0, 42, 123, 255] { + let input = generate_test_data(seed); + let mut transposed = [0u8; 128]; + let mut roundtrip = [0u8; 128]; + + unsafe { + transpose_bits_neon(&input, &mut transposed); + untranspose_bits_neon(&transposed, &mut roundtrip); + } + + assert_eq!( + input, roundtrip, + "NEON TBL roundtrip failed for seed {seed}" + ); + } + } + + #[test] + fn test_untranspose_neon_matches_baseline() { + for seed in [0, 42, 123, 255] { + let input = generate_test_data(seed); + let mut baseline_out = [0u8; 128]; + let mut tbl_out = [0u8; 128]; + + untranspose_bits_baseline(&input, &mut baseline_out); + unsafe { untranspose_bits_neon(&input, &mut tbl_out) }; + + assert_eq!( + baseline_out, tbl_out, + "NEON TBL untranspose doesn't match baseline for seed {seed}" + ); + } + } +} diff --git a/encodings/fastlanes/src/bit_transpose/mod.rs b/encodings/fastlanes/src/bit_transpose/mod.rs new file mode 100644 index 00000000000..864591af9d0 --- /dev/null +++ b/encodings/fastlanes/src/bit_transpose/mod.rs @@ -0,0 +1,177 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Fast implementations of the `FastLanes` 1024-bit transpose. +//! +//! The `FastLanes` transpose is a fixed permutation of 1024 bits (128 bytes) that +//! enables SIMD parallelism for encodings like delta and RLE. This module provides +//! optimized implementations for different x86 SIMD instruction sets. +//! +//! The key insight is that each output byte is formed by extracting the SAME bit +//! position from 8 different input bytes at stride 16. The input byte groups follow +//! the `FL_ORDER` permutation pattern. + +#[cfg(feature = "_test-harness")] +pub mod aarch64; +#[cfg(feature = "_test-harness")] +pub mod scalar; +#[cfg(feature = "_test-harness")] +pub mod x86; + +#[cfg(not(feature = "_test-harness"))] +mod aarch64; +#[cfg(not(feature = "_test-harness"))] +mod scalar; +#[cfg(not(feature = "_test-harness"))] +mod x86; + +/// Base indices for the first 64 output bytes (lanes 0-7). +/// Each entry indicates the starting input byte index for that output byte group. +/// Pattern: [0*2, 4*2, 2*2, 6*2, 1*2, 5*2, 3*2, 7*2] = [0, 8, 4, 12, 2, 10, 6, 14] +const BASE_PATTERN_FIRST: [usize; 8] = [0, 8, 4, 12, 2, 10, 6, 14]; + +/// Base indices for the second 64 output bytes (lanes 8-15). +/// Pattern: first pattern + 1 = [1, 9, 5, 13, 3, 11, 7, 15] +const BASE_PATTERN_SECOND: [usize; 8] = [1, 9, 5, 13, 3, 11, 7, 15]; + +/// Masks for transposing 8x8 bit blocks. +const TRANSPOSE_2X2: u64 = 0x00AA_00AA_00AA_00AA; +const TRANSPOSE_4X4: u64 = 0x0000_CCCC_0000_CCCC; +const TRANSPOSE_8X8: u64 = 0x0000_0000_F0F0_F0F0; + +/// Dispatch to the best available implementation at runtime. +#[inline] +pub fn transpose_bits(input: &[u8; 128], output: &mut [u8; 128]) { + #[cfg(target_arch = "x86_64")] + { + // VBMI is fastest + if x86::has_vbmi() { + unsafe { x86::transpose_bits_vbmi(input, output) }; + return; + } + if x86::has_bmi2() { + unsafe { x86::transpose_bits_bmi2(input, output) }; + return; + } + // Fall back to scalar + scalar::transpose_bits_scalar(input, output); + } + #[cfg(target_arch = "aarch64")] + { + unsafe { aarch64::transpose_bits_neon(input, output) }; + } + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + scalar::transpose_bits_scalar(input, output); +} + +/// Dispatch untranspose to the best available implementation at runtime. +#[inline] +pub fn untranspose_bits(input: &[u8; 128], output: &mut [u8; 128]) { + #[cfg(target_arch = "x86_64")] + { + // VBMI is fastest + if x86::has_vbmi() { + unsafe { x86::untranspose_bits_vbmi(input, output) }; + return; + } + if x86::has_bmi2() { + unsafe { x86::untranspose_bits_bmi2(input, output) }; + return; + } + // Fall back to scalar + scalar::untranspose_bits_scalar(input, output); + } + #[cfg(target_arch = "aarch64")] + { + unsafe { aarch64::untranspose_bits_neon(input, output) }; + } + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + scalar::untranspose_bits_scalar(input, output); +} + +#[cfg(test)] +#[allow(clippy::cast_possible_truncation)] +fn generate_test_data(seed: u8) -> [u8; 128] { + let mut data = [0u8; 128]; + for (i, byte) in data.iter_mut().enumerate() { + *byte = seed.wrapping_mul(17).wrapping_add(i as u8).wrapping_mul(31); + } + data +} + +#[cfg(test)] +pub fn transpose_bits_baseline(input: &[u8; 128], output: &mut [u8; 128]) { + for in_bit in 0..1024 { + let out_bit = fastlanes::transpose(in_bit); + let in_byte = in_bit / 8; + let in_bit_pos = in_bit % 8; + let out_byte = out_bit / 8; + let out_bit_pos = out_bit % 8; + let bit_val = (input[in_byte] >> in_bit_pos) & 1; + output[out_byte] |= bit_val << out_bit_pos; + } +} + +#[cfg(test)] +pub fn untranspose_bits_baseline(input: &[u8; 128], output: &mut [u8; 128]) { + for out_bit in 0..1024 { + let in_bit = fastlanes::transpose(out_bit); + let in_byte = in_bit / 8; + let in_bit_pos = in_bit % 8; + let out_byte = out_bit / 8; + let out_bit_pos = out_bit % 8; + let bit_val = (input[in_byte] >> in_bit_pos) & 1; + output[out_byte] |= bit_val << out_bit_pos; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_transpose_baseline_roundtrip() { + let input = generate_test_data(42); + let mut transposed = [0u8; 128]; + let mut roundtrip = [0u8; 128]; + + transpose_bits_baseline(&input, &mut transposed); + untranspose_bits_baseline(&transposed, &mut roundtrip); + + assert_eq!(input, roundtrip); + } + + #[test] + fn test_dispatch_matches_baseline() { + for seed in [0, 42, 123, 255] { + let input = generate_test_data(seed); + let mut baseline_out = [0u8; 128]; + let mut out = [0u8; 128]; + + transpose_bits_baseline(&input, &mut baseline_out); + transpose_bits(&input, &mut out); + + assert_eq!( + baseline_out, out, + "best dispatch doesn't match baseline for seed {seed}" + ); + } + } + + #[test] + fn test_untranspose_dispatch_matches_baseline() { + for seed in [0, 42, 123, 255] { + let input = generate_test_data(seed); + let mut baseline_out = [0u8; 128]; + let mut out = [0u8; 128]; + + untranspose_bits_baseline(&input, &mut baseline_out); + untranspose_bits(&input, &mut out); + + assert_eq!( + baseline_out, out, + "best untranspose dispatch doesn't match baseline for seed {seed}" + ); + } + } +} diff --git a/encodings/fastlanes/src/bit_transpose/scalar.rs b/encodings/fastlanes/src/bit_transpose/scalar.rs new file mode 100644 index 00000000000..c7a197202cd --- /dev/null +++ b/encodings/fastlanes/src/bit_transpose/scalar.rs @@ -0,0 +1,224 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use crate::bit_transpose::BASE_PATTERN_FIRST; +use crate::bit_transpose::BASE_PATTERN_SECOND; +use crate::bit_transpose::TRANSPOSE_2X2; +use crate::bit_transpose::TRANSPOSE_4X4; +use crate::bit_transpose::TRANSPOSE_8X8; + +/// Fast scalar transpose using the 8x8 bit matrix transpose algorithm. +/// +/// This version uses 64-bit gather + parallel bit operations instead of +/// extracting bits one by one. Typically 5-10x faster than the basic scalar version. +#[inline(never)] +#[allow(dead_code)] +pub fn transpose_bits_scalar(input: &[u8; 128], output: &mut [u8; 128]) { + // Helper to perform 8x8 bit transpose on a u64 (each byte becomes a row) + #[inline] + fn transpose_8x8(mut x: u64) -> u64 { + // Step 1: Transpose 2x2 bit blocks + let t = (x ^ (x >> 7)) & TRANSPOSE_2X2; + x = x ^ t ^ (t << 7); + // Step 2: Transpose 4x4 bit blocks + let t = (x ^ (x >> 14)) & TRANSPOSE_4X4; + x = x ^ t ^ (t << 14); + // Step 3: Transpose 8x8 bit blocks + let t = (x ^ (x >> 28)) & TRANSPOSE_8X8; + x ^ t ^ (t << 28) + } + + // Helper to gather 8 bytes at stride 16 into a u64 + #[inline] + fn gather(input: &[u8; 128], base: usize) -> u64 { + u64::from(input[base]) + | (u64::from(input[base + 16]) << 8) + | (u64::from(input[base + 32]) << 16) + | (u64::from(input[base + 48]) << 24) + | (u64::from(input[base + 64]) << 32) + | (u64::from(input[base + 80]) << 40) + | (u64::from(input[base + 96]) << 48) + | (u64::from(input[base + 112]) << 56) + } + + // Process first half (8 base groups, fully unrolled) + let r0 = transpose_8x8(gather(input, BASE_PATTERN_FIRST[0])); + let r1 = transpose_8x8(gather(input, BASE_PATTERN_FIRST[1])); + let r2 = transpose_8x8(gather(input, BASE_PATTERN_FIRST[2])); + let r3 = transpose_8x8(gather(input, BASE_PATTERN_FIRST[3])); + let r4 = transpose_8x8(gather(input, BASE_PATTERN_FIRST[4])); + let r5 = transpose_8x8(gather(input, BASE_PATTERN_FIRST[5])); + let r6 = transpose_8x8(gather(input, BASE_PATTERN_FIRST[6])); + let r7 = transpose_8x8(gather(input, BASE_PATTERN_FIRST[7])); + + // Write first 64 output bytes (unrolled) + for bit_pos in 0..8 { + output[bit_pos * 8] = (r0 >> (bit_pos * 8)) as u8; + output[bit_pos * 8 + 1] = (r1 >> (bit_pos * 8)) as u8; + output[bit_pos * 8 + 2] = (r2 >> (bit_pos * 8)) as u8; + output[bit_pos * 8 + 3] = (r3 >> (bit_pos * 8)) as u8; + output[bit_pos * 8 + 4] = (r4 >> (bit_pos * 8)) as u8; + output[bit_pos * 8 + 5] = (r5 >> (bit_pos * 8)) as u8; + output[bit_pos * 8 + 6] = (r6 >> (bit_pos * 8)) as u8; + output[bit_pos * 8 + 7] = (r7 >> (bit_pos * 8)) as u8; + } + + // Process second half + let r0 = transpose_8x8(gather(input, BASE_PATTERN_SECOND[0])); + let r1 = transpose_8x8(gather(input, BASE_PATTERN_SECOND[1])); + let r2 = transpose_8x8(gather(input, BASE_PATTERN_SECOND[2])); + let r3 = transpose_8x8(gather(input, BASE_PATTERN_SECOND[3])); + let r4 = transpose_8x8(gather(input, BASE_PATTERN_SECOND[4])); + let r5 = transpose_8x8(gather(input, BASE_PATTERN_SECOND[5])); + let r6 = transpose_8x8(gather(input, BASE_PATTERN_SECOND[6])); + let r7 = transpose_8x8(gather(input, BASE_PATTERN_SECOND[7])); + + for bit_pos in 0..8 { + output[64 + bit_pos * 8] = (r0 >> (bit_pos * 8)) as u8; + output[64 + bit_pos * 8 + 1] = (r1 >> (bit_pos * 8)) as u8; + output[64 + bit_pos * 8 + 2] = (r2 >> (bit_pos * 8)) as u8; + output[64 + bit_pos * 8 + 3] = (r3 >> (bit_pos * 8)) as u8; + output[64 + bit_pos * 8 + 4] = (r4 >> (bit_pos * 8)) as u8; + output[64 + bit_pos * 8 + 5] = (r5 >> (bit_pos * 8)) as u8; + output[64 + bit_pos * 8 + 6] = (r6 >> (bit_pos * 8)) as u8; + output[64 + bit_pos * 8 + 7] = (r7 >> (bit_pos * 8)) as u8; + } +} + +/// Fast scalar untranspose using the 8x8 bit matrix transpose algorithm. +#[inline(never)] +#[allow(dead_code)] +pub fn untranspose_bits_scalar(input: &[u8; 128], output: &mut [u8; 128]) { + #[inline] + fn transpose_8x8(mut x: u64) -> u64 { + let t = (x ^ (x >> 7)) & TRANSPOSE_2X2; + x = x ^ t ^ (t << 7); + let t = (x ^ (x >> 14)) & TRANSPOSE_4X4; + x = x ^ t ^ (t << 14); + let t = (x ^ (x >> 28)) & TRANSPOSE_8X8; + x ^ t ^ (t << 28) + } + + #[inline] + fn gather_transposed(input: &[u8; 128], base_group: usize, offset: usize) -> u64 { + let mut result: u64 = 0; + for bit_pos in 0..8 { + result |= u64::from(input[offset + bit_pos * 8 + base_group]) << (bit_pos * 8); + } + result + } + + #[inline] + fn scatter(output: &mut [u8; 128], base: usize, val: u64) { + output[base] = val as u8; + output[base + 16] = (val >> 8) as u8; + output[base + 32] = (val >> 16) as u8; + output[base + 48] = (val >> 24) as u8; + output[base + 64] = (val >> 32) as u8; + output[base + 80] = (val >> 40) as u8; + output[base + 96] = (val >> 48) as u8; + output[base + 112] = (val >> 56) as u8; + } + + // First half (unrolled) + let r0 = transpose_8x8(gather_transposed(input, 0, 0)); + let r1 = transpose_8x8(gather_transposed(input, 1, 0)); + let r2 = transpose_8x8(gather_transposed(input, 2, 0)); + let r3 = transpose_8x8(gather_transposed(input, 3, 0)); + let r4 = transpose_8x8(gather_transposed(input, 4, 0)); + let r5 = transpose_8x8(gather_transposed(input, 5, 0)); + let r6 = transpose_8x8(gather_transposed(input, 6, 0)); + let r7 = transpose_8x8(gather_transposed(input, 7, 0)); + + scatter(output, BASE_PATTERN_FIRST[0], r0); + scatter(output, BASE_PATTERN_FIRST[1], r1); + scatter(output, BASE_PATTERN_FIRST[2], r2); + scatter(output, BASE_PATTERN_FIRST[3], r3); + scatter(output, BASE_PATTERN_FIRST[4], r4); + scatter(output, BASE_PATTERN_FIRST[5], r5); + scatter(output, BASE_PATTERN_FIRST[6], r6); + scatter(output, BASE_PATTERN_FIRST[7], r7); + + // Second half + let r0 = transpose_8x8(gather_transposed(input, 0, 64)); + let r1 = transpose_8x8(gather_transposed(input, 1, 64)); + let r2 = transpose_8x8(gather_transposed(input, 2, 64)); + let r3 = transpose_8x8(gather_transposed(input, 3, 64)); + let r4 = transpose_8x8(gather_transposed(input, 4, 64)); + let r5 = transpose_8x8(gather_transposed(input, 5, 64)); + let r6 = transpose_8x8(gather_transposed(input, 6, 64)); + let r7 = transpose_8x8(gather_transposed(input, 7, 64)); + + scatter(output, BASE_PATTERN_SECOND[0], r0); + scatter(output, BASE_PATTERN_SECOND[1], r1); + scatter(output, BASE_PATTERN_SECOND[2], r2); + scatter(output, BASE_PATTERN_SECOND[3], r3); + scatter(output, BASE_PATTERN_SECOND[4], r4); + scatter(output, BASE_PATTERN_SECOND[5], r5); + scatter(output, BASE_PATTERN_SECOND[6], r6); + scatter(output, BASE_PATTERN_SECOND[7], r7); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::bit_transpose::generate_test_data; + + #[test] + fn test_scalar_matches_baseline() { + for seed in [0, 42, 123, 255] { + let input = generate_test_data(seed); + let mut baseline_out = [0u8; 128]; + let mut fast_out = [0u8; 128]; + + transpose_bits_scalar(&input, &mut baseline_out); + transpose_bits_scalar(&input, &mut fast_out); + + assert_eq!( + baseline_out, fast_out, + "scalar_fast transpose doesn't match baseline for seed {seed}" + ); + } + } + + #[test] + fn test_scalar_roundtrip() { + for seed in [0, 42, 123, 255] { + let input = generate_test_data(seed); + let mut transposed = [0u8; 128]; + let mut roundtrip = [0u8; 128]; + + transpose_bits_scalar(&input, &mut transposed); + untranspose_bits_scalar(&transposed, &mut roundtrip); + + assert_eq!( + input, roundtrip, + "scalar_fast roundtrip failed for seed {seed}" + ); + } + } + + #[test] + fn test_all_zeros() { + let input = [0u8; 128]; + let mut output = [0xFFu8; 128]; + + transpose_bits_scalar(&input, &mut output); + assert_eq!(output, [0u8; 128]); + + untranspose_bits_scalar(&input, &mut output); + assert_eq!(output, [0u8; 128]); + } + + #[test] + fn test_all_ones() { + let input = [0xFFu8; 128]; + let mut output = [0u8; 128]; + + transpose_bits_scalar(&input, &mut output); + assert_eq!(output, [0xFFu8; 128]); + + untranspose_bits_scalar(&input, &mut output); + assert_eq!(output, [0xFFu8; 128]); + } +} diff --git a/encodings/fastlanes/src/bit_transpose/x86.rs b/encodings/fastlanes/src/bit_transpose/x86.rs new file mode 100644 index 00000000000..8b2bd968eca --- /dev/null +++ b/encodings/fastlanes/src/bit_transpose/x86.rs @@ -0,0 +1,719 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors +#![cfg(target_arch = "x86_64")] + +use core::arch::x86_64::__m512i; +use core::arch::x86_64::_mm512_and_si512; +use core::arch::x86_64::_mm512_loadu_si512; +use core::arch::x86_64::_mm512_permutex2var_epi8; +use core::arch::x86_64::_mm512_permutexvar_epi8; +use core::arch::x86_64::_mm512_set1_epi64; +use core::arch::x86_64::_mm512_slli_epi64; +use core::arch::x86_64::_mm512_srli_epi64; +use core::arch::x86_64::_mm512_storeu_si512; +use core::arch::x86_64::_mm512_xor_si512; +use core::arch::x86_64::_pdep_u64; +use core::arch::x86_64::_pext_u64; +use std::is_x86_feature_detected; + +use crate::bit_transpose::BASE_PATTERN_FIRST; +use crate::bit_transpose::BASE_PATTERN_SECOND; +use crate::bit_transpose::TRANSPOSE_2X2; +use crate::bit_transpose::TRANSPOSE_4X4; +use crate::bit_transpose::TRANSPOSE_8X8; + +/// Check if BMI2 is available. +#[inline] +#[must_use] +pub fn has_bmi2() -> bool { + is_x86_feature_detected!("bmi2") +} + +/// Check if AVX-512 VBMI is available (for byte permutation). +#[inline] +#[must_use] +pub fn has_vbmi() -> bool { + is_x86_feature_detected!("avx512vbmi") +} + +/// Transpose 1024 bits using BMI2 PEXT instruction. +/// +/// PEXT extracts bits at positions specified by a mask into contiguous low bits. +/// Fully unrolled for ~12% better performance vs looped version. +/// +/// # Safety +/// Requires BMI2 support. Check with `has_bmi2()` before calling. +#[target_feature(enable = "bmi2")] +#[inline(never)] +#[allow(clippy::too_many_lines)] +#[allow(unsafe_op_in_unsafe_fn)] +pub unsafe fn transpose_bits_bmi2(input: &[u8; 128], output: &mut [u8; 128]) { + // Helper to gather 8 bytes at stride 16 into a u64 + #[inline] + fn gather(input: &[u8; 128], base: usize) -> u64 { + (input[base] as u64) + | ((input[base + 16] as u64) << 8) + | ((input[base + 32] as u64) << 16) + | ((input[base + 48] as u64) << 24) + | ((input[base + 64] as u64) << 32) + | ((input[base + 80] as u64) << 40) + | ((input[base + 96] as u64) << 48) + | ((input[base + 112] as u64) << 56) + } + + // Gather all 16 groups (fully unrolled) + // First half: BASE_PATTERN_FIRST = [0, 8, 4, 12, 2, 10, 6, 14] + let g0 = gather(input, 0); + let g1 = gather(input, 8); + let g2 = gather(input, 4); + let g3 = gather(input, 12); + let g4 = gather(input, 2); + let g5 = gather(input, 10); + let g6 = gather(input, 6); + let g7 = gather(input, 14); + // Second half: BASE_PATTERN_SECOND = [1, 9, 5, 13, 3, 11, 7, 15] + let g8 = gather(input, 1); + let g9 = gather(input, 9); + let g10 = gather(input, 5); + let g11 = gather(input, 13); + let g12 = gather(input, 3); + let g13 = gather(input, 11); + let g14 = gather(input, 7); + let g15 = gather(input, 15); + + // Masks for each bit position + let m0: u64 = 0x0101_0101_0101_0101; + let m1: u64 = 0x0202_0202_0202_0202; + let m2: u64 = 0x0404_0404_0404_0404; + let m3: u64 = 0x0808_0808_0808_0808; + let m4: u64 = 0x1010_1010_1010_1010; + let m5: u64 = 0x2020_2020_2020_2020; + let m6: u64 = 0x4040_4040_4040_4040; + let m7: u64 = 0x8080_8080_8080_8080; + + // First half - 64 PEXT operations (fully unrolled) + output[0] = _pext_u64(g0, m0) as u8; + output[1] = _pext_u64(g1, m0) as u8; + output[2] = _pext_u64(g2, m0) as u8; + output[3] = _pext_u64(g3, m0) as u8; + output[4] = _pext_u64(g4, m0) as u8; + output[5] = _pext_u64(g5, m0) as u8; + output[6] = _pext_u64(g6, m0) as u8; + output[7] = _pext_u64(g7, m0) as u8; + output[8] = _pext_u64(g0, m1) as u8; + output[9] = _pext_u64(g1, m1) as u8; + output[10] = _pext_u64(g2, m1) as u8; + output[11] = _pext_u64(g3, m1) as u8; + output[12] = _pext_u64(g4, m1) as u8; + output[13] = _pext_u64(g5, m1) as u8; + output[14] = _pext_u64(g6, m1) as u8; + output[15] = _pext_u64(g7, m1) as u8; + output[16] = _pext_u64(g0, m2) as u8; + output[17] = _pext_u64(g1, m2) as u8; + output[18] = _pext_u64(g2, m2) as u8; + output[19] = _pext_u64(g3, m2) as u8; + output[20] = _pext_u64(g4, m2) as u8; + output[21] = _pext_u64(g5, m2) as u8; + output[22] = _pext_u64(g6, m2) as u8; + output[23] = _pext_u64(g7, m2) as u8; + output[24] = _pext_u64(g0, m3) as u8; + output[25] = _pext_u64(g1, m3) as u8; + output[26] = _pext_u64(g2, m3) as u8; + output[27] = _pext_u64(g3, m3) as u8; + output[28] = _pext_u64(g4, m3) as u8; + output[29] = _pext_u64(g5, m3) as u8; + output[30] = _pext_u64(g6, m3) as u8; + output[31] = _pext_u64(g7, m3) as u8; + output[32] = _pext_u64(g0, m4) as u8; + output[33] = _pext_u64(g1, m4) as u8; + output[34] = _pext_u64(g2, m4) as u8; + output[35] = _pext_u64(g3, m4) as u8; + output[36] = _pext_u64(g4, m4) as u8; + output[37] = _pext_u64(g5, m4) as u8; + output[38] = _pext_u64(g6, m4) as u8; + output[39] = _pext_u64(g7, m4) as u8; + output[40] = _pext_u64(g0, m5) as u8; + output[41] = _pext_u64(g1, m5) as u8; + output[42] = _pext_u64(g2, m5) as u8; + output[43] = _pext_u64(g3, m5) as u8; + output[44] = _pext_u64(g4, m5) as u8; + output[45] = _pext_u64(g5, m5) as u8; + output[46] = _pext_u64(g6, m5) as u8; + output[47] = _pext_u64(g7, m5) as u8; + output[48] = _pext_u64(g0, m6) as u8; + output[49] = _pext_u64(g1, m6) as u8; + output[50] = _pext_u64(g2, m6) as u8; + output[51] = _pext_u64(g3, m6) as u8; + output[52] = _pext_u64(g4, m6) as u8; + output[53] = _pext_u64(g5, m6) as u8; + output[54] = _pext_u64(g6, m6) as u8; + output[55] = _pext_u64(g7, m6) as u8; + output[56] = _pext_u64(g0, m7) as u8; + output[57] = _pext_u64(g1, m7) as u8; + output[58] = _pext_u64(g2, m7) as u8; + output[59] = _pext_u64(g3, m7) as u8; + output[60] = _pext_u64(g4, m7) as u8; + output[61] = _pext_u64(g5, m7) as u8; + output[62] = _pext_u64(g6, m7) as u8; + output[63] = _pext_u64(g7, m7) as u8; + + // Second half - 64 PEXT operations (fully unrolled) + output[64] = _pext_u64(g8, m0) as u8; + output[65] = _pext_u64(g9, m0) as u8; + output[66] = _pext_u64(g10, m0) as u8; + output[67] = _pext_u64(g11, m0) as u8; + output[68] = _pext_u64(g12, m0) as u8; + output[69] = _pext_u64(g13, m0) as u8; + output[70] = _pext_u64(g14, m0) as u8; + output[71] = _pext_u64(g15, m0) as u8; + output[72] = _pext_u64(g8, m1) as u8; + output[73] = _pext_u64(g9, m1) as u8; + output[74] = _pext_u64(g10, m1) as u8; + output[75] = _pext_u64(g11, m1) as u8; + output[76] = _pext_u64(g12, m1) as u8; + output[77] = _pext_u64(g13, m1) as u8; + output[78] = _pext_u64(g14, m1) as u8; + output[79] = _pext_u64(g15, m1) as u8; + output[80] = _pext_u64(g8, m2) as u8; + output[81] = _pext_u64(g9, m2) as u8; + output[82] = _pext_u64(g10, m2) as u8; + output[83] = _pext_u64(g11, m2) as u8; + output[84] = _pext_u64(g12, m2) as u8; + output[85] = _pext_u64(g13, m2) as u8; + output[86] = _pext_u64(g14, m2) as u8; + output[87] = _pext_u64(g15, m2) as u8; + output[88] = _pext_u64(g8, m3) as u8; + output[89] = _pext_u64(g9, m3) as u8; + output[90] = _pext_u64(g10, m3) as u8; + output[91] = _pext_u64(g11, m3) as u8; + output[92] = _pext_u64(g12, m3) as u8; + output[93] = _pext_u64(g13, m3) as u8; + output[94] = _pext_u64(g14, m3) as u8; + output[95] = _pext_u64(g15, m3) as u8; + output[96] = _pext_u64(g8, m4) as u8; + output[97] = _pext_u64(g9, m4) as u8; + output[98] = _pext_u64(g10, m4) as u8; + output[99] = _pext_u64(g11, m4) as u8; + output[100] = _pext_u64(g12, m4) as u8; + output[101] = _pext_u64(g13, m4) as u8; + output[102] = _pext_u64(g14, m4) as u8; + output[103] = _pext_u64(g15, m4) as u8; + output[104] = _pext_u64(g8, m5) as u8; + output[105] = _pext_u64(g9, m5) as u8; + output[106] = _pext_u64(g10, m5) as u8; + output[107] = _pext_u64(g11, m5) as u8; + output[108] = _pext_u64(g12, m5) as u8; + output[109] = _pext_u64(g13, m5) as u8; + output[110] = _pext_u64(g14, m5) as u8; + output[111] = _pext_u64(g15, m5) as u8; + output[112] = _pext_u64(g8, m6) as u8; + output[113] = _pext_u64(g9, m6) as u8; + output[114] = _pext_u64(g10, m6) as u8; + output[115] = _pext_u64(g11, m6) as u8; + output[116] = _pext_u64(g12, m6) as u8; + output[117] = _pext_u64(g13, m6) as u8; + output[118] = _pext_u64(g14, m6) as u8; + output[119] = _pext_u64(g15, m6) as u8; + output[120] = _pext_u64(g8, m7) as u8; + output[121] = _pext_u64(g9, m7) as u8; + output[122] = _pext_u64(g10, m7) as u8; + output[123] = _pext_u64(g11, m7) as u8; + output[124] = _pext_u64(g12, m7) as u8; + output[125] = _pext_u64(g13, m7) as u8; + output[126] = _pext_u64(g14, m7) as u8; + output[127] = _pext_u64(g15, m7) as u8; +} + +/// Untranspose 1024 bits using BMI2 PDEP instruction. +/// +/// Structured per-output-group: for each group of 8 output bytes at stride 16, +/// PDEP 8 input bytes into different bit positions, OR in registers, then +/// scatter-store once. Each output byte is written exactly once (no read-modify-write). +/// +/// # Safety +/// Requires BMI2 support. Check with `has_bmi2()` before calling. +#[target_feature(enable = "bmi2")] +#[inline(never)] +#[allow(clippy::too_many_lines)] +#[allow(unsafe_op_in_unsafe_fn)] +pub unsafe fn untranspose_bits_bmi2(input: &[u8; 128], output: &mut [u8; 128]) { + // Helper: scatter a u64 to 8 output bytes at stride 16 + #[inline] + fn scatter(output: &mut [u8; 128], base: usize, val: u64) { + output[base] = val as u8; + output[base + 16] = (val >> 8) as u8; + output[base + 32] = (val >> 16) as u8; + output[base + 48] = (val >> 24) as u8; + output[base + 64] = (val >> 32) as u8; + output[base + 80] = (val >> 40) as u8; + output[base + 96] = (val >> 48) as u8; + output[base + 112] = (val >> 56) as u8; + } + + // Masks for each bit position + let m0: u64 = 0x0101_0101_0101_0101; + let m1: u64 = 0x0202_0202_0202_0202; + let m2: u64 = 0x0404_0404_0404_0404; + let m3: u64 = 0x0808_0808_0808_0808; + let m4: u64 = 0x1010_1010_1010_1010; + let m5: u64 = 0x2020_2020_2020_2020; + let m6: u64 = 0x4040_4040_4040_4040; + let m7: u64 = 0x8080_8080_8080_8080; + + // For each output group, the input bytes that contribute are at + // input[bit_pos * 8 + group_idx] for bit_pos 0..8. + // PDEP deposits the 8 bits of the input byte into the bit_pos position + // of each byte in the u64. + + // First half: 8 groups using BASE_PATTERN_FIRST + // Group 0 (base=0): input bytes [0, 8, 16, 24, 32, 40, 48, 56] + let v = _pdep_u64(input[0] as u64, m0) + | _pdep_u64(input[8] as u64, m1) + | _pdep_u64(input[16] as u64, m2) + | _pdep_u64(input[24] as u64, m3) + | _pdep_u64(input[32] as u64, m4) + | _pdep_u64(input[40] as u64, m5) + | _pdep_u64(input[48] as u64, m6) + | _pdep_u64(input[56] as u64, m7); + scatter(output, 0, v); + + // Group 1 (base=8) + let v = _pdep_u64(input[1] as u64, m0) + | _pdep_u64(input[9] as u64, m1) + | _pdep_u64(input[17] as u64, m2) + | _pdep_u64(input[25] as u64, m3) + | _pdep_u64(input[33] as u64, m4) + | _pdep_u64(input[41] as u64, m5) + | _pdep_u64(input[49] as u64, m6) + | _pdep_u64(input[57] as u64, m7); + scatter(output, 8, v); + + // Group 2 (base=4) + let v = _pdep_u64(input[2] as u64, m0) + | _pdep_u64(input[10] as u64, m1) + | _pdep_u64(input[18] as u64, m2) + | _pdep_u64(input[26] as u64, m3) + | _pdep_u64(input[34] as u64, m4) + | _pdep_u64(input[42] as u64, m5) + | _pdep_u64(input[50] as u64, m6) + | _pdep_u64(input[58] as u64, m7); + scatter(output, 4, v); + + // Group 3 (base=12) + let v = _pdep_u64(input[3] as u64, m0) + | _pdep_u64(input[11] as u64, m1) + | _pdep_u64(input[19] as u64, m2) + | _pdep_u64(input[27] as u64, m3) + | _pdep_u64(input[35] as u64, m4) + | _pdep_u64(input[43] as u64, m5) + | _pdep_u64(input[51] as u64, m6) + | _pdep_u64(input[59] as u64, m7); + scatter(output, 12, v); + + // Group 4 (base=2) + let v = _pdep_u64(input[4] as u64, m0) + | _pdep_u64(input[12] as u64, m1) + | _pdep_u64(input[20] as u64, m2) + | _pdep_u64(input[28] as u64, m3) + | _pdep_u64(input[36] as u64, m4) + | _pdep_u64(input[44] as u64, m5) + | _pdep_u64(input[52] as u64, m6) + | _pdep_u64(input[60] as u64, m7); + scatter(output, 2, v); + + // Group 5 (base=10) + let v = _pdep_u64(input[5] as u64, m0) + | _pdep_u64(input[13] as u64, m1) + | _pdep_u64(input[21] as u64, m2) + | _pdep_u64(input[29] as u64, m3) + | _pdep_u64(input[37] as u64, m4) + | _pdep_u64(input[45] as u64, m5) + | _pdep_u64(input[53] as u64, m6) + | _pdep_u64(input[61] as u64, m7); + scatter(output, 10, v); + + // Group 6 (base=6) + let v = _pdep_u64(input[6] as u64, m0) + | _pdep_u64(input[14] as u64, m1) + | _pdep_u64(input[22] as u64, m2) + | _pdep_u64(input[30] as u64, m3) + | _pdep_u64(input[38] as u64, m4) + | _pdep_u64(input[46] as u64, m5) + | _pdep_u64(input[54] as u64, m6) + | _pdep_u64(input[62] as u64, m7); + scatter(output, 6, v); + + // Group 7 (base=14) + let v = _pdep_u64(input[7] as u64, m0) + | _pdep_u64(input[15] as u64, m1) + | _pdep_u64(input[23] as u64, m2) + | _pdep_u64(input[31] as u64, m3) + | _pdep_u64(input[39] as u64, m4) + | _pdep_u64(input[47] as u64, m5) + | _pdep_u64(input[55] as u64, m6) + | _pdep_u64(input[63] as u64, m7); + scatter(output, 14, v); + + // Second half: 8 groups using BASE_PATTERN_SECOND + // Group 0 (base=1) + let v = _pdep_u64(input[64] as u64, m0) + | _pdep_u64(input[72] as u64, m1) + | _pdep_u64(input[80] as u64, m2) + | _pdep_u64(input[88] as u64, m3) + | _pdep_u64(input[96] as u64, m4) + | _pdep_u64(input[104] as u64, m5) + | _pdep_u64(input[112] as u64, m6) + | _pdep_u64(input[120] as u64, m7); + scatter(output, 1, v); + + // Group 1 (base=9) + let v = _pdep_u64(input[65] as u64, m0) + | _pdep_u64(input[73] as u64, m1) + | _pdep_u64(input[81] as u64, m2) + | _pdep_u64(input[89] as u64, m3) + | _pdep_u64(input[97] as u64, m4) + | _pdep_u64(input[105] as u64, m5) + | _pdep_u64(input[113] as u64, m6) + | _pdep_u64(input[121] as u64, m7); + scatter(output, 9, v); + + // Group 2 (base=5) + let v = _pdep_u64(input[66] as u64, m0) + | _pdep_u64(input[74] as u64, m1) + | _pdep_u64(input[82] as u64, m2) + | _pdep_u64(input[90] as u64, m3) + | _pdep_u64(input[98] as u64, m4) + | _pdep_u64(input[106] as u64, m5) + | _pdep_u64(input[114] as u64, m6) + | _pdep_u64(input[122] as u64, m7); + scatter(output, 5, v); + + // Group 3 (base=13) + let v = _pdep_u64(input[67] as u64, m0) + | _pdep_u64(input[75] as u64, m1) + | _pdep_u64(input[83] as u64, m2) + | _pdep_u64(input[91] as u64, m3) + | _pdep_u64(input[99] as u64, m4) + | _pdep_u64(input[107] as u64, m5) + | _pdep_u64(input[115] as u64, m6) + | _pdep_u64(input[123] as u64, m7); + scatter(output, 13, v); + + // Group 4 (base=3) + let v = _pdep_u64(input[68] as u64, m0) + | _pdep_u64(input[76] as u64, m1) + | _pdep_u64(input[84] as u64, m2) + | _pdep_u64(input[92] as u64, m3) + | _pdep_u64(input[100] as u64, m4) + | _pdep_u64(input[108] as u64, m5) + | _pdep_u64(input[116] as u64, m6) + | _pdep_u64(input[124] as u64, m7); + scatter(output, 3, v); + + // Group 5 (base=11) + let v = _pdep_u64(input[69] as u64, m0) + | _pdep_u64(input[77] as u64, m1) + | _pdep_u64(input[85] as u64, m2) + | _pdep_u64(input[93] as u64, m3) + | _pdep_u64(input[101] as u64, m4) + | _pdep_u64(input[109] as u64, m5) + | _pdep_u64(input[117] as u64, m6) + | _pdep_u64(input[125] as u64, m7); + scatter(output, 11, v); + + // Group 6 (base=7) + let v = _pdep_u64(input[70] as u64, m0) + | _pdep_u64(input[78] as u64, m1) + | _pdep_u64(input[86] as u64, m2) + | _pdep_u64(input[94] as u64, m3) + | _pdep_u64(input[102] as u64, m4) + | _pdep_u64(input[110] as u64, m5) + | _pdep_u64(input[118] as u64, m6) + | _pdep_u64(input[126] as u64, m7); + scatter(output, 7, v); + + // Group 7 (base=15) + let v = _pdep_u64(input[71] as u64, m0) + | _pdep_u64(input[79] as u64, m1) + | _pdep_u64(input[87] as u64, m2) + | _pdep_u64(input[95] as u64, m3) + | _pdep_u64(input[103] as u64, m4) + | _pdep_u64(input[111] as u64, m5) + | _pdep_u64(input[119] as u64, m6) + | _pdep_u64(input[127] as u64, m7); + scatter(output, 15, v); +} + +// Static permutation tables for VBMI gather operations +static GATHER_FIRST: [u8; 64] = [ + // Gather bytes at stride 16 for first 8 groups (bases from BASE_PATTERN_FIRST) + // Group 0: base=0 + 0, 16, 32, 48, 64, 80, 96, 112, // Group 1: base=8 + 8, 24, 40, 56, 72, 88, 104, 120, // Group 2: base=4 + 4, 20, 36, 52, 68, 84, 100, 116, // Group 3: base=12 + 12, 28, 44, 60, 76, 92, 108, 124, // Group 4: base=2 + 2, 18, 34, 50, 66, 82, 98, 114, // Group 5: base=10 + 10, 26, 42, 58, 74, 90, 106, 122, // Group 6: base=6 + 6, 22, 38, 54, 70, 86, 102, 118, // Group 7: base=14 + 14, 30, 46, 62, 78, 94, 110, 126, +]; + +static GATHER_SECOND: [u8; 64] = [ + // Gather bytes at stride 16 for second 8 groups (bases from BASE_PATTERN_SECOND) + // Group 0: base=1 + 1, 17, 33, 49, 65, 81, 97, 113, // Group 1: base=9 + 9, 25, 41, 57, 73, 89, 105, 121, // Group 2: base=5 + 5, 21, 37, 53, 69, 85, 101, 117, // Group 3: base=13 + 13, 29, 45, 61, 77, 93, 109, 125, // Group 4: base=3 + 3, 19, 35, 51, 67, 83, 99, 115, // Group 5: base=11 + 11, 27, 43, 59, 75, 91, 107, 123, // Group 6: base=7 + 7, 23, 39, 55, 71, 87, 103, 119, // Group 7: base=15 + 15, 31, 47, 63, 79, 95, 111, 127, +]; + +// 8x8 byte transpose permutation for scatter phase +// Input: [g0b0..g0b7, g1b0..g1b7, ..., g7b0..g7b7] (8 groups of 8 bytes) +// Output: [g0b0,g1b0,..,g7b0, g0b1,g1b1,..,g7b1, ...] (8 rows of 8 bytes) +static SCATTER_8X8: [u8; 64] = [ + 0, 8, 16, 24, 32, 40, 48, 56, // byte 0 from each group + 1, 9, 17, 25, 33, 41, 49, 57, // byte 1 from each group + 2, 10, 18, 26, 34, 42, 50, 58, // byte 2 from each group + 3, 11, 19, 27, 35, 43, 51, 59, // byte 3 from each group + 4, 12, 20, 28, 36, 44, 52, 60, // byte 4 from each group + 5, 13, 21, 29, 37, 45, 53, 61, // byte 5 from each group + 6, 14, 22, 30, 38, 46, 54, 62, // byte 6 from each group + 7, 15, 23, 31, 39, 47, 55, 63, // byte 7 from each group +]; + +/// Transpose 1024 bits using AVX-512 VBMI for vectorized gather and scatter. +/// +/// Uses vpermi2b to gather bytes from stride-16 positions in parallel, +/// and vpermb for the final 8x8 byte transpose to output format. +/// +/// # Safety +/// Requires AVX-512F, AVX-512BW, and AVX-512VBMI support. +#[target_feature(enable = "avx512f", enable = "avx512bw", enable = "avx512vbmi")] +#[inline(never)] +#[allow(clippy::cast_possible_wrap)] +#[allow(clippy::cast_ptr_alignment)] +#[allow(unsafe_op_in_unsafe_fn)] +pub unsafe fn transpose_bits_vbmi(input: &[u8; 128], output: &mut [u8; 128]) { + // Load all 128 input bytes into two ZMM registers + let in_lo = _mm512_loadu_si512(input.as_ptr().cast::<__m512i>()); + let in_hi = _mm512_loadu_si512(input.as_ptr().add(64).cast::<__m512i>()); + + // Load permutation indices (static tables) + let idx_first = _mm512_loadu_si512(GATHER_FIRST.as_ptr().cast::<__m512i>()); + let idx_second = _mm512_loadu_si512(GATHER_SECOND.as_ptr().cast::<__m512i>()); + let idx_scatter = _mm512_loadu_si512(SCATTER_8X8.as_ptr().cast::<__m512i>()); + + // Masks for 8x8 bit transpose + let mask1 = _mm512_set1_epi64(TRANSPOSE_2X2 as i64); + let mask2 = _mm512_set1_epi64(TRANSPOSE_4X4 as i64); + let mask3 = _mm512_set1_epi64(TRANSPOSE_8X8 as i64); + + // Process first half + let gathered = _mm512_permutex2var_epi8(in_lo, idx_first, in_hi); + + // 8x8 bit transpose on all 8 groups in parallel + let mut v = gathered; + let t = _mm512_and_si512(_mm512_xor_si512(v, _mm512_srli_epi64::<7>(v)), mask1); + v = _mm512_xor_si512(_mm512_xor_si512(v, t), _mm512_slli_epi64::<7>(t)); + let t = _mm512_and_si512(_mm512_xor_si512(v, _mm512_srli_epi64::<14>(v)), mask2); + v = _mm512_xor_si512(_mm512_xor_si512(v, t), _mm512_slli_epi64::<14>(t)); + let t = _mm512_and_si512(_mm512_xor_si512(v, _mm512_srli_epi64::<28>(v)), mask3); + v = _mm512_xor_si512(_mm512_xor_si512(v, t), _mm512_slli_epi64::<28>(t)); + + // 8x8 byte transpose for scatter using vpermb + let scattered = _mm512_permutexvar_epi8(idx_scatter, v); + _mm512_storeu_si512(output.as_mut_ptr().cast::<__m512i>(), scattered); + + // Process second half + let gathered = _mm512_permutex2var_epi8(in_lo, idx_second, in_hi); + + let mut v = gathered; + let t = _mm512_and_si512(_mm512_xor_si512(v, _mm512_srli_epi64::<7>(v)), mask1); + v = _mm512_xor_si512(_mm512_xor_si512(v, t), _mm512_slli_epi64::<7>(t)); + let t = _mm512_and_si512(_mm512_xor_si512(v, _mm512_srli_epi64::<14>(v)), mask2); + v = _mm512_xor_si512(_mm512_xor_si512(v, t), _mm512_slli_epi64::<14>(t)); + let t = _mm512_and_si512(_mm512_xor_si512(v, _mm512_srli_epi64::<28>(v)), mask3); + v = _mm512_xor_si512(_mm512_xor_si512(v, t), _mm512_slli_epi64::<28>(t)); + + let scattered = _mm512_permutexvar_epi8(idx_scatter, v); + _mm512_storeu_si512(output.as_mut_ptr().add(64).cast::<__m512i>(), scattered); +} + +/// Untranspose 1024 bits using AVX-512 VBMI for vectorized scatter. +/// +/// # Safety +/// Requires AVX-512F, AVX-512BW, and AVX-512VBMI support. +#[target_feature(enable = "avx512f", enable = "avx512bw", enable = "avx512vbmi")] +#[inline(never)] +#[allow(clippy::cast_possible_wrap)] +#[allow(clippy::cast_ptr_alignment)] +#[allow(unsafe_op_in_unsafe_fn)] +pub unsafe fn untranspose_bits_vbmi(input: &[u8; 128], output: &mut [u8; 128]) { + // For untranspose, we gather consecutive bytes from transposed layout, + // then scatter back to stride-16 positions + + // Gather indices for first half - collect 8 bytes per group from transposed layout + // In transposed layout, bytes for group 0 are at: [0, 8, 16, 24, 32, 40, 48, 56] + let gather_indices: [u8; 64] = [ + 0, 8, 16, 24, 32, 40, 48, 56, // Group 0 + 1, 9, 17, 25, 33, 41, 49, 57, // Group 1 + 2, 10, 18, 26, 34, 42, 50, 58, // Group 2 + 3, 11, 19, 27, 35, 43, 51, 59, // Group 3 + 4, 12, 20, 28, 36, 44, 52, 60, // Group 4 + 5, 13, 21, 29, 37, 45, 53, 61, // Group 5 + 6, 14, 22, 30, 38, 46, 54, 62, // Group 6 + 7, 15, 23, 31, 39, 47, 55, 63, // Group 7 + ]; + + let in_first = _mm512_loadu_si512(input.as_ptr().cast::<__m512i>()); + let idx = _mm512_loadu_si512(gather_indices.as_ptr().cast::<__m512i>()); + let gathered = _mm512_permutexvar_epi8(idx, in_first); + + // 8x8 bit transpose + let mask1 = _mm512_set1_epi64(TRANSPOSE_2X2 as i64); + let mask2 = _mm512_set1_epi64(TRANSPOSE_4X4 as i64); + let mask3 = _mm512_set1_epi64(TRANSPOSE_8X8 as i64); + + let mut v = gathered; + let t = _mm512_and_si512(_mm512_xor_si512(v, _mm512_srli_epi64::<7>(v)), mask1); + v = _mm512_xor_si512(_mm512_xor_si512(v, t), _mm512_slli_epi64::<7>(t)); + + let t = _mm512_and_si512(_mm512_xor_si512(v, _mm512_srli_epi64::<14>(v)), mask2); + v = _mm512_xor_si512(_mm512_xor_si512(v, t), _mm512_slli_epi64::<14>(t)); + + let t = _mm512_and_si512(_mm512_xor_si512(v, _mm512_srli_epi64::<28>(v)), mask3); + v = _mm512_xor_si512(_mm512_xor_si512(v, t), _mm512_slli_epi64::<28>(t)); + + // Scatter to output at stride 16 - need to use scalar stores for now + // (AVX-512 scatter is available but complex for this pattern) + let mut result = [0u64; 8]; + _mm512_storeu_si512(result.as_mut_ptr().cast::<__m512i>(), v); + + for base_group in 0..8 { + let out_base = BASE_PATTERN_FIRST[base_group]; + for i in 0..8 { + output[out_base + i * 16] = (result[base_group] >> (i * 8)) as u8; + } + } + + // Second half + let in_second = _mm512_loadu_si512(input.as_ptr().add(64).cast::<__m512i>()); + let gathered = _mm512_permutexvar_epi8(idx, in_second); + + let mut v = gathered; + let t = _mm512_and_si512(_mm512_xor_si512(v, _mm512_srli_epi64::<7>(v)), mask1); + v = _mm512_xor_si512(_mm512_xor_si512(v, t), _mm512_slli_epi64::<7>(t)); + + let t = _mm512_and_si512(_mm512_xor_si512(v, _mm512_srli_epi64::<14>(v)), mask2); + v = _mm512_xor_si512(_mm512_xor_si512(v, t), _mm512_slli_epi64::<14>(t)); + + let t = _mm512_and_si512(_mm512_xor_si512(v, _mm512_srli_epi64::<28>(v)), mask3); + v = _mm512_xor_si512(_mm512_xor_si512(v, t), _mm512_slli_epi64::<28>(t)); + + _mm512_storeu_si512(result.as_mut_ptr().cast::<__m512i>(), v); + + for base_group in 0..8 { + let out_base = BASE_PATTERN_SECOND[base_group]; + for i in 0..8 { + output[out_base + i * 16] = (result[base_group] >> (i * 8)) as u8; + } + } +} + +#[cfg(test)] +mod tests { + use crate::bit_transpose::generate_test_data; + use crate::bit_transpose::transpose_bits_baseline; + use crate::bit_transpose::x86::has_bmi2; + use crate::bit_transpose::x86::has_vbmi; + use crate::bit_transpose::x86::transpose_bits_bmi2; + use crate::bit_transpose::x86::transpose_bits_vbmi; + use crate::bit_transpose::x86::untranspose_bits_bmi2; + use crate::bit_transpose::x86::untranspose_bits_vbmi; + + #[test] + fn test_bmi2_matches_baseline() { + if !has_bmi2() { + return; + } + + for seed in [0, 42, 123, 255] { + let input = generate_test_data(seed); + let mut baseline_out = [0u8; 128]; + let mut bmi2_out = [0u8; 128]; + + transpose_bits_baseline(&input, &mut baseline_out); + unsafe { transpose_bits_bmi2(&input, &mut bmi2_out) }; + + assert_eq!( + baseline_out, bmi2_out, + "BMI2 transpose doesn't match baseline for seed {seed}" + ); + } + } + + #[test] + fn test_bmi2_roundtrip() { + if !has_bmi2() { + return; + } + + for seed in [0, 42, 123, 255] { + let input = generate_test_data(seed); + let mut transposed = [0u8; 128]; + let mut roundtrip = [0u8; 128]; + + unsafe { + transpose_bits_bmi2(&input, &mut transposed); + untranspose_bits_bmi2(&transposed, &mut roundtrip); + } + + assert_eq!(input, roundtrip, "BMI2 roundtrip failed for seed {seed}"); + } + } + + #[test] + fn test_vbmi_matches_baseline() { + if !has_vbmi() { + return; + } + + for seed in [0, 42, 123, 255] { + let input = generate_test_data(seed); + let mut baseline_out = [0u8; 128]; + let mut vbmi_out = [0u8; 128]; + + transpose_bits_baseline(&input, &mut baseline_out); + unsafe { transpose_bits_vbmi(&input, &mut vbmi_out) }; + + assert_eq!( + baseline_out, vbmi_out, + "VBMI transpose doesn't match baseline for seed {seed}" + ); + } + } + + #[test] + fn test_vbmi_roundtrip() { + if !has_vbmi() { + return; + } + + for seed in [0, 42, 123, 255] { + let input = generate_test_data(seed); + let mut transposed = [0u8; 128]; + let mut roundtrip = [0u8; 128]; + + unsafe { + transpose_bits_vbmi(&input, &mut transposed); + untranspose_bits_vbmi(&transposed, &mut roundtrip); + } + + assert_eq!(input, roundtrip, "VBMI roundtrip failed for seed {seed}"); + } + } +} diff --git a/encodings/fastlanes/src/lib.rs b/encodings/fastlanes/src/lib.rs index f8625cd4dd2..b981175319a 100644 --- a/encodings/fastlanes/src/lib.rs +++ b/encodings/fastlanes/src/lib.rs @@ -8,6 +8,7 @@ pub use delta::*; pub use r#for::*; pub use rle::*; +pub mod bit_transpose; mod bitpacking; mod delta; mod r#for;