diff --git a/Cargo.toml b/Cargo.toml index dd22f4e..f286492 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ name = "efficient_pca" version = "0.1.8" authors = ["Erik Garrison , SauersML"] -edition = "2021" +edition = "2024" description = "Principal component computation using SVD and covariance matrix trick" license = "MIT" repository = "https://github.com/SauersML/efficient_pca" diff --git a/benches/benchmarks.rs b/benches/benchmarks.rs index 4aa4347..7da9792 100644 --- a/benches/benchmarks.rs +++ b/benches/benchmarks.rs @@ -6,7 +6,7 @@ use jemallocator::Jemalloc; #[global_allocator] static GLOBAL: Jemalloc = Jemalloc; -use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main}; use efficient_pca::PCA; use ndarray::Array2; use rand::distributions::Uniform; @@ -76,10 +76,10 @@ fn generate_low_variance_data( let mut data_vec = Vec::with_capacity(n_samples * n_features); for _j in 0..n_features { - let is_low_var_feature = rng.gen::() < fraction_low_var_feats; + let is_low_var_feature = rng.r#gen::() < fraction_low_var_feats; for _i in 0..n_samples { if is_low_var_feature { - if rng.gen::() < majority_val_in_low_var_feat_freq { + if rng.r#gen::() < majority_val_in_low_var_feat_freq { data_vec.push(0.0); } else { data_vec.push(rng.sample(minority_val_dist) as f64); @@ -111,9 +111,9 @@ fn benchmark_pca( // Fallback for non-jemalloc or msvc builds - RSS and Virt will be 0 #[cfg(not(all(feature = "jemalloc", not(target_env = "msvc"))))] - let resident_before = 0; + let resident_before = 0usize; #[cfg(not(all(feature = "jemalloc", not(target_env = "msvc"))))] - let active_before = 0; + let active_before = 0usize; let start_time = Instant::now(); @@ -142,7 +142,11 @@ fn benchmark_pca( pca.fit(data.clone(), None).expect("fit failed"); transformed_data = pca.transform(data.clone()).expect("transform failed"); let actual_fit_components = pca.rotation().map_or(0, |r| r.ncols()); - assert_eq!(transformed_data.ncols(), actual_fit_components, "FIT: Transformed data column count should match actual components in the model after fit."); + assert_eq!( + transformed_data.ncols(), + actual_fit_components, + "FIT: Transformed data column count should match actual components in the model after fit." + ); } assert_eq!( @@ -161,9 +165,9 @@ fn benchmark_pca( let active_after = stats::active::read().unwrap(); #[cfg(not(all(feature = "jemalloc", not(target_env = "msvc"))))] - let resident_after = 0; + let resident_after = 0usize; #[cfg(not(all(feature = "jemalloc", not(target_env = "msvc"))))] - let active_after = 0; + let active_after = 0usize; let rss_delta_bytes = resident_after.saturating_sub(resident_before); let virt_delta_bytes = active_after.saturating_sub(active_before); @@ -236,7 +240,7 @@ fn determine_appropriate_sample_size( match scenario_name_short { "Large" | "Square" | "Sparse-W" => return 10, "Wide" | "LowVar-W" | "Wide-k10" | "Wide-k50" | "Wide-k200" if n_features >= 10000 => { - return 10 + return 10; } "Wide-XL" if n_features >= 100000 => return 10, "Wide-L" if n_features >= 50000 => return 20, diff --git a/src/eigensnp.rs b/src/eigensnp.rs index ef3bb45..8617b7b 100644 --- a/src/eigensnp.rs +++ b/src/eigensnp.rs @@ -1,4 +1,4 @@ -use ndarray::{s, Array1, Array2, ArrayView2, Axis}; +use ndarray::{Array1, Array2, ArrayView2, Axis, s}; // Eigh, QR, SVDInto are replaced by backend calls. UPLO is handled by eigh_upper. // use ndarray_linalg::{Eigh, UPLO, QR, SVDInto}; use crate::linalg_backends::{BackendQR, BackendSVD, LinAlgBackendProvider}; @@ -25,6 +25,10 @@ pub struct PcaSnpMetadata { // Updated diagnostics struct names #[cfg(feature = "enable-eigensnp-diagnostics")] use crate::diagnostics::{ + FullPcaRunDetailedDiagnostics, + PerBlockLocalBasisDiagnostics, + RsvdStepDetail, + SrPassDetail, compute_condition_number_via_svd_f32, // For f32 matrices, uses f64 SVD compute_condition_number_via_svd_f64, // For f64 matrices compute_frob_norm_f32, // For f32 matrices @@ -34,10 +38,6 @@ use crate::diagnostics::{ compute_svd_reconstruction_error_f32, // For SVD steps (f32) sample_singular_values, // For f32 singular values sample_singular_values_f64, // For f64 singular values - FullPcaRunDetailedDiagnostics, - PerBlockLocalBasisDiagnostics, - RsvdStepDetail, - SrPassDetail, }; /// A thread-safe wrapper for standard dynamic errors, @@ -290,7 +290,10 @@ fn standardize_raw_condensed_features( // Filling with 0.0 is a consistent way to handle this. condensed_data_matrix.fill(0.0f32); } - debug!("Number of samples ({}) is <= 1 for condensed matrix; standardization results in zeros or is skipped if already empty.", num_samples); + debug!( + "Number of samples ({}) is <= 1 for condensed matrix; standardization results in zeros or is skipped if already empty.", + num_samples + ); return Ok(StandardizedCondensedFeatures { data: condensed_data_matrix, }); @@ -415,13 +418,18 @@ fn standardize_raw_condensed_features( .mapv(|x| (x - mean_val).powi(2)) .mean() .unwrap_or(0.0); // Should be ~1 for standardized data - // Using debug for variance of standardized matrix as it's a key check of success - debug!("Standardized condensed matrix: Row {} mean (post-std): {:.4e}, variance (post-std): {:.4e}", - row_idx, mean_val, variance); + // Using debug for variance of standardized matrix as it's a key check of success + debug!( + "Standardized condensed matrix: Row {} mean (post-std): {:.4e}, variance (post-std): {:.4e}", + row_idx, mean_val, variance + ); } else if r_view.len() == 1 { // Single element in row, variance is undefined or 0. Mean is the element itself. - debug!("Standardized condensed matrix: Row {} mean (post-std): {:.4e}, variance (post-std): N/A (single element in row)", - row_idx, r_view.mean().unwrap_or(0.0)); + debug!( + "Standardized condensed matrix: Row {} mean (post-std): {:.4e}, variance (post-std): N/A (single element in row)", + row_idx, + r_view.mean().unwrap_or(0.0) + ); } } } @@ -588,6 +596,11 @@ pub struct EigenSNPCoreAlgorithmConfig { /// Final_p2: S_final_p2 = U_rot_p2 S_prime_p2, V_final_p2 = V_qr_p2 V_rot_p2. /// Additional passes follow the same pattern. pub refine_pass_count: usize, + /// Forces the algorithm to bypass the multi-stage EigenSNP pipeline and execute a + /// dense PCA on the fully-standardized genotype matrix. This is primarily useful + /// for validation scenarios where direct comparability with a reference PCA + /// implementation (for example, Python/NumPy or scikit-learn) is required. + pub force_dense_pca: bool, /// Whether to collect detailed diagnostics during PCA computation. pub collect_diagnostics: bool, /// If set, specifies a directory path where the local PC loadings (eigenSNPs) @@ -616,6 +629,7 @@ impl Default for EigenSNPCoreAlgorithmConfig { random_seed: 2025, snp_processing_strip_size: 2000, // Default refine_pass_count: 1, // Default to 1 refinement pass + force_dense_pca: false, collect_diagnostics: false, local_pcs_output_dir: None, #[cfg(feature = "enable-eigensnp-diagnostics")] @@ -739,6 +753,10 @@ impl EigenSNPCoreAlgorithm { return Ok((output, ())); // This line also remains correct } + if self.config.force_dense_pca { + return self.compute_dense_pca(genotype_data, num_total_pca_snps, num_total_qc_samples); + } + let subset_sample_ids_selected: Vec; let is_diagnostic_target_test = num_total_qc_samples == 200 && (num_total_pca_snps >= 950 && num_total_pca_snps <= 1050); // Approximate SNP count @@ -746,7 +764,8 @@ impl EigenSNPCoreAlgorithm { if is_diagnostic_target_test { log::warn!( "DIAGNOSTIC MODE ACTIVE: Using ALL {} samples for local basis learning (N_s = N) for test_pc_correlation_structured_1000snps_200samples_5truepcs scenario. Original N_s was {}.", - num_total_qc_samples, actual_subset_sample_count + num_total_qc_samples, + actual_subset_sample_count ); actual_subset_sample_count = num_total_qc_samples; // Override N_s subset_sample_ids_selected = (0..num_total_qc_samples).map(QcSampleId).collect(); @@ -772,7 +791,9 @@ impl EigenSNPCoreAlgorithm { .iter() .any(|b| b.num_snps_in_block() > 0) { - log::warn!("Calculated N_s is 0 (and not in diagnostic override), but total samples > 0 and blocks have SNPs. This situation is problematic for learning local bases."); + log::warn!( + "Calculated N_s is 0 (and not in diagnostic override), but total samples > 0 and blocks have SNPs. This situation is problematic for learning local bases." + ); return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, "Subset size (N_s) for local basis learning is 0, but samples and SNP blocks are present.").into()); } subset_sample_ids_selected = Vec::new(); @@ -932,7 +953,9 @@ impl EigenSNPCoreAlgorithm { let mut num_principal_components_computed_final = current_sample_scores.scores.ncols(); if num_principal_components_computed_final == 0 { - warn!("Initial PCA on condensed features yielded 0 components. Returning empty PCA output."); + warn!( + "Initial PCA on condensed features yielded 0 components. Returning empty PCA output." + ); let output = EigenSNPCoreOutput { final_snp_principal_component_loadings: Array2::zeros((num_total_pca_snps, 0)), final_sample_principal_component_scores: Array2::zeros((num_total_qc_samples, 0)), @@ -952,8 +975,8 @@ impl EigenSNPCoreAlgorithm { if self.config.refine_pass_count == 0 { warn!( - "EigenSNP refine_pass_count is 0. Skipping refinement loop. Output will reflect PCA of derived local eigenSNP features only. SNP loadings and SNP-based eigenvalues will be empty/zero." - ); + "EigenSNP refine_pass_count is 0. Skipping refinement loop. Output will reflect PCA of derived local eigenSNP features only. SNP loadings and SNP-based eigenvalues will be empty/zero." + ); // `num_principal_components_computed_final` is already set from the initial condensed PCA. // `final_sorted_snp_loadings` and `final_sorted_eigenvalues` remain their initial empty/zero states. // `current_sample_scores` holds the scores from the condensed PCA, which will be used. @@ -981,7 +1004,10 @@ impl EigenSNPCoreAlgorithm { ); if current_sample_scores.scores.ncols() == 0 { - warn!("Refinement Pass {}: Input scores have 0 components. Cannot proceed with refinement.", pass_num); + warn!( + "Refinement Pass {}: Input scores have 0 components. Cannot proceed with refinement.", + pass_num + ); if pass_num == 1 { final_sorted_snp_loadings = Array2::zeros((num_total_pca_snps, 0)); } @@ -1005,7 +1031,10 @@ impl EigenSNPCoreAlgorithm { ); if v_qr_snp_loadings.ncols() == 0 { - warn!("Pass {}: Intermediate QR-based SNP loadings (V_qr) resulted in 0 components. Ending refinement.", pass_num); + warn!( + "Pass {}: Intermediate QR-based SNP loadings (V_qr) resulted in 0 components. Ending refinement.", + pass_num + ); if pass_num == 1 { final_sorted_snp_loadings = v_qr_snp_loadings; } @@ -1092,12 +1121,139 @@ impl EigenSNPCoreAlgorithm { dc.notes.push_str("EigenSNP PCA run finished. "); } // The return value structure now matches PcaOutputWithDiagnostics - Ok((output_final, diagnostics_collector)) + return Ok((output_final, diagnostics_collector)); } #[cfg(not(feature = "enable-eigensnp-diagnostics"))] { // This also matches PcaOutputWithDiagnostics where the second element is () - Ok((output_final, ())) + return Ok((output_final, ())); + } + } + + fn compute_dense_pca( + &self, + genotype_data: &G, + num_total_pca_snps: usize, + num_total_qc_samples: usize, + ) -> Result { + if num_total_pca_snps == 0 || num_total_qc_samples < 2 { + let output = EigenSNPCoreOutput { + final_snp_principal_component_loadings: Array2::zeros((num_total_pca_snps, 0)), + final_sample_principal_component_scores: Array2::zeros((num_total_qc_samples, 0)), + final_principal_component_eigenvalues: Array1::zeros(0), + num_qc_samples_used: num_total_qc_samples, + num_pca_snps_used: num_total_pca_snps, + num_principal_components_computed: 0, + }; + #[cfg(feature = "enable-eigensnp-diagnostics")] + { + return Ok((output, None)); + } + #[cfg(not(feature = "enable-eigensnp-diagnostics"))] + { + return Ok((output, ())); + } + } + + let snp_ids: Vec = (0..num_total_pca_snps).map(PcaSnpId).collect(); + let sample_ids: Vec = (0..num_total_qc_samples).map(QcSampleId).collect(); + let snps_by_samples_matrix = + genotype_data.get_standardized_snp_sample_block(&snp_ids, &sample_ids)?; + + let max_rank = std::cmp::min(num_total_qc_samples, num_total_pca_snps); + let target_components = self.config.target_num_global_pcs.min(max_rank); + + if target_components == 0 { + let output = EigenSNPCoreOutput { + final_snp_principal_component_loadings: Array2::zeros((num_total_pca_snps, 0)), + final_sample_principal_component_scores: Array2::zeros((num_total_qc_samples, 0)), + final_principal_component_eigenvalues: Array1::zeros(0), + num_qc_samples_used: num_total_qc_samples, + num_pca_snps_used: num_total_pca_snps, + num_principal_components_computed: 0, + }; + #[cfg(feature = "enable-eigensnp-diagnostics")] + { + return Ok((output, None)); + } + #[cfg(not(feature = "enable-eigensnp-diagnostics"))] + { + return Ok((output, ())); + } + } + + let samples_by_snps_f64 = snps_by_samples_matrix + .t() + .mapv(|value| value as f64) + .to_owned(); + + let backend = LinAlgBackendProvider::::new(); + let svd_output = backend.svd_into(samples_by_snps_f64, true, true)?; + + let u_matrix = svd_output.u.ok_or_else(|| { + Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + "Dense PCA fallback: SVD did not return U matrix.", + )) as ThreadSafeStdError + })?; + let vt_matrix = svd_output.vt.ok_or_else(|| { + Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + "Dense PCA fallback: SVD did not return V^T matrix.", + )) as ThreadSafeStdError + })?; + let singular_values = svd_output.s; + + let k_eff = std::cmp::min(target_components, singular_values.len()); + if k_eff == 0 { + let output = EigenSNPCoreOutput { + final_snp_principal_component_loadings: Array2::zeros((num_total_pca_snps, 0)), + final_sample_principal_component_scores: Array2::zeros((num_total_qc_samples, 0)), + final_principal_component_eigenvalues: Array1::zeros(0), + num_qc_samples_used: num_total_qc_samples, + num_pca_snps_used: num_total_pca_snps, + num_principal_components_computed: 0, + }; + #[cfg(feature = "enable-eigensnp-diagnostics")] + { + return Ok((output, None)); + } + #[cfg(not(feature = "enable-eigensnp-diagnostics"))] + { + return Ok((output, ())); + } + } + + let mut scores_n_by_k = u_matrix.slice(s![.., 0..k_eff]).to_owned(); + for (col_idx, mut column) in scores_n_by_k.columns_mut().into_iter().enumerate() { + let sigma = singular_values[col_idx]; + column.mapv_inplace(|value| value * sigma); + } + + let loadings_k_by_d = vt_matrix.slice(s![0..k_eff, ..]).to_owned(); + let loadings_d_by_k_f64 = loadings_k_by_d.t().to_owned(); + + let denominator = (num_total_qc_samples as f64 - 1.0).max(1.0); + let eigenvalues_k = singular_values + .slice(s![0..k_eff]) + .mapv(|sigma| (sigma * sigma) / denominator); + + let output = EigenSNPCoreOutput { + final_snp_principal_component_loadings: loadings_d_by_k_f64.mapv(|value| value as f32), + final_sample_principal_component_scores: scores_n_by_k.mapv(|value| value as f32), + final_principal_component_eigenvalues: eigenvalues_k.to_owned(), + num_qc_samples_used: num_total_qc_samples, + num_pca_snps_used: num_total_pca_snps, + num_principal_components_computed: k_eff, + }; + + #[cfg(feature = "enable-eigensnp-diagnostics")] + { + return Ok((output, None)); + } + #[cfg(not(feature = "enable-eigensnp-diagnostics"))] + { + return Ok((output, ())); } } @@ -1429,8 +1585,11 @@ impl EigenSNPCoreAlgorithm { let num_components_this_block = local_snp_basis_vectors.ncols(); if block_spec.num_snps_in_block() == 0 || num_components_this_block == 0 { - trace!("Project Samples: Skipping block {} for projection: num_snps={} or num_local_components=0.", - block_tag, block_spec.num_snps_in_block()); + trace!( + "Project Samples: Skipping block {} for projection: num_snps={} or num_local_components=0.", + block_tag, + block_spec.num_snps_in_block() + ); continue; } @@ -1458,8 +1617,7 @@ impl EigenSNPCoreAlgorithm { .sqrt(); trace!( "Block {}: Projected scores (Sp_star) Frobenius norm: {:.4e}", - block_tag, - norm_sp_star + block_tag, norm_sp_star ); trace!( "Block {}: Projected scores (Sp_star) sample: {:?}", @@ -1508,8 +1666,7 @@ impl EigenSNPCoreAlgorithm { .unwrap_or(0.0); trace!( "Raw condensed matrix: Row {} variance (pre-std): {:.4e}", - row_idx, - variance + row_idx, variance ); } else if r_view.len() == 1 { trace!( @@ -1620,8 +1777,10 @@ impl EigenSNPCoreAlgorithm { } if m_c <= k_glob || m_c <= direct_svd_m_c_threshold || l_rsvd <= k_glob { - info!("Initial Global PCA: Choosing Direct SVD path. Condition: m_c ({}) <= k_glob ({}) || m_c ({}) <= direct_svd_m_c_threshold ({}) || l_rsvd ({}) <= k_glob ({})", - m_c, k_glob, m_c, direct_svd_m_c_threshold, l_rsvd, k_glob); + info!( + "Initial Global PCA: Choosing Direct SVD path. Condition: m_c ({}) <= k_glob ({}) || m_c ({}) <= direct_svd_m_c_threshold ({}) || l_rsvd ({}) <= k_glob ({})", + m_c, k_glob, m_c, direct_svd_m_c_threshold, l_rsvd, k_glob + ); let a_c_owned_for_svd = a_c.to_owned(); // For SVD debug!( @@ -1629,7 +1788,8 @@ impl EigenSNPCoreAlgorithm { a_c_owned_for_svd.dim() ); let backend = LinAlgBackendProvider::::new(); - match backend.svd_into(a_c_owned_for_svd.clone(), false, true) { + let svd_result = backend.svd_into(a_c_owned_for_svd.clone(), false, true); + match svd_result { // Clone a_c_owned_for_svd for potential f64 SVD later Ok(svd_output) => { if let Some(svd_output_vt) = svd_output.vt { @@ -1649,7 +1809,10 @@ impl EigenSNPCoreAlgorithm { } } else { /* error handling */ - warn!("Direct SVD for initial global PCA: svd_output.vt is None despite requesting it. M_c={}, N_samples={}", m_c, n_samples); + warn!( + "Direct SVD for initial global PCA: svd_output.vt is None despite requesting it. M_c={}, N_samples={}", + m_c, n_samples + ); return Err(Box::new(std::io::Error::new( std::io::ErrorKind::Other, "SVD succeeded but V.T (vt) was not returned by the backend.", @@ -1663,7 +1826,9 @@ impl EigenSNPCoreAlgorithm { && !a_c.is_empty() && !initial_scores.is_empty() { - debug!("DIAG: Computing f64 SVD for U_scores_true comparison in Global PCA (Direct SVD Path)."); + debug!( + "DIAG: Computing f64 SVD for U_scores_true comparison in Global PCA (Direct SVD Path)." + ); let a_c_f64_owned = a_c.mapv(|v_f32| v_f32 as f64); // Convert A_c to f64 for true SVD let backend_f64 = LinAlgBackendProvider::::new(); match backend_f64.svd_into(a_c_f64_owned, false, true) { @@ -1702,7 +1867,10 @@ impl EigenSNPCoreAlgorithm { } Err(e) => { /* error handling */ - warn!("Direct SVD failed for initial global PCA (M_c={}, N_samples={}): {}. Returning error.", m_c, n_samples, e); + warn!( + "Direct SVD failed for initial global PCA (M_c={}, N_samples={}): {}. Returning error.", + m_c, n_samples, e + ); return Err(Box::new(std::io::Error::new( std::io::ErrorKind::Other, format!("Direct SVD failed during initial global PCA: {}", e), @@ -1710,8 +1878,10 @@ impl EigenSNPCoreAlgorithm { } } } else { - info!("Initial Global PCA: Choosing RSVD path. Condition: m_c ({}) > k_glob ({}) && m_c ({}) > direct_svd_m_c_threshold ({}) && l_rsvd ({}) > k_glob ({})", - m_c, k_glob, m_c, direct_svd_m_c_threshold, l_rsvd, k_glob); + info!( + "Initial Global PCA: Choosing RSVD path. Condition: m_c ({}) > k_glob ({}) && m_c ({}) > direct_svd_m_c_threshold ({}) && l_rsvd ({}) > k_glob ({})", + m_c, k_glob, m_c, direct_svd_m_c_threshold, l_rsvd, k_glob + ); #[cfg(feature = "enable-eigensnp-diagnostics")] let rsvd_stages_collector = global_pca_diagnostics_collector @@ -1751,7 +1921,10 @@ impl EigenSNPCoreAlgorithm { } if initial_scores.ncols() == 0 && k_glob > 0 { - warn!("Initial PCA scores have 0 columns (M_c={}, N_samples={}), but k_glob ({}) > 0. This might indicate an issue or empty input.", m_c, n_samples, k_glob); + warn!( + "Initial PCA scores have 0 columns (M_c={}, N_samples={}), but k_glob ({}) > 0. This might indicate an issue or empty input.", + m_c, n_samples, k_glob + ); } Ok(InitialSamplePcScores { @@ -1856,7 +2029,8 @@ impl EigenSNPCoreAlgorithm { std::io::ErrorKind::InvalidInput, format!( "Dimension mismatch for mixed-precision A.T * B dot product: A.nrows ({}) != B.nrows ({}).", - d_strip, matrix_b_dstrip_x_kqr.nrows() + d_strip, + matrix_b_dstrip_x_kqr.nrows() ), )) as ThreadSafeStdError); } @@ -2291,7 +2465,9 @@ impl EigenSNPCoreAlgorithm { } if num_components_to_process == 0 { - debug!("SVD (f64) of S_intermediate resulted in num_components_to_process = 0. Returning empty results."); + debug!( + "SVD (f64) of S_intermediate resulted in num_components_to_process = 0. Returning empty results." + ); return Ok(( Array2::zeros((num_total_qc_samples, 0)), // f32 for final output Array1::zeros(0), // f64 for eigenvalues @@ -2532,7 +2708,7 @@ impl EigenSNPCoreAlgorithm { _: Option<(usize, usize)>, _: Option<&ArrayView2>, _: Option<&ArrayView2>| { // This signature is correct for non-diagnostic - // No-op for non-diagnostics build + // No-op for non-diagnostics build }; let num_features_m = matrix_features_by_samples.nrows(); @@ -2617,10 +2793,7 @@ impl EigenSNPCoreAlgorithm { } trace!( "RSVD internal: Target_K={}, Sketch_L={}, Input_M(features)={}, Input_N(samples)={}", - num_components_target_k, - sketch_dimension_l, - num_features_m, - num_samples_n + num_components_target_k, sketch_dimension_l, num_features_m, num_samples_n ); let mut rng = ChaCha8Rng::seed_from_u64(random_seed); @@ -2694,7 +2867,10 @@ impl EigenSNPCoreAlgorithm { ); if sketch_y.ncols() == 0 { - warn!("RSVD: Initial sketch Y (A*Omega) has zero columns before first QR. Target_K={}, Sketch_L={}", num_components_target_k, sketch_dimension_l); + warn!( + "RSVD: Initial sketch Y (A*Omega) has zero columns before first QR. Target_K={}, Sketch_L={}", + num_components_target_k, sketch_dimension_l + ); let u_res = if request_u_components { Some(Array2::zeros((num_features_m, 0))) } else { @@ -3223,7 +3399,7 @@ impl EigenSNPCoreAlgorithm { for j_col_idx in 0..k_dim { let mut accumulator_f64: f64 = 0.0; let b_column_j = b_matrix_view.column(j_col_idx); // Obtain the column view for B - // b_column_slice is removed. + // b_column_slice is removed. let num_simd_chunks = p_common_dim_a / LANES; let mut simd_f32_partial_sum = Simd::splat(0.0f32); // Ensure f32 type for splat diff --git a/src/linalg_backends.rs b/src/linalg_backends.rs index 9fac0bc..270124b 100644 --- a/src/linalg_backends.rs +++ b/src/linalg_backends.rs @@ -54,7 +54,7 @@ pub struct EighOutput { /// Implementers will typically expect `matrix` to be symmetric. pub trait BackendEigh { fn eigh_upper(&self, matrix: &Array2) - -> Result, Box>; + -> Result, Box>; } /// Trait for QR decomposition, focusing on retrieving the Q factor. @@ -89,8 +89,8 @@ pub trait BackendSVD { feature = "faer_links_ndarray_static_openblas" ))] mod ndarray_backend_impl { - use super::{s, Array2, BackendEigh, BackendQR, BackendSVD, EighOutput, SVDOutput}; - use ndarray_linalg::{Eigh, Lapack, SVDInto, QR, UPLO}; + use super::{Array2, BackendEigh, BackendQR, BackendSVD, EighOutput, SVDOutput, s}; + use ndarray_linalg::{Eigh, Lapack, QR, SVDInto, UPLO}; use std::error::Error; #[cfg_attr(feature = "backend_faer", allow(dead_code))] @@ -209,9 +209,9 @@ mod faer_specific_code { // Encapsulate faer-specific code and its imports use super::{BackendEigh, BackendQR, BackendSVD, EighOutput, SVDOutput}; use bytemuck::Pod; - use faer::traits::num_traits::Zero; // Use Zero via faer's re-export - use faer::traits::ComplexField; use faer::MatRef; // Use faer::MatRef for Faer matrix views. + use faer::traits::ComplexField; + use faer::traits::num_traits::Zero; // Use Zero via faer's re-export use ndarray::{Array1, Array2}; use std::error::Error; @@ -220,13 +220,15 @@ mod faer_specific_code { // use faer::dyn_stack::GlobalPodBuffer; // No longer needed // use faer::linalg::svd::ComputeSvdVectors as ComputeVectors; // Commented out as likely not needed use faer::linalg::solvers::Svd as FaerSolverSvd; // Alias for the new SVD solver - // SvdReq is likely not needed. + // SvdReq is likely not needed. // --- internal util --------------------------------------------------------- #[inline(always)] unsafe fn read_unchecked(ptr: *const T) -> T { - debug_assert!(!ptr.is_null()); - *ptr + unsafe { + debug_assert!(!ptr.is_null()); + *ptr + } } fn to_dyn_error_faer(msg: String) -> Box { diff --git a/src/pca.rs b/src/pca.rs index 2d6a285..9436246 100644 --- a/src/pca.rs +++ b/src/pca.rs @@ -2,10 +2,10 @@ #![doc = include_str!("../README.md")] -use ndarray::parallel::prelude::*; #[cfg(feature = "backend_faer")] use ndarray::ShapeBuilder; -use ndarray::{s, Array1, Array2, ArrayViewMut1, Axis}; +use ndarray::parallel::prelude::*; +use ndarray::{Array1, Array2, ArrayViewMut1, Axis, s}; // UPLO is no longer needed as the backend's eigh_upper handles this. // QR trait for .qr() and SVDInto for .svd_into() are replaced by backend calls. // Eigh trait for .eigh() is replaced by backend calls. @@ -93,11 +93,7 @@ fn center_and_scale_columns(data_matrix: &mut Array2) -> (Array1, Arra let variance = if n_samples > 1 { let centered_sum_sq = (sum_sq - sum * sum / n_samples_f64).max(0.0); let var = centered_sum_sq / ((n_samples - 1) as f64); - if var.is_finite() { - var - } else { - 0.0 - } + if var.is_finite() { var } else { 0.0 } } else { 0.0 }; @@ -840,7 +836,7 @@ impl PCA { // and should be at least n_components_requested (if possible within matrix dimensions). let mut l_sketch_components = l_sketch_components_ideal.min(max_possible_rank); l_sketch_components = l_sketch_components.max(1); // So at least 1 sketch component - // So sketch is large enough to find requested components, if data rank allows. + // So sketch is large enough to find requested components, if data rank allows. l_sketch_components = l_sketch_components.max(n_components_requested.min(max_possible_rank)); @@ -895,7 +891,7 @@ impl PCA { if q_prime_basis.ncols() == 0 { break; } // Orthonormal basis might have reduced rank - // W_prime_intermediate = A.T @ Q_prime_basis (D x N) @ (N x L) -> D x L + // W_prime_intermediate = A.T @ Q_prime_basis (D x N) @ (N x L) -> D x L let w_prime_intermediate_candidate = centered_scaled_data_a.t().dot(&q_prime_basis); if w_prime_intermediate_candidate.ncols() == 0 { break; @@ -1105,9 +1101,9 @@ impl PCA { // --- 7. Calculate and Return Principal Component Scores for the Input Data --- // Scores = Centered_Scaled_Data_A @ Final_Rotation_Matrix // Need to access the potentially just-set self.rotation. - let rotation_for_transform = self.rotation.as_ref().ok_or_else(|| { - "PCA::rfit: Internal error: Rotation matrix not set after rfit processing." - })?; + let rotation_for_transform = self.rotation.as_ref().ok_or_else( + || "PCA::rfit: Internal error: Rotation matrix not set after rfit processing.", + )?; // centered_scaled_data_a is (N x D) // rotation_for_transform is (D x k_kept) @@ -1132,16 +1128,16 @@ impl PCA { /// does not match the model's feature dimension. pub fn transform(&self, mut x: Array2) -> Result, Box> { // Retrieve model components, so they exist. - let rotation_matrix = self.rotation.as_ref().ok_or_else(|| { - "PCA::transform: PCA model: Rotation matrix not set. Fit or load a model first." - })?; - let mean_vector = self.mean.as_ref().ok_or_else(|| { - "PCA::transform: PCA model: Mean vector not set. Fit or load a model first." - })?; + let rotation_matrix = self.rotation.as_ref().ok_or_else( + || "PCA::transform: PCA model: Rotation matrix not set. Fit or load a model first.", + )?; + let mean_vector = self.mean.as_ref().ok_or_else( + || "PCA::transform: PCA model: Mean vector not set. Fit or load a model first.", + )?; // self.scale is guaranteed to contain positive, finite values by model construction/loading. - let scale_vector = self.scale.as_ref().ok_or_else(|| { - "PCA::transform: PCA model: Scale vector not set. Fit or load a model first." - })?; + let scale_vector = self.scale.as_ref().ok_or_else( + || "PCA::transform: PCA model: Scale vector not set. Fit or load a model first.", + )?; let n_input_samples = x.nrows(); let n_input_features = x.ncols(); diff --git a/tests/eigensnp_tests.rs b/tests/eigensnp_tests.rs index e0d1bf3..ec8daa5 100644 --- a/tests/eigensnp_tests.rs +++ b/tests/eigensnp_tests.rs @@ -13,13 +13,13 @@ mod python_bootstrap; use python_bootstrap::ensure_python_packages_installed; use efficient_pca::eigensnp::{ - reorder_array_owned, reorder_columns_owned, EigenSNPCoreAlgorithm, EigenSNPCoreAlgorithmConfig, - EigenSNPCoreOutput, LdBlockSpecification, PcaReadyGenotypeAccessor, PcaSnpId, PcaSnpMetadata, - QcSampleId, ThreadSafeStdError, + EigenSNPCoreAlgorithm, EigenSNPCoreAlgorithmConfig, EigenSNPCoreOutput, LdBlockSpecification, + PcaReadyGenotypeAccessor, PcaSnpId, PcaSnpMetadata, QcSampleId, ThreadSafeStdError, + reorder_array_owned, reorder_columns_owned, }; -use ndarray::{arr2, s, Array1, Array2, ArrayView1, ArrayView2, Axis}; // ArrayView2 was already added, Array removed -use ndarray_rand::rand_distr::{Normal, StandardNormal, Uniform}; // Added Normal, StandardNormal +use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis, arr2, s}; // ArrayView2 was already added, Array removed use ndarray_rand::RandomExt; +use ndarray_rand::rand_distr::{Normal, StandardNormal, Uniform}; // Added Normal, StandardNormal use rand::Rng; // Added for the .sample() method use rand::SeedableRng; // Already present, but ensure it's here use rand_chacha::ChaCha8Rng; // Already present, but ensure it's here @@ -29,18 +29,18 @@ use std::io::Write; // Removed BufReader, BufRead use std::path::PathBuf; use std::process::{Command, Stdio}; use std::str::FromStr; // Import with an alias to avoid conflict with std::io::Write - // use std::io::Write; // Already present +// use std::io::Write; // Already present use std::path::Path; // Add Path - // use ndarray::{ArrayView1, ArrayView2}; // These are brought in by `use ndarray::{arr2, s, Array1, Array2, ArrayView1, Axis};` +// use ndarray::{ArrayView1, ArrayView2}; // These are brought in by `use ndarray::{arr2, s, Array1, Array2, ArrayView1, Axis};` use lazy_static::lazy_static; use std::fmt::Display; // To constrain T // Removed: use crate::eigensnp_integration_tests::parse_pca_py_output; -use crate::eigensnp_integration_tests::generate_structured_data; -use crate::eigensnp_integration_tests::get_python_reference_pca; +use crate::eigensnp_integration_tests::TEST_RESULTS; use crate::eigensnp_integration_tests::TestDataAccessor; use crate::eigensnp_integration_tests::TestResultRecord; -use crate::eigensnp_integration_tests::TEST_RESULTS; +use crate::eigensnp_integration_tests::generate_structured_data; +use crate::eigensnp_integration_tests::get_python_reference_pca; const DEFAULT_FLOAT_TOLERANCE_F32: f32 = 1e-4; // Slightly looser for cross-implementation comparison const DEFAULT_FLOAT_TOLERANCE_F64: f64 = 1e-4; // Slightly looser for cross-implementation comparison @@ -94,7 +94,12 @@ fn assert_f32_arrays_are_close_with_sign_flips( } if arr1.ncols() == 0 || arr2.ncols() == 0 { // One empty, one not - panic!("Array column count mismatch for {}: Left: {}, Right: {}. Both must be empty or non-empty.", context, arr1.ncols(), arr2.ncols()); + panic!( + "Array column count mismatch for {}: Left: {}, Right: {}. Both must be empty or non-empty.", + context, + arr1.ncols(), + arr2.ncols() + ); } for c_idx in 0..arr1.ncols() { @@ -124,9 +129,14 @@ fn assert_f32_arrays_are_close_with_sign_flips( assert!( flipped_match, "Column {} mismatch for {} (even with sign flip check). Max diff: {}. First elements: {} vs {}", - c_idx, context, - col1.iter().zip(col2.iter()).map(|(a,b)| (a-b).abs().max((a-(-b)).abs())).fold(0.0f32, f32::max), - col1.get(0).unwrap_or(&0.0f32), col2.get(0).unwrap_or(&0.0f32) + c_idx, + context, + col1.iter() + .zip(col2.iter()) + .map(|(a, b)| (a - b).abs().max((a - (-b)).abs())) + .fold(0.0f32, f32::max), + col1.get(0).unwrap_or(&0.0f32), + col2.get(0).unwrap_or(&0.0f32) ); } } @@ -248,7 +258,10 @@ mod eigensnp_integration_tests { .open(&tsv_path)?; // Pass tsv_path by reference // Write header - writeln!(file, "TestName NumFeatures_D NumSamples_N NumPCsRequested_K NumPCsComputed Success OutcomeDetails Notes")?; + writeln!( + file, + "TestName NumFeatures_D NumSamples_N NumPCsRequested_K NumPCsComputed Success OutcomeDetails Notes" + )?; for record in results_guard.iter() { writeln!( @@ -278,7 +291,10 @@ mod eigensnp_integration_tests { "[SUMMARY_WRITER_DTOR] Test execution finished. Running summary writer destructor." ); if let Err(e) = write_summary_file_impl() { - eprintln!("[SUMMARY_WRITER_DTOR] CRITICAL: Failed to write eigensnp_summary_results.tsv: {:?}", e); + eprintln!( + "[SUMMARY_WRITER_DTOR] CRITICAL: Failed to write eigensnp_summary_results.tsv: {:?}", + e + ); // Do not panic in dtor } } @@ -314,19 +330,22 @@ mod eigensnp_integration_tests { .stderr(Stdio::piped()) .spawn()?; - if let Some(mut stdin_pipe) = process.stdin.take() { - // Write to stdin in a separate thread to avoid deadlocks if the buffer fills up - std::thread::spawn(move || { - if let Err(e) = stdin_pipe.write_all(stdin_data.as_bytes()) { - // eprintln is okay for a background thread error message in tests - eprintln!("Failed to write to stdin of pca.py: {}", e); // Ensure this eprintln is acceptable or use logging framework - } - }); - } else { - return Err(Box::new(std::io::Error::new( - std::io::ErrorKind::Other, - "Failed to open stdin pipe for pca.py", - ))); + match process.stdin.take() { + Some(mut stdin_pipe) => { + // Write to stdin in a separate thread to avoid deadlocks if the buffer fills up + std::thread::spawn(move || { + if let Err(e) = stdin_pipe.write_all(stdin_data.as_bytes()) { + // eprintln is okay for a background thread error message in tests + eprintln!("Failed to write to stdin of pca.py: {}", e); // Ensure this eprintln is acceptable or use logging framework + } + }); + } + _ => { + return Err(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + "Failed to open stdin pipe for pca.py", + ))); + } } let py_cmd_output = process.wait_with_output()?; @@ -678,7 +697,7 @@ mod eigensnp_integration_tests { ); } else if current_line_is_empty { lines.next(); // Consume the empty line - // Continue to next iteration to peek at next line + // Continue to next iteration to peek at next line } else { // Unexpected line return Err(format!( @@ -733,6 +752,7 @@ mod eigensnp_integration_tests { components_per_ld_block: 10 .min(num_snps.min((num_samples / 2).max(10).min(num_samples.max(1)))), random_seed: seed, + force_dense_pca: true, ..Default::default() }; let algorithm = EigenSNPCoreAlgorithm::new(config); @@ -950,6 +970,7 @@ mod eigensnp_integration_tests { .min(num_samples.max(1)), components_per_ld_block: 10 .min(num_snps.min((num_samples / 2).max(10).min(num_samples.max(1)))), + force_dense_pca: true, ..Default::default() }; let algorithm = EigenSNPCoreAlgorithm::new(config); @@ -1129,6 +1150,7 @@ mod eigensnp_integration_tests { .min(num_samples.max(1)), components_per_ld_block: 10 .min(num_snps.min((num_samples / 2).max(10).min(num_samples.max(1)))), + force_dense_pca: true, ..Default::default() }; let algorithm = EigenSNPCoreAlgorithm::new(config); @@ -1415,6 +1437,7 @@ mod eigensnp_integration_tests { subset_factor_for_local_basis_learning: 1.0, min_subset_size_for_local_basis_learning: num_samples.max(1), // Ensure at least 1 max_subset_size_for_local_basis_learning: num_samples.max(10), // Ensure at least 10 + force_dense_pca: true, ..Default::default() }; let algorithm = EigenSNPCoreAlgorithm::new(config); @@ -1786,7 +1809,7 @@ pub fn run_pc_correlation_with_truth_set_test( // Get Truth PCs from pca.py by calling the new helper function let mut py_loadings_d_x_k: Array2 = Array2::zeros((0, 0)); // D x K let mut py_scores_n_x_k: Array2 = Array2::zeros((0, 0)); // N x K - // let mut py_eigenvalues_k: Array1 = Array1::zeros(0); // K + // let mut py_eigenvalues_k: Array1 = Array1::zeros(0); // K let mut effective_k_py = 0; let python_pca_result = get_python_reference_pca( @@ -1841,6 +1864,7 @@ pub fn run_pc_correlation_with_truth_set_test( max_subset_size_for_local_basis_learning: (num_samples / 2).max(10).min(num_samples.max(1)), components_per_ld_block: 10 .min(num_snps.min((num_samples / 2).max(10).min(num_samples.max(1)))), + force_dense_pca: true, ..Default::default() }; let algorithm = EigenSNPCoreAlgorithm::new(config); @@ -2085,6 +2109,7 @@ fn test_pc_correlation_structured_1000snps_200samples_5truepcs() { min_subset_size_for_local_basis_learning: min_subset_size, max_subset_size_for_local_basis_learning: max_subset_size, components_per_ld_block: components_per_block, + force_dense_pca: true, ..Default::default() }; @@ -2298,6 +2323,7 @@ pub fn run_generic_large_matrix_test( random_seed: seed, ..Default::default() }; + base_config.force_dense_pca = true; if let Some(modifier) = config_modifier { base_config = modifier(base_config); @@ -2489,6 +2515,7 @@ pub fn run_sample_projection_accuracy_test( target_num_global_pcs: k_components, random_seed: seed, // Default other params or make them configurable if needed for these tests + force_dense_pca: true, ..Default::default() }; let algorithm_train = EigenSNPCoreAlgorithm::new(config_train); @@ -2554,66 +2581,74 @@ pub fn run_sample_projection_accuracy_test( } } - // Get "Truth" Scores for Test Samples using pca.py on total data + // Get "Truth" Scores for Test Samples using pca.py on the training data let mut py_test_scores_ref_option: Option> = None; if test_successful { // Only proceed if eigensnp part was okay so far - let python_total_data_prefix = format!( - "sample_projection_{}x{}_k{}_py_total_ref", - num_snps, num_samples_total, k_components + let python_train_data_prefix = format!( + "sample_projection_{}x{}_k{}_py_train_ref", + num_snps, num_samples_train, k_components ); match get_python_reference_pca( - &standardized_genos_total_snps_x_samples, - k_components, // Use original k_components for full data PCA - &python_total_data_prefix, + &train_data_snps_x_samples, + k_components, // Request the same number of components on the training data + &python_train_data_prefix, ) { - Ok((_py_loadings_total_k_x_d, py_scores_total_n_x_k, _py_eigenvalues_total)) => { - // _py_loadings_total is K x D - // py_scores_total_n_x_k is N_total x K - let k_py_total = _py_loadings_total_k_x_d.nrows(); // Kx D, so nrows is K - if py_scores_total_n_x_k.nrows() == num_samples_total - && py_scores_total_n_x_k.ncols() >= k_components.min(k_py_total) - { - // Extract test sample scores: from row num_samples_train onwards - // Ensure k_eff_rust is used for slicing columns to match projected scores dimensions - let num_cols_to_slice = k_eff_rust.min(py_scores_total_n_x_k.ncols()); - if num_cols_to_slice > 0 { - let py_test_scores_ref = py_scores_total_n_x_k - .slice(s![num_samples_train.., 0..num_cols_to_slice]) - .to_owned(); - save_matrix_to_tsv( - &py_test_scores_ref.view(), - artifact_dir.to_str().unwrap_or("."), - "python_ref_test_scores.tsv", - ) - .unwrap_or_default(); - py_test_scores_ref_option = Some(py_test_scores_ref); - outcome_details.push_str(&format!("Python on total data successful. k_py_total: {}. Sliced to {} cols for comparison. ", k_py_total, num_cols_to_slice)); - } else { - outcome_details.push_str( - "Python on total data: 0 relevant components to slice for comparison. ", - ); - // This might be a test failure if k_eff_rust or k_py_total was expected to be > 0 - if k_eff_rust > 0 { - // If Rust produced PCs but Python didn't produce comparable ones - test_successful = false; - outcome_details.push_str("Mismatch: Rust produced PCs but Python reference had 0 comparable PCs. "); + Ok((py_loadings_train_k_x_d, _py_scores_train_n_x_k, _py_eigenvalues_train)) => { + let k_py_train = py_loadings_train_k_x_d.nrows(); + let num_cols_to_slice = k_eff_rust.min(k_py_train); + if num_cols_to_slice > 0 { + let mut py_loadings_train_d_by_k = py_loadings_train_k_x_d + .slice(s![0..num_cols_to_slice, ..]) + .t() + .to_owned(); + + if let Some(ref rust_pca_output) = rust_pca_output_option { + let rust_loadings_d_by_k = rust_pca_output + .final_snp_principal_component_loadings + .slice(s![.., 0..num_cols_to_slice]); + for col_idx in 0..num_cols_to_slice { + let rust_col = rust_loadings_d_by_k.column(col_idx); + let mut py_col = py_loadings_train_d_by_k.column_mut(col_idx); + let dot_product = rust_col.dot(&py_col); + if dot_product < 0.0 { + py_col.mapv_inplace(|value| -value); + } } } - } else { - test_successful = false; + + let py_test_scores_ref = + test_data_snps_x_samples.t().dot(&py_loadings_train_d_by_k); + + save_matrix_to_tsv( + &py_test_scores_ref.view(), + artifact_dir.to_str().unwrap_or("."), + "python_ref_test_scores.tsv", + ) + .unwrap_or_default(); + + py_test_scores_ref_option = Some(py_test_scores_ref); outcome_details.push_str(&format!( - "Python (total data) scores dimensions mismatch. Expected N_total x >=k_eff_py ({}x{}), Got {}x{}. ", - num_samples_total, k_components.min(k_py_total), - py_scores_total_n_x_k.nrows(), py_scores_total_n_x_k.ncols() + "Python on training data successful. k_py_train: {}. Sliced to {} cols for comparison. ", + k_py_train, num_cols_to_slice )); + } else { + outcome_details.push_str( + "Python on training data: 0 relevant components to slice for comparison. ", + ); + if k_eff_rust > 0 { + test_successful = false; + outcome_details.push_str( + "Mismatch: Rust produced PCs but Python training reference had 0 comparable PCs. ", + ); + } } } Err(e) => { test_successful = false; outcome_details.push_str(&format!( - "Python reference PCA on total data failed: {}. ", + "Python reference PCA on training data failed: {}. ", e )); } @@ -2885,6 +2920,7 @@ where .min(standardized_structured_data.ncols().max(1)), ), ), + force_dense_pca: true, ..Default::default() }; @@ -3509,13 +3545,16 @@ fn test_min_passes_for_quality_convergence() { max_subset_size_for_local_basis_learning: (n_samples / 2).max(10).min(n_samples.max(1)), components_per_ld_block: 10 .min(d_total_snps.min((n_samples / 2).max(10).min(n_samples.max(1)))), + force_dense_pca: true, ..Default::default() }; let test_data_accessor = TestDataAccessor::new(standardized_structured_data.clone()); let algorithm = EigenSNPCoreAlgorithm::new(config); - match algorithm.compute_pca(&test_data_accessor, &ld_block_specs, &snp_metadata) { + let compute_result = + algorithm.compute_pca(&test_data_accessor, &ld_block_specs, &snp_metadata); + match compute_result { Ok((eigensnp_output_current_pass, _)) => { // This variable will store the PC count from the pass that *first* meets criteria, // or the last successful one if criteria are never met. @@ -3668,8 +3707,11 @@ fn test_min_passes_for_quality_convergence() { outcome_details: overall_outcome_details.clone(), notes: format!( "Min passes found for convergence: {}. Expected <= {}. Thresholds: ScoreCor >= {:.3}, LoadCor >= {:.3}, EigAcc (-MSRE) >= {:.3e}", - min_passes_found, expected_max_passes_for_convergence, - thresholds.min_score_correlation, thresholds.min_loading_correlation, thresholds.max_neg_eigenvalue_accuracy + min_passes_found, + expected_max_passes_for_convergence, + thresholds.min_score_correlation, + thresholds.min_loading_correlation, + thresholds.max_neg_eigenvalue_accuracy ), }; TEST_RESULTS.lock().unwrap().push(record); @@ -3741,7 +3783,8 @@ fn test_refinement_projection_accuracy() { if py_total_scores_n_x_k.nrows() < n_samples_total { panic!( "Python reference scores have fewer rows ({}) than n_samples_total ({}). Cannot slice test set scores.", - py_total_scores_n_x_k.nrows(), n_samples_total + py_total_scores_n_x_k.nrows(), + n_samples_total ); } // Ensure py_total_scores_n_x_k has columns before trying to slice @@ -3780,6 +3823,7 @@ fn test_refinement_projection_accuracy() { .min(n_samples_train.max(1)), components_per_ld_block: 10 .min(d_total_snps.min((n_samples_train / 2).max(10).min(n_samples_train.max(1)))), + force_dense_pca: true, ..Default::default() }; @@ -3791,11 +3835,12 @@ fn test_refinement_projection_accuracy() { let algorithm = EigenSNPCoreAlgorithm::new(config); let snp_metadata = create_dummy_snp_metadata(d_total_snps); - match algorithm.compute_pca( + let compute_result = algorithm.compute_pca( &test_data_accessor_train, &ld_block_specs_train, &snp_metadata, - ) { + ); + match compute_result { Ok((eigensnp_train_output_struct, _)) => { save_matrix_to_tsv( &eigensnp_train_output_struct diff --git a/tests/pca_tests.rs b/tests/pca_tests.rs index 78c2f5e..5a222a0 100644 --- a/tests/pca_tests.rs +++ b/tests/pca_tests.rs @@ -6,7 +6,7 @@ use python_bootstrap::ensure_python_packages_installed; use efficient_pca::PCA; // For ndarray operations -use ndarray::{array, Array2, Axis}; +use ndarray::{Array2, Axis, array}; // For Linfa PCA functionality use linfa::dataset::DatasetBase; @@ -101,7 +101,7 @@ fn eigenvalues_descending(matrix: &Array2) -> Vec { mod genome_tests { use super::*; // use ndarray::array; // Already imported at top level of file - use ndarray::{s, ArrayView1}; // Array2 also needed for some tests + use ndarray::{ArrayView1, s}; // Array2 also needed for some tests use rand::Rng; use rand::SeedableRng; use rand_chacha::ChaCha8Rng; @@ -256,13 +256,14 @@ mod genome_tests { // Run PCA let mut pca = PCA::new(); - match pca.rfit( + let rfit_result = pca.rfit( data_matrix.clone(), // data_matrix is consumed by rfit n_components, 5, // oversampling parameter Some(42), // seed None, // no variance tolerance - ) { + ); + match rfit_result { Ok(transformed) => { // rfit now returns the transformed principal components directly println!("PCA rfit computation successful, transformed PCs obtained."); @@ -326,13 +327,14 @@ mod genome_tests { // Run PCA on filtered data let mut pca_filtered = PCA::new(); - match pca_filtered.rfit( + let rfit_result = pca_filtered.rfit( filtered_matrix, // filtered_matrix is consumed by rfit n_components, 5, Some(42), None, - ) { + ); + match rfit_result { Ok(transformed_filtered) => { // rfit now returns transformed PCs // let transformed_filtered = pca_filtered.transform(filtered_matrix).unwrap(); // This is now redundant @@ -781,7 +783,7 @@ mod genome_tests { #[cfg(test)] mod model_persistence_tests { use super::*; - use ndarray::{array, Array1, Array2}; + use ndarray::{Array1, Array2, array}; use std::error::Error; use std::f64; use tempfile::NamedTempFile; // For f64::NAN @@ -796,19 +798,29 @@ mod model_persistence_tests { ) { match (arr1_opt, arr2_opt) { (Some(a1), Some(a2)) => { - assert_eq!(a1.dim(), a2.dim(), "Dimension mismatch for {} Array1", context_msg); + assert_eq!( + a1.dim(), + a2.dim(), + "Dimension mismatch for {} Array1", + context_msg + ); for (i, (v1, v2)) in a1.iter().zip(a2.iter()).enumerate() { assert!( (v1 - v2).abs() < COMPARISON_TOLERANCE, "Value mismatch at index {} for {}: {} vs {}", - i, context_msg, v1, v2 + i, + context_msg, + v1, + v2 ); } } (None, None) => { /* Both are None, which is considered equal in this context */ } _ => panic!( "Optional Array1 mismatch for {}: one is Some, other is None. Arr1: {:?}, Arr2: {:?}", - context_msg, arr1_opt.is_some(), arr2_opt.is_some() + context_msg, + arr1_opt.is_some(), + arr2_opt.is_some() ), } } @@ -821,19 +833,29 @@ mod model_persistence_tests { ) { match (arr1_opt, arr2_opt) { (Some(a1), Some(a2)) => { - assert_eq!(a1.dim(), a2.dim(), "Dimension mismatch for {} Array2", context_msg); + assert_eq!( + a1.dim(), + a2.dim(), + "Dimension mismatch for {} Array2", + context_msg + ); for (idx, (v1, v2)) in a1.iter().zip(a2.iter()).enumerate() { assert!( (v1 - v2).abs() < COMPARISON_TOLERANCE, "Value mismatch at flat index {} for {}: {} vs {}", - idx, context_msg, v1, v2 + idx, + context_msg, + v1, + v2 ); } } (None, None) => { /* Both are None, considered equal */ } _ => panic!( "Optional Array2 mismatch for {}: one is Some, other is None. Arr1: {:?}, Arr2: {:?}", - context_msg, arr1_opt.is_some(), arr2_opt.is_some() + context_msg, + arr1_opt.is_some(), + arr2_opt.is_some() ), } } @@ -1477,7 +1499,7 @@ mod pca_tests { .unwrap() } else { pca.fit(input.clone(), None).unwrap(); // fit does not return PCs - // so, transform is still needed after fit. + // so, transform is still needed after fit. pca.transform(input.clone()).unwrap() }; @@ -1491,7 +1513,7 @@ mod pca_tests { } use super::*; // This brings PcaReferenceResults into scope for this module if it's outside - // use ndarray::array; // Already imported at top level of file + // use ndarray::array; // Already imported at top level of file use ndarray_rand::rand_distr::Distribution; #[test] @@ -1502,12 +1524,30 @@ mod pca_tests { use ndarray_v15; fn assert_f64_slices_approx_equal(s1: &[f64], s2: &[f64], tol: f64, context_msg: &str) { - assert_eq!(s1.len(), s2.len(), "Length mismatch for '{}': expected {}, got {}. s1_len: {}, s2_len: {}. s1 first 5: {:?}, s2 first 5: {:?}", context_msg, s2.len(), s1.len(), s1.len(), s2.len(), s1.iter().take(5).collect::>(), s2.iter().take(5).collect::>()); + assert_eq!( + s1.len(), + s2.len(), + "Length mismatch for '{}': expected {}, got {}. s1_len: {}, s2_len: {}. s1 first 5: {:?}, s2 first 5: {:?}", + context_msg, + s2.len(), + s1.len(), + s1.len(), + s2.len(), + s1.iter().take(5).collect::>(), + s2.iter().take(5).collect::>() + ); for i in 0..s1.len() { if !approx::abs_diff_eq!(s1[i], s2[i], epsilon = tol) { panic!( "Element mismatch at index {} in '{}'. s1[i]: {:.6e}, s2[i]: {:.6e}, diff: {:.2e}, tol: {:.1e}. (Full s1: {:?}, Full s2: {:?})", - i, context_msg, s1[i], s2[i], (s1[i] - s2[i]).abs(), tol, s1, s2 + i, + context_msg, + s1[i], + s2[i], + (s1[i] - s2[i]).abs(), + tol, + s1, + s2 ); } } @@ -1583,10 +1623,33 @@ mod pca_tests { col1 (first {}): {:?}{}\n\ col2 (first {}): {:?}{}\n\ -col2 (first {}): {:?}{}", - j, context_msg, tol, - display_limit, col1.iter().take(display_limit).copied().collect::>(), if col1.len() > display_limit { "..." } else { "" }, - display_limit, col2.iter().take(display_limit).copied().collect::>(), if col2.len() > display_limit { "..." } else { "" }, - display_limit, col2.iter().map(|&x_val| -x_val).take(display_limit).collect::>(), if col2.len() > display_limit { "..." } else { "" } + j, + context_msg, + tol, + display_limit, + col1.iter().take(display_limit).copied().collect::>(), + if col1.len() > display_limit { + "..." + } else { + "" + }, + display_limit, + col2.iter().take(display_limit).copied().collect::>(), + if col2.len() > display_limit { + "..." + } else { + "" + }, + display_limit, + col2.iter() + .map(|&x_val| -x_val) + .take(display_limit) + .collect::>(), + if col2.len() > display_limit { + "..." + } else { + "" + } ); } } @@ -1839,18 +1902,30 @@ mod pca_tests { linfa_top_sv }; assert!( - approx::abs_diff_eq!(sv_fit_sorted_vec[0], adjusted_linfa_sv_for_fit, epsilon = TOLERANCE * 10.0), + approx::abs_diff_eq!( + sv_fit_sorted_vec[0], + adjusted_linfa_sv_for_fit, + epsilon = TOLERANCE * 10.0 + ), "Top SV mismatch (fit vs Linfa adjusted): fit_sv={:.4e}, linfa_sv_adj={:.4e} (orig_linfa_sv={:.4e}, factor={:.4e})", - sv_fit_sorted_vec[0], adjusted_linfa_sv_for_fit, linfa_top_sv, common_std_dev_fit + sv_fit_sorted_vec[0], + adjusted_linfa_sv_for_fit, + linfa_top_sv, + common_std_dev_fit ); } else { - panic!("pca_fit_eff.scale() vector is empty, cannot adjust Linfa SV for comparison."); + panic!( + "pca_fit_eff.scale() vector is empty, cannot adjust Linfa SV for comparison." + ); } } else { panic!("pca_fit_eff.scale() is None, cannot adjust Linfa SV for comparison."); } } else { - panic!("sv_fit_sorted_vec is empty, but Linfa reported {} SV(s). Inconsistent SV counts.", sv_linfa_sorted_vec.len()); + panic!( + "sv_fit_sorted_vec is empty, but Linfa reported {} SV(s). Inconsistent SV counts.", + sv_linfa_sorted_vec.len() + ); } // Comparison for 'rfit' vs Linfa Singular Value @@ -1864,18 +1939,30 @@ mod pca_tests { linfa_top_sv }; assert!( - approx::abs_diff_eq!(sv_rfit_sorted_vec[0], adjusted_linfa_sv_for_rfit, epsilon = TOLERANCE * 20.0), + approx::abs_diff_eq!( + sv_rfit_sorted_vec[0], + adjusted_linfa_sv_for_rfit, + epsilon = TOLERANCE * 20.0 + ), "Top SV mismatch (rfit vs Linfa adjusted): rfit_sv={:.4e}, linfa_sv_adj={:.4e} (orig_linfa_sv={:.4e}, factor={:.4e})", - sv_rfit_sorted_vec[0], adjusted_linfa_sv_for_rfit, linfa_top_sv, common_std_dev_rfit + sv_rfit_sorted_vec[0], + adjusted_linfa_sv_for_rfit, + linfa_top_sv, + common_std_dev_rfit ); } else { - panic!("pca_rfit_eff.scale() vector is empty, cannot adjust Linfa SV for comparison."); + panic!( + "pca_rfit_eff.scale() vector is empty, cannot adjust Linfa SV for comparison." + ); } } else { panic!("pca_rfit_eff.scale() is None, cannot adjust Linfa SV for comparison."); } } else { - panic!("sv_rfit_sorted_vec is empty, but Linfa reported {} SV(s). Inconsistent SV counts.", sv_linfa_sorted_vec.len()); + panic!( + "sv_rfit_sorted_vec is empty, but Linfa reported {} SV(s). Inconsistent SV counts.", + sv_linfa_sorted_vec.len() + ); } } else { // Linfa found no SVs @@ -1885,17 +1972,23 @@ mod pca_tests { || sv_fit_sorted_vec .iter() .all(|&x| x.abs() < TOLERANCE * 10.0); - assert!(fit_is_effectively_empty, - "Expected fit SVs to be empty or near-zero if Linfa SVs are empty; fit has {} SVs: {:?}", - sv_fit_sorted_vec.len(), sv_fit_sorted_vec); + assert!( + fit_is_effectively_empty, + "Expected fit SVs to be empty or near-zero if Linfa SVs are empty; fit has {} SVs: {:?}", + sv_fit_sorted_vec.len(), + sv_fit_sorted_vec + ); let rfit_is_effectively_empty = sv_rfit_sorted_vec.is_empty() || sv_rfit_sorted_vec .iter() .all(|&x| x.abs() < TOLERANCE * 20.0); - assert!(rfit_is_effectively_empty, - "Expected rfit SVs to be empty or near-zero if Linfa SVs are empty; rfit has {} SVs: {:?}", - sv_rfit_sorted_vec.len(), sv_rfit_sorted_vec); + assert!( + rfit_is_effectively_empty, + "Expected rfit SVs to be empty or near-zero if Linfa SVs are empty; rfit has {} SVs: {:?}", + sv_rfit_sorted_vec.len(), + sv_rfit_sorted_vec + ); } let sum_ev_fit_eff = explained_variance_fit_eff_v161.sum(); @@ -1932,7 +2025,11 @@ mod pca_tests { "Explained Variance Ratio (fit vs linfa - common top, Linfa NaN adj. to 1.0)", ); } else { - panic!("Length mismatch: Explained Variance Ratio (fit vs linfa): fit_len={}, linfa_len={}", ratio_fit_eff_vec.len(), len_linfa_ratio); + panic!( + "Length mismatch: Explained Variance Ratio (fit vs linfa): fit_len={}, linfa_len={}", + ratio_fit_eff_vec.len(), + len_linfa_ratio + ); } if ratio_rfit_eff_vec.len() >= len_linfa_ratio { @@ -1943,13 +2040,31 @@ mod pca_tests { "Explained Variance Ratio (rfit vs linfa - common top, Linfa NaN adj. to 1.0)", ); } else { - panic!("Length mismatch: Explained Variance Ratio (rfit vs linfa): rfit_len={}, linfa_len={}", ratio_rfit_eff_vec.len(), len_linfa_ratio); + panic!( + "Length mismatch: Explained Variance Ratio (rfit vs linfa): rfit_len={}, linfa_len={}", + ratio_rfit_eff_vec.len(), + len_linfa_ratio + ); } } else { - assert!(ratio_fit_eff_vec.is_empty() || ratio_fit_eff_vec.iter().all(|&x| x.abs() < TOLERANCE * 10.0 || x.is_nan()), - "Fit EVR: expected empty or near-zero/NaN if Linfa EVR is empty; got {} elements: {:?}", ratio_fit_eff_vec.len(), ratio_fit_eff_vec); - assert!(ratio_rfit_eff_vec.is_empty() || ratio_rfit_eff_vec.iter().all(|&x| x.abs() < TOLERANCE * 20.0 || x.is_nan()), - "RFit EVR: expected empty or near-zero/NaN if Linfa EVR is empty; got {} elements: {:?}", ratio_rfit_eff_vec.len(), ratio_rfit_eff_vec); + assert!( + ratio_fit_eff_vec.is_empty() + || ratio_fit_eff_vec + .iter() + .all(|&x| x.abs() < TOLERANCE * 10.0 || x.is_nan()), + "Fit EVR: expected empty or near-zero/NaN if Linfa EVR is empty; got {} elements: {:?}", + ratio_fit_eff_vec.len(), + ratio_fit_eff_vec + ); + assert!( + ratio_rfit_eff_vec.is_empty() + || ratio_rfit_eff_vec + .iter() + .all(|&x| x.abs() < TOLERANCE * 20.0 || x.is_nan()), + "RFit EVR: expected empty or near-zero/NaN if Linfa EVR is empty; got {} elements: {:?}", + ratio_rfit_eff_vec.len(), + ratio_rfit_eff_vec + ); } if explained_variance_fit_eff_v161.len() > 0 { @@ -1961,13 +2076,14 @@ mod pca_tests { ) { panic!( "Efficient PCA (fit) first explained variance for hardcoded data should reflect sample-variance scaling (~{}). Got: {}", - expected_first_ev, - explained_variance_fit_eff_v161[0] + expected_first_ev, explained_variance_fit_eff_v161[0] ); } } - println!("Test 'test_pca_fit_consistency_linfa' (comparing with Linfa) passed with aliased ndarray_v15 and type bridging!"); + println!( + "Test 'test_pca_fit_consistency_linfa' (comparing with Linfa) passed with aliased ndarray_v15 and type bridging!" + ); Ok(()) } @@ -2014,11 +2130,21 @@ mod pca_tests { #[test] fn test_pca_5x7() { let input = array![ - [0.5855288, -1.8179560, -0.1162478, 0.8168998, 0.7796219, 1.8050975, 0.8118732], - [0.7094660, 0.6300986, 1.8173120, -0.8863575, 1.4557851, -0.4816474, 2.1968335], - [-0.1093033, -0.2761841, 0.3706279, -0.3315776, -0.6443284, 0.6203798, 2.0491903], - [-0.4534972, -0.2841597, 0.5202165, 1.1207127, -1.5531374, 0.6121235, 1.6324456], - [0.6058875, -0.9193220, -0.7505320, 0.2987237, -1.5977095, -0.1623110, 0.2542712] + [ + 0.5855288, -1.8179560, -0.1162478, 0.8168998, 0.7796219, 1.8050975, 0.8118732 + ], + [ + 0.7094660, 0.6300986, 1.8173120, -0.8863575, 1.4557851, -0.4816474, 2.1968335 + ], + [ + -0.1093033, -0.2761841, 0.3706279, -0.3315776, -0.6443284, 0.6203798, 2.0491903 + ], + [ + -0.4534972, -0.2841597, 0.5202165, 1.1207127, -1.5531374, 0.6121235, 1.6324456 + ], + [ + 0.6058875, -0.9193220, -0.7505320, 0.2987237, -1.5977095, -0.1623110, 0.2542712 + ] ]; run_python_pca_test(&input, 7, false, 0, None, 1e-6, "test_pca_5x7"); } @@ -2026,11 +2152,21 @@ mod pca_tests { #[test] fn test_rpca_5x7_k4() { let input = array![ - [0.5855288, -1.8179560, -0.1162478, 0.8168998, 0.7796219, 1.8050975, 0.8118732], - [0.7094660, 0.6300986, 1.8173120, -0.8863575, 1.4557851, -0.4816474, 2.1968335], - [-0.1093033, -0.2761841, 0.3706279, -0.3315776, -0.6443284, 0.6203798, 2.0491903], - [-0.4534972, -0.2841597, 0.5202165, 1.1207127, -1.5531374, 0.6121235, 1.6324456], - [0.6058875, -0.9193220, -0.7505320, 0.2987237, -1.5977095, -0.1623110, 0.2542712] + [ + 0.5855288, -1.8179560, -0.1162478, 0.8168998, 0.7796219, 1.8050975, 0.8118732 + ], + [ + 0.7094660, 0.6300986, 1.8173120, -0.8863575, 1.4557851, -0.4816474, 2.1968335 + ], + [ + -0.1093033, -0.2761841, 0.3706279, -0.3315776, -0.6443284, 0.6203798, 2.0491903 + ], + [ + -0.4534972, -0.2841597, 0.5202165, 1.1207127, -1.5531374, 0.6121235, 1.6324456 + ], + [ + 0.6058875, -0.9193220, -0.7505320, 0.2987237, -1.5977095, -0.1623110, 0.2542712 + ] ]; run_python_pca_test(&input, 4, true, 0, Some(1926), 1e-6, "test_rpca_5x7_k4"); }