Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ faer = { version = "0.22.6", optional = true }
float-cmp = "0.10.0"
log = "0.4.27"
ndarray = { version = "0.16.1", features = ["serde", "rayon"] }
ndarray-linalg = { version = "0.17.0", default-features = false }
ndarray-linalg = { version = "0.17.0", default-features = false, optional = true }
ndarray-rand = "0.15.0"
rand = "0.8.5"
rand_chacha = "0.3.1"
Expand All @@ -49,7 +49,6 @@ path = "src/lib.rs"

[dev-dependencies]
dirs = "6.0.0"
reqwest = { version = "0.12.15", features = ["blocking"] }
tar = "0.4.44"
zstd = "0.13.3"
criterion = { version = "0.6.0", features = ["html_reports"] }
Expand All @@ -63,13 +62,13 @@ lazy_static = "1.5.0"
serde_json = "1.0"

[features]
default = ["backend_openblas"]
backend_openblas = ["ndarray-linalg/openblas-static"]
backend_openblas_system = ["ndarray-linalg/openblas-system"]
backend_mkl = ["ndarray-linalg/intel-mkl-static"]
backend_mkl_system = ["ndarray-linalg/intel-mkl-system"]
default = ["backend_faer"]
backend_openblas = ["dep:ndarray-linalg", "ndarray-linalg/openblas-static"]
backend_openblas_system = ["dep:ndarray-linalg", "ndarray-linalg/openblas-system"]
backend_mkl = ["dep:ndarray-linalg", "ndarray-linalg/intel-mkl-static"]
backend_mkl_system = ["dep:ndarray-linalg", "ndarray-linalg/intel-mkl-system"]
backend_faer = ["dep:faer"]
faer_links_ndarray_static_openblas = ["backend_faer", "ndarray-linalg/openblas-static"]
faer_links_ndarray_static_openblas = ["backend_faer", "backend_openblas"]

# --- Utility Features ---
jemalloc = ["dep:jemalloc-ctl"]
Expand Down
9 changes: 6 additions & 3 deletions examples/test_backends.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@ use ndarray::Array2;
fn main() {
// Create a simple test matrix
let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();

// Create and fit PCA
let mut pca = PCA::new();
pca.fit(data, None).expect("PCA fit failed");

println!("PCA backend test works!");
println!("Rotation matrix shape: {:?}", pca.rotation().unwrap().dim());
println!("Explained variance: {:?}", pca.explained_variance().unwrap());
println!(
"Explained variance: {:?}",
pca.explained_variance().unwrap()
);
}
3 changes: 1 addition & 2 deletions src/eigensnp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -841,8 +841,7 @@ impl EigenSNPCoreAlgorithm {
writeln!(writer)?;

// Zip the metadata with the rows of the loadings matrix
for (snp_info, loadings_row) in
block_snp_metadata.iter().zip(local_pcs.rows())
for (snp_info, loadings_row) in block_snp_metadata.iter().zip(local_pcs.rows())
{
write!(
writer,
Expand Down
203 changes: 114 additions & 89 deletions src/linalg_backends.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@ impl<F: 'static + Copy + Send + Sync> LinAlgBackendProvider<F> {
}

// --- Common imports needed by multiple sections ---
use ndarray::{s, Array1, Array2};
#[cfg(any(
feature = "backend_openblas",
feature = "backend_openblas_system",
feature = "backend_mkl",
feature = "backend_mkl_system",
feature = "faer_links_ndarray_static_openblas"
))]
use ndarray::s;
use ndarray::{Array1, Array2};
// use num_traits::Float; // No longer needed directly by provider
use std::error::Error;
use std::marker::PhantomData;
Expand All @@ -33,7 +41,7 @@ pub struct EighOutput<F: 'static> {

/// Trait for symmetric eigendecomposition (similar to LAPACK's DSYEVR or DSYEVD).
/// Implementers will typically expect `matrix` to be symmetric.
pub trait BackendEigh<F: Lapack + 'static + Copy + Send + Sync> {
pub trait BackendEigh<F: 'static + Copy + Send + Sync> {
fn eigh_upper(&self, matrix: &Array2<F>)
-> Result<EighOutput<F>, Box<dyn Error + Send + Sync>>;
}
Expand All @@ -52,7 +60,7 @@ pub struct SVDOutput<F: 'static> {
}

/// Trait for Singular Value Decomposition.
pub trait BackendSVD<F: Lapack + 'static + Copy + Send + Sync> {
pub trait BackendSVD<F: 'static + Copy + Send + Sync> {
fn svd_into(
&self,
matrix: Array2<F>,
Expand All @@ -62,105 +70,122 @@ pub trait BackendSVD<F: Lapack + 'static + Copy + Send + Sync> {
}

// --- NdarrayLinAlgBackend Implementation (originally from ndarray_backend.rs) ---
// Specific imports for ndarray-linalg backend
use ndarray_linalg::{Eigh, Lapack, SVDInto, QR, UPLO};
// use num_traits::AsPrimitive; // Removed as not directly used by trait impls

// Define a concrete type for ndarray-linalg backend
#[derive(Debug, Default, Copy, Clone)]
pub struct NdarrayLinAlgBackend;
#[cfg(any(
feature = "backend_openblas",
feature = "backend_openblas_system",
feature = "backend_mkl",
feature = "backend_mkl_system",
feature = "faer_links_ndarray_static_openblas"
))]
mod ndarray_backend_impl {
use super::{s, Array2, BackendEigh, BackendQR, BackendSVD, EighOutput, SVDOutput};
use ndarray_linalg::{Eigh, Lapack, SVDInto, QR, UPLO};
use std::error::Error;

// Helper to convert ndarray-linalg's error to Box<dyn Error + Send + Sync>
fn to_dyn_error<E: Error + Send + Sync + 'static>(e: E) -> Box<dyn Error + Send + Sync> {
Box::new(e)
}
#[derive(Debug, Default, Copy, Clone)]
pub struct NdarrayLinAlgBackend;

// Single impl block handles f32, f64, complex if you need it
impl<F> BackendEigh<F> for NdarrayLinAlgBackend
where
F: Lapack<Real = F> + 'static + Copy + Send + Sync,
{
fn eigh_upper(
&self,
matrix: &Array2<F>,
) -> Result<EighOutput<F>, Box<dyn Error + Send + Sync>> {
// Use direct Eigh call
let (eigvals, eigvecs) = matrix.eigh(UPLO::Upper).map_err(to_dyn_error)?;
Ok(EighOutput {
eigenvalues: eigvals,
eigenvectors: eigvecs,
})
fn to_dyn_error<E: Error + Send + Sync + 'static>(e: E) -> Box<dyn Error + Send + Sync> {
Box::new(e)
}
}

impl<F> BackendQR<F> for NdarrayLinAlgBackend
where
F: Lapack + 'static + Copy + Send + Sync,
{
fn qr_q_factor(&self, matrix: &Array2<F>) -> Result<Array2<F>, Box<dyn Error + Send + Sync>> {
let (nrows, ncols) = matrix.dim();
if nrows == 0 {
return Ok(Array2::zeros((0, 0)));
impl<F> BackendEigh<F> for NdarrayLinAlgBackend
where
F: Lapack<Real = F> + 'static + Copy + Send + Sync,
{
fn eigh_upper(
&self,
matrix: &Array2<F>,
) -> Result<EighOutput<F>, Box<dyn Error + Send + Sync>> {
let (eigvals, eigvecs) = matrix.eigh(UPLO::Upper).map_err(to_dyn_error)?;
Ok(EighOutput {
eigenvalues: eigvals,
eigenvectors: eigvecs,
})
}
let k = nrows.min(ncols); // Re-introduce k
// Use direct QR call
let (q_full, _) = matrix.qr().map_err(to_dyn_error)?;
Ok(q_full.slice_move(s![.., 0..k]))
}
}

impl<F> BackendSVD<F> for NdarrayLinAlgBackend
where
F: Lapack<Real = F> + 'static + Copy + Send + Sync,
{
fn svd_into(
&self,
matrix: Array2<F>,
compute_u: bool,
compute_v: bool,
) -> Result<SVDOutput<F>, Box<dyn Error + Send + Sync>> {
let original_rows = matrix.nrows();
let original_cols = matrix.ncols();
impl<F> BackendQR<F> for NdarrayLinAlgBackend
where
F: Lapack + 'static + Copy + Send + Sync,
{
fn qr_q_factor(
&self,
matrix: &Array2<F>,
) -> Result<Array2<F>, Box<dyn Error + Send + Sync>> {
let (nrows, ncols) = matrix.dim();
if nrows == 0 {
return Ok(Array2::zeros((0, 0)));
}
let k = nrows.min(ncols);
let (q_full, _) = matrix.qr().map_err(to_dyn_error)?;
Ok(q_full.slice_move(s![.., 0..k]))
}
}

// Use direct SVDInto call
let (u_option, s, vt_option) = matrix
.svd_into(compute_u, compute_v)
.map_err(to_dyn_error)?;
impl<F> BackendSVD<F> for NdarrayLinAlgBackend
where
F: Lapack<Real = F> + 'static + Copy + Send + Sync,
{
fn svd_into(
&self,
matrix: Array2<F>,
compute_u: bool,
compute_v: bool,
) -> Result<SVDOutput<F>, Box<dyn Error + Send + Sync>> {
let original_rows = matrix.nrows();
let original_cols = matrix.ncols();

let (u_option, s, vt_option) = matrix
.svd_into(compute_u, compute_v)
.map_err(to_dyn_error)?;

let k_effective = s.len();

let u_final = if let Some(mut u_mat) = u_option {
if u_mat.ncols() > k_effective {
assert_eq!(u_mat.nrows(), original_rows, "U matrix row count mismatch");
u_mat = u_mat.slice_move(s![.., 0..k_effective]);
}
Some(u_mat)
} else {
None
};

let k_effective = s.len();
let vt_final = if let Some(mut vt_mat) = vt_option {
if vt_mat.nrows() > k_effective {
assert_eq!(
vt_mat.ncols(),
original_cols,
"VT matrix column count mismatch",
);
vt_mat = vt_mat.slice_move(s![0..k_effective, ..]);
}
Some(vt_mat)
} else {
None
};

let u_final = if let Some(mut u_mat) = u_option {
if u_mat.ncols() > k_effective {
assert_eq!(u_mat.nrows(), original_rows, "U matrix row count mismatch");
u_mat = u_mat.slice_move(s![.., 0..k_effective]);
}
Some(u_mat)
} else {
None
};

let vt_final = if let Some(mut vt_mat) = vt_option {
if vt_mat.nrows() > k_effective {
assert_eq!(
vt_mat.ncols(),
original_cols,
"VT matrix column count mismatch"
);
vt_mat = vt_mat.slice_move(s![0..k_effective, ..]);
}
Some(vt_mat)
} else {
None
};

Ok(SVDOutput {
u: u_final,
s,
vt: vt_final,
})
Ok(SVDOutput {
u: u_final,
s,
vt: vt_final,
})
}
}

pub use NdarrayLinAlgBackend as Backend;
}

#[cfg(any(
feature = "backend_openblas",
feature = "backend_openblas_system",
feature = "backend_mkl",
feature = "backend_mkl_system",
feature = "faer_links_ndarray_static_openblas"
))]
use ndarray_backend_impl::Backend as NdarrayLinAlgBackend;

// --- FaerLinAlgBackend Implementation (originally from faer_backend.rs) ---
#[cfg(feature = "backend_faer")]
mod faer_specific_code {
Expand Down
4 changes: 3 additions & 1 deletion src/pca.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
#![doc = include_str!("../README.md")]

use ndarray::parallel::prelude::*;
use ndarray::{s, Array1, Array2, ArrayViewMut1, Axis, ShapeBuilder};
#[cfg(feature = "backend_faer")]
use ndarray::ShapeBuilder;
use ndarray::{s, Array1, Array2, ArrayViewMut1, Axis};
// UPLO is no longer needed as the backend's eigh_upper handles this.
// QR trait for .qr() and SVDInto for .svd_into() are replaced by backend calls.
// Eigh trait for .eigh() is replaced by backend calls.
Expand Down
9 changes: 8 additions & 1 deletion tests/eigensnp_diagnostics_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,14 @@ fn run_diagnostic_test_with_params(
};

let algorithm = EigenSNPCoreAlgorithm::new(config.clone());
let snp_metadata: Vec<efficient_pca::eigensnp::PcaSnpMetadata> = (0..mock_data_accessor.num_pca_snps()).map(|i| efficient_pca::eigensnp::PcaSnpMetadata { id: std::sync::Arc::new(format!("snp_{}", i)), chr: std::sync::Arc::new("chr1".to_string()), pos: i as u64 * 1000 + 100000 }).collect();
let snp_metadata: Vec<efficient_pca::eigensnp::PcaSnpMetadata> = (0..mock_data_accessor
.num_pca_snps())
.map(|i| efficient_pca::eigensnp::PcaSnpMetadata {
id: std::sync::Arc::new(format!("snp_{}", i)),
chr: std::sync::Arc::new("chr1".to_string()),
pos: i as u64 * 1000 + 100000,
})
.collect();
let pca_result_tuple = algorithm
.compute_pca(&mock_data_accessor, &ld_block_specs, &snp_metadata)
.map_err(|e| e as Box<dyn std::error::Error>)?;
Expand Down
Loading
Loading