diff --git a/benches/resamplers.rs b/benches/resamplers.rs index caf3231..e41f205 100644 --- a/benches/resamplers.rs +++ b/benches/resamplers.rs @@ -40,26 +40,25 @@ mod bench_asyncro { let resample_ratio = 192000 as f64 / 44100 as f64; let interpolation_type = $ip; - let interpolator = $it::<$ft>::new( - sinc_len, - oversampling_factor, - f_cutoff, - window, - ); - let interpolator = unwrap_helper!($($unwrap)* interpolator); - let interpolator = Box::new(interpolator); - let mut resampler = Async::<$ft>::new_with_sinc_interpolator( - resample_ratio, - 1.1, - interpolation_type, - interpolator, - chunksize, - 1, - FixedAsync::Input, - ).unwrap(); - let buffer_in = InterleavedOwned::new(0.0, 1, chunksize); - let mut buffer_out = InterleavedOwned::new(0.0, 1, resampler.output_frames_max()); - c.bench_function($desc, |b| b.iter(|| resampler.process_into_buffer(black_box(&buffer_in), &mut buffer_out, None).unwrap())); + let mut group = c.benchmark_group($desc); + for (label, channels) in [("1ch", 1usize), ("2ch", 2), ("4ch", 4)] { + let interpolator = $it::<$ft>::new(sinc_len, oversampling_factor, f_cutoff, window); + let interpolator = unwrap_helper!($($unwrap)* interpolator); + let interpolator = Box::new(interpolator); + let mut resampler = Async::<$ft>::new_with_sinc_interpolator( + resample_ratio, + 1.1, + interpolation_type, + interpolator, + chunksize, + channels, + FixedAsync::Input, + ).unwrap(); + let buffer_in = InterleavedOwned::new(0.0, channels, chunksize); + let mut buffer_out = InterleavedOwned::new(0.0, channels, resampler.output_frames_max()); + group.bench_function(label, |b| b.iter(|| resampler.process_into_buffer(black_box(&buffer_in), &mut buffer_out, None).unwrap())); + } + group.finish(); } }; } @@ -266,24 +265,30 @@ mod bench_asyncro { let chunksize = 1024; let interpolation_type = $ip; let resample_ratio = 192000 as f64 / 44100 as f64; - let mut resampler = Async::<$ft>::new_poly( - resample_ratio, - 1.1, - interpolation_type, - chunksize, - 1, - FixedAsync::Input, - ) - .unwrap(); - let buffer_in = InterleavedOwned::new(0.0, 1, chunksize); - let mut buffer_out = InterleavedOwned::new(0.0, 1, resampler.output_frames_max()); - c.bench_function($desc, |b| { - b.iter(|| { - resampler - .process_into_buffer(black_box(&buffer_in), &mut buffer_out, None) - .unwrap() - }) - }); + + let mut group = c.benchmark_group($desc); + for (label, channels) in [("1ch", 1usize), ("2ch", 2), ("4ch", 4)] { + let mut resampler = Async::<$ft>::new_poly( + resample_ratio, + 1.1, + interpolation_type, + chunksize, + channels, + FixedAsync::Input, + ) + .unwrap(); + let buffer_in = InterleavedOwned::new(0.0, channels, chunksize); + let mut buffer_out = + InterleavedOwned::new(0.0, channels, resampler.output_frames_max()); + group.bench_function(label, |b| { + b.iter(|| { + resampler + .process_into_buffer(black_box(&buffer_in), &mut buffer_out, None) + .unwrap() + }) + }); + } + group.finish(); } }; } diff --git a/src/asynchro.rs b/src/asynchro.rs index 3559d0e..b1b9ae4 100644 --- a/src/asynchro.rs +++ b/src/asynchro.rs @@ -28,7 +28,7 @@ pub trait InnerResampler: Send { /// Make the scalar product between the waveform starting at `index` and the sinc of `subindex`. #[allow(clippy::too_many_arguments)] fn process( - &self, + &mut self, index: f64, nbr_frames: usize, channel_mask: &[bool], @@ -298,10 +298,7 @@ where let interpolator_len = interpolator.nbr_points(); - let inner_resampler = InnerSinc { - interpolator, - interpolation: interpolation_type, - }; + let inner_resampler = InnerSinc::new(interpolator, interpolation_type); let last_index = inner_resampler.init_last_index(); let needed_input_size = Self::calculate_input_size( @@ -511,7 +508,7 @@ where let mut idx = self.last_index; // Process - idx = self.inner_resampler.process( + idx = self.inner_resampler.as_mut().process( idx, self.needed_output_size, &self.channel_mask, @@ -901,4 +898,89 @@ mod tests { let resampler = Async::::new_sinc(1.0, 4.0, ¶ms, 1024, 2, fixed).unwrap(); check_relative_ratio_changes_frame_ratio(resampler); } + + /// Run a 1-channel and a 4-channel sinc resampler with identical per-channel input and + /// compare their outputs. The 1-channel resampler uses the direct path (separate dot + /// products per nearest point); the 4-channel resampler uses the combined-sinc path + /// (SIMD SAXPY build then one dot product per channel). They must agree within floating- + /// point rounding tolerance. + fn compare_1ch_4ch_sinc_output( + interpolation: SincInterpolationType, + ratio: f64, + fixed: FixedAsync, + ) { + let params = SincInterpolationParameters { + sinc_len: 64, + f_cutoff: 0.95, + interpolation, + oversampling_factor: 16, + window: WindowFunction::BlackmanHarris2, + }; + let chunk = 256; + + let mut r1 = Async::::new_sinc(ratio, 1.0, ¶ms, chunk, 1, fixed).unwrap(); + let mut r4 = Async::::new_sinc(ratio, 1.0, ¶ms, chunk, 4, fixed).unwrap(); + + let mut phase = 0.0f64; + for _ in 0..20 { + let frames_in = r1.input_frames_next(); + let frames_out = r1.output_frames_next(); + assert_eq!(frames_in, r4.input_frames_next()); + assert_eq!(frames_out, r4.output_frames_next()); + + let wave: Vec = (0..frames_in) + .map(|i| (phase + i as f64 * 0.1).sin()) + .collect(); + phase += frames_in as f64 * 0.1; + + let in1_data = vec![wave.clone()]; + let in4_data = vec![wave.clone(), wave.clone(), wave.clone(), wave.clone()]; + let input_1ch = SequentialSliceOfVecs::new(&in1_data, 1, frames_in).unwrap(); + let input_4ch = SequentialSliceOfVecs::new(&in4_data, 4, frames_in).unwrap(); + + let mut out1_data = vec![vec![0.0f64; frames_out]; 1]; + let mut out4_data = vec![vec![0.0f64; frames_out]; 4]; + let mut out1 = SequentialSliceOfVecs::new_mut(&mut out1_data, 1, frames_out).unwrap(); + let mut out4 = SequentialSliceOfVecs::new_mut(&mut out4_data, 4, frames_out).unwrap(); + + r1.process_into_buffer(&input_1ch, &mut out1, None).unwrap(); + r4.process_into_buffer(&input_4ch, &mut out4, None).unwrap(); + + for frame in 0..frames_out { + let expected = out1_data[0][frame]; + // All 4 channels must be exactly equal (identical input, same combined sinc). + for ch in 1..4 { + assert_eq!( + out4_data[ch][frame], out4_data[0][frame], + "interp={interpolation:?} ratio={ratio} frame={frame}: \ + ch{ch} differs from ch0 inside 4ch resampler" + ); + } + // 4ch output must agree with 1ch output within floating-point tolerance. + for (ch, ch_data) in out4_data.iter().enumerate() { + let diff = (ch_data[frame] - expected).abs(); + assert!( + diff < 1e-10, + "interp={interpolation:?} ratio={ratio} ch={ch} frame={frame}: \ + 4ch={} vs 1ch={expected} diff={diff}", + ch_data[frame] + ); + } + } + } + } + + #[test_log::test(test_matrix( + [ + SincInterpolationType::Cubic, + SincInterpolationType::Quadratic, + SincInterpolationType::Linear, + SincInterpolationType::Nearest, + ], + [0.8f64, 1.2f64], + [FixedAsync::Input, FixedAsync::Output] + ))] + fn sinc_4ch_matches_1ch(interp: SincInterpolationType, ratio: f64, fixed: FixedAsync) { + compare_1ch_4ch_sinc_output(interp, ratio, fixed); + } } diff --git a/src/asynchro_fast.rs b/src/asynchro_fast.rs index abcb50f..be14951 100644 --- a/src/asynchro_fast.rs +++ b/src/asynchro_fast.rs @@ -146,7 +146,7 @@ where T: Sample, { fn process( - &self, + &mut self, idx: f64, nbr_frames: usize, channel_mask: &[bool], diff --git a/src/asynchro_sinc.rs b/src/asynchro_sinc.rs index 38fc424..b1da153 100644 --- a/src/asynchro_sinc.rs +++ b/src/asynchro_sinc.rs @@ -127,6 +127,7 @@ where /// Perform cubic polynomial interpolation to get value at x. /// Input points are assumed to be at x = -1, 0, 1, 2. +#[allow(dead_code)] pub fn interp_cubic(x: T, yvals: &[T; 4]) -> T where T: Sample, @@ -140,8 +141,26 @@ where a0 + a1 * x + a2 * x2 + a3 * x3 } +/// Compute the four blending weights for cubic interpolation at fractional position x. +/// These are the per-point coefficients such that interp_cubic(x, pts) == dot(weights, pts). +/// Input points are assumed to be at x = -1, 0, 1, 2. +pub fn interp_cubic_weights(x: T) -> [T; 4] +where + T: Sample, +{ + let x2 = x * x; + let x3 = x2 * x; + [ + t!(-1.0 / 3.0) * x + t!(0.5) * x2 - t!(1.0 / 6.0) * x3, + t!(1.0) - t!(0.5) * x - x2 + t!(0.5) * x3, + x + t!(0.5) * x2 - t!(0.5) * x3, + -t!(1.0 / 6.0) * x + t!(1.0 / 6.0) * x3, + ] +} + /// Perform quadratic polynomial interpolation to get value at x. /// Input points are assumed to be at x = 0, 1, 2. +#[allow(dead_code)] pub fn interp_quad(x: T, yvals: &[T; 3]) -> T where T: Sample, @@ -153,7 +172,23 @@ where t!(0.5) * (a0 + a1 * x + a2 * x2) } +/// Compute the three blending weights for quadratic interpolation at fractional position x. +/// These are the per-point coefficients such that interp_quad(x, pts) == dot(weights, pts). +/// Input points are assumed to be at x = 0, 1, 2. +pub fn interp_quad_weights(x: T) -> [T; 3] +where + T: Sample, +{ + let x2 = x * x; + [ + t!(0.5) * (t!(2.0) - t!(3.0) * x + x2), + t!(0.5) * (t!(4.0) * x - t!(2.0) * x2), + t!(0.5) * (x2 - x), + ] +} + /// Perform linear interpolation between two points at x=0 and x=1. +#[allow(dead_code)] pub fn interp_lin(x: T, yvals: &[T; 2]) -> T where T: Sample, @@ -161,9 +196,35 @@ where yvals[0] + x * (yvals[1] - yvals[0]) } +/// Compute the two blending weights for linear interpolation at fractional position x. +/// These are the per-point coefficients such that interp_lin(x, pts) == dot(weights, pts). +pub fn interp_lin_weights(x: T) -> [T; 2] +where + T: Sample, +{ + [t!(1.0) - x, x] +} + pub(crate) struct InnerSinc { pub interpolator: Box>, pub interpolation: SincInterpolationType, + // Pre-allocated buffer for the combined sinc (used by the >2 channel path). + // Length is interpolator.nbr_points() + 1. + combined: Vec, +} + +impl InnerSinc { + pub(crate) fn new( + interpolator: Box>, + interpolation: SincInterpolationType, + ) -> Self { + let len = interpolator.nbr_points() + 1; + Self { + interpolator, + interpolation, + combined: vec![T::zero(); len], + } + } } impl InnerResampler for InnerSinc @@ -171,7 +232,7 @@ where T: Sample, { fn process( - &self, + &mut self, idx: f64, nbr_frames: usize, channel_mask: &[bool], @@ -184,10 +245,13 @@ where let mut t_ratio = t_ratio; let mut idx = idx; let interpolator_len = self.interpolator.nbr_points(); + let active_count = channel_mask.iter().filter(|&&a| a).count(); match self.interpolation { SincInterpolationType::Cubic => { + // Cubic has 4 nearest points so the build cost breaks even at 2 channels. + let use_combined = active_count >= 2; let oversampling_factor = self.interpolator.nbr_sincs(); - let mut points = [T::zero(); 4]; + let sincs = self.interpolator.get_sincs(); let mut nearest = [(0isize, 0isize); 4]; for frame in 0..nbr_frames { t_ratio += t_ratio_increment; @@ -196,28 +260,67 @@ where let frac = idx * oversampling_factor as f64 - (idx * oversampling_factor as f64).floor(); let frac_offset = t!(frac); - for (chan, active) in channel_mask.iter().enumerate() { - if *active { - let buf = &wave_in[chan]; - for (n, p) in nearest.iter().zip(points.iter_mut()) { - *p = self.interpolator.get_sinc_interpolated( + let weights = interp_cubic_weights(frac_offset); + if use_combined { + let min_idx = self.interpolator.make_combined_sinc( + &nearest, + &weights, + &mut self.combined, + ); + let base = (min_idx + 2 * interpolator_len as isize) as usize; + for (chan, active) in channel_mask.iter().enumerate() { + if *active { + let buf = &wave_in[chan]; + let result = self.interpolator.get_sinc_dot_product( buf, - (n.0 + 2 * interpolator_len as isize) as usize, - n.1 as usize, - ); + base, + &self.combined[..interpolator_len], + ) + self.combined[interpolator_len] + * buf[base + interpolator_len]; + wave_out.write_sample(chan, frame + output_offset, &result); + } + } + } else { + let bases = nearest.map(|n| (n.0 + 2 * interpolator_len as isize) as usize); + let subs = nearest.map(|n| n.1 as usize); + for (chan, active) in channel_mask.iter().enumerate() { + if *active { + let buf = &wave_in[chan]; + let result = weights[0] + * self.interpolator.get_sinc_dot_product( + buf, + bases[0], + &sincs[subs[0]], + ) + + weights[1] + * self.interpolator.get_sinc_dot_product( + buf, + bases[1], + &sincs[subs[1]], + ) + + weights[2] + * self.interpolator.get_sinc_dot_product( + buf, + bases[2], + &sincs[subs[2]], + ) + + weights[3] + * self.interpolator.get_sinc_dot_product( + buf, + bases[3], + &sincs[subs[3]], + ); + wave_out.write_sample(chan, frame + output_offset, &result); } - wave_out.write_sample( - chan, - frame + output_offset, - &interp_cubic(frac_offset, &points), - ); } } } } SincInterpolationType::Quadratic => { + // Quadratic has 3 nearest points; combined sinc pays off from 3 channels. + let use_combined = active_count > 2; let oversampling_factor = self.interpolator.nbr_sincs(); - let mut points = [T::zero(); 3]; + let sincs = self.interpolator.get_sincs(); let mut nearest = [(0isize, 0isize); 3]; for frame in 0..nbr_frames { t_ratio += t_ratio_increment; @@ -226,28 +329,61 @@ where let frac = idx * oversampling_factor as f64 - (idx * oversampling_factor as f64).floor(); let frac_offset = t!(frac); - for (chan, active) in channel_mask.iter().enumerate() { - if *active { - let buf = &wave_in[chan]; - for (n, p) in nearest.iter().zip(points.iter_mut()) { - *p = self.interpolator.get_sinc_interpolated( + let weights = interp_quad_weights(frac_offset); + if use_combined { + let min_idx = self.interpolator.make_combined_sinc( + &nearest, + &weights, + &mut self.combined, + ); + let base = (min_idx + 2 * interpolator_len as isize) as usize; + for (chan, active) in channel_mask.iter().enumerate() { + if *active { + let buf = &wave_in[chan]; + let result = self.interpolator.get_sinc_dot_product( buf, - (n.0 + 2 * interpolator_len as isize) as usize, - n.1 as usize, - ); + base, + &self.combined[..interpolator_len], + ) + self.combined[interpolator_len] + * buf[base + interpolator_len]; + wave_out.write_sample(chan, frame + output_offset, &result); + } + } + } else { + let bases = nearest.map(|n| (n.0 + 2 * interpolator_len as isize) as usize); + let subs = nearest.map(|n| n.1 as usize); + for (chan, active) in channel_mask.iter().enumerate() { + if *active { + let buf = &wave_in[chan]; + let result = weights[0] + * self.interpolator.get_sinc_dot_product( + buf, + bases[0], + &sincs[subs[0]], + ) + + weights[1] + * self.interpolator.get_sinc_dot_product( + buf, + bases[1], + &sincs[subs[1]], + ) + + weights[2] + * self.interpolator.get_sinc_dot_product( + buf, + bases[2], + &sincs[subs[2]], + ); + wave_out.write_sample(chan, frame + output_offset, &result); } - wave_out.write_sample( - chan, - frame + output_offset, - &interp_quad(frac_offset, &points), - ); } } } } SincInterpolationType::Linear => { + // Linear has 2 nearest points; combined sinc pays off from 3 channels. + let use_combined = active_count > 2; let oversampling_factor = self.interpolator.nbr_sincs(); - let mut points = [T::zero(); 2]; + let sincs = self.interpolator.get_sincs(); let mut nearest = [(0isize, 0isize); 2]; for frame in 0..nbr_frames { t_ratio += t_ratio_increment; @@ -256,21 +392,46 @@ where let frac = idx * oversampling_factor as f64 - (idx * oversampling_factor as f64).floor(); let frac_offset = t!(frac); - for (chan, active) in channel_mask.iter().enumerate() { - if *active { - let buf = &wave_in[chan]; - for (n, p) in nearest.iter().zip(points.iter_mut()) { - *p = self.interpolator.get_sinc_interpolated( + let weights = interp_lin_weights(frac_offset); + if use_combined { + let min_idx = self.interpolator.make_combined_sinc( + &nearest, + &weights, + &mut self.combined, + ); + let base = (min_idx + 2 * interpolator_len as isize) as usize; + for (chan, active) in channel_mask.iter().enumerate() { + if *active { + let buf = &wave_in[chan]; + let result = self.interpolator.get_sinc_dot_product( buf, - (n.0 + 2 * interpolator_len as isize) as usize, - n.1 as usize, - ); + base, + &self.combined[..interpolator_len], + ) + self.combined[interpolator_len] + * buf[base + interpolator_len]; + wave_out.write_sample(chan, frame + output_offset, &result); + } + } + } else { + let bases = nearest.map(|n| (n.0 + 2 * interpolator_len as isize) as usize); + let subs = nearest.map(|n| n.1 as usize); + for (chan, active) in channel_mask.iter().enumerate() { + if *active { + let buf = &wave_in[chan]; + let result = weights[0] + * self.interpolator.get_sinc_dot_product( + buf, + bases[0], + &sincs[subs[0]], + ) + + weights[1] + * self.interpolator.get_sinc_dot_product( + buf, + bases[1], + &sincs[subs[1]], + ); + wave_out.write_sample(chan, frame + output_offset, &result); } - wave_out.write_sample( - chan, - frame + output_offset, - &interp_lin(frac_offset, &points), - ); } } } @@ -311,7 +472,10 @@ where #[cfg(test)] mod tests { - use super::{interp_cubic, interp_lin}; + use super::{ + interp_cubic, interp_cubic_weights, interp_lin, interp_lin_weights, interp_quad, + interp_quad_weights, + }; #[test] fn int_cubic() { @@ -334,10 +498,42 @@ mod tests { assert_eq!(interp, 3.0f32); } + /// Verify that interp_cubic_weights produces the same result as interp_cubic. #[test] - fn int_lin() { - let yvals = [1.0f64, 5.0f64]; - let interp = interp_lin(0.25f64, &yvals); - assert_eq!(interp, 2.0f64); + fn cubic_weights_match_cubic() { + let yvals = [1.3f64, -0.7f64, 2.1f64, 0.4f64]; + for x_int in 0..=10 { + let x = x_int as f64 / 10.0; + let direct = interp_cubic(x, &yvals); + let w = interp_cubic_weights(x); + let blended = w[0] * yvals[0] + w[1] * yvals[1] + w[2] * yvals[2] + w[3] * yvals[3]; + assert!((direct - blended).abs() < 1.0e-12, "mismatch at x={x}"); + } + } + + /// Verify that interp_quad_weights produces the same result as interp_quad. + #[test] + fn quad_weights_match_quad() { + let yvals = [1.3f64, -0.7f64, 2.1f64]; + for x_int in 0..=10 { + let x = x_int as f64 / 10.0; + let direct = interp_quad(x, &yvals); + let w = interp_quad_weights(x); + let blended = w[0] * yvals[0] + w[1] * yvals[1] + w[2] * yvals[2]; + assert!((direct - blended).abs() < 1.0e-12, "mismatch at x={x}"); + } + } + + /// Verify that interp_lin_weights produces the same result as interp_lin. + #[test] + fn lin_weights_match_lin() { + let yvals = [1.3f64, -0.7f64]; + for x_int in 0..=10 { + let x = x_int as f64 / 10.0; + let direct = interp_lin(x, &yvals); + let w = interp_lin_weights(x); + let blended = w[0] * yvals[0] + w[1] * yvals[1]; + assert!((direct - blended).abs() < 1.0e-12, "mismatch at x={x}"); + } } } diff --git a/src/sinc_interpolator/mod.rs b/src/sinc_interpolator/mod.rs index 813a05d..d5aa363 100644 --- a/src/sinc_interpolator/mod.rs +++ b/src/sinc_interpolator/mod.rs @@ -51,14 +51,63 @@ interpolator! { /// Functions for making the scalar product with a sinc. #[cfg_attr(feature = "bench_asyncro", visibility::make(pub))] pub(crate) trait SincInterpolator: Send { - /// Make the scalar product between the waveform starting at `index` and the sinc of `subindex`. - fn get_sinc_interpolated(&self, wave: &[T], index: usize, subindex: usize) -> T; + /// Compute the dot product of the wave starting at `index` with the provided sinc slice. + fn get_sinc_dot_product(&self, wave: &[T], index: usize, sinc: &[T]) -> T; + + /// Expose the raw sinc table for use in combined-sinc building. + fn get_sincs(&self) -> &[Vec]; /// Get sinc length. fn nbr_points(&self) -> usize; /// Get number of sincs used for oversampling. fn nbr_sincs(&self) -> usize; + + /// Make the scalar product between the waveform starting at `index` and the sinc of `subindex`. + fn get_sinc_interpolated(&self, wave: &[T], index: usize, subindex: usize) -> T { + assert!( + (index + self.nbr_points()) < wave.len(), + "Tried to interpolate for index {}, max for the given input is {}", + index, + wave.len() - self.nbr_points() - 1 + ); + assert!( + subindex < self.nbr_sincs(), + "Tried to use sinc subindex {}, max is {}", + subindex, + self.nbr_sincs() - 1 + ); + self.get_sinc_dot_product(wave, index, &self.get_sincs()[subindex]) + } + + /// Build a combined sinc by blending the sincs indicated by `nearest` with `weights`. + /// + /// The 4 (or fewer) nearest points may span two consecutive integer indices, so `combined` + /// must have length `nbr_points() + 1`. The extra element at position `nbr_points()` holds + /// the contribution of any points at the higher index, which the caller adds as a scalar + /// multiply after the main SIMD dot-product loop. + /// + /// Returns the minimum integer index found in `nearest`, used by the caller to compute the + /// base buffer offset. + fn make_combined_sinc( + &self, + nearest: &[(isize, isize)], + weights: &[T], + combined: &mut [T], + ) -> isize + where + T: Sample, + { + let min_idx = nearest.iter().map(|n| n.0).min().unwrap(); + combined.iter_mut().for_each(|x| *x = T::zero()); + for (n, &w) in nearest.iter().zip(weights.iter()) { + let shift = (n.0 - min_idx) as usize; + for (k, &s) in self.get_sincs()[n.1 as usize].iter().enumerate() { + combined[shift + k] += w * s; + } + } + min_idx + } } /// A plain scalar interpolator. @@ -73,22 +122,8 @@ impl SincInterpolator for ScalarInterpolator where T: Sample, { - /// Calculate the scalar produt of an input wave and the selected sinc filter. - fn get_sinc_interpolated(&self, wave: &[T], index: usize, subindex: usize) -> T { - assert!( - (index + self.length) < wave.len(), - "Tried to interpolate for index {}, max for the given input is {}", - index, - wave.len() - self.length - 1 - ); - assert!( - subindex < self.nbr_sincs, - "Tried to use sinc subindex {}, max is {}", - subindex, - self.nbr_sincs - 1 - ); - let wave_cut = &wave[index..(index + self.sincs[subindex].len())]; - let sinc = &self.sincs[subindex]; + fn get_sinc_dot_product(&self, wave: &[T], index: usize, sinc: &[T]) -> T { + let wave_cut = &wave[index..(index + sinc.len())]; unsafe { let mut acc0 = T::zero(); let mut acc1 = T::zero(); @@ -114,6 +149,10 @@ where } } + fn get_sincs(&self) -> &[Vec] { + &self.sincs + } + fn nbr_points(&self) -> usize { self.length } diff --git a/src/sinc_interpolator/sinc_interpolator_avx.rs b/src/sinc_interpolator/sinc_interpolator_avx.rs index f2265ba..0a5d50c 100644 --- a/src/sinc_interpolator/sinc_interpolator_avx.rs +++ b/src/sinc_interpolator/sinc_interpolator_avx.rs @@ -8,11 +8,12 @@ use core::arch::x86_64::{ _mm256_extractf128_ps, }; use core::arch::x86_64::{ - _mm256_add_pd, _mm256_fmadd_pd, _mm256_loadu_pd, _mm256_setzero_pd, _mm_add_pd, _mm_hadd_pd, - _mm_store_sd, + _mm256_add_pd, _mm256_fmadd_pd, _mm256_loadu_pd, _mm256_set1_pd, _mm256_setzero_pd, + _mm256_storeu_pd, _mm_add_pd, _mm_hadd_pd, _mm_store_sd, }; use core::arch::x86_64::{ - _mm256_fmadd_ps, _mm256_loadu_ps, _mm256_setzero_ps, _mm_add_ps, _mm_hadd_ps, _mm_store_ss, + _mm256_fmadd_ps, _mm256_loadu_ps, _mm256_set1_ps, _mm256_setzero_ps, _mm256_storeu_ps, + _mm_add_ps, _mm_hadd_ps, _mm_store_ss, }; /// Collection of CPU features required for this interpolator. @@ -20,64 +21,59 @@ static FEATURES: &[CpuFeature] = &[CpuFeature::Avx, CpuFeature::Fma]; /// Trait governing what can be done with an AvxSample. pub trait AvxSample: Sized + Send { - type Sinc: Send; - - /// Pack sincs into a vector. - /// - /// # Safety - /// - /// This is unsafe because it uses target_enable dispatching. There are no - /// special requirements from the caller. - unsafe fn pack_sincs(sincs: Vec>) -> Vec>; - - /// Interpolate a sinc sample. + /// Compute the dot product of `wave[index..]` with `sinc` using AVX instructions. /// /// # Safety /// - /// The caller must ensure that the various indexes are not out of bounds - /// in the collection of sincs. - unsafe fn get_sinc_interpolated_unsafe( + /// The caller must ensure that `wave[index..index+length]` and `sinc[..length]` are + /// valid, and that `length` is a multiple of 8. + unsafe fn get_sinc_dot_product_unsafe( wave: &[Self], index: usize, - subindex: usize, - sincs: &[Vec], + sinc: &[Self], length: usize, ) -> Self; + + /// Compute `out[..length] += scale * input[..length]` using AVX instructions. + /// + /// # Safety + /// + /// The caller must ensure that `out[..length]` and `input[..length]` are valid, + /// and that `length` is a multiple of 8. + unsafe fn saxpy_unsafe(out: &mut [Self], scale: Self, input: &[Self], length: usize); } impl AvxSample for f32 { - type Sinc = __m256; - #[target_feature(enable = "avx", enable = "fma")] - unsafe fn pack_sincs(sincs: Vec>) -> Vec> { - let mut packed_sincs = Vec::new(); - for sinc in sincs.iter() { - let mut packed = Vec::new(); - for elements in sinc.chunks(8) { - let packed_elems = _mm256_loadu_ps(&elements[0]); - packed.push(packed_elems); - } - packed_sincs.push(packed); + unsafe fn saxpy_unsafe(out: &mut [f32], scale: f32, input: &[f32], length: usize) { + let scale_vec = _mm256_set1_ps(scale); + let mut idx = 0; + for _ in 0..length / 8 { + let x = _mm256_loadu_ps(input.get_unchecked(idx)); + let y = _mm256_loadu_ps(out.get_unchecked(idx)); + _mm256_storeu_ps( + out.get_unchecked_mut(idx) as *mut f32, + _mm256_fmadd_ps(scale_vec, x, y), + ); + idx += 8; } - packed_sincs } #[target_feature(enable = "avx", enable = "fma")] - unsafe fn get_sinc_interpolated_unsafe( + unsafe fn get_sinc_dot_product_unsafe( wave: &[f32], index: usize, - subindex: usize, - sincs: &[Vec], + sinc: &[f32], length: usize, ) -> f32 { - let sinc = sincs.get_unchecked(subindex); let wave_cut = &wave[index..(index + length)]; let mut acc = _mm256_setzero_ps(); - let mut w_idx = 0; - for s_idx in 0..length / 8 { - let w = _mm256_loadu_ps(wave_cut.get_unchecked(w_idx)); - acc = _mm256_fmadd_ps(w, *sinc.get_unchecked(s_idx), acc); - w_idx += 8; + let mut idx = 0; + for _ in 0..length / 8 { + let w = _mm256_loadu_ps(wave_cut.get_unchecked(idx)); + let s = _mm256_loadu_ps(sinc.get_unchecked(idx)); + acc = _mm256_fmadd_ps(w, s, acc); + idx += 8; } let acc_high = _mm256_extractf128_ps(acc, 1); let acc_low = _mm_add_ps(acc_high, _mm256_castps256_ps128(acc)); @@ -90,43 +86,40 @@ impl AvxSample for f32 { } impl AvxSample for f64 { - type Sinc = __m256d; - #[target_feature(enable = "avx", enable = "fma")] - unsafe fn pack_sincs(sincs: Vec>) -> Vec> { - let mut packed_sincs = Vec::new(); - for sinc in sincs.iter() { - let mut packed = Vec::new(); - for elements in sinc.chunks(4) { - let packed_elems = _mm256_loadu_pd(&elements[0]); - packed.push(packed_elems); - } - packed_sincs.push(packed); + unsafe fn saxpy_unsafe(out: &mut [f64], scale: f64, input: &[f64], length: usize) { + let scale_vec = _mm256_set1_pd(scale); + let mut idx = 0; + for _ in 0..length / 4 { + let x = _mm256_loadu_pd(input.get_unchecked(idx)); + let y = _mm256_loadu_pd(out.get_unchecked(idx)); + _mm256_storeu_pd( + out.get_unchecked_mut(idx) as *mut f64, + _mm256_fmadd_pd(scale_vec, x, y), + ); + idx += 4; } - packed_sincs } #[target_feature(enable = "avx", enable = "fma")] - unsafe fn get_sinc_interpolated_unsafe( + unsafe fn get_sinc_dot_product_unsafe( wave: &[f64], index: usize, - subindex: usize, - sincs: &[Vec], + sinc: &[f64], length: usize, ) -> f64 { - let sinc = sincs.get_unchecked(subindex); let wave_cut = &wave[index..(index + length)]; let mut acc0 = _mm256_setzero_pd(); let mut acc1 = _mm256_setzero_pd(); - let mut w_idx = 0; - let mut s_idx = 0; - for _ in 0..wave_cut.len() / 8 { - let w0 = _mm256_loadu_pd(wave_cut.get_unchecked(w_idx)); - let w1 = _mm256_loadu_pd(wave_cut.get_unchecked(w_idx + 4)); - acc0 = _mm256_fmadd_pd(w0, *sinc.get_unchecked(s_idx), acc0); - acc1 = _mm256_fmadd_pd(w1, *sinc.get_unchecked(s_idx + 1), acc1); - w_idx += 8; - s_idx += 2; + let mut idx = 0; + for _ in 0..length / 8 { + let w0 = _mm256_loadu_pd(wave_cut.get_unchecked(idx)); + let w1 = _mm256_loadu_pd(wave_cut.get_unchecked(idx + 4)); + let s0 = _mm256_loadu_pd(sinc.get_unchecked(idx)); + let s1 = _mm256_loadu_pd(sinc.get_unchecked(idx + 4)); + acc0 = _mm256_fmadd_pd(w0, s0, acc0); + acc1 = _mm256_fmadd_pd(w1, s1, acc1); + idx += 8; } let acc_all = _mm256_add_pd(acc0, acc1); let acc_high = _mm256_extractf128_pd(acc_all, 1); @@ -144,7 +137,7 @@ pub(crate) struct AvxInterpolator where T: AvxSample, { - sincs: Vec>, + sincs: Vec>, length: usize, nbr_sincs: usize, } @@ -153,21 +146,12 @@ impl SincInterpolator for AvxInterpolator where T: AvxSample, { - /// Calculate the scalar produt of an input wave and the selected sinc filter. - fn get_sinc_interpolated(&self, wave: &[T], index: usize, subindex: usize) -> T { - assert!( - (index + self.length) < wave.len(), - "Tried to interpolate for index {}, max for the given input is {}", - index, - wave.len() - self.length - 1 - ); - assert!( - subindex < self.nbr_sincs, - "Tried to use sinc subindex {}, max is {}", - subindex, - self.nbr_sincs - 1 - ); - unsafe { T::get_sinc_interpolated_unsafe(wave, index, subindex, &self.sincs, self.length) } + fn get_sinc_dot_product(&self, wave: &[T], index: usize, sinc: &[T]) -> T { + unsafe { T::get_sinc_dot_product_unsafe(wave, index, sinc, self.length) } + } + + fn get_sincs(&self) -> &[Vec] { + &self.sincs } fn nbr_points(&self) -> usize { @@ -177,11 +161,36 @@ where fn nbr_sincs(&self) -> usize { self.nbr_sincs } + + fn make_combined_sinc( + &self, + nearest: &[(isize, isize)], + weights: &[T], + combined: &mut [T], + ) -> isize + where + T: crate::Sample, + { + let min_idx = nearest.iter().map(|n| n.0).min().unwrap(); + combined.iter_mut().for_each(|x| *x = T::zero()); + for (n, &w) in nearest.iter().zip(weights.iter()) { + let shift = (n.0 - min_idx) as usize; + unsafe { + ::saxpy_unsafe( + &mut combined[shift..shift + self.length], + w, + &self.sincs[n.1 as usize], + self.length, + ); + } + } + min_idx + } } impl AvxInterpolator where - T: Sample, + T: AvxSample + Sample, { /// Create a new AvxInterpolator. /// @@ -202,7 +211,6 @@ where assert!(sinc_len % 8 == 0, "Sinc length must be a multiple of 8."); let sincs = make_sincs(sinc_len, oversampling_factor, f_cutoff, window); - let sincs = unsafe { ::pack_sincs(sincs) }; Ok(Self { sincs, @@ -212,6 +220,14 @@ where } } +// Suppress dead_code warning for __m256/__m256d: they are used only in the +// target_feature-gated functions above and the compiler can't see through that. +#[allow(dead_code)] +const _: () = { + let _ = core::mem::size_of::<__m256>(); + let _ = core::mem::size_of::<__m256d>(); +}; + #[cfg(test)] mod tests { use crate::sinc::make_sincs; @@ -278,6 +294,6 @@ mod tests { let value = interpolator.get_sinc_interpolated(&wave, 333, 123); let check = get_sinc_interpolated(&wave, 333, &sincs[123]); - assert!((value - check).abs() < 1.0e-5); + assert!((value - check).abs() < 1.0e-6); } } diff --git a/src/sinc_interpolator/sinc_interpolator_neon.rs b/src/sinc_interpolator/sinc_interpolator_neon.rs index 5b37065..01036d2 100644 --- a/src/sinc_interpolator/sinc_interpolator_neon.rs +++ b/src/sinc_interpolator/sinc_interpolator_neon.rs @@ -5,7 +5,8 @@ use crate::windows::WindowFunction; use crate::Sample; use core::arch::aarch64::{float32x4_t, float64x2_t}; use core::arch::aarch64::{ - vadd_f32, vaddq_f32, vfmaq_f32, vget_high_f32, vget_low_f32, vld1q_f32, vmovq_n_f32, vst1_f32, + vadd_f32, vaddq_f32, vfmaq_f32, vget_high_f32, vget_low_f32, vld1q_f32, vmovq_n_f32, + vst1_f32, vst1q_f32, }; use core::arch::aarch64::{vaddq_f64, vfmaq_f64, vld1q_f64, vmovq_n_f64, vst1q_f64}; @@ -14,69 +15,60 @@ static FEATURES: &[CpuFeature] = &[CpuFeature::Neon]; /// Trait governing what can be done with an NeonSample. pub trait NeonSample: Sized + Send { - type Sinc: Send; - - /// Pack sincs into a vector. - /// - /// # Safety - /// - /// This is unsafe because it uses target_enable dispatching. There are no - /// special requirements from the caller. - unsafe fn pack_sincs(sincs: Vec>) -> Vec>; - - /// Interpolate a sinc sample. + /// Compute the dot product of `wave[index..]` with `sinc` using NEON instructions. /// /// # Safety /// - /// The caller must ensure that the various indexes are not out of bounds - /// in the collection of sincs. - unsafe fn get_sinc_interpolated_unsafe( + /// The caller must ensure that `wave[index..index+length]` and `sinc[..length]` are + /// valid, and that `length` is a multiple of 8. + unsafe fn get_sinc_dot_product_unsafe( wave: &[Self], index: usize, - subindex: usize, - sincs: &[Vec], + sinc: &[Self], length: usize, ) -> Self; + + /// Compute `out[..length] += scale * input[..length]` using NEON instructions. + /// + /// # Safety + /// + /// The caller must ensure that `out[..length]` and `input[..length]` are valid, + /// and that `length` is a multiple of 8. + unsafe fn saxpy_unsafe(out: &mut [Self], scale: Self, input: &[Self], length: usize); } impl NeonSample for f32 { - type Sinc = float32x4_t; - #[target_feature(enable = "neon")] - unsafe fn pack_sincs(sincs: Vec>) -> Vec> { - let mut packed_sincs = Vec::new(); - for sinc in sincs.iter() { - let mut packed = Vec::new(); - for elements in sinc.chunks(4) { - let packed_elems = vld1q_f32(&elements[0]); - packed.push(packed_elems); - } - packed_sincs.push(packed); + unsafe fn saxpy_unsafe(out: &mut [f32], scale: f32, input: &[f32], length: usize) { + let scale_vec = vmovq_n_f32(scale); + let mut idx = 0; + for _ in 0..length / 4 { + let x = vld1q_f32(input.get_unchecked(idx)); + let y = vld1q_f32(out.get_unchecked(idx)); + vst1q_f32(out.get_unchecked_mut(idx) as *mut f32, vfmaq_f32(y, scale_vec, x)); + idx += 4; } - packed_sincs } #[target_feature(enable = "neon")] - unsafe fn get_sinc_interpolated_unsafe( + unsafe fn get_sinc_dot_product_unsafe( wave: &[f32], index: usize, - subindex: usize, - sincs: &[Vec], + sinc: &[f32], length: usize, ) -> f32 { - let sinc = sincs.get_unchecked(subindex); let wave_cut = &wave[index..(index + length)]; let mut acc0 = vmovq_n_f32(0.0); let mut acc1 = vmovq_n_f32(0.0); - let mut w_idx = 0; - let mut s_idx = 0; - for _ in 0..wave_cut.len() / 8 { - let w0 = vld1q_f32(wave_cut.get_unchecked(w_idx)); - let w1 = vld1q_f32(wave_cut.get_unchecked(w_idx + 4)); - acc0 = vfmaq_f32(acc0, w0, *sinc.get_unchecked(s_idx)); - acc1 = vfmaq_f32(acc1, w1, *sinc.get_unchecked(s_idx + 1)); - w_idx += 8; - s_idx += 2; + let mut idx = 0; + for _ in 0..length / 8 { + let w0 = vld1q_f32(wave_cut.get_unchecked(idx)); + let w1 = vld1q_f32(wave_cut.get_unchecked(idx + 4)); + let s0 = vld1q_f32(sinc.get_unchecked(idx)); + let s1 = vld1q_f32(sinc.get_unchecked(idx + 4)); + acc0 = vfmaq_f32(acc0, w0, s0); + acc1 = vfmaq_f32(acc1, w1, s1); + idx += 8; } let sum4 = vaddq_f32(acc0, acc1); let high = vget_high_f32(sum4); @@ -89,49 +81,37 @@ impl NeonSample for f32 { } impl NeonSample for f64 { - type Sinc = float64x2_t; - #[target_feature(enable = "neon")] - unsafe fn pack_sincs(sincs: Vec>) -> Vec> { - let mut packed_sincs = Vec::new(); - for sinc in sincs.iter() { - let mut packed = Vec::new(); - for elements in sinc.chunks(2) { - let packed_elems = vld1q_f64(&elements[0]); - packed.push(packed_elems); - } - packed_sincs.push(packed); + unsafe fn saxpy_unsafe(out: &mut [f64], scale: f64, input: &[f64], length: usize) { + let scale_vec = vmovq_n_f64(scale); + let mut idx = 0; + for _ in 0..length / 2 { + let x = vld1q_f64(input.get_unchecked(idx)); + let y = vld1q_f64(out.get_unchecked(idx)); + vst1q_f64(out.get_unchecked_mut(idx) as *mut f64, vfmaq_f64(y, scale_vec, x)); + idx += 2; } - packed_sincs } #[target_feature(enable = "neon")] - unsafe fn get_sinc_interpolated_unsafe( + unsafe fn get_sinc_dot_product_unsafe( wave: &[f64], index: usize, - subindex: usize, - sincs: &[Vec], + sinc: &[f64], length: usize, ) -> f64 { - let sinc = sincs.get_unchecked(subindex); let wave_cut = &wave[index..(index + length)]; let mut acc0 = vmovq_n_f64(0.0); let mut acc1 = vmovq_n_f64(0.0); let mut acc2 = vmovq_n_f64(0.0); let mut acc3 = vmovq_n_f64(0.0); - let mut w_idx = 0; - let mut s_idx = 0; - for _ in 0..wave_cut.len() / 8 { - let w0 = vld1q_f64(wave_cut.get_unchecked(w_idx)); - let w1 = vld1q_f64(wave_cut.get_unchecked(w_idx + 2)); - let w2 = vld1q_f64(wave_cut.get_unchecked(w_idx + 4)); - let w3 = vld1q_f64(wave_cut.get_unchecked(w_idx + 6)); - acc0 = vfmaq_f64(acc0, w0, *sinc.get_unchecked(s_idx)); - acc1 = vfmaq_f64(acc1, w1, *sinc.get_unchecked(s_idx + 1)); - acc2 = vfmaq_f64(acc2, w2, *sinc.get_unchecked(s_idx + 2)); - acc3 = vfmaq_f64(acc3, w3, *sinc.get_unchecked(s_idx + 3)); - w_idx += 8; - s_idx += 4; + let mut idx = 0; + for _ in 0..length / 8 { + acc0 = vfmaq_f64(acc0, vld1q_f64(wave_cut.get_unchecked(idx)), vld1q_f64(sinc.get_unchecked(idx))); + acc1 = vfmaq_f64(acc1, vld1q_f64(wave_cut.get_unchecked(idx + 2)), vld1q_f64(sinc.get_unchecked(idx + 2))); + acc2 = vfmaq_f64(acc2, vld1q_f64(wave_cut.get_unchecked(idx + 4)), vld1q_f64(sinc.get_unchecked(idx + 4))); + acc3 = vfmaq_f64(acc3, vld1q_f64(wave_cut.get_unchecked(idx + 6)), vld1q_f64(sinc.get_unchecked(idx + 6))); + idx += 8; } let packedsum0 = vaddq_f64(acc0, acc1); let packedsum1 = vaddq_f64(acc2, acc3); @@ -142,36 +122,27 @@ impl NeonSample for f64 { } } -/// A SSE accelerated interpolator. +/// A NEON accelerated interpolator. #[cfg_attr(feature = "bench_asyncro", visibility::make(pub))] pub(crate) struct NeonInterpolator where T: NeonSample, { - sincs: Vec>, + sincs: Vec>, length: usize, nbr_sincs: usize, } impl SincInterpolator for NeonInterpolator where - T: Sample, + T: NeonSample, { - /// Calculate the scalar produt of an input wave and the selected sinc filter. - fn get_sinc_interpolated(&self, wave: &[T], index: usize, subindex: usize) -> T { - assert!( - (index + self.length) < wave.len(), - "Tried to interpolate for index {}, max for the given input is {}", - index, - wave.len() - self.length - 1 - ); - assert!( - subindex < self.nbr_sincs, - "Tried to use sinc subindex {}, max is {}", - subindex, - self.nbr_sincs - 1 - ); - unsafe { T::get_sinc_interpolated_unsafe(wave, index, subindex, &self.sincs, self.length) } + fn get_sinc_dot_product(&self, wave: &[T], index: usize, sinc: &[T]) -> T { + unsafe { T::get_sinc_dot_product_unsafe(wave, index, sinc, self.length) } + } + + fn get_sincs(&self) -> &[Vec] { + &self.sincs } fn nbr_points(&self) -> usize { @@ -181,11 +152,36 @@ where fn nbr_sincs(&self) -> usize { self.nbr_sincs } + + fn make_combined_sinc( + &self, + nearest: &[(isize, isize)], + weights: &[T], + combined: &mut [T], + ) -> isize + where + T: crate::Sample, + { + let min_idx = nearest.iter().map(|n| n.0).min().unwrap(); + combined.iter_mut().for_each(|x| *x = T::zero()); + for (n, &w) in nearest.iter().zip(weights.iter()) { + let shift = (n.0 - min_idx) as usize; + unsafe { + T::saxpy_unsafe( + &mut combined[shift..shift + self.length], + w, + &self.sincs[n.1 as usize], + self.length, + ); + } + } + min_idx + } } impl NeonInterpolator where - T: Sample, + T: NeonSample + Sample, { /// Create a new NeonInterpolator. /// @@ -206,7 +202,6 @@ where assert!(sinc_len % 8 == 0, "Sinc length must be a multiple of 8."); let sincs = make_sincs(sinc_len, oversampling_factor, f_cutoff, window); - let sincs = unsafe { ::pack_sincs(sincs) }; Ok(Self { sincs, @@ -216,6 +211,14 @@ where } } +// Suppress dead_code warnings for float32x4_t/float64x2_t used only in +// target_feature-gated functions. +#[allow(dead_code)] +const _: () = { + let _ = core::mem::size_of::(); + let _ = core::mem::size_of::(); +}; + #[cfg(test)] mod tests { use crate::sinc::make_sincs; diff --git a/src/sinc_interpolator/sinc_interpolator_sse.rs b/src/sinc_interpolator/sinc_interpolator_sse.rs index e2025eb..7a776aa 100644 --- a/src/sinc_interpolator/sinc_interpolator_sse.rs +++ b/src/sinc_interpolator/sinc_interpolator_sse.rs @@ -5,10 +5,12 @@ use crate::windows::WindowFunction; use crate::Sample; use core::arch::x86_64::{__m128, __m128d}; use core::arch::x86_64::{ - _mm_add_pd, _mm_hadd_pd, _mm_loadu_pd, _mm_mul_pd, _mm_setzero_pd, _mm_store_sd, + _mm_add_pd, _mm_hadd_pd, _mm_loadu_pd, _mm_mul_pd, _mm_set1_pd, _mm_setzero_pd, + _mm_store_sd, _mm_storeu_pd, }; use core::arch::x86_64::{ - _mm_add_ps, _mm_hadd_ps, _mm_loadu_ps, _mm_mul_ps, _mm_setzero_ps, _mm_store_ss, + _mm_add_ps, _mm_hadd_ps, _mm_loadu_ps, _mm_mul_ps, _mm_set1_ps, _mm_setzero_ps, + _mm_store_ss, _mm_storeu_ps, }; /// Collection of CPU features required for this interpolator. @@ -16,71 +18,64 @@ static FEATURES: &[CpuFeature] = &[CpuFeature::Sse3]; /// Trait governing what can be done with an SseSample. pub trait SseSample: Sized + Send { - type Sinc: Send; - - /// Pack sincs into a vector. - /// - /// # Safety - /// - /// This is unsafe because it uses target_enable dispatching. There are no - /// special requirements from the caller. - unsafe fn pack_sincs(sincs: Vec>) -> Vec>; - - /// Interpolate a sinc sample. + /// Compute the dot product of `wave[index..]` with `sinc` using SSE instructions. /// /// # Safety /// - /// The caller must ensure that the various indexes are not out of bounds - /// in the collection of sincs. - unsafe fn get_sinc_interpolated_unsafe( + /// The caller must ensure that `wave[index..index+length]` and `sinc[..length]` are + /// valid, and that `length` is a multiple of 8. + unsafe fn get_sinc_dot_product_unsafe( wave: &[Self], index: usize, - subindex: usize, - sincs: &[Vec], + sinc: &[Self], length: usize, ) -> Self; + + /// Compute `out[..length] += scale * input[..length]` using SSE instructions. + /// + /// # Safety + /// + /// The caller must ensure that `out[..length]` and `input[..length]` are valid, + /// and that `length` is a multiple of 8. + unsafe fn saxpy_unsafe(out: &mut [Self], scale: Self, input: &[Self], length: usize); } impl SseSample for f32 { - type Sinc = __m128; - #[target_feature(enable = "sse3")] - unsafe fn pack_sincs(sincs: Vec>) -> Vec> { - let mut packed_sincs = Vec::new(); - for sinc in sincs.iter() { - let mut packed = Vec::new(); - for elements in sinc.chunks(4) { - let packed_elems = _mm_loadu_ps(&elements[0]); - packed.push(packed_elems); - } - packed_sincs.push(packed); + unsafe fn saxpy_unsafe(out: &mut [f32], scale: f32, input: &[f32], length: usize) { + let scale_vec = _mm_set1_ps(scale); + let mut idx = 0; + for _ in 0..length / 4 { + let x = _mm_loadu_ps(input.get_unchecked(idx)); + let y = _mm_loadu_ps(out.get_unchecked(idx)); + _mm_storeu_ps( + out.get_unchecked_mut(idx) as *mut f32, + _mm_add_ps(y, _mm_mul_ps(scale_vec, x)), + ); + idx += 4; } - packed_sincs } #[target_feature(enable = "sse3")] - unsafe fn get_sinc_interpolated_unsafe( + unsafe fn get_sinc_dot_product_unsafe( wave: &[f32], index: usize, - subindex: usize, - sincs: &[Vec], + sinc: &[f32], length: usize, ) -> f32 { - let sinc = sincs.get_unchecked(subindex); let wave_cut = &wave[index..(index + length)]; let mut acc0 = _mm_setzero_ps(); let mut acc1 = _mm_setzero_ps(); - let mut w_idx = 0; - let mut s_idx = 0; - for _ in 0..wave_cut.len() / 8 { - let w0 = _mm_loadu_ps(wave_cut.get_unchecked(w_idx)); - let w1 = _mm_loadu_ps(wave_cut.get_unchecked(w_idx + 4)); - let s0 = _mm_mul_ps(w0, *sinc.get_unchecked(s_idx)); - let s1 = _mm_mul_ps(w1, *sinc.get_unchecked(s_idx + 1)); - acc0 = _mm_add_ps(acc0, s0); - acc1 = _mm_add_ps(acc1, s1); - w_idx += 8; - s_idx += 2; + let mut idx = 0; + for _ in 0..length / 8 { + let w0 = _mm_loadu_ps(wave_cut.get_unchecked(idx)); + let w1 = _mm_loadu_ps(wave_cut.get_unchecked(idx + 4)); + acc0 = _mm_add_ps(acc0, _mm_mul_ps(w0, _mm_loadu_ps(sinc.get_unchecked(idx)))); + acc1 = _mm_add_ps( + acc1, + _mm_mul_ps(w1, _mm_loadu_ps(sinc.get_unchecked(idx + 4))), + ); + idx += 8; } let temp4 = _mm_add_ps(acc0, acc1); let temp2 = _mm_hadd_ps(temp4, temp4); @@ -92,53 +87,64 @@ impl SseSample for f32 { } impl SseSample for f64 { - type Sinc = __m128d; - #[target_feature(enable = "sse3")] - unsafe fn pack_sincs(sincs: Vec>) -> Vec> { - let mut packed_sincs = Vec::new(); - for sinc in sincs.iter() { - let mut packed = Vec::new(); - for elements in sinc.chunks(2) { - let packed_elems = _mm_loadu_pd(&elements[0]); - packed.push(packed_elems); - } - packed_sincs.push(packed); + unsafe fn saxpy_unsafe(out: &mut [f64], scale: f64, input: &[f64], length: usize) { + let scale_vec = _mm_set1_pd(scale); + let mut idx = 0; + for _ in 0..length / 2 { + let x = _mm_loadu_pd(input.get_unchecked(idx)); + let y = _mm_loadu_pd(out.get_unchecked(idx)); + _mm_storeu_pd( + out.get_unchecked_mut(idx) as *mut f64, + _mm_add_pd(y, _mm_mul_pd(scale_vec, x)), + ); + idx += 2; } - packed_sincs } #[target_feature(enable = "sse3")] - unsafe fn get_sinc_interpolated_unsafe( + unsafe fn get_sinc_dot_product_unsafe( wave: &[f64], index: usize, - subindex: usize, - sincs: &[Vec], + sinc: &[f64], length: usize, ) -> f64 { - let sinc = sincs.get_unchecked(subindex); let wave_cut = &wave[index..(index + length)]; let mut acc0 = _mm_setzero_pd(); let mut acc1 = _mm_setzero_pd(); let mut acc2 = _mm_setzero_pd(); let mut acc3 = _mm_setzero_pd(); - let mut w_idx = 0; - let mut s_idx = 0; - for _ in 0..wave_cut.len() / 8 { - let w0 = _mm_loadu_pd(wave_cut.get_unchecked(w_idx)); - let w1 = _mm_loadu_pd(wave_cut.get_unchecked(w_idx + 2)); - let w2 = _mm_loadu_pd(wave_cut.get_unchecked(w_idx + 4)); - let w3 = _mm_loadu_pd(wave_cut.get_unchecked(w_idx + 6)); - let s0 = _mm_mul_pd(w0, *sinc.get_unchecked(s_idx)); - let s1 = _mm_mul_pd(w1, *sinc.get_unchecked(s_idx + 1)); - let s2 = _mm_mul_pd(w2, *sinc.get_unchecked(s_idx + 2)); - let s3 = _mm_mul_pd(w3, *sinc.get_unchecked(s_idx + 3)); - acc0 = _mm_add_pd(acc0, s0); - acc1 = _mm_add_pd(acc1, s1); - acc2 = _mm_add_pd(acc2, s2); - acc3 = _mm_add_pd(acc3, s3); - w_idx += 8; - s_idx += 4; + let mut idx = 0; + for _ in 0..length / 8 { + acc0 = _mm_add_pd( + acc0, + _mm_mul_pd( + _mm_loadu_pd(wave_cut.get_unchecked(idx)), + _mm_loadu_pd(sinc.get_unchecked(idx)), + ), + ); + acc1 = _mm_add_pd( + acc1, + _mm_mul_pd( + _mm_loadu_pd(wave_cut.get_unchecked(idx + 2)), + _mm_loadu_pd(sinc.get_unchecked(idx + 2)), + ), + ); + acc2 = _mm_add_pd( + acc2, + _mm_mul_pd( + _mm_loadu_pd(wave_cut.get_unchecked(idx + 4)), + _mm_loadu_pd(sinc.get_unchecked(idx + 4)), + ), + ); + acc3 = _mm_add_pd( + acc3, + _mm_mul_pd( + _mm_loadu_pd(wave_cut.get_unchecked(idx + 6)), + _mm_loadu_pd(sinc.get_unchecked(idx + 6)), + ), + ); + idx += 8; } let temp2_0 = _mm_add_pd(acc0, acc1); let temp2_1 = _mm_add_pd(acc2, acc3); @@ -156,7 +162,7 @@ pub(crate) struct SseInterpolator where T: SseSample, { - sincs: Vec>, + sincs: Vec>, length: usize, nbr_sincs: usize, } @@ -165,21 +171,12 @@ impl SincInterpolator for SseInterpolator where T: SseSample, { - /// Calculate the scalar produt of an input wave and the selected sinc filter. - fn get_sinc_interpolated(&self, wave: &[T], index: usize, subindex: usize) -> T { - assert!( - (index + self.length) < wave.len(), - "Tried to interpolate for index {}, max for the given input is {}", - index, - wave.len() - self.length - 1 - ); - assert!( - subindex < self.nbr_sincs, - "Tried to use sinc subindex {}, max is {}", - subindex, - self.nbr_sincs - 1 - ); - unsafe { T::get_sinc_interpolated_unsafe(wave, index, subindex, &self.sincs, self.length) } + fn get_sinc_dot_product(&self, wave: &[T], index: usize, sinc: &[T]) -> T { + unsafe { T::get_sinc_dot_product_unsafe(wave, index, sinc, self.length) } + } + + fn get_sincs(&self) -> &[Vec] { + &self.sincs } fn nbr_points(&self) -> usize { @@ -189,11 +186,36 @@ where fn nbr_sincs(&self) -> usize { self.nbr_sincs } + + fn make_combined_sinc( + &self, + nearest: &[(isize, isize)], + weights: &[T], + combined: &mut [T], + ) -> isize + where + T: crate::Sample, + { + let min_idx = nearest.iter().map(|n| n.0).min().unwrap(); + combined.iter_mut().for_each(|x| *x = T::zero()); + for (n, &w) in nearest.iter().zip(weights.iter()) { + let shift = (n.0 - min_idx) as usize; + unsafe { + ::saxpy_unsafe( + &mut combined[shift..shift + self.length], + w, + &self.sincs[n.1 as usize], + self.length, + ); + } + } + min_idx + } } impl SseInterpolator where - T: Sample, + T: SseSample + Sample, { /// Create a new SseInterpolator. /// @@ -214,7 +236,6 @@ where assert!(sinc_len % 8 == 0, "Sinc length must be a multiple of 8."); let sincs = make_sincs(sinc_len, oversampling_factor, f_cutoff, window); - let sincs = unsafe { ::pack_sincs(sincs) }; Ok(Self { sincs, @@ -224,6 +245,13 @@ where } } +// Suppress dead_code warnings for __m128/__m128d used only in target_feature-gated functions. +#[allow(dead_code)] +const _: () = { + let _ = core::mem::size_of::<__m128>(); + let _ = core::mem::size_of::<__m128d>(); +}; + #[cfg(test)] mod tests { use crate::sinc::make_sincs;