diff --git a/Cargo.toml b/Cargo.toml index 7238fdd..5d4f4e8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,7 +31,7 @@ faer = { version = "0.22.6", optional = true } float-cmp = "0.10.0" log = "0.4.27" ndarray = { version = "0.16.1", features = ["serde", "rayon"] } -ndarray-linalg = { version = "0.17.0", default-features = false } +ndarray-linalg = { version = "0.17.0", default-features = false, optional = true } ndarray-rand = "0.15.0" rand = "0.8.5" rand_chacha = "0.3.1" @@ -49,7 +49,6 @@ path = "src/lib.rs" [dev-dependencies] dirs = "6.0.0" -reqwest = { version = "0.12.15", features = ["blocking"] } tar = "0.4.44" zstd = "0.13.3" criterion = { version = "0.6.0", features = ["html_reports"] } @@ -63,13 +62,13 @@ lazy_static = "1.5.0" serde_json = "1.0" [features] -default = ["backend_openblas"] -backend_openblas = ["ndarray-linalg/openblas-static"] -backend_openblas_system = ["ndarray-linalg/openblas-system"] -backend_mkl = ["ndarray-linalg/intel-mkl-static"] -backend_mkl_system = ["ndarray-linalg/intel-mkl-system"] +default = ["backend_faer"] +backend_openblas = ["dep:ndarray-linalg", "ndarray-linalg/openblas-static"] +backend_openblas_system = ["dep:ndarray-linalg", "ndarray-linalg/openblas-system"] +backend_mkl = ["dep:ndarray-linalg", "ndarray-linalg/intel-mkl-static"] +backend_mkl_system = ["dep:ndarray-linalg", "ndarray-linalg/intel-mkl-system"] backend_faer = ["dep:faer"] -faer_links_ndarray_static_openblas = ["backend_faer", "ndarray-linalg/openblas-static"] +faer_links_ndarray_static_openblas = ["backend_faer", "backend_openblas"] # --- Utility Features --- jemalloc = ["dep:jemalloc-ctl"] diff --git a/examples/test_backends.rs b/examples/test_backends.rs index 75afa38..3060363 100644 --- a/examples/test_backends.rs +++ b/examples/test_backends.rs @@ -4,12 +4,15 @@ use ndarray::Array2; fn main() { // Create a simple test matrix let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(); - + // Create and fit PCA let mut pca = PCA::new(); pca.fit(data, None).expect("PCA fit failed"); - + println!("PCA backend test works!"); println!("Rotation matrix shape: {:?}", pca.rotation().unwrap().dim()); - println!("Explained variance: {:?}", pca.explained_variance().unwrap()); + println!( + "Explained variance: {:?}", + pca.explained_variance().unwrap() + ); } diff --git a/src/eigensnp.rs b/src/eigensnp.rs index 9311c73..ef3bb45 100644 --- a/src/eigensnp.rs +++ b/src/eigensnp.rs @@ -841,8 +841,7 @@ impl EigenSNPCoreAlgorithm { writeln!(writer)?; // Zip the metadata with the rows of the loadings matrix - for (snp_info, loadings_row) in - block_snp_metadata.iter().zip(local_pcs.rows()) + for (snp_info, loadings_row) in block_snp_metadata.iter().zip(local_pcs.rows()) { write!( writer, diff --git a/src/linalg_backends.rs b/src/linalg_backends.rs index 847e864..3b06c66 100644 --- a/src/linalg_backends.rs +++ b/src/linalg_backends.rs @@ -14,7 +14,15 @@ impl LinAlgBackendProvider { } // --- Common imports needed by multiple sections --- -use ndarray::{s, Array1, Array2}; +#[cfg(any( + feature = "backend_openblas", + feature = "backend_openblas_system", + feature = "backend_mkl", + feature = "backend_mkl_system", + feature = "faer_links_ndarray_static_openblas" +))] +use ndarray::s; +use ndarray::{Array1, Array2}; // use num_traits::Float; // No longer needed directly by provider use std::error::Error; use std::marker::PhantomData; @@ -33,7 +41,7 @@ pub struct EighOutput { /// Trait for symmetric eigendecomposition (similar to LAPACK's DSYEVR or DSYEVD). /// Implementers will typically expect `matrix` to be symmetric. -pub trait BackendEigh { +pub trait BackendEigh { fn eigh_upper(&self, matrix: &Array2) -> Result, Box>; } @@ -52,7 +60,7 @@ pub struct SVDOutput { } /// Trait for Singular Value Decomposition. -pub trait BackendSVD { +pub trait BackendSVD { fn svd_into( &self, matrix: Array2, @@ -62,105 +70,122 @@ pub trait BackendSVD { } // --- NdarrayLinAlgBackend Implementation (originally from ndarray_backend.rs) --- -// Specific imports for ndarray-linalg backend -use ndarray_linalg::{Eigh, Lapack, SVDInto, QR, UPLO}; -// use num_traits::AsPrimitive; // Removed as not directly used by trait impls - -// Define a concrete type for ndarray-linalg backend -#[derive(Debug, Default, Copy, Clone)] -pub struct NdarrayLinAlgBackend; +#[cfg(any( + feature = "backend_openblas", + feature = "backend_openblas_system", + feature = "backend_mkl", + feature = "backend_mkl_system", + feature = "faer_links_ndarray_static_openblas" +))] +mod ndarray_backend_impl { + use super::{s, Array2, BackendEigh, BackendQR, BackendSVD, EighOutput, SVDOutput}; + use ndarray_linalg::{Eigh, Lapack, SVDInto, QR, UPLO}; + use std::error::Error; -// Helper to convert ndarray-linalg's error to Box -fn to_dyn_error(e: E) -> Box { - Box::new(e) -} + #[derive(Debug, Default, Copy, Clone)] + pub struct NdarrayLinAlgBackend; -// Single impl block handles f32, f64, complex if you need it -impl BackendEigh for NdarrayLinAlgBackend -where - F: Lapack + 'static + Copy + Send + Sync, -{ - fn eigh_upper( - &self, - matrix: &Array2, - ) -> Result, Box> { - // Use direct Eigh call - let (eigvals, eigvecs) = matrix.eigh(UPLO::Upper).map_err(to_dyn_error)?; - Ok(EighOutput { - eigenvalues: eigvals, - eigenvectors: eigvecs, - }) + fn to_dyn_error(e: E) -> Box { + Box::new(e) } -} -impl BackendQR for NdarrayLinAlgBackend -where - F: Lapack + 'static + Copy + Send + Sync, -{ - fn qr_q_factor(&self, matrix: &Array2) -> Result, Box> { - let (nrows, ncols) = matrix.dim(); - if nrows == 0 { - return Ok(Array2::zeros((0, 0))); + impl BackendEigh for NdarrayLinAlgBackend + where + F: Lapack + 'static + Copy + Send + Sync, + { + fn eigh_upper( + &self, + matrix: &Array2, + ) -> Result, Box> { + let (eigvals, eigvecs) = matrix.eigh(UPLO::Upper).map_err(to_dyn_error)?; + Ok(EighOutput { + eigenvalues: eigvals, + eigenvectors: eigvecs, + }) } - let k = nrows.min(ncols); // Re-introduce k - // Use direct QR call - let (q_full, _) = matrix.qr().map_err(to_dyn_error)?; - Ok(q_full.slice_move(s![.., 0..k])) } -} -impl BackendSVD for NdarrayLinAlgBackend -where - F: Lapack + 'static + Copy + Send + Sync, -{ - fn svd_into( - &self, - matrix: Array2, - compute_u: bool, - compute_v: bool, - ) -> Result, Box> { - let original_rows = matrix.nrows(); - let original_cols = matrix.ncols(); + impl BackendQR for NdarrayLinAlgBackend + where + F: Lapack + 'static + Copy + Send + Sync, + { + fn qr_q_factor( + &self, + matrix: &Array2, + ) -> Result, Box> { + let (nrows, ncols) = matrix.dim(); + if nrows == 0 { + return Ok(Array2::zeros((0, 0))); + } + let k = nrows.min(ncols); + let (q_full, _) = matrix.qr().map_err(to_dyn_error)?; + Ok(q_full.slice_move(s![.., 0..k])) + } + } - // Use direct SVDInto call - let (u_option, s, vt_option) = matrix - .svd_into(compute_u, compute_v) - .map_err(to_dyn_error)?; + impl BackendSVD for NdarrayLinAlgBackend + where + F: Lapack + 'static + Copy + Send + Sync, + { + fn svd_into( + &self, + matrix: Array2, + compute_u: bool, + compute_v: bool, + ) -> Result, Box> { + let original_rows = matrix.nrows(); + let original_cols = matrix.ncols(); + + let (u_option, s, vt_option) = matrix + .svd_into(compute_u, compute_v) + .map_err(to_dyn_error)?; + + let k_effective = s.len(); + + let u_final = if let Some(mut u_mat) = u_option { + if u_mat.ncols() > k_effective { + assert_eq!(u_mat.nrows(), original_rows, "U matrix row count mismatch"); + u_mat = u_mat.slice_move(s![.., 0..k_effective]); + } + Some(u_mat) + } else { + None + }; - let k_effective = s.len(); + let vt_final = if let Some(mut vt_mat) = vt_option { + if vt_mat.nrows() > k_effective { + assert_eq!( + vt_mat.ncols(), + original_cols, + "VT matrix column count mismatch", + ); + vt_mat = vt_mat.slice_move(s![0..k_effective, ..]); + } + Some(vt_mat) + } else { + None + }; - let u_final = if let Some(mut u_mat) = u_option { - if u_mat.ncols() > k_effective { - assert_eq!(u_mat.nrows(), original_rows, "U matrix row count mismatch"); - u_mat = u_mat.slice_move(s![.., 0..k_effective]); - } - Some(u_mat) - } else { - None - }; - - let vt_final = if let Some(mut vt_mat) = vt_option { - if vt_mat.nrows() > k_effective { - assert_eq!( - vt_mat.ncols(), - original_cols, - "VT matrix column count mismatch" - ); - vt_mat = vt_mat.slice_move(s![0..k_effective, ..]); - } - Some(vt_mat) - } else { - None - }; - - Ok(SVDOutput { - u: u_final, - s, - vt: vt_final, - }) + Ok(SVDOutput { + u: u_final, + s, + vt: vt_final, + }) + } } + + pub use NdarrayLinAlgBackend as Backend; } +#[cfg(any( + feature = "backend_openblas", + feature = "backend_openblas_system", + feature = "backend_mkl", + feature = "backend_mkl_system", + feature = "faer_links_ndarray_static_openblas" +))] +use ndarray_backend_impl::Backend as NdarrayLinAlgBackend; + // --- FaerLinAlgBackend Implementation (originally from faer_backend.rs) --- #[cfg(feature = "backend_faer")] mod faer_specific_code { diff --git a/src/pca.rs b/src/pca.rs index 3b8afa0..2d6a285 100644 --- a/src/pca.rs +++ b/src/pca.rs @@ -3,7 +3,9 @@ #![doc = include_str!("../README.md")] use ndarray::parallel::prelude::*; -use ndarray::{s, Array1, Array2, ArrayViewMut1, Axis, ShapeBuilder}; +#[cfg(feature = "backend_faer")] +use ndarray::ShapeBuilder; +use ndarray::{s, Array1, Array2, ArrayViewMut1, Axis}; // UPLO is no longer needed as the backend's eigh_upper handles this. // QR trait for .qr() and SVDInto for .svd_into() are replaced by backend calls. // Eigh trait for .eigh() is replaced by backend calls. diff --git a/tests/eigensnp_diagnostics_tests.rs b/tests/eigensnp_diagnostics_tests.rs index 9214157..c6294f6 100644 --- a/tests/eigensnp_diagnostics_tests.rs +++ b/tests/eigensnp_diagnostics_tests.rs @@ -163,7 +163,14 @@ fn run_diagnostic_test_with_params( }; let algorithm = EigenSNPCoreAlgorithm::new(config.clone()); - let snp_metadata: Vec = (0..mock_data_accessor.num_pca_snps()).map(|i| efficient_pca::eigensnp::PcaSnpMetadata { id: std::sync::Arc::new(format!("snp_{}", i)), chr: std::sync::Arc::new("chr1".to_string()), pos: i as u64 * 1000 + 100000 }).collect(); + let snp_metadata: Vec = (0..mock_data_accessor + .num_pca_snps()) + .map(|i| efficient_pca::eigensnp::PcaSnpMetadata { + id: std::sync::Arc::new(format!("snp_{}", i)), + chr: std::sync::Arc::new("chr1".to_string()), + pos: i as u64 * 1000 + 100000, + }) + .collect(); let pca_result_tuple = algorithm .compute_pca(&mock_data_accessor, &ld_block_specs, &snp_metadata) .map_err(|e| e as Box)?; diff --git a/tests/eigensnp_tests.rs b/tests/eigensnp_tests.rs index edb5b3e..e0d1bf3 100644 --- a/tests/eigensnp_tests.rs +++ b/tests/eigensnp_tests.rs @@ -7,10 +7,15 @@ // exceeds the number of samples. Small test cases or cases where samples >= features // have been deemphasized or removed to better reflect real-world usage scenarios. +#[path = "python_bootstrap.rs"] +mod python_bootstrap; + +use python_bootstrap::ensure_python_packages_installed; + use efficient_pca::eigensnp::{ reorder_array_owned, reorder_columns_owned, EigenSNPCoreAlgorithm, EigenSNPCoreAlgorithmConfig, - EigenSNPCoreOutput, LdBlockSpecification, PcaReadyGenotypeAccessor, PcaSnpId, PcaSnpMetadata, QcSampleId, - ThreadSafeStdError, + EigenSNPCoreOutput, LdBlockSpecification, PcaReadyGenotypeAccessor, PcaSnpId, PcaSnpMetadata, + QcSampleId, ThreadSafeStdError, }; use ndarray::{arr2, s, Array1, Array2, ArrayView1, ArrayView2, Axis}; // ArrayView2 was already added, Array removed use ndarray_rand::rand_distr::{Normal, StandardNormal, Uniform}; // Added Normal, StandardNormal @@ -298,6 +303,7 @@ mod eigensnp_integration_tests { let mut script_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); script_path.push("tests/pca.py"); + ensure_python_packages_installed(); let mut process = Command::new("python3") .arg(script_path.to_str().ok_or("Invalid script path")?) .arg("--generate-reference-pca") @@ -1417,7 +1423,7 @@ mod eigensnp_integration_tests { pca_snp_ids_in_block: (0..num_total_snps).map(PcaSnpId).collect(), }]; - let snp_metadata = create_dummy_snp_metadata(num_total_snps); + let snp_metadata = create_dummy_snp_metadata(num_total_snps); let rust_output_result_tuple = algorithm.compute_pca(&test_data, &ld_blocks, &snp_metadata); let rust_output = match rust_output_result_tuple { @@ -1844,7 +1850,7 @@ pub fn run_pc_correlation_with_truth_set_test( }]; let mut rust_pcs_computed = 0; - let snp_metadata = create_dummy_snp_metadata(num_snps); + let snp_metadata = create_dummy_snp_metadata(num_snps); match algorithm.compute_pca(&test_data_accessor, &ld_blocks, &snp_metadata) { Ok((rust_result, _)) => { rust_pcs_computed = rust_result.num_principal_components_computed; @@ -2089,7 +2095,7 @@ fn test_pc_correlation_structured_1000snps_200samples_5truepcs() { }]; let mut rust_pcs_computed = 0; - let snp_metadata = create_dummy_snp_metadata(num_snps); + let snp_metadata = create_dummy_snp_metadata(num_snps); match algorithm.compute_pca(&test_data_accessor, &ld_blocks, &snp_metadata) { Ok((rust_result, _)) => { rust_pcs_computed = rust_result.num_principal_components_computed; @@ -2306,7 +2312,7 @@ pub fn run_generic_large_matrix_test( let mut rust_pcs_computed = 0; - let snp_metadata = create_dummy_snp_metadata(num_snps); + let snp_metadata = create_dummy_snp_metadata(num_snps); match algorithm.compute_pca(&test_data_accessor, &ld_blocks, &snp_metadata) { Ok((output, _)) => { rust_pcs_computed = output.num_principal_components_computed; @@ -2494,7 +2500,7 @@ pub fn run_sample_projection_accuracy_test( let mut rust_pca_output_option: Option = None; // Now directly in scope let mut k_eff_rust = 0; - let snp_metadata = create_dummy_snp_metadata(num_snps); + let snp_metadata = create_dummy_snp_metadata(num_snps); match algorithm_train.compute_pca(&test_data_accessor_train, &ld_blocks_train, &snp_metadata) { Ok((output_struct, _)) => { k_eff_rust = output_struct.num_principal_components_computed; @@ -2893,9 +2899,10 @@ where let algorithm_b = EigenSNPCoreAlgorithm::new(config_b); let snp_metadata = create_dummy_snp_metadata(standardized_structured_data.nrows()); - + // Run EigenSnp A - let output_a = match algorithm_a.compute_pca(&test_data_accessor, ld_block_specs, &snp_metadata) { + let output_a = match algorithm_a.compute_pca(&test_data_accessor, ld_block_specs, &snp_metadata) + { Ok((out, _)) => { writeln!( outcome_details, @@ -2944,7 +2951,8 @@ where }; // Run EigenSnp B - let output_b = match algorithm_b.compute_pca(&test_data_accessor, ld_block_specs, &snp_metadata) { + let output_b = match algorithm_b.compute_pca(&test_data_accessor, ld_block_specs, &snp_metadata) + { Ok((out, _)) => { writeln!( outcome_details, @@ -3783,7 +3791,11 @@ fn test_refinement_projection_accuracy() { let algorithm = EigenSNPCoreAlgorithm::new(config); let snp_metadata = create_dummy_snp_metadata(d_total_snps); - match algorithm.compute_pca(&test_data_accessor_train, &ld_block_specs_train, &snp_metadata) { + match algorithm.compute_pca( + &test_data_accessor_train, + &ld_block_specs_train, + &snp_metadata, + ) { Ok((eigensnp_train_output_struct, _)) => { save_matrix_to_tsv( &eigensnp_train_output_struct diff --git a/tests/pca_tests.rs b/tests/pca_tests.rs index 1e700c1..61e03d4 100644 --- a/tests/pca_tests.rs +++ b/tests/pca_tests.rs @@ -1,3 +1,7 @@ +#[path = "python_bootstrap.rs"] +mod python_bootstrap; + +use python_bootstrap::ensure_python_packages_installed; // For the crate's PCA use efficient_pca::PCA; @@ -8,7 +12,6 @@ use ndarray::{array, Array2, Axis}; use linfa::dataset::DatasetBase; use linfa::prelude::*; use linfa_reduction::Pca as LinfaPcaModel; // The PCA implementation from Linfa, aliased -use ndarray_linalg::{Eigh, QR, UPLO}; use rand::Rng; use rand::SeedableRng; @@ -21,6 +24,79 @@ fn generate_random_data(n_samples: usize, n_features: usize, seed: u64) -> Array }) } +fn orthonormalize_columns(matrix: &Array2) -> Array2 { + let (rows, cols) = matrix.dim(); + let mut q = Array2::::zeros((rows, cols)); + + for j in 0..cols { + let mut v = matrix.column(j).to_owned(); + for k in 0..j { + let projection = v.dot(&q.column(k)); + for i in 0..rows { + v[i] -= projection * q[[i, k]]; + } + } + + let norm_squared: f64 = v.iter().map(|&val| val * val).sum(); + if norm_squared > 1e-24 { + let norm = norm_squared.sqrt(); + for i in 0..rows { + q[[i, j]] = v[i] / norm; + } + } else { + for i in 0..rows { + q[[i, j]] = 0.0; + } + } + } + + q +} + +#[cfg(feature = "backend_faer")] +fn eigenvalues_descending(matrix: &Array2) -> Vec { + use faer::linalg::solvers::SelfAdjointEigen; + use faer::{MatRef, Side}; + + let (nrows, ncols) = matrix.dim(); + assert_eq!(nrows, ncols, "covariance matrix must be square"); + if nrows == 0 { + return Vec::new(); + } + + let eigen = if let Some(slice) = matrix.as_slice_memory_order() { + let mat_ref = if matrix.is_standard_layout() { + MatRef::from_row_major_slice(slice, nrows, ncols) + } else { + MatRef::from_column_major_slice(slice, nrows, ncols) + }; + SelfAdjointEigen::new(mat_ref, Side::Upper) + .expect("faer self-adjoint eigen decomposition failed") + } else { + let owned = matrix.to_owned(); + let slice = owned + .as_slice_memory_order() + .expect("owned covariance copy should be contiguous"); + let mat_ref = MatRef::from_row_major_slice(slice, nrows, ncols); + SelfAdjointEigen::new(mat_ref, Side::Upper) + .expect("faer self-adjoint eigen decomposition failed") + }; + + let mut values: Vec = eigen.S().column_vector().iter().copied().collect(); + values.sort_by(|a, b| b.partial_cmp(a).expect("eigenvalues must be comparable")); + values +} + +#[cfg(not(feature = "backend_faer"))] +fn eigenvalues_descending(matrix: &Array2) -> Vec { + use ndarray_linalg::{Eigh, UPLO}; + + let (mut eigenvalues, _) = matrix.eigh(UPLO::Upper).expect("eigendecomposition failed"); + let mut values = eigenvalues.to_vec(); + values.sort_by(|a, b| b.partial_cmp(a).expect("eigenvalues must be comparable")); + values +} + #[cfg(test)] mod genome_tests { use super::*; @@ -409,7 +485,7 @@ mod genome_tests { let random_basis = Array2::::from_shape_fn((n_samples, n_real_components), |_| { rng.gen_range(-1.0..1.0) }); - let (q, _) = random_basis.qr().unwrap(); + let q = super::orthonormalize_columns(&random_basis); // Make sure we have orthogonal unit vectors for true factors let true_factors = q.slice(s![.., 0..n_real_components]).to_owned(); @@ -471,6 +547,7 @@ mod genome_tests { } } + super::ensure_python_packages_installed(); let cmd_output = Command::new("python3") .args(&vec![ "tests/pca.py", @@ -496,11 +573,7 @@ mod genome_tests { } } let cov_matrix = centered_data.dot(¢ered_data.t()) / (n_samples as f64 - 1.0); - let (mut eigenvalues, _) = cov_matrix.eigh(UPLO::Upper).unwrap(); - eigenvalues - .as_slice_mut() - .unwrap() - .sort_by(|a, b| b.partial_cmp(a).unwrap()); + let eigenvalues = super::eigenvalues_descending(&cov_matrix); let total_variance: f64 = eigenvalues.iter().take(n_components).sum(); println!("\n[Comparison] Explained Variance:"); @@ -1258,6 +1331,7 @@ mod pca_tests { tol: f64, test_name: &str, ) { + super::ensure_python_packages_installed(); fn parse_transformed_csv_from_python(output_text: &str) -> ndarray::Array2 { println!("[Rust Debug] Entering parse_transformed_csv_from_python..."); let mut lines = Vec::new(); diff --git a/tests/python_bootstrap.rs b/tests/python_bootstrap.rs new file mode 100644 index 0000000..d7d9053 --- /dev/null +++ b/tests/python_bootstrap.rs @@ -0,0 +1,51 @@ +use std::process::Command; +use std::sync::Once; + +static PY_DEPS_INIT: Once = Once::new(); + +pub fn ensure_python_packages_installed() { + PY_DEPS_INIT.call_once(|| { + let packages = [ + ("numpy", "numpy"), + ("scipy", "scipy"), + ("scikit-learn", "sklearn"), + ]; + + let mut missing = Vec::new(); + for (package_name, module_name) in packages.iter() { + let status = Command::new("python3") + .args(["-c", &format!("import {}", module_name)]) + .status() + .expect("failed to invoke python3 to probe optional modules"); + if !status.success() { + missing.push(*package_name); + } + } + + if missing.is_empty() { + return; + } + + println!( + "Installing missing Python packages required for reference PCA: {:?}", + missing + ); + + let mut cmd = Command::new("python3"); + cmd.args(["-m", "pip", "install", "--user"]); + cmd.args(&missing); + + let output = cmd + .output() + .expect("failed to invoke pip to install python dependencies"); + + if !output.status.success() { + panic!( + "Failed to install required python packages. Status: {:?}\nSTDOUT:\n{}\nSTDERR:\n{}", + output.status, + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + } + }); +}