diff --git a/CHANGELOG.md b/CHANGELOG.md index eae1069c..b99276e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,19 +1,54 @@ # CHANGELOG -## v0.8.7 (2025-04-03) +## v0.9.0 (2025-05-23) + +### Step + +- Bumping minor version + ([`e333641`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/e3336417a09b4ef26e71bde1b54da840f0980ab9)) + + +## v0.8.11 (2025-05-23) ### Bug Fixes +- **_ripley**: Fixed conflicts + ([`fa4c06f`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/fa4c06f697ebe95438c3fc583e7767399b72dcf7)) + +- **_ripley**: Removed old call + ([`e89835b`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/e89835b339034d6c543bd4b6231508811828c26d)) + +- **core**: Return hist_data instead of original data + ([`2734421`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/27344216b2d0f1fef43ac0e66fc1613ddfbf9349)) + - **core**: Specify weights for all histplot calls ([`b661495`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/b66149509fa4aa2280d14dfa6e83567c95c87cf8)) - **docstring**: Add returned df to doctstring ([`9caddca`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/9caddca51e69e2213da866d9f74c6abe9ab7c181)) +- **interactive_spatial_plot**: Fixed typos in api name and arguments + ([`b833123`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/b833123dc9f124abf34e539be553a8388048ee1b)) + +- **knn clustering**: Fixed conflict in imports + ([`e6d5f1e`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/e6d5f1eb4668d8e25ed09469525e128abcabbb6a)) + +- **knn_clustering**: Fixed format of the error message + ([`ddf92e6`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/ddf92e63c760c56e035071c0edc2ce245aa3fb7c)) + +- **present_summary_as_figure**: Fixed json conversion when non python types are used + ([`61c8480`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/61c848051cf86f6cf1cff5fe9bf012bb1c12c9d2)) + - **relational_heatmap**: Adjusted the flipped axis labels ([`5d950bb`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/5d950bb0bb3dd2f45e32d4bd7c4aa15f939922ac)) +- **select_values**: Added support when observation are numerical + ([`4bcbaa2`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/4bcbaa249e70c51860ae55356c5b6ab2bf8961bc)) + +- **summarize_dataframe**: Remove duplicated missing index + ([`2c0d907`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/2c0d9077ea9414172090d8c10488f75360430e10)) + - **tests**: Add tests for figure_type and showfliers params ([`3488a59`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/3488a597ead551285f8791d4028fe800a3c822d4)) @@ -29,43 +64,38 @@ - **tests**: Update tests to match new return param ([`da83e3d`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/da83e3dcef1b7ccb7de6e6854613f63bef7e8780)) -### Features - -- **core**: Change histogram/boxplot return types to dicts - ([`bf160c2`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/bf160c24c8fca82ad9057fb2be049b200f4f7139)) - -- **core**: Changed histogram to precompute data. - ([`0201c6f`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/0201c6f153b440c7b93494fc5355dcf1fe446c28)) - -- **core**: Changed how boxplot return type is handled - ([`714bf98`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/714bf98e2acf8b42b39588ec866f87ac38979dd0)) +- **visualize_nearest_neighbor**: Add comments on function and refactor unit tests + ([`fa20694`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/fa206948af9db292a6389ddfedd142ac8563f5f9)) -- **core**: Use plotly figure for static plot instead of png - ([`5738234`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/5738234fc7e16cc8f9c1e421343028976f73ecc2)) +- **visualize_nearest_neighbor**: Rewrite unit tests and move up library in function + ([`4b49fac`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/4b49fac0e56949a2a66a2c7c5a415733ba5e7aef)) +### Build System -## v0.8.6 (2025-03-18) +- Restored docker file to FNLCR organization + ([`739b4b3`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/739b4b3221efec4af2e9228aa8b3260f74a541bf)) -### Bug Fixes +### Code Style -- **_ripley**: Fixed conflicts - ([`fa4c06f`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/fa4c06f697ebe95438c3fc583e7767399b72dcf7)) +- **knn_clustering**: Adjusted to code style standar + ([`2b02538`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/2b02538d4d15a748ce399eb3837dc51eb99108f6)) -- **_ripley**: Removed old call - ([`e89835b`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/e89835b339034d6c543bd4b6231508811828c26d)) +### Continuous Integration -- **interactive_spatial_plot**: Fixed typos in api name and arguments - ([`b833123`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/b833123dc9f124abf34e539be553a8388048ee1b)) +- **version**: Automatic development release + ([`195761d`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/195761de5563e80a60a7ea43ecb73e6105dc7d1d)) -- **present_summary_as_figure**: Fixed json conversion when non python types are used - ([`61c8480`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/61c848051cf86f6cf1cff5fe9bf012bb1c12c9d2)) +- **version**: Automatic development release + ([`f1b0ab2`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/f1b0ab271b53758668f5256bdc63004ae007741a)) -### Build System +- **version**: Automatic development release + ([`70c004e`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/70c004e2079086aa46e38db7803c7c52e6c3356b)) -- Restored docker file to FNLCR organization - ([`739b4b3`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/739b4b3221efec4af2e9228aa8b3260f74a541bf)) +- **version**: Automatic development release + ([`b2921d7`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/b2921d70a7c2d8928f9b48d047a625aebea2d55b)) -### Continuous Integration +- **version**: Automatic development release + ([`1ad51bb`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/1ad51bb329ec95c14de24e276bbf36c7375081b5)) - **version**: Automatic development release ([`627b384`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/627b3846a0d913318c846fc73f11d6141fc6a64e)) @@ -93,12 +123,34 @@ - **_ripley_l_multiple**: Enabled edget correction to remove center cell near border ([`87324fd`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/87324fd1f168df6a809cf94321d10c632d2b9448)) +- **core**: Change histogram/boxplot return types to dicts + ([`bf160c2`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/bf160c24c8fca82ad9057fb2be049b200f4f7139)) + +- **core**: Changed histogram to precompute data. + ([`0201c6f`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/0201c6f153b440c7b93494fc5355dcf1fe446c28)) + +- **core**: Changed how boxplot return type is handled + ([`714bf98`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/714bf98e2acf8b42b39588ec866f87ac38979dd0)) + +- **core**: Use plotly figure for static plot instead of png + ([`5738234`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/5738234fc7e16cc8f9c1e421343028976f73ecc2)) + +- **pin_color**: Add pin color feature to visualize nearest neighbor and helpers and unit test files + individually + ([`6e0adc9`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/6e0adc9aaf63c443c615e27dd3d6a6ded4a4b391)) + - **ripley_l**: Added edge correction parameter to the high level function ([`9a54f15`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/9a54f150c88cd4e86c0559d22ecf3f663bc6afd9)) - **summarize dataframe**: Added two visualization functions for the summary ([`66fdd4e`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/66fdd4e4e7c136d4aec3e4c7b49958a5af020858)) +- **visualize_nearest_neighbor**: Add pin-color and corresponding unit test + ([`015692f`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/015692f87e75e5c4581638dedd4d277c63c6ddec)) + +- **visualize_nn**: Add pin-color feature and enhance layout of fig and ax + ([`22e047a`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/22e047aa3d765469c8d0acd32f8ea3d8be5b27bd)) + ### Performance Improvements - Updating the default colormap for interactive spatial to rainbow @@ -123,6 +175,12 @@ ### Testing +- **comments**: Add extensive comments for complex data set generation in utag tests + ([`ef95276`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/ef952769f8d20310903f8a9269772a67cb1057d4)) + +- **comments**: Add extensive comments for complex data set generation in utag tests + ([`cd8cb25`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/cd8cb25764a3a5f4718692dea3585bae17f88105)) + - **compute_box_plot_metric**: Added verbose error message and check ([`fb3a372`](https://github.com/FNLCR-DMAP/SCSAWorkflow/commit/fb3a372d9f0cb2613ab77bda8e3c2407a6dd0e4c)) diff --git a/README.md b/README.md index 0d79f7e3..11ca2f00 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,7 @@ This Python-based package leverages the anndata framework for easy integration w ## Installing SPAC with Conda Run the following command to establish the Conda environment supporting usage and contribution to spac package: +Latest released version is v0.9.0 at 5/23/2025 ```bash cd # If conda is not activate diff --git a/paper/figure.tif b/paper/figure.tif new file mode 100644 index 00000000..e78e70c8 Binary files /dev/null and b/paper/figure.tif differ diff --git a/paper/paper.bib b/paper/paper.bib new file mode 100644 index 00000000..1f902e1d --- /dev/null +++ b/paper/paper.bib @@ -0,0 +1,190 @@ +@article{Gerdes:2013, + abstract = {Limitations on the number of unique protein and DNA molecules that can be characterized microscopically in a single tissue specimen impede advances in understanding the biological basis of health and disease. Here we present a multiplexed fluorescence microscopy method (MxIF) for quantitative, single-cell, and subcellular characterization of multiple analytes in formalin-fixed paraffinembedded tissue. Chemical inactivation of fluorescent dyes after each image acquisition round allows reuse of common dyes in iterative staining and imaging cycles. The mild inactivation chemistry is compatible with total and phosphoprotein detection, as well as DNA FISH. Accurate computational registration of sequential images is achieved by aligning nuclear counterstain-derived fiducial points. Individual cells, plasma membrane, cytoplasm, nucleus, tumor, and stromal regions are segmented to achieve cellular and subcellular quantification of multiplexed targets. In a comparison of pathologist scoring of diaminobenzidine staining of serial sections and automated MxIF scoring of a single section, human epidermal growth factor receptor 2, estrogen receptor, p53, and androgen receptor staining by diaminobenzidine and MxIF methods yielded similar results. Single-cell staining patterns of 61 protein antigens by MxIF in 747 colorectal cancer subjects reveals extensive tumor heterogeneity, and cluster analysis of divergent signaling through ERK1/2, S6 kinase 1, and 4E binding protein 1 provides insights into the spatial organization of mechanistic target of rapamycin and MAPK signal transduction. Our results suggest MxIF should be broadly applicable to problems in the fields of basic biological research, drug discovery and development, and clinical diagnostics.}, + author = {Gerdes, Michael J and Sevinsky, Christopher J and Sood, Anup and Adak, Sudeshna and Bello, Musodiq O and Bordwell, Alexander and Can, Ali and Corwin, Alex and Dinn, Sean and Filkins, Robert J and Hollman, Denise and Kamath, Vidya and Kaanumalle, Sireesha and Kenny, Kevin and Larsen, Melinda and Lazare, Michael and Li, Qing and Lowes, Christina and McCulloch, Colin C and McDonough, Elizabeth and Montalto, Michael C and Pang, Zhengyu and Rittscher, Jens and Santamaria-Pang, Alberto and Sarachan, Brion D and Seel, Maximilian L and Seppo, Antti and Shaikh, Kashan and Sui, Yunxia and Zhang, Jingyu and Ginty, Fiona}, + doi = {10.1073/pnas.1300136110}, + issn = {00278424}, + journal = {Proceedings of the National Academy of Sciences of the United States of America}, + keywords = {Cancer diagnostics,High-content cellular analysis,Image analysis,MTOR,Multiplexing}, + month = {jul}, + number = {29}, + pages = {11982--11987}, + pmid = {23818604}, + title = {{Highly multiplexed single-cell analysis of formalinfixed, paraffin-embedded cancer tissue}}, + volume = {110}, + year = {2013} +} + +@article{Nirmal:2024, + abstract = {Multiplexed imaging data are revolutionizing our understanding of the composition and organization of tissues and tumors. A critical aspect of such tissue profiling is quantifying the spatial relationship relationships among cells at different scales from the interaction of neighboring cells to recurrent communities of cells of multiple types. This often involves statistical analysis of 10^7 or more cells in which up to 100 biomolecules (commonly proteins) have been measured. While software tools currently cater to the analysis of spatial transcriptomics data, there remains a need for toolkits explicitly tailored to the complexities of multiplexed imaging data including the need to seamlessly integrate image visualization with data analysis and exploration. We introduce SCIMAP, a Python package specifically crafted to address these challenges. With SCIMAP, users can efficiently preprocess, analyze, and visualize large datasets, facilitating the exploration of spatial relationships and their statistical significance. SCIMAP's modular design enables the integration of new algorithms, enhancing its capabilities for spatial analysis.}, + author = {Nirmal, Ajit J and Sorger, Peter K}, + doi = {10.21105/joss.06604}, + journal = {Journal of Open Source Software}, + month = {may}, + number = {97}, + pages = {6604}, + publisher = {The Open Journal}, + title = {{SCIMAP: A Python Toolkit for Integrated Spatial Analysis of Multiplexed Imaging Data}}, + volume = {9}, + year = {2024} +} + +@article{Goltsev:2018, + abstract = {A highly multiplexed cytometric imaging approach, termed co-detection by indexing (CODEX), is used here to create multiplexed datasets of normal and lupus (MRL/lpr) murine spleens. CODEX iteratively visualizes antibody binding events using DNA barcodes, fluorescent dNTP analogs, and an in situ polymerization-based indexing procedure. An algorithmic pipeline for single-cell antigen quantification in tightly packed tissues was developed and used to overlay well-known morphological features with de novo characterization of lymphoid tissue architecture at a single-cell and cellular neighborhood levels. We observed an unexpected, profound impact of the cellular neighborhood on the expression of protein receptors on immune cells. By comparing normal murine spleen to spleens from animals with systemic autoimmune disease (MRL/lpr), extensive and previously uncharacterized splenic cell-interaction dynamics in the healthy versus diseased state was observed. The fidelity of multiplexed spatial cytometry demonstrated here allows for quantitative systemic characterization of tissue architecture in normal and clinically aberrant samples. A DNA barcoding-based imaging technique uses multiplexed tissue antigen staining to enable the characterization of cell types and dynamics in a model of autoimmune disease.}, + author = {Goltsev, Yury and Samusik, Nikolay and Kennedy-Darling, Julia and Bhate, Salil and Hale, Matthew and Vazquez, Gustavo and Black, Sarah and Nolan, Garry P}, + doi = {10.1016/j.cell.2018.07.010}, + issn = {10974172}, + journal = {Cell}, + keywords = {CODEX,autoimmunity,immune tissue,microenvironment,multidimensional imaging,multiplexed imaging,niche,tissue architecture}, + month = {aug}, + number = {4}, + pages = {968--981.e15}, + pmid = {30078711}, + publisher = {Cell Press}, + title = {{Deep Profiling of Mouse Splenic Architecture with CODEX Multiplexed Imaging}}, + volume = {174}, + year = {2018} +} + +@article{Lin:2018, + author = {Lin, Jia-Ren and Izar, Benjamin and Sorger, Peter K}, + doi = {10.7554/eLife.31657.002}, + journal = {eLife}, + title = {{Highly multiplexed immunofluorescence imaging of human tissues and tumors using t-CyCIF and conventional optical microscopes}}, + year = {2018} +} + +@article{Palla:2022, + abstract = {Spatial omics data are advancing the study of tissue organization and cellular communication at an unprecedented scale. Flexible tools are required to store, integrate and visualize the large diversity of spatial omics data. Here, we present Squidpy, a Python framework that brings together tools from omics and image analysis to enable scalable description of spatial molecular data, such as transcriptome or multivariate proteins. Squidpy provides efficient infrastructure and numerous analysis methods that allow to efficiently store, manipulate and interactively visualize spatial omics data. Squidpy is extensible and can be interfaced with a variety of already existing libraries for the scalable analysis of spatial omics data.}, + author = {Palla, Giovanni and Spitzer, Hannah and Klein, Michal and Fischer, David and Schaar, Anna Christina and Kuemmerle, Louis Benedikt and Rybakov, Sergei and Ibarra, Ignacio L and Holmberg, Olle and Virshup, Isaac and Lotfollahi, Mohammad and Richter, Sabrina and Theis, Fabian J}, + doi = {10.1038/s41592-021-01358-2}, + issn = {15487105}, + journal = {Nature Methods}, + month = {feb}, + number = {2}, + pages = {171--178}, + pmid = {35102346}, + publisher = {Nature Research}, + title = {{Squidpy: a scalable framework for spatial omics analysis}}, + volume = {19}, + year = {2022} +} + +@article{Dries:2021, + abstract = {Spatial transcriptomic and proteomic technologies have provided new opportunities to investigate cells in their native microenvironment. Here we present Giotto, a comprehensive and open-source toolbox for spatial data analysis and visualization. The analysis module provides end-to-end analysis by implementing a wide range of algorithms for characterizing tissue composition, spatial expression patterns, and cellular interactions. Furthermore, single-cell RNAseq data can be integrated for spatial cell-type enrichment analysis. The visualization module allows users to interactively visualize analysis outputs and imaging features. To demonstrate its general applicability, we apply Giotto to a wide range of datasets encompassing diverse technologies and platforms.}, + author = {Dries, Ruben and Zhu, Qian and Dong, Rui and Eng, Chee Huat Linus and Li, Huipeng and Liu, Kan and Fu, Yuntian and Zhao, Tianxiao and Sarkar, Arpan and Bao, Feng and George, Rani E and Pierson, Nico and Cai, Long and Yuan, Guo Cheng}, + doi = {10.1186/s13059-021-02286-2}, + issn = {1474760X}, + journal = {Genome Biology}, + month = {dec}, + number = {1}, + pmid = {33685491}, + publisher = {BioMed Central Ltd}, + title = {{Giotto: a toolbox for integrative analysis and visualization of spatial expression data}}, + volume = {22}, + year = {2021} +} + +@article{Giraldo:2021, + abstract = {Multiplex immunofluorescence (mIF) can detail spatial relationships and complex cell phenotypes in the tumor microenvironment (TME). However, the analysis and visualization of mIF data can be complex and time-consuming. Here, we used tumor specimens from 93 patients with metastatic melanoma to develop and validate a mIF data analysis pipeline using established flow cytometry workflows (image cytometry). Unlike flow cytometry, spatial information from the TME was conserved at single-cell resolution. A spatial uniform manifold approximation and projection (UMAP) was constructed using the image cytometry output. Spatial UMAP subtraction analysis (survivors vs. nonsurvivors at 5 years) was used to identify topographic and coexpression signatures with positive or negative prognostic impact. Cell densities and proportions identified by image cytometry showed strong correlations when compared with those obtained using gold-standard, digital pathology software (R2 > 0.8). The associated spatial UMAP highlighted “immune neighborhoods” and associated topographic immunoactive protein expression patterns. We found that PD-L1 and PD-1 expression intensity was spatially encoded—the highest PD-L1 expression intensity was observed on CD163+ cells in neighborhoods with high CD8+ cell density, and the highest PD-1 expression intensity was observed on CD8+ cells in neighborhoods with dense arrangements of tumor cells. Spatial UMAP subtraction analysis revealed numerous spatial clusters associated with clinical outcome. The variables represented in the key clusters from the unsupervised UMAP analysis were validated using established, supervised approaches. In conclusion, image cytometry and the spatial UMAPs presented herein are powerful tools for the visualization and interpretation of single-cell, spatially resolved mIF data and associated topographic biomarker development.}, + author = {Giraldo, Nicolas A and Berry, Sneha and Becht, Etienne and Ates, Deniz and Schenk, Kara M and Engle, Elizabeth L and Green, Benjamin and Nguyen, Peter and Soni, Abha and Stein, Julie E and Succaria, Farah and Ogurtsova, Aleksandra and Xu, Haiying and Gottardo, Raphael and Anders, Robert A and Lipson, Evan J and Danilova, Ludmila and Baras, Alexander S and Taube, Janis M}, + doi = {10.1158/2326-6066.CIR-21-0015}, + issn = {23266074}, + journal = {Cancer Immunology Research}, + month = {nov}, + number = {11}, + pages = {1262--1269}, + pmid = {34433588}, + publisher = {American Association for Cancer Research Inc.}, + title = {{Spatial UMAP and image cytometry for topographic immuno-oncology biomarker discovery}}, + volume = {9}, + year = {2021} +} + +@article{Long:2023, + abstract = {Spatial transcriptomics technologies generate gene expression profiles with spatial context, requiring spatially informed analysis tools for three key tasks, spatial clustering, multisample integration, and cell-type deconvolution. We present GraphST, a graph self-supervised contrastive learning method that fully exploits spatial transcriptomics data to outperform existing methods. It combines graph neural networks with self-supervised contrastive learning to learn informative and discriminative spot representations by minimizing the embedding distance between spatially adjacent spots and vice versa. We demonstrated GraphST on multiple tissue types and technology platforms. GraphST achieved 10% higher clustering accuracy and better delineated fine-grained tissue structures in brain and embryo tissues. GraphST is also the only method that can jointly analyze multiple tissue slices in vertical or horizontal integration while correcting batch effects. Lastly, GraphST demonstrated superior cell-type deconvolution to capture spatial niches like lymph node germinal centers and exhausted tumor infiltrating T cells in breast tumor tissue.}, + author = {Long, Yahui and Ang, Kok Siong and Li, Mengwei and Chong, Kian Long Kelvin and Sethi, Raman and Zhong, Chengwei and Xu, Hang and Ong, Zhiwei and Sachaphibulkij, Karishma and Chen, Ao and Zeng, Li and Fu, Huazhu and Wu, Min and Lim, Lina Hsiu Kim and Liu, Longqi and Chen, Jinmiao}, + doi = {10.1038/s41467-023-36796-3}, + issn = {20411723}, + journal = {Nature Communications}, + month = {dec}, + number = {1}, + pmid = {36859400}, + publisher = {Nature Research}, + title = {{Spatially informed clustering, integration, and deconvolution of spatial transcriptomics with GraphST}}, + volume = {14}, + year = {2023} +} + +@article{Hao:2021, + abstract = {The simultaneous measurement of multiple modalities represents an exciting frontier for single-cell genomics and necessitates computational methods that can define cellular states based on multimodal data. Here, we introduce “weighted-nearest neighbor” analysis, an unsupervised framework to learn the relative utility of each data type in each cell, enabling an integrative analysis of multiple modalities. We apply our procedure to a CITE-seq dataset of 211,000 human peripheral blood mononuclear cells (PBMCs) with panels extending to 228 antibodies to construct a multimodal reference atlas of the circulating immune system. Multimodal analysis substantially improves our ability to resolve cell states, allowing us to identify and validate previously unreported lymphoid subpopulations. Moreover, we demonstrate how to leverage this reference to rapidly map new datasets and to interpret immune responses to vaccination and coronavirus disease 2019 (COVID-19). Our approach represents a broadly applicable strategy to analyze single-cell multimodal datasets and to look beyond the transcriptome toward a unified and multimodal definition of cellular identity.}, + author = {Hao, Yuhan and Hao, Stephanie and Andersen-Nissen, Erica and Mauck, William M and Zheng, Shiwei and Butler, Andrew and Lee, Maddie J and Wilk, Aaron J and Darby, Charlotte and Zager, Michael and Hoffman, Paul and Stoeckius, Marlon and Papalexi, Efthymia and Mimitou, Eleni P and Jain, Jaison and Srivastava, Avi and Stuart, Tim and Fleming, Lamar M and Yeung, Bertrand and Rogers, Angela J and McElrath, Juliana M and Blish, Catherine A and Gottardo, Raphael and Smibert, Peter and Satija, Rahul}, + doi = {10.1016/j.cell.2021.04.048}, + issn = {10974172}, + journal = {Cell}, + keywords = {CITE-seq,COVID-19,T cell,immune system,multimodal analysis,reference mapping,single cell genomics}, + month = {jun}, + number = {13}, + pages = {3573--3587.e29}, + pmid = {34062119}, + publisher = {Elsevier B.V.}, + title = {{Integrated analysis of multimodal single-cell data}}, + volume = {184}, + year = {2021} +} + +@article{Mah:2024, + abstract = {The spatial organization of molecules in a cell is essential for their functions. While current methods focus on discerning tissue architecture, cell–cell interactions, and spatial expression patterns, they are limited to the multicellular scale. We present Bento, a Python toolkit that takes advantage of single-molecule information to enable spatial analysis at the subcellular scale. Bento ingests molecular coordinates and segmentation boundaries to perform three analyses: defining subcellular domains, annotating localization patterns, and quantifying gene–gene colocalization. We demonstrate MERFISH, seqFISH +, Molecular Cartography, and Xenium datasets. Bento is part of the open-source Scverse ecosystem, enabling integration with other single-cell analysis tools.}, + author = {Mah, Clarence K and Ahmed, Noorsher and Lopez, Nicole A and Lam, Dylan C and Pong, Avery and Monell, Alexander and Kern, Colin and Han, Yuanyuan and Prasad, Gino and Cesnik, Anthony J and Lundberg, Emma and Zhu, Quan and Carter, Hannah and Yeo, Gene W}, + doi = {10.1186/s13059-024-03217-7}, + issn = {1474760X}, + journal = {Genome Biology}, + month = {dec}, + number = {1}, + publisher = {BioMed Central Ltd}, + title = {{Bento: a toolkit for subcellular analysis of spatial transcriptomics data}}, + volume = {25}, + year = {2024} +} + +@article{Feng:2023, + abstract = {Spatial proteomics technologies have revealed an underappreciated link between the location of cells in tissue microenvironments and the underlying biology and clinical features, but there is significant lag in the development of downstream analysis methods and benchmarking tools. Here we present SPIAT (spatial image analysis of tissues), a spatial-platform agnostic toolkit with a suite of spatial analysis algorithms, and spaSim (spatial simulator), a simulator of tissue spatial data. SPIAT includes multiple colocalization, neighborhood and spatial heterogeneity metrics to characterize the spatial patterns of cells. Ten spatial metrics of SPIAT are benchmarked using simulated data generated with spaSim. We show how SPIAT can uncover cancer immune subtypes correlated with prognosis in cancer and characterize cell dysfunction in diabetes. Our results suggest SPIAT and spaSim as useful tools for quantifying spatial patterns, identifying and validating correlates of clinical outcomes and supporting method development.}, + author = {Feng, Yuzhou and Yang, Tianpei and Zhu, John and Li, Mabel and Doyle, Maria and Ozcoban, Volkan and Bass, Greg T and Pizzolla, Angela and Cain, Lachlan and Weng, Sirui and Pasam, Anupama and Kocovski, Nikolce and Huang, Yu Kuan and Keam, Simon P and Speed, Terence P and Neeson, Paul J and Pearson, Richard B and Sandhu, Shahneen and Goode, David L and Trigos, Anna S}, + doi = {10.1038/s41467-023-37822-0}, + issn = {20411723}, + journal = {Nature Communications}, + month = {dec}, + number = {1}, + pmid = {37188662}, + publisher = {Nature Research}, + title = {{Spatial analysis with SPIAT and spaSim to characterize and simulate tissue microenvironments}}, + volume = {14}, + year = {2023} +} + +@misc{Keretsu:2022, + abstract = {

Glioblastoma (GBM) is the most aggressive primary brain cancer in adults and remains incurable. Our study revealed an immunosuppressive role of mucosal-associated invariant T (MAIT) cells in GBM. In bulk RNA sequencing data analysis of GBM tissues, MAIT cell gene signature significantly correlated with poor patient survival. A scRNA-seq of CD45 + cells from 23 GBM tissue samples showed 15 (65.2%) were positive for MAIT cells and the enrichment of MAIT17. The MAIT cell signature significantly correlated with the activity of tumor-associated neutrophils (TANs) and myeloid-derived suppressor cells (MDSCs). Multiple immune suppressive genes known to be used by TANs/MDSCs were upregulated in MAIT-positive tumors. Spatial imaging analysis of GBM tissues showed that all specimens were positive for both MAIT cells and TANs and localized enrichment of TANs. These findings highlight the MAIT-TAN/MDSC axis as a novel therapeutic target to modulate GBM's immunosuppressive tumor microenvironment.

}, + author = {Keretsu, Seketoulie and Hana, Taijun and Lee, Alexander and Kedei, Noemi and Malik, Nargis and Kim, Hye and Spurgeon, Jo and Khayrullina, Guzal and Ruf, Benjamin and Hara, Ayaka and Coombs, Morgan and Watowich, Matthew and Hari, Ananth and Ford, Michael K B and Sahinalp, Cenk and Watanabe, Masashi and Zaki, George and Gilbert, Mark R and Cimino, Patrick. J and Prins, Robert and Terabe, Masaki}, + doi = {10.1101/2022.07.17.499189}, + institution = {bioRxiv}, + month = {jul}, + title = {{MAIT cells have a negative impact on glioblastoma}}, + url = {http://biorxiv.org/lookup/doi/10.1101/2022.07.17.499189}, + year = {2022} +} + +@misc{CodeOcean, + author = {{Code Ocean}}, + title = {{Code Ocean}}, + url = {https://codeocean.com/}, + urldate = {2025-04-01}, + year = {2019} +} + +@misc{PalantirTechnologies, + author = {{Palantir Technologies}}, + booktitle = {Palantir Technologies}, + title = {{Palantir Foundry Documentation}}, + url = {https://palantir.com/docs/foundry/}, + urldate = {2025-04-01}, + year = {2003} +} diff --git a/paper/paper.md b/paper/paper.md new file mode 100644 index 00000000..6e7ca005 --- /dev/null +++ b/paper/paper.md @@ -0,0 +1,71 @@ +--- +title: 'SPAC: A Python Package for Spatial Single-Cell Analysis of Multiplexed Imaging' +tags: + - multiplexed imaging + - spatial proteomics + - single-cell analysis + - tumor microenvironment +authors: + - name: Fang Liu + orcid: 0000-0002-4283-8325 + affiliation: 1 + - name: Rui He + affiliation: 2 + - name: Andrei Bombin + affiliation: 3 + - name: Ahmad B. Abdallah + affiliation: 4 + - name: Omar Eldaghar + affiliation: 4 + - name: Tommy R. Sheeley + affiliation: 4 + - name: Sam E. Ying + affiliation: 4 + - name: George Zaki + orcid: 0000-0002-2740-3307 + corresponding: true + affiliation: 1 +affiliations: + - index: 1 + name: Frederick National Laboratory for Cancer Research, United States + - index: 2 + name: Essential Software Inc., United States + - index: 3 + name: Axle Informatics, United States + - index: 4 + name: Purdue University, United States +date: 12 April 2025 +bibliography: paper.bib +--- + +# Summary + +Multiplexed immunofluorescence microscopy captures detailed measurements of spatially resolved, multiple biomarkers simultaneously, revealing tissue composition and cellular interactions in situ among single cells. The growing scale and dimensional complexity of these datasets demand reproducible, comprehensive and user-friendly computational tools. To address this need, we developed SPAC **(SPA**tial single-**C**ell analysis), a Python-based package and a corresponding shiny application within an integrated, modular SPAC ecosystem designed specifically for biologists without extensive coding expertise. Following image segmentation and extraction of spatially resolved single-cell data, SPAC streamlines downstream phenotyping and spatial analysis, facilitating characterization of cellular heterogeneity and spatial organization within tissues. Through scalable performance, specialized spatial statistics, highly customizable visualizations, and seamless workflows from dataset to insights, SPAC significantly lowers barriers to sophisticated spatial analyses. + +# Statement of Need + +Advanced multiplex imaging technologies, such as CODEX [@Goltsev:2018], MxIF [@Gerdes:2013], CyCIF [@Lin:2018], generate high dimensional dataset capable of profiling up to dozens of biomarkers simultaneously. Analyzing and interpreting these complex spatial protein data pose significant computational challenges, especially given that high-resolution whole-slide imaging data can reach hundreds of gigabytes in size and contain millions of cells across extensive tissue areas. Currently, many spatial biology tools (e.g., Seurat [@Hao:2021], GraphST [@Long:2023], and bento [@Mah:2024]), primarily address spatial transcriptomics and cannot directly handle multiplexed protein imaging data. Other specialized software such as SPIA [@Feng:2023], Giotto [@Dries:2021], Squidpy [@Palla:2022], and SCIMAP [@Nirmal:2024] provides valuable capabilities tailored for spatial protein analyses. However, these tools lack sufficient flexibility and customization options necessary to meet the diverse scalable analysis and visualization needs of non-technical users. + +To address this gap, we developed the SPAC Python package and the web-based SPAC Shiny application, which together enhance analytical capabilities through intuitive terminology, optimized computational performance, specialized spatial statistics, and extensive visualization configurations. Results computed using the SPAC Python package are stored as AnnData objects, which can be interactively explored in real time via the SPAC Shiny web application, enabling researchers to dynamically visualize data, toggle annotations, inspect cell populations, and compare experimental conditions without requiring extensive computational expertise. + +Specifically, SPAC uses biologist-friendly terminology to simplify technical AnnData concepts. In SPAC, \"cells\" are rows in the data matrix, \"features\" denote protein expression levels, \"tables\" contain transformed data layers, \"associated tables\" store spatial coordinates or dimensional reductions (e.g., UMAP embeddings), and \"annotations\" indicate cell phenotypes, experimental labels, slide identifiers, and other categorical data. + +To address real-time scalability challenges in analyzing large multiplex imaging datasets (exceeding 10 million cells), SPAC enhances computational efficiency by over 5x by integrating optimized numerical routines from NumPy\'s compiled C-based backend. Traditional visualization methods, such as seaborn, were computationally inefficient at this scale. SPAC's modified routines reduce visualization processing times from tens of seconds to a few seconds for generating histograms, box plots, and other visualizations involving millions of cells. + +SPAC introduces specialized functions that enhance conventional spatial analyses. For example, SPAC implements a specialized variant of Ripley's L statistic to evaluate clustering or dispersion between predefined cell phenotype pairs---a "center" phenotype relative to a "neighbor" phenotype. Unlike generalized Ripley's implementations (e.g., Squidpy), SPAC explicitly distinguishes phenotype pairings and employs edge correction by excluding cells located near the region\'s borders within the analytical radius, mitigating edge-effect biases and enhancing statistical reliability. Furthermore, SPAC supports flexible phenotyping methods, accommodating both manual and unsupervised approaches tailored to diverse experimental designs and biological questions. It also implements efficient neighborhood profiling via a KDTree‐based approach, quantifying the distribution of neighboring cell phenotypes within user-defined distance bins. The resulting three-dimensional array, capturing the local cellular microenvironment, is stored in the AnnData object and supports dimensionality reduction methods like spatial UMAP [@Giraldo:2021]. This enhances comparative analysis and visualization of complex spatial relationships across multiple slides and phenotype combinations. + +SPAC provides customizable visualization methods, leveraging Plotly\'s interactive capabilities for dynamic exploration of spatial data. Interactive spatial plots allow users to toggle of features (e.g., biomarkers) and multiple annotations simultaneously, while a pin-color option ensures consistent color mapping across analyses. These designs help researchers intuitively explore spatial relationships by switching between different cell populations and identify patterns before performing detailed quantitative analyses. In addition, SPAC supports comparative visualization, such as overlaying manual classifications with unsupervised clustering or comparing spatial distributions across experimental conditions or treatments. It also enhances core analytical functions (e.g., nearest neighbor computations using SCIMAP\'s spatial distance calculations) by integrating extensive visualization configurations, including subgroup analyses, subset plots, and faceted layouts, allowing tailored visual outputs for various experimental contexts and research questions. + +# Structure and Implementation + +The SPAC package is available at [GitHub](https://github.com/FNLCR-DMAP/SCSAWorkflow) and can be installed locally via conda. It includes five modules that streamline data processing, transformation, spatial analysis, visualization, and utility functions. The data utils module standardizes data into AnnData objects, manages annotations, rescales and normalizes features, and performs filtering, downsampling, and essential spatial computations (e.g., centroid calculation). The transformation tools module employs clustering algorithms (e.g., Phenograph, UTAG, KNN), dimensionality reduction, and normalization methods (batch normalization, z-score, arcsinh) to translate high-dimensional data into biological interpretation. The spatial analysis module offers specialized functions like spatial interaction matrices, Ripley's L statistic with edge correction, and efficient KDTree-based neighborhood profiling. It supports stratified analyses, capturing spatial signatures of cell phenotypes. The visualization module provides interactive and customizable visualizations, allowing dynamic exploration of spatial relationships and comparative visualization across experimental conditions. The utils module includes helper functions for input validation, naming conventions, regex searches, and user-friendly error handling to ensure data integrity. + +All SPAC modules are interoperable, forming a cohesive workflow \autoref{fig:workflow}. By adopting the AnnData format, SPAC ensures broad compatibility with existing single-cell analysis tools, produces high-quality figures, and facilitates easy export for external analyses. SPAC adheres to enterprise-level software engineering standards, featuring extensive unit testing, rigorous edge-case evaluation, comprehensive logging, and clear, context-rich error handling. These practices ensure reliability, adaptability, and easy-of-use across various deployment environments, including interactive Jupyter notebooks, analytic platforms (e.g., Code Ocean [@CodeOcean], Palantir Foundry [@PalantirTechnologies]), and real-time dashboards such as Shiny. Emphasizing readability and maintainability, SPAC provides a versatile and enhanced analytical solution for spatial single-cell analyses. To date, SPAC has been used in the analysis of over 8 datasets with over 30 million cells across diverse studies [@Keretsu:2022]. + +![Overview of the SPAC Workflow. The schematic presents an integrated pipeline for spatial single-cell analysis. Segmented cell data with spatial coordinates from various imaging platforms are ingested, normalized, clustered and phenotyped, and analyzed spatially to assess cell distribution and interactions while maintaining consistent data lineage.\label{fig:workflow}](figure.tif) + +# Acknowledgements + +We thank our collaborators at the National Cancer Institute Frederick National Laboratory, the Purdue Data Mine program, and the single-cell and spatial imaging communities for their essential contributions and resources. + +# References diff --git a/setup.py b/setup.py index 5c3873e4..ce335996 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='spac', - version="0.8.7", + version="0.9.0", description=( 'SPatial Analysis for single-Cell analysis (SPAC)' 'is a Scalable Python package for single-cell spatial protein data ' diff --git a/src/spac/__init__.py b/src/spac/__init__.py index c010a40d..f8b63dd6 100644 --- a/src/spac/__init__.py +++ b/src/spac/__init__.py @@ -22,7 +22,7 @@ functions.extend(module_functions) # Define the package version before using it in __all__ -__version__ = "0.8.7" +__version__ = "0.9.0" # Define a __all__ list to specify which functions should be considered public __all__ = functions diff --git a/src/spac/data_utils.py b/src/spac/data_utils.py index 9124ee34..d03f35bf 100644 --- a/src/spac/data_utils.py +++ b/src/spac/data_utils.py @@ -481,9 +481,9 @@ def _select_values_dataframe(data, annotation, values, exclude_values): # Proceed with filtering based on values or exclude_values if values is not None: - filtered_data = data[data[annotation].isin(values)] + filtered_data = data[data[annotation].astype(str).isin(values)] elif exclude_values is not None: - filtered_data = data[~data[annotation].isin(exclude_values)] + filtered_data = data[~data[annotation].astype(str).isin(exclude_values)] count = filtered_data.shape[0] logging.info( @@ -525,9 +525,9 @@ def _select_values_anndata(data, annotation, values, exclude_values): # Proceed with filtering based on values or exclude_values if values is not None: - filtered_data = data[data.obs[annotation].isin(values)].copy() + filtered_data = data[data.obs[annotation].astype(str).isin(values)].copy() elif exclude_values is not None: - filtered_data = data[~data.obs[annotation].isin(exclude_values)].copy() + filtered_data = data[~data.obs[annotation].astype(str).isin(exclude_values)].copy() count = filtered_data.n_obs logging.info( @@ -1151,6 +1151,7 @@ def summarize_dataframe( A dictionary where each key is a column name and its value is another dictionary with: - 'data_type': either 'numeric' or 'categorical' + - 'missing_count': int - 'missing_indices': list of row indices with missing values - 'summary': summary statistics if numeric or unique labels with counts if categorical @@ -1205,7 +1206,7 @@ def summarize_dataframe( print(f"Summary for column '{col}':") print(f"Type: {col_info['data_type']}") print("Count missing indices:", col_info['count_missing_indices']) - print("Missing indices:", col_info['missing_indices']) + # print("Missing indices:", col_info['missing_indices']) print("Details:", col_info['summary']) print("-" * 40) return results diff --git a/src/spac/transformations.py b/src/spac/transformations.py index f497d402..b2044f1c 100644 --- a/src/spac/transformations.py +++ b/src/spac/transformations.py @@ -11,6 +11,8 @@ from scipy.sparse import issparse from typing import List, Union, Optional from numpy.lib import NumpyVersion +from sklearn.neighbors import KNeighborsClassifier +from sklearn.preprocessing import LabelEncoder import multiprocessing import parmap from spac.utag_functions import utag @@ -104,6 +106,120 @@ def phenograph_clustering( adata.uns["phenograph_features"] = features +def knn_clustering( + adata, + features, + annotation, + layer=None, + k=50, + output_annotation="knn", + associated_table=None, + missing_label="no_label", + **kwargs): + """ + Calculate knn clusters using sklearn KNeighborsClassifier + + The function will add these two attributes to `adata`: + `.obs[output_annotation]` + The assigned int64 class labels by KNeighborsClassifier + + `.uns[output_annotation_features]` + The features used to calculate the knn clusters + + Parameters + ---------- + adata : anndata.AnnData + The AnnData object. + + features : list of str + The variables that would be included in fitting the KNN classifier. + + annotation : str + The name of the annotation used for classifying the data + + layer : str, optional + The layer to be used. + + k : int, optional + The number of nearest neighbor to be used in creating the graph. + + output_annotation : str, optional + The name of the output layer where the clusters are stored. + + associated_table : str, optional + If set, use the corresponding key `adata.obsm` to calcuate the + clustering. Takes priority over the layer argument. + + missing_label : str or int + The value of missing annotations in adata.obs[annotation] + + Returns + ------- + None + adata is updated inplace + """ + + # read in data, validate annotation in the call here + _validate_transformation_inputs( + adata=adata, + layer=layer, + associated_table=associated_table, + features=features, + annotation=annotation, + ) + + if not isinstance(k, int) or k <= 0: + raise ValueError( + f"`k` must be a positive integer. Received value: `{k}`" + ) + + data = _select_input_features( + adata=adata, + layer=layer, + associated_table=associated_table, + features=features + ) + + # boolean masks for labeled and unlabeled data + annotation_data = adata.obs[annotation] + annotation_mask = annotation_data != missing_label + annotation_mask &= pd.notnull(annotation_data) + unlabeled_mask = ~annotation_mask + + # check that annotation is non-trivial + if all(annotation_mask): + raise ValueError( + f"All cells are labeled in the annotation `{annotation}`." + " Please provide a mix of labeled and unlabeled data." + ) + elif not any(annotation_mask): + raise ValueError( + f"No cells are labeled in the annotation `{annotation}`." + " Please provide a mix of labeled and unlabeled data." + ) + + # fit knn classifier to labeled data and predict on unlabeled data + data_labeled = data[annotation_mask] + label_encoder = LabelEncoder() + annotation_labeled = label_encoder.fit_transform( + annotation_data[annotation_mask] + ) + + classifier = KNeighborsClassifier(n_neighbors=k, **kwargs) + classifier.fit(data_labeled, annotation_labeled) + + data_unlabeled = data[unlabeled_mask] + knn_predict = classifier.predict(data_unlabeled) + predicted_labels = label_encoder.inverse_transform(knn_predict) + + # format output and place predictions/data in right location + adata.obs[output_annotation] = np.nan + adata.obs[output_annotation][unlabeled_mask] = predicted_labels + adata.obs[output_annotation][annotation_mask] = \ + annotation_data[annotation_mask] + adata.uns[f"{output_annotation}_features"] = features + + def get_cluster_info(adata, annotation, features=None, layer=None): """ Retrieve information about clusters based on specific annotation. @@ -316,7 +432,8 @@ def _validate_transformation_inputs( adata: anndata, layer: Optional[str] = None, associated_table: Optional[str] = None, - features: Optional[Union[List[str], str]] = None + features: Optional[Union[List[str], str]] = None, + annotation: Optional[str] = None, ) -> None: """ Validate inputs for transformation functions. @@ -331,6 +448,8 @@ def _validate_transformation_inputs( Name of the key in `obsm` that contains the numpy array. features : list of str or str, optional Names of features to use for transformation. + annotation: str, optional + Name of annotation column in `obs` that contains class labels Raises ------ @@ -355,6 +474,9 @@ def _validate_transformation_inputs( if features is not None: check_feature(adata, features=features) + if annotation is not None: + check_annotation(adata, annotations=annotation) + def _select_input_features(adata: anndata, layer: str = None, @@ -1088,13 +1210,13 @@ def run_utag_clustering( adata : anndata.AnnData The AnnData object. features : list - List of features to use for clustering or for PCA. Default + List of features to use for clustering or for PCA. Default (None) is to use all. k : int The number of nearest neighbor to be used in creating the graph. Default is 15. resolution : float - Resolution parameter for the clustering, higher resolution produces + Resolution parameter for the clustering, higher resolution produces more clusters. Default is 1. max_dist : float Maximum distance to cut edges within a graph. Default is 20. @@ -1107,8 +1229,8 @@ def run_utag_clustering( n_iterations : int Number of iterations for the clustering. slide_key: str - Key of adata.obs containing information on the batch structure - of the data.In general, for image data this will often be a variable + Key of adata.obs containing information on the batch structure + of the data.In general, for image data this will often be a variable indicating the imageb so image-specific effects are removed from data. Default is "Slide". @@ -1118,14 +1240,14 @@ def run_utag_clustering( Updated AnnData object with clustering results. """ resolutions = [resolution] - + _validate_transformation_inputs( adata=adata, layer=layer, associated_table=associated_table, features=features ) - + # add print the current k value if not isinstance(k, int) or k <= 0: raise ValueError(f"`k` must be a positive integer, but received {k}.") @@ -1143,7 +1265,7 @@ def run_utag_clustering( adata_utag.X = data else: adata_utag = adata.copy() - + utag_results = utag( adata_utag, slide_key=slide_key, @@ -1152,15 +1274,15 @@ def run_utag_clustering( apply_clustering=True, clustering_method="leiden", resolutions=resolutions, - leiden_kwargs={"n_iterations": n_iterations, + leiden_kwargs={"n_iterations": n_iterations, "random_state": random_state}, n_pcs=n_pcs, parallel=parallel, processes=n_jobs, k=k, ) - # change camel case to snake - curClusterCol = 'UTAG Label_leiden_' + str(resolution) - cluster_list = utag_results.obs[curClusterCol].copy() + # change camel case to snake + cur_cluster_col = 'UTAG Label_leiden_' + str(resolution) + cluster_list = utag_results.obs[cur_cluster_col].copy() adata.obs[output_annotation] = cluster_list.copy() adata.uns["utag_features"] = features diff --git a/src/spac/utils.py b/src/spac/utils.py index 634388c1..f3506a20 100644 --- a/src/spac/utils.py +++ b/src/spac/utils.py @@ -1,6 +1,7 @@ import re import anndata as ad import numpy as np +import matplotlib import matplotlib.cm as cm import pandas as pd import logging @@ -683,7 +684,7 @@ def color_mapping( raise ValueError("Opacity must be between 0 and 1") try: - cmap = cm.get_cmap(color_map) + cmap = matplotlib.colormaps.get_cmap(color_map) except ValueError: raise ValueError(f"Invalid color map name: {color_map}") @@ -1007,7 +1008,11 @@ def get_defined_color_map(adata, defined_color_map=None, annotations=None, "an annotation column must be specified." ) # Generate a color mapping based on unique values in the annotation - unique_labels = np.unique(adata.obs[annotations].values) + if isinstance(annotations, str): + annotations = [annotations] + combined_labels = np.concatenate( + [adata.obs[col].astype(str).values for col in annotations]) + unique_labels = np.unique(combined_labels) return color_mapping( unique_labels, color_map=colorscale, diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 0ab0ee11..db44e510 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -1,3 +1,4 @@ +import logging import seaborn as sns import seaborn import pandas as pd @@ -21,15 +22,139 @@ from spac.data_utils import select_values import logging import warnings +import datashader as ds +import datashader.transfer_functions as tf + import base64 import time import json - +import re +from typing import Dict, List, Union +import matplotlib.colors as mcolors +import matplotlib.patches as mpatch +from functools import partial +from collections import OrderedDict # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +def heatmap_datashader(x, y, labels=None, theme=None, + x_axis_title="Component 1", y_axis_title="Component 2", + plot_title=None, **kwargs): + """ + Generates a heatmap using Datashader for large-scale scatter data. + + Parameters + ---------- + x : iterable + X-axis coordinates. + y : iterable + Y-axis coordinates. + labels : iterable, optional + Categorical labels for subgrouping data. + theme : str, optional, default='viridis' + Colormap theme for visualization. + x_axis_title : str, optional, default='Component 1' + Label for the x-axis. + y_axis_title : str, optional, default='Component 2' + Label for the y-axis. + plot_title : str, optional + Title of the plot. + **kwargs : dict, optional + Additional keyword arguments (e.g., 'fig_width', 'fig_height'). + + Returns + ------- + matplotlib.figure.Figure + A Matplotlib figure containing the heatmap visualization. + """ + + # Ensure x and y are iterable + if not hasattr(x, "__iter__") or not hasattr(y, "__iter__"): + raise ValueError("x and y must be array-like.") + if len(x) != len(y): + raise ValueError("x and y must have the same length.") + if labels is not None and len(labels) != len(x): + raise ValueError("Labels length should match x and y length.") + + # Define available color themes + themes = { + 'Elevation': ds.colors.Elevation, + 'viridis': ds.colors.viridis, + 'Hot': ds.colors.Hot, + 'Set1': ds.colors.Set1, + 'Set2': ds.colors.Set2, + 'Set3': ds.colors.Set3, + 'Sets1to3': ds.colors.Sets1to3, + 'inferno': ds.colors.inferno, + 'color_lookup': ds.colors.color_lookup, + } + cmap = themes.get(theme, ds.colors.viridis) # Default to 'viridis' if theme is not specified + + # Create a DataFrame for processing + coords = pd.DataFrame({"x": x, "y": y}) + if labels is not None: + coords["labels"] = labels + + # Determine plot boundaries + x_min, x_max = coords["x"].min(), coords["x"].max() + y_min, y_max = coords["y"].min(), coords["y"].max() + + # Supply default ranges if not in kwargs + canvas_kwargs = { + 'plot_width': 600, + 'plot_height': 400, + 'x_range': (x_min, x_max), + 'y_range': (y_min, y_max) + } + canvas_kwargs.update(kwargs) + + create_canvas = partial(ds.Canvas, **canvas_kwargs) + + if labels is not None: + categories = pd.Categorical(coords["labels"]).categories + num_categories = len(categories) + + rows = (num_categories // 3) + (1 if num_categories % 3 != 0 else 0) + fig, axes = plt.subplots(rows, 3, figsize=(12, 4 * rows)) + axes = axes.flatten() + + for i, cat in enumerate(categories): + subset = coords[coords["labels"] == cat] + canvas = create_canvas() + agg = canvas.points(subset, x="x", y="y", agg=ds.count()) + img = tf.shade(agg, cmap=cmap).to_pil() + + ax = axes[i] + ax.imshow(img, origin='lower', extent=(x_min, x_max, y_min, y_max)) + ax.set_title(f"{plot_title} - {cat}" if plot_title else str(cat)) + ax.set_xlabel(x_axis_title) + ax.set_ylabel(y_axis_title) + + for j in range(i + 1, len(axes)): + fig.delaxes(axes[j]) + else: + canvas = create_canvas() + agg = canvas.points(coords, x="x", y="y", agg=ds.count()) + img = tf.shade(agg, cmap=cmap).to_pil() + + fig, ax = plt.subplots( + figsize=( + canvas_kwargs["plot_width"] / 100, + canvas_kwargs["plot_height"] / 100 + ) + ) + ax.imshow(img, origin='lower', extent=(x_min, x_max, y_min, y_max)) + ax.set_title(plot_title if plot_title else "Density Plot") + ax.set_xlabel(x_axis_title) + ax.set_ylabel(y_axis_title) + + plt.tight_layout() + return fig + + + def visualize_2D_scatter( x, y, labels=None, point_size=None, theme=None, @@ -483,7 +608,7 @@ def histogram(adata, feature=None, annotation=None, layer=None, axs : matplotlib.axes.Axes or list of Axes The Axes object(s) of the histogram plot(s). Returns a single Axes if only one plot is created, otherwise returns a list of Axes. - + df : pandas.DataFrame DataFrame containing the data used for plotting the histogram. @@ -529,8 +654,9 @@ def histogram(adata, feature=None, annotation=None, layer=None, data_column = feature if feature else annotation - # Check for negative values and apply log1p transformation if + # Check for negative values and apply log1p transformation if # x_log_scale is True + if x_log_scale: if (df[data_column] < 0).any(): print( @@ -576,25 +702,25 @@ def calculate_histogram(data, bins, bin_edges=None): Parameters: - data (pd.Series): The input data to be binned. - - bins (int or sequence): Number of bins (if numeric) or unique categories + - bins (int or sequence): Number of bins (if numeric) or unique categories (if categorical). - - bin_edges (array-like, optional): Predefined bin edges for numeric data. + - bin_edges (array-like, optional): Predefined bin edges for numeric data. If None, automatic binning is used. Returns: - pd.DataFrame: A DataFrame containing the following columns: - - `count`: + - `count`: Frequency of values in each bin. - - `bin_left`: + - `bin_left`: Left edge of each bin (for numeric data). - - `bin_right`: + - `bin_right`: Right edge of each bin (for numeric data). - - `bin_center`: - Center of each bin (for numeric data) or category labels + - `bin_center`: + Center of each bin (for numeric data) or category labels (for categorical data). - + """ - + # Check if the data is numeric or categorical if pd.api.types.is_numeric_dtype(data): if bin_edges is None: @@ -612,7 +738,7 @@ def calculate_histogram(data, bins, bin_edges=None): else: counts = data.value_counts().sort_index() return pd.DataFrame({ - 'bin_center': counts.index, + 'bin_center': counts.index, 'bin_left': counts.index, 'bin_right': counts.index, 'count': counts.values @@ -641,7 +767,7 @@ def calculate_histogram(data, bins, bin_edges=None): group_data = plot_data[ plot_data[group_by] == group ][data_column] - group_hist = calculate_histogram(group_data, kwargs['bins'], + group_hist = calculate_histogram(group_data, kwargs['bins'], bin_edges=global_bin_edges) group_hist[group_by] = group hist_data.append(group_hist) @@ -651,8 +777,8 @@ def calculate_histogram(data, bins, bin_edges=None): kwargs.setdefault("multiple", "stack") kwargs.setdefault("element", "bars") - - sns.histplot(data=hist_data, x='bin_center', weights='count', + + sns.histplot(data=hist_data, x='bin_center', weights='count', hue=group_by, ax=ax, **kwargs) # If plotting feature specify which layer if feature: @@ -671,11 +797,11 @@ def calculate_histogram(data, bins, bin_edges=None): ax_array = ax_array.flatten() for i, ax_i in enumerate(ax_array): - group_data = plot_data[plot_data[group_by] == + group_data = plot_data[plot_data[group_by] == groups[i]][data_column] hist_data = calculate_histogram(group_data, kwargs['bins']) - sns.histplot(data=hist_data, x="bin_center", ax=ax_i, + sns.histplot(data=hist_data, x="bin_center", ax=ax_i, weights='count', **kwargs) # If plotting feature specify which layer if feature: @@ -712,17 +838,17 @@ def calculate_histogram(data, bins, bin_edges=None): # Precompute histogram data for single plot hist_data = calculate_histogram(plot_data[data_column], kwargs['bins']) if pd.api.types.is_numeric_dtype(plot_data[data_column]): - ax.set_xlim(hist_data['bin_left'].min(), + ax.set_xlim(hist_data['bin_left'].min(), hist_data['bin_right'].max()) - + sns.histplot( - data=hist_data, + data=hist_data, x='bin_center', - weights="count", - ax=ax, + weights="count", + ax=ax, **kwargs ) - + # If plotting feature specify which layer if feature: ax.set_title(f'Layer: {layer}') @@ -753,9 +879,9 @@ def calculate_histogram(data, bins, bin_edges=None): ax.set_ylabel(ylabel) if len(axs) == 1: - return {"fig": fig, "axs": axs[0], "df": plot_data} + return {"fig": fig, "axs": axs[0], "df": hist_data} else: - return {"fig": fig, "axs": axs, "df": plot_data} + return {"fig": fig, "axs": axs, "df": hist_data} def heatmap(adata, column, layer=None, **kwargs): """ @@ -1611,35 +1737,24 @@ def boxplot_interactive( DPI (dots per inch) for the figure. Default is 200. defined_color_map : str, optional - Predefined color mapping stored in adata.uns for specific labels. - Default is None, which will generate the color mapping automatically. - - annotation_colorscale : str, default='viridis' - Name of the color scale to use for the dots when annotation - is used. - - feature_colorscale: str, default='seismic' - Name of the color scale to use for the dots when feature - is used. - - figure_type : {"interactive", "static", "png"}, default = "interactive" - If "interactive", the plot is interactive, allowing for zooming - and panning. - If "static", the plot is static. - If "png", the plot is returned as a PNG image. - - return_metrics: bool, default = False - If True, the function also returns the computed boxplot metrics. - - **kwargs : additional keyword arguments - Any other keyword arguments passed to the underlying plotting function. + Key in 'adata.uns' holding a pre-computed color dictionary. + Falls back to automatic generation from 'annotation' values. + ax : matplotlib.axes.Axes, optional + A Matplotlib Axes object. Currently, this parameter is not used by the + underlying plotting functions (Seaborn's `catplot`/`displot`), which + will always generate a new figure and axes. The `ax` key in the + returned dictionary will contain the Axes from these new plots. + This parameter is maintained for API consistency and potential + future enhancements. Default is None. + **kwargs : dict + Additional arguments for seaborn figure-level functions. Returns ------- A dictionary containing the following keys: fig : plotly.graph_objects.Figure or str The generated boxplot figure, which can be either: - - If `figure_type` is "static": A base64-encoded PNG + - If `figure_type` is "static": A base64-encoded PNG image string - If `figure_type` is "interactive": A Plotly figure object @@ -1989,14 +2104,14 @@ def boxplot_from_statistics( 'hovermode': False, 'clickmode': 'none', 'modebar_remove': [ - 'toimage', - 'zoom', - 'zoomin', + 'toimage', + 'zoom', + 'zoomin', 'zoomout', - 'select', - 'pan', - 'lasso', - 'autoscale', + 'select', + 'pan', + 'lasso', + 'autoscale', 'resetscale' ], 'legend_itemclick': False, @@ -3208,145 +3323,270 @@ def _plot_spatial_distance_dispatch( plot_type, stratify_by=None, facet_plot=False, - **kwargs + distance_col="distance", + hue_axis="group", + palette=None, + **kwargs, ): """ - Decides the figure layout based on 'stratify_by' and 'facet_plot' - and dispatches actual plotting calls. - - Logic: - 1) If stratify_by and facet_plot => single figure with subplots (faceted) - 2) If stratify_by and not facet_plot => multiple figures, one per group - 3) If stratify_by is None => single figure (no subplots) - - This function calls seaborn figure-level functions (catplot or displot). + Dispatch a seaborn call to visualise nearest-neighbor distances. + Returns Axes object(s) for further customization. + + Layout logic + ------------ + 1. ``stratify_by`` & ``facet_plot`` → Faceted plot, returns ``Axes`` + or ``List[Axes]`` for the "ax" key. + 2. ``stratify_by`` & not ``facet_plot`` → List of plots, returns + ``List[Axes]`` for the "ax" key. + 3. ``stratify_by`` is None → Single plot, returns ``Axes`` or + ``List[Axes]`` (if plot_type creates facets) for the "ax" key. Parameters ---------- df_long : pd.DataFrame - Tidy DataFrame with columns ['cellid', 'group', 'distance', + Tidy DataFrame returned by `_prepare_spatial_distance_data` with + a long layout and with columns ['cellid', 'group', 'distance', 'phenotype', 'stratify_by']. method : {'numeric', 'distribution'} - Determines which seaborn function is used (catplot or displot). + ``'numeric'`` → :pyfunc:`seaborn.catplot` + ``'distribution'`` → :pyfunc:`seaborn.displot` plot_type : str - For method='numeric': 'box', 'violin', 'boxen', etc. - For method='distribution': 'hist', 'kde', 'ecdf', etc. + Kind forwarded to Seaborn. + Numeric (`method='numeric'`) – box, violin, boxen, strip, swarm, etc. + Distribution (`method='distribution'`) – hist, kde, ecdf, etc. stratify_by : str or None - Column name for grouping. If None, no grouping is done. - facet_plot : bool - If True, subplots in a single figure (faceted). - If False, separate figures (one per group) or a single figure. + Column used to split data. *None* for no splitting. + facet_plot : bool, default False + If True with stratify_by, create a faceted grid, otherwise + returns individual axes. + distance_col : str, default 'distance' + Column name in df_long holding the numeric distance values. + 'distance' – raw Euclidean / pixel / micron distances. + 'log_distance' – natural-log‐transformed distances. + The axis label is automatically adjusted. + hue_axis : str, default 'group' + Column that encodes the hue (color) dimension. + palette : dict or str or None + • dict → color map forwarded to seaborn/Matpotlib. + • str → any Seaborn/Matplotlib palette name + • None → defaults chosen by Seaborn + Typically the pin‑color map prepared upstream. **kwargs - Additional seaborn plotting arguments (e.g., col_wrap=2). + Extra keyword args propagated to Seaborn. Legend control + (e.g. `legend=False`) should be passed here if needed. Returns ------- dict - Dictionary with two keys: - - "data": the DataFrame (df_long) - - "fig": a Matplotlib Figure or a list of Figures - - Raises - ------ - ValueError - If 'method' is invalid (not 'numeric' or 'distribution'). - - Examples - -------- - Called internally by 'visualize_nearest_neighbor'. Typically not used - directly by end users. + { + 'data': pandas.DataFrame, # the input df_long + 'ax' : matplotlib.axes.Axes | list[Axes] + } """ - distance_col = kwargs.pop('distance_col', 'distance') - hue_axis = kwargs.pop('hue_axis', None) - - if method not in ['numeric', 'distribution']: + if method not in ("numeric", "distribution"): raise ValueError("`method` must be 'numeric' or 'distribution'.") - # Set up the plotting function using partial - if method == 'numeric': - plot_func = partial( + # Choose plotting function + if method == "numeric": + _plot_base = partial( sns.catplot, data=None, x=distance_col, - y='group', - kind=plot_type + y="group", + hue=hue_axis, + kind=plot_type, + palette=palette, ) else: # distribution - plot_func = partial( + _plot_base = partial( sns.displot, data=None, x=distance_col, - hue=hue_axis if hue_axis else None, - kind=plot_type + hue=hue_axis, + kind=plot_type, + palette=palette, ) - # Helper to plot a single figure or faceted figure - def _make_figure(data, **kws): - g = plot_func(data=data, **kws) - if distance_col == 'log_distance': - x_label = "Log(Nearest Neighbor Distance)" - else: - x_label = "Nearest Neighbor Distance" + # Single plotting wrapper to create Axes object(s) + def _make_axes_object(_data, **kws_plot): + g = _plot_base(data=_data, **kws_plot) + + axis_label = ( + "Log(Nearest Neighbor Distance)" + if "log" in distance_col + else "Nearest Neighbor Distance" + ) + + g.set_axis_labels(axis_label, None) - # Set axis label based on whether log transform was applied - if hasattr(g, 'set_axis_labels'): - g.set_axis_labels(x_label, None) + if g.axes.size == 1: + returned_ax = g.ax else: - # Fallback if 'set_axis_labels' is unavailable - plt.xlabel(x_label) + returned_ax = g.axes.flatten().tolist() - return g.fig + return returned_ax - figures = [] + # Build axes + final_axes_object = None - # Branching logic for figure creation if stratify_by and facet_plot: - # Single figure with faceted subplots (col=stratify_by) - fig = _make_figure(df_long, col=stratify_by, **kwargs) - figures.append(fig) - + final_axes_object = _make_axes_object( + df_long, col=stratify_by, **kwargs + ) elif stratify_by and not facet_plot: - # Multiple separate figures, one per unique value in stratify_by - categories = df_long[stratify_by].unique() - for cat in categories: - subset = df_long[df_long[stratify_by] == cat] - fig = _make_figure(subset, **kwargs) - figures.append(fig) - else: - # Single figure (no subplots) - fig = _make_figure(df_long, **kwargs) - figures.append(fig) - - # Return dictionary: { 'data': DataFrame, 'fig': Figure(s) } - result = {"data": df_long} - if len(figures) == 1: - result["fig"] = figures[0] + list_of_all_axes = [] + for category_value in df_long[stratify_by].unique(): + data_subset = df_long[df_long[stratify_by] == category_value] + axes_or_list_for_category = _make_axes_object( + data_subset, **kwargs + ) + if isinstance(axes_or_list_for_category, list): + list_of_all_axes.extend(axes_or_list_for_category) + else: + list_of_all_axes.append(axes_or_list_for_category) + final_axes_object = list_of_all_axes else: - result["fig"] = figures - return result + final_axes_object = _make_axes_object(df_long, **kwargs) + + return {"data": df_long, "ax": final_axes_object} + + +# Build a master HEX palette and cache it inside the AnnData object +# ----------------------------------------------------------------------------- +# WHAT Convert every entry in ``color_dict_rgb`` (which may contain RGB tuples, +# "rgb()" strings, or already‑hex values) into a canonical six‑digit HEX +# string, storing the results in ``palette_hex``. +# WHY Downstream plotting utilities (Matplotlib / Seaborn) expect colours in +# HEX. Performing the conversion once, here, guarantees a uniform format +# for all later plots and prevents inconsistencies when colours are +# re‑used. +# HOW The helper ``_css_rgb_or_hex_to_hex`` normalises each colour. The +# resulting dictionary is cached under ``adata.uns['_spac_palettes']`` so +# that *any* later function can retrieve the same palette by name. +# ``defined_color_map or annotation`` forms a unique key that ties the +# palette to either a user‑defined map or the current annotation field. +def _css_rgb_or_hex_to_hex(col, keep_alpha=False): + """ + Normalise a CSS-style color string to a hexadecimal value or + a valid Matplotlib color name. + + Parameters + ---------- + col : str + Accepted formats: + * '#abc', '#aabbcc', '#rrggbbaa' + * 'rgb(r,g,b)' or 'rgba(r,g,b,a)', where r, g, b are 0-255 and + a is 0-1 or 0-255 + * any named Matplotlib color + + keep_alpha : bool, optional + If True and the input includes alpha, return an 8-digit hex; + otherwise drop the alpha channel. Default is False. + + Returns + ------- + str + * Lower-case colour name or + * 6- or 8-digit lower-case hex. + + Raises + ------ + ValueError + If the color cannot be interpreted. + + Examples + -------- + >>> _css_rgb_or_hex_to_hex('gold') + 'gold' + >>> _css_rgb_or_hex_to_hex('rgb(255,0,0)') + '#ff0000' + >>> _css_rgb_or_hex_to_hex('rgba(255,0,0,0.5)', keep_alpha=True) + '#ff000080' + """ + + col = col.strip().lower() + + # Compile the rgb()/rgba() matcher locally to satisfy style request. + rgb_re = re.compile( + r'rgba?\s*\(' + r'\s*([0-9]{1,3})\s*,' + r'\s*([0-9]{1,3})\s*,' + r'\s*([0-9]{1,3})' + r'(?:\s*,\s*([0-9]*\.?[0-9]+))?' + r'\s*\)', + re.I, + ) + + # 1. direct hex + if col.startswith('#'): + return mcolors.to_hex(col, keep_alpha=keep_alpha).lower() + + # 2. rgb()/rgba() + match = rgb_re.fullmatch(col) + if match: + r, g, b, a = match.groups() + r, g, b = map(int, (r, g, b)) + if not all(0 <= v <= 255 for v in (r, g, b)): + raise ValueError( + f'RGB components in "{col}" must be between 0 and 255' + ) + rgba = [r / 255, g / 255, b / 255] + if a is not None: + a_val = float(a) + if a_val > 1: # user supplied 0-255 alpha + a_val /= 255 + rgba.append(a_val) + return mcolors.to_hex(rgba, keep_alpha=keep_alpha).lower() + + # 3. named color + if col in mcolors.get_named_colors_mapping(): + return col # let Matplotlib handle named colors + + # 4. unsupported format + raise ValueError(f'Unsupported color format: "{col}"') + + +# Helper function (can be defined at module level) +def _ordered_unique_figs(axes_list: list): + """ + Helper to get unique figures from a list of axes, + preserving first-seen order. + """ + seen = OrderedDict() + for ax_item in axes_list: # Assumes axes_list is indeed a list + fig = getattr(ax_item, 'figure', None) + if fig is not None: + seen.setdefault(fig, None) + return list(seen) def visualize_nearest_neighbor( adata, annotation, + distance_from, + distance_to=None, stratify_by=None, spatial_distance='spatial_distance', - distance_from=None, - distance_to=None, facet_plot=False, + method=None, plot_type=None, log=False, - method=None, + annotation_colorscale='rainbow', + defined_color_map=None, + ax=None, **kwargs ): """ Visualize nearest-neighbor (spatial distance) data between groups of cells - as numeric or distribution plots. + with optional pin-color map via numeric or distribution plots. - This user-facing function assembles the data by calling - `_prepare_spatial_distance_data` and then creates plots through - `_plot_spatial_distance_dispatch`. + This landing function first constructs a tidy long-form DataFrame via + function `_prepare_spatial_distance_data`, then dispatches plotting to + function `_plot_spatial_distance_dispatch`. A pin-color feature guarantees + consistent mapping from annotation labels to colors across figures, + drawing the mapping from ``adata.uns`` (if present) or generating one + automatically through `spac.utils.color_mapping`. Plot arrangement logic: 1) If stratify_by is not None and facet_plot=True => single figure @@ -3361,26 +3601,38 @@ def visualize_nearest_neighbor( Annotated data matrix with distances in `adata.obsm[spatial_distance]`. annotation : str Column in `adata.obs` containing cell phenotypes or annotations. - stratify_by : str, optional - Column in `adata.obs` used to group or stratify data (e.g. imageid). - spatial_distance : str, optional - Key in `adata.obsm` storing the distance DataFrame. Default is - 'spatial_distance'. distance_from : str Reference phenotype from which distances are measured. Required. distance_to : str or list of str, optional Target phenotype(s) to measure distance to. If None, uses all available phenotypes. + stratify_by : str, optional + Column in `adata.obs` used to group or stratify data (e.g. imageid). + spatial_distance : str, optional + Key in `adata.obsm` storing the distance DataFrame. Default is + 'spatial_distance'. facet_plot : bool, optional If True (and stratify_by is not None), subplots in a single figure. - Else, multiple or single figure(s). - plot_type : str, optional - For method='numeric': 'box', 'violin', 'boxen', etc. - For method='distribution': 'hist', 'kde', 'ecdf', etc. - log : bool, optional - If True, applies np.log1p transform to the distance values. + Otherwise, multiple or single figure(s). method : {'numeric', 'distribution'} Determines the plotting style (catplot vs displot). + plot_type : str or None, optional + Specific seaborn plot kind. If None, sensible defaults are selected + ('boxen' for numeric, 'violin' for distribution). + For method='numeric': 'box', 'violin', 'boxen', 'strip', 'swarm'. + For method='distribution': 'hist', 'kde', 'ecdf'. + log : bool, optional + If True, applies np.log1p transform to the distance values. + annotation_colorscale : str, optional + Matplotlib colormap name used when auto-enerating a new mapping. + Ignored if 'defined_color_map' is provided. + defined_color_map : str, optional + Key in 'adata.uns' holding a pre-computed color dictionary. + Falls back to automatic generation from 'annotation' values. + ax : matplotlib.axes.Axes, optional + The matplotlib Axes containing the analysis plots. + The returned ax is the passed ax or new ax created. + Only works if plotting a single component. Default is None. **kwargs : dict Additional arguments for seaborn figure-level functions. @@ -3388,53 +3640,64 @@ def visualize_nearest_neighbor( ------- dict { - "data": pd.DataFrame, # Tidy DataFrame used for plotting - "fig": Figure or list[Figure] # Single or multiple figures + 'data': pd.DataFrame, # long-form table for plotting + 'fig' : matplotlib.figure.Figure | list[Figure] | None, + 'ax': matplotlib.axes.Axes | list[matplotlib.axes.Axes], + 'palette': dict # {label: '#rrggbb'} } Raises ------ ValueError - If required parameters are missing or invalid. + If required parameters are invalid. Examples -------- - >>> # Numeric box plot comparing Tumor distances to multiple targets >>> res = visualize_nearest_neighbor( ... adata=my_adata, ... annotation='cell_type', - ... stratify_by='sample_id', - ... spatial_distance='spatial_distance', - ... distance_from='Tumor', - ... distance_to=['Stroma', 'Immune'], - ... facet_plot=True, + ... distance_from='Tumour', + ... distance_to=['Stroma', 'B cell'], + ... method='numeric', ... plot_type='box', - ... method='numeric' - ... ) - >>> df_long, fig = res["data"], res["fig"] - - >>> # Distribution plot (kde) for a single target, single figure - >>> res2 = visualize_nearest_neighbor( - ... adata=my_adata, - ... annotation='cell_type', - ... distance_from='Tumor', - ... distance_to='Stroma', - ... method='distribution', - ... plot_type='kde' + ... facet_plot=True, + ... stratify_by='image_id', + ... defined_color_map='pin_color_map' ... ) - >>> df_dist, fig2 = res2["data"], res2["fig"] + >>> fig = res['fig'] # matplotlib.figure.Figure + >>> ax_list = res['ax'] # list[matplotlib.axes.Axes] (faceted plot) + >>> df = res['data'] # long-form DataFrame + >>> ax_list[0].set_title('Tumour → Stroma distances') """ - if distance_from is None: - raise ValueError( - "Please specify the 'distance_from' phenotype. It indicates " - "the reference group from which distances are measured." - ) if method not in ['numeric', 'distribution']: raise ValueError( "Invalid 'method'. Please choose 'numeric' or 'distribution'." ) + # Determine plot_type if not provided + if plot_type is None: + plot_type = 'boxen' if method == 'numeric' else 'kde' + + # If log=True, the column name is 'log_distance', else 'distance' + distance_col = 'log_distance' if log else 'distance' + + # Build/fetch color palette + color_dict_rgb = get_defined_color_map( + adata=adata, + defined_color_map=defined_color_map, + annotations=annotation, + colorscale=annotation_colorscale + ) + + palette_hex = { + k: _css_rgb_or_hex_to_hex(v) for k, v in color_dict_rgb.items() + } + adata.uns.setdefault('_spac_palettes', {})[ + f"{defined_color_map or annotation}_hex" + ] = palette_hex + + # Reshape data df_long = _prepare_spatial_distance_data( adata=adata, annotation=annotation, @@ -3445,25 +3708,102 @@ def visualize_nearest_neighbor( log=log ) - # Determine plot_type if not provided - if plot_type is None: - plot_type = 'boxen' if method == 'numeric' else 'kde' + # Filter the full palette to include only the target groups present in + # df_long['group']. These are the groups that will actually be used for hue + # in the plot. + # Derive a palette tailored to *this* figure + # ----------------------------------------------------------------------------- + # WHAT ``plot_specific_palette`` keeps only the colours that correspond to the + # groups actually present in the tidy DataFrame ``df_long``. + # WHY Passing the full master palette could create legend entries (and colour + # assignments) for groups that do not appear in the current subset, + # cluttering the figure. Trimming the palette ensures a clean, accurate + # legend and avoids any mismatch between data and colour. + # HOW ``target_groups_in_plot`` is the list of unique group labels in the + # plot. For each label we look up its HEX code in ``palette_hex``; if a + # colour exists we copy the mapping into the new dictionary. + target_groups_in_plot = df_long['group'].astype(str).unique() + + plot_specific_palette = { + str(group): palette_hex.get(str(group)) + for group in target_groups_in_plot + if palette_hex.get(str(group)) is not None + } - # If log=True, the column name is 'log_distance', else 'distance' - distance_col = 'log_distance' if log else 'distance' + # Assemble kwargs & dispatch + # Inject the palette into the plotting dispatcher + # ----------------------------------------------------------------------------- + # WHAT Two keyword arguments are added/overwritten: + # • ``hue_axis='group'`` tells the plotting function to colour elements + # by the ``group`` column. + # • ``palette=plot_specific_palette`` supplies the exact colour mapping + # we just created. + # WHY Explicitly specifying both the hue axis and its palette guarantees that + # every group is rendered with the intended colour, bypassing Seaborn’s + # default colour cycle and preventing accidental re‑ordering. + # HOW ``dispatch_kwargs`` starts as a copy of any user‑supplied kwargs; the + # call to ``update`` adds these palette‑related keys before control is + # handed off to the generic plotting helper. + + dispatch_kwargs = dict(kwargs) + dispatch_kwargs.update({ + 'hue_axis': 'group', + 'palette': plot_specific_palette + }) + if method == 'numeric': + dispatch_kwargs.setdefault('saturation', 1.0) + + # Set legend=False to allow for custom legend creation by the caller + # The user can still override this by passing legend=True in kwargs + dispatch_kwargs.setdefault('legend', False) - # Dispatch to the plot logic - result_dict = _plot_spatial_distance_dispatch( + disp = _plot_spatial_distance_dispatch( df_long=df_long, method=method, plot_type=plot_type, stratify_by=stratify_by, facet_plot=facet_plot, distance_col=distance_col, - **kwargs + **dispatch_kwargs ) - return result_dict + returned_axes = disp['ax'] + fig_object = None # Initialize + + if isinstance(returned_axes, list): + if returned_axes: + # Unique figures, preserved in axis order + unique_figs_ordered = _ordered_unique_figs(returned_axes) + + if unique_figs_ordered: # at least one valid figure + if stratify_by and not facet_plot: + # one figure per category → return the ordered list + fig_object = unique_figs_ordered + else: + # single-figure layout (facet grid or no stratify) + if len(unique_figs_ordered) == 1: + fig_object = unique_figs_ordered[0] + # first (and usually only) figure + else: # defensive fallback + logging.warning( + "Multiple figures detected in a single-figure " + "scenario; using the first one." + ) + # Return the first one + fig_object = unique_figs_ordered[0] + # empty list → keep fig_object = None + elif returned_axes is not None: + # single Axes → grab its figure + fig_object = getattr(returned_axes, 'figure', None) + # returned_axes is None → fig_object stays None + + return { + 'data': disp['data'], + 'fig': fig_object, + 'ax': disp['ax'], + 'palette': plot_specific_palette # Return the filtered palette + } + import json import plotly.graph_objects as go @@ -3477,7 +3817,7 @@ def present_summary_as_html(summary_dict: dict) -> str: For each specified column, the HTML includes: - Column name and data type - Count and list of missing indices - - Summary details presented in a table (for numeric: stats; + - Summary details presented in a table (for numeric: stats; categorical: unique values and counts) Parameters @@ -3562,11 +3902,11 @@ def present_summary_as_figure(summary_dict: dict) -> go.Figure: clean_data = {} for k, v in info['summary'].items(): # Check if the value is a NumPy integer - if isinstance(v, np.integer): - clean_data[k] = int(v) + if isinstance(v, np.integer): + clean_data[k] = int(v) # Check if the value is a NumPy float - elif isinstance(v, np.floating): - clean_data[k] = float(v) + elif isinstance(v, np.floating): + clean_data[k] = float(v) else: # Keep the value as is if it's already a standard type clean_data[k] = v diff --git a/tests/test_data_utils/test_select_values.py b/tests/test_data_utils/test_select_values.py index 3561edb1..ec8f45bd 100644 --- a/tests/test_data_utils/test_select_values.py +++ b/tests/test_data_utils/test_select_values.py @@ -17,7 +17,8 @@ def setUp(self): # AnnData setup with values 'X', 'Y', 'Z' for the same 'column1' self.adata = ad.AnnData( np.random.rand(6, 2), - obs={'column1': ['X', 'Y', 'X', 'Y', 'X', 'Z']} + obs={'column1': ['X', 'Y', 'X', 'Y', 'X', 'Z'], + 'numerical': [1, 2, 3, 1, 2, 3]} ) def test_dataframe_nonexistent_annotation(self): @@ -47,6 +48,19 @@ def test_select_values_dataframe_typical_case(self): unique_values_in_result = result_df['column1'].unique().tolist() self.assertCountEqual(unique_values_in_result, expected_values) + def test_select_values_dataframe_numerical_case(self): + """ + Test selecting specified numerical values from a DataFrame column. + """ + result_df = select_values(self.df, 'column2', ['1']) + # Expecting 1 rows where column2 + self.assertEqual(len(result_df), 1) + # Assert that the sets of unique values in the result and expected + # values are identical. + expected_values = [1] + unique_values_in_result = result_df['column2'].unique().tolist() + self.assertCountEqual(unique_values_in_result, expected_values) + def test_select_values_adata_typical_case(self): """ Test selecting specified values from an AnnData object. @@ -58,6 +72,17 @@ def test_select_values_adata_typical_case(self): expected_values = ['X', 'Y'] self.assertCountEqual(unique_values_in_result, expected_values) + def test_select_values_adata_numerical_case(self): + """ + Test selecting specified numerical values from an AnnData object. + """ + result_adata = select_values(self.adata, 'numerical', ['1']) + # Expecting 2 rows where column2 is '1' + self.assertEqual(result_adata.n_obs, 2) + unique_values_in_result = result_adata.obs['numerical'].unique().tolist() + expected_values = [1] + self.assertCountEqual(unique_values_in_result, expected_values) + def test_exclude_values_dataframe_typical_case(self): """ Test excludeing specified values from a DataFrame column. diff --git a/tests/test_transformations/test_knn_clustering.py b/tests/test_transformations/test_knn_clustering.py new file mode 100644 index 00000000..8efd1d97 --- /dev/null +++ b/tests/test_transformations/test_knn_clustering.py @@ -0,0 +1,233 @@ +import unittest +import numpy as np +import pandas as pd +from anndata import AnnData +from spac.transformations import knn_clustering + + +class TestKnnClustering(unittest.TestCase): + def setUp(self): + """ + Set up a test environment for KNN clustering. + + This method is run before each test in the TestKnnClustering class. It initializes a synthetic + AnnData object (`adata`) that simulates a dataset for supervised clustering tasks. The + dataset includes features and class annotations, with a portion of the labels intentionally + set to "no_label" to test the handling of missing values. + + The attributes of the created AnnData object include: + + - `adata.X`: A 2D numpy array representing the feature matrix, where: + - The first half of the rows are generated from a normal distribution with a mean of 10. + - The second half of the rows are generated from a normal distribution with a mean of 100. + + - `adata.obs['classes']`: Class annotations for each row in `adata`, where approximately + half of the rows have missing labels represented by "no_label". The labels are as follows: + - 0: Corresponds to data points with a mean around 10. + - 1: Corresponds to data points with a mean around 100. + + - `adata.obs['all_missing_classes']`: An array where all entries are set to "no_label", + simulating a scenario where no class labels are available. + + - `adata.obs['no_missing_classes']`: An array containing all class labels (0 and 1), + indicating that all data points have valid annotations. + + - `adata.obs['alt_classes']`: An alternative class label array where "no_label" entries + are replaced with NaN values, allowing for testing scenarios that require handling + missing values as NaNs. + + Additionally, this method sets up several attributes for use in tests: + + - `self.annotation`: A string representing the column name for class labels in `obs`. + - `self.alt_annotation`: A string representing the column name for alternative class labels. + - `self.layer`: A string indicating which layer of data to use for KNN clustering. + - `self.features`: A list of feature names used in the AnnData object, which includes + "gene1" and "gene2". + """ + ######### + # adata # + ######### + + # Generate 6 rows, two with mean centered at (10, 10) and two with means at (100, 100) + data = np.array([ + np.concatenate( + ( + np.random.normal(10, 1, 3), + np.random.normal(10, 1, 3) + ) + ), + np.concatenate( + ( + np.random.normal(100, 1, 3), + np.random.normal(100, 1, 3) + ) + ), + ]).reshape(-1, 2) + + # Generate class labels, label 0 = mean at (10, 10), label 1 = mean at (100, 100) + full_class_labels = np.array([0, 0, 0, 1, 1, 1],dtype=object) + class_labels = np.array([0, 0, "no_label", "no_label", 1, 1],dtype=object) + alt_class_labels = np.array([0, 0, np.nan, np.nan, 1, 1],dtype=object) + + # Wrap into an AnnData object + self.dataset = data + self.adata = AnnData( + X=self.dataset, var=pd.DataFrame(index=["gene1", "gene2"]) + ) + + self.adata.layers["counts"] = self.dataset + self.adata.obsm["derived_features"] = self.dataset + self.adata.obs["classes"] = class_labels + + # annotations with all labels missing or present + self.adata.obs["all_missing_classes"] = np.array(["no_label" for x in full_class_labels]) + self.adata.obs["no_missing_classes"] = full_class_labels + self.adata.obs["alt_classes"] = alt_class_labels + + # non-adata parameters for unittests + self.annotation = "classes" + self.alt_annotation = "alt_classes" + self.layer = "counts" + self.features = ["gene1", "gene2"] + + + def test_typical_case(self): + # This test checks if the function correctly adds 'knn' to the + # AnnData object's obs attribute and if it correctly sets + # 'knn_features' in the AnnData object's uns attribute. + knn_clustering( + adata=self.adata, + features=self.features, + annotation=self.annotation, + layer=self.layer, + k = 2 + ) + self.assertIn("knn", self.adata.obs) + self.assertEqual(self.adata.uns["knn_features"], self.features) + + def test_output_annotation(self): + # This test checks if the function correctly adds the "output_annotation" + # to the # AnnData object's obs attribute + output_annotation_name = "my_output_annotation" + knn_clustering( + adata=self.adata, + features=self.features, + annotation=self.annotation, + layer=self.layer, + k = 2, + output_annotation=output_annotation_name, + ) + self.assertIn(output_annotation_name, self.adata.obs) + + def test_layer_none_case(self): + # This test checks if the function works correctly when layer is None. + knn_clustering( + adata=self.adata, + features=self.features, + annotation=self.annotation, + layer=None, + k = 2 + ) + self.assertIn("knn", self.adata.obs) + self.assertEqual(self.adata.uns["knn_features"], self.features) + + def test_invalid_k(self): + # This test checks if the function raises a ValueError when the + # k argument is not a positive integer and checks the error message + invalid_k_value = 'invalid' + err_msg = (f"`k` must be a positive integer. Received value: `{invalid_k_value}`") + with self.assertRaisesRegex(ValueError, err_msg): + knn_clustering( + adata=self.adata, + features=self.features, + annotation=self.annotation, + layer=self.layer, + k=invalid_k_value, + ) + + def test_trivial_label(self): + # This test checks if the data is fully labeled or missing labels for every datapoint + # and the associated error messages + + # all datapoints labeled + no_missing_annotation = "no_missing_classes" + err_msg = (f"All cells are labeled in the annotation `{no_missing_annotation}`. Please provide a mix of labeled and unlabeled data.") + with self.assertRaisesRegex(ValueError, err_msg): + knn_clustering( + adata=self.adata, + features=self.features, + annotation=no_missing_annotation, + layer=self.layer, + k = 2 + ) + + # no datapoints labeled + all_missing_annotation = "all_missing_classes" + err_msg = (f"No cells are labeled in the annotation `{all_missing_annotation}`. Please provide a mix of labeled and unlabeled data.") + with self.assertRaisesRegex(ValueError, err_msg): + knn_clustering( + adata=self.adata, + features=self.features, + annotation="all_missing_classes", + layer=self.layer, + k = 2 + ) + + def test_clustering_accuracy(self): + knn_clustering( + adata=self.adata, + features=self.features, + annotation=self.annotation, + layer="counts", + k=2, + ) + + self.assertIn("knn", self.adata.obs) + self.assertEqual(len(np.unique(self.adata.obs["knn"])), 2) + + def test_associated_features(self): + # Run knn using the derived feature and generate two clusters + output_annotation = "derived_knn" + associated_table = "derived_features" + knn_clustering( + adata=self.adata, + features=None, + annotation=self.annotation, + layer=None, + k=2, + output_annotation=output_annotation, + associated_table=associated_table, + ) + + self.assertEqual(len(np.unique(self.adata.obs[output_annotation])), 2) + + def test_missing_label(self): + # This test checks that the missing label parameter works as intended + #first knn call with normal data + knn_clustering( + adata=self.adata, + features=self.features, + annotation=self.annotation, + layer="counts", + k=2, + output_annotation="knn_1", + associated_table=None, + missing_label = "no_label" + ) + #second knn call with alt_class data + knn_clustering( + adata=self.adata, + features=self.features, + annotation=self.alt_annotation, + layer="counts", + k=2, + output_annotation="knn_2", + associated_table=None, + missing_label = np.nan + ) + + #assert that they produce the same final label + self.assertTrue(all(self.adata.obs["knn_1"]==self.adata.obs["knn_2"])) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_transformations/test_utag_clustering.py b/tests/test_transformations/test_utag_clustering.py index 57b329a2..4cbb44ca 100644 --- a/tests/test_transformations/test_utag_clustering.py +++ b/tests/test_transformations/test_utag_clustering.py @@ -53,24 +53,46 @@ def create_syn_data(self): # make a dataset with non normal distribution of genes, so that clustering # done with PCAs and with features will produce different clusters def create_adata_complex(self, n_cells_complex=500): - # Generate spatial coordinates in a circular pattern + # Creates a complex AnnData object with spatial gene expression patterns. + # Step 1: Generate spatial coordinates in a circular pattern + # - theta represents angular position (0 to 2π radians) + # - r represents radial distance from center (0 to 10 units) theta = np.random.uniform(0, 2*np.pi, n_cells_complex) r = np.random.uniform(0, 10, n_cells_complex) + # Step 2: Convert polar coordinates (r, theta) to Cartesian coordinates (x, y) x_coord = r * np.cos(theta) y_coord = r * np.sin(theta) - # Radial distance-dependent genes (higher expression at the periphery) - gene1 = np.exp(r/5) + np.random.normal(0, 0.5, n_cells_complex) - gene2 = -np.exp(r/5) + np.random.normal(0, 0.5, n_cells_complex) - # Angular position-dependent genes (periodic pattern) + # Step 3: Create radial distance-dependent genes + # - gene1: Expression increases with distance from center (exponential gradient) + # - gene2: Expression decreases with distance from center (negative exponential gradient) + # - Both include random noise to simulate biological variability + gene1 = np.exp(r/5) + np.random.normal(0, 0.5, n_cells_complex) # Higher expression at periphery + gene2 = -np.exp(r/5) + np.random.normal(0, 0.5, n_cells_complex) # Higher expression at center + # Step 4: Create angular position-dependent genes + # - gene3: Expression follows sinusoidal pattern based on angular position + # - gene4: Expression follows cosine pattern based on angular position + # - These create 3 peaks/valleys around the circle (frequency=3) + # - The cosine pattern in gene4 is shifted 30° (π/6 radians) compared to gene3's sine pattern + # - Adds random noise with lower standard deviation (0.3) gene3 = np.sin(3*theta) + np.random.normal(0, 0.3, n_cells_complex) gene4 = np.cos(3*theta) + np.random.normal(0, 0.3, n_cells_complex) - # Quadrant-specific genes + # Step 5: Identify quadrants based on Cartesian coordinates + # - Quadrant 1: x>0, y>0 (top right) + # - Quadrant 2: x<0, y>0 (top left) + # - Quadrant 3: x<0, y<0 (bottom left) + # - Quadrant 4: x>0, y<0 (bottom right) quadrant = np.where((x_coord > 0) & (y_coord > 0), 1, np.where((x_coord < 0) & (y_coord > 0), 2, np.where((x_coord < 0) & (y_coord < 0), 3, 4))) + # Step 6: Create genes with quadrant-specific expression patterns + # - gene5: Highly expressed in top half (quadrants 1,2) + # - gene6: Highly expressed in right half (quadrants 1,4) + # - Both include random noise to simulate biological variability gene5 = np.where(np.isin(quadrant, [1, 2]), 3, 0) + np.random.normal(0, 0.3, n_cells_complex) gene6 = np.where(np.isin(quadrant, [1, 4]), 3, 0) + np.random.normal(0, 0.3, n_cells_complex) - # Random noise genes + # Step 7: Create control genes with random expression (no spatial pattern) + # - gene7, gene8: Random normal distribution with no spatial dependency + # - These simulate genes that are not spatially regulated gene7 = np.random.normal(0, 1, n_cells_complex) gene8 = np.random.normal(0, 1, n_cells_complex) # Combine all genes @@ -96,8 +118,9 @@ def create_adata_complex(self, n_cells_complex=500): ) ) - # Add raw counts layer (here our main matrix is already "normalized") + # Add raw counts layer adata_complex.layers['counts'] = expression_matrix.copy() + # Add spatial coordinates adata_complex.obsm["spatial"] = np.random.rand(n_cells_complex, n_cells_complex) return adata_complex @@ -213,7 +236,9 @@ def test_invalid_k(self): output_annotation="UTAG", associated_table=None, parallel=False) - + + # test is done on the simple synthetic data with 2 populations of 500 cells + # each and 2 features def test_clustering_accuracy(self): run_utag_clustering(adata=self.syn_data, features=None, diff --git a/tests/test_utils/test_get_defined_color_map.py b/tests/test_utils/test_get_defined_color_map.py index 7dba59dd..c2e0a360 100644 --- a/tests/test_utils/test_get_defined_color_map.py +++ b/tests/test_utils/test_get_defined_color_map.py @@ -86,7 +86,8 @@ def test_generate_color_map(self): self.assertIn('a', result) self.assertIn('b', result) # Check that the colors are correctly generated. - self.assertTrue(all(isinstance(color, str) for color in result.values())) + self.assertTrue(all(isinstance(color, str) for color + in result.values())) def test_missing_annotations(self): """ @@ -101,6 +102,18 @@ def test_missing_annotations(self): ): get_defined_color_map(dummy, defined_color_map=None) + def test_generate_color_map_multiple_annotations(self): + """ + Test handling of list-based annotations, + raises a NotImplementedError. + """ + obs = {'my_ann': pd.Series(['a', 'b', 'a']), + 'my_ann_2': pd.Series(['a', 'b', 'a'])} + dummy = DummyAnnData(uns={'dummy': {}}, obs=obs) + annos_list = list(('my_ann', 'my_ann_2')) + result = get_defined_color_map(dummy, + annotations=annos_list) + self.assertIsNotNone(result) if __name__ == '__main__': unittest.main() diff --git a/tests/test_visualization/test_css_rgb_or_hex_to_hex.py b/tests/test_visualization/test_css_rgb_or_hex_to_hex.py new file mode 100644 index 00000000..d71f412e --- /dev/null +++ b/tests/test_visualization/test_css_rgb_or_hex_to_hex.py @@ -0,0 +1,89 @@ +import unittest +import re # Import re if used directly in tests, though _css_rgb_or_hex_to_hex encapsulates its re usage +import matplotlib.colors as mcolors +from spac.visualization import _css_rgb_or_hex_to_hex + +class TestCssRgbOrHexToHex(unittest.TestCase): + """ + Test suite for the _css_rgb_or_hex_to_hex function, + focusing on major features and error handling. + """ + + def test_valid_hex_colors(self): + """ + Test valid hex color conversions: + - 6-digit (lowercase, uppercase, mixed) to lowercase. + - 3-digit to 6-digit lowercase. + - 8-digit (with alpha) to 8-digit or 6-digit lowercase. + """ + self.assertEqual(_css_rgb_or_hex_to_hex('#ff0000'), '#ff0000') + self.assertEqual(_css_rgb_or_hex_to_hex('#FF00AA'), '#ff00aa') + self.assertEqual(_css_rgb_or_hex_to_hex('#Ff00aA'), '#ff00aa') + # 3-digit hex + self.assertEqual(_css_rgb_or_hex_to_hex('#f0a'), '#ff00aa') + # 8-digit hex with alpha + self.assertEqual( + _css_rgb_or_hex_to_hex('#ff00aa80', keep_alpha=True), + '#ff00aa80' + ) + self.assertEqual( + _css_rgb_or_hex_to_hex('#FF00AA80', keep_alpha=False), + '#ff00aa' + ) + + # ---------- named colours ------------------------------------------- + + def test_named_colour_passthrough(self): + self.assertEqual(_css_rgb_or_hex_to_hex('gold'), 'gold') + self.assertEqual(_css_rgb_or_hex_to_hex(' GOLd '), 'gold') + + # ---------- hexadecimal forms --------------------------------------- + + def test_short_and_long_hex(self): + self.assertEqual(_css_rgb_or_hex_to_hex('#ABC'), '#aabbcc') + self.assertEqual(_css_rgb_or_hex_to_hex('#a1b2c3'), '#a1b2c3') + + def test_keep_alpha(self): + # 8-digit input keeps alpha → unchanged + self.assertEqual( + _css_rgb_or_hex_to_hex('#ff000080', keep_alpha=True), + '#ff000080', + ) + # when keep_alpha is False alpha is stripped + self.assertEqual( + _css_rgb_or_hex_to_hex('#ff000080', keep_alpha=False), + '#ff0000', + ) + + # ---------- rgb()/rgba() strings ------------------------------------ + + def test_rgb_to_hex(self): + self.assertEqual( + _css_rgb_or_hex_to_hex('rgb(255,0,0)'), + '#ff0000', + ) + + def test_rgba_to_hex_with_alpha(self): + self.assertEqual( + _css_rgb_or_hex_to_hex('rgba(255,0,0,0.5)', keep_alpha=True), + '#ff000080', # 0.5 → 0x80 + ) + # keep_alpha False strips alpha + self.assertEqual( + _css_rgb_or_hex_to_hex('rgba(255,0,0,0.5)', keep_alpha=False), + '#ff0000', + ) + + # ---------- error handling ------------------------------------------ + + def test_out_of_range_raises_value_error(self): + """RGB components >255 should trigger a descriptive ValueError.""" + with self.assertRaisesRegex( + ValueError, + r'RGB components in ".+" must be between 0 and 255' + ): + _css_rgb_or_hex_to_hex('rgb(300,0,0)') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_visualization/test_datashader.py b/tests/test_visualization/test_datashader.py new file mode 100644 index 00000000..c6c3a7df --- /dev/null +++ b/tests/test_visualization/test_datashader.py @@ -0,0 +1,82 @@ +import unittest +import numpy as np +import pandas as pd +from spac.visualization import heatmap_datashader +import matplotlib + +matplotlib.use('Agg') # Set the backend to 'Agg' to suppress plot window + + +class TestDataShaderHeatMap(unittest.TestCase): + def setUp(self): + """Prepare data for testing.""" + self.x = np.random.rand(10) + self.y = np.random.rand(10) + # Fixed categorical labels to ensure representation of each category + fixed_labels = ['A', 'B', 'C', 'A', 'B', 'C', 'A', 'B', 'C', 'A'] + self.labels_categorical = pd.Series(fixed_labels, dtype="category") + + def test_invalid_input_type(self): + """Test handling of invalid input types.""" + with self.assertRaises(ValueError) as context_manager: + heatmap_datashader(1, self.y, labels=self.labels_categorical) + self.assertIn("x and y must be array-like", + str(context_manager.exception)) + + def test_labels_length_mismatch(self): + """Test handling of mismatched lengths between data and labels.""" + wrong_labels = pd.Series(['A'] * 9) # Shorter than x and y + with self.assertRaises(ValueError) as context_manager: + heatmap_datashader(self.x, self.y, labels=wrong_labels) + self.assertIn("Labels length should match x and y length", + str(context_manager.exception)) + + def test_valid_input_returns_figure_basic(self): + """Test that valid input returns a matplotlib figure with expected subplots.""" + fig = heatmap_datashader(self.x, self.y, + labels=self.labels_categorical) + self.assertIsInstance(fig, matplotlib.figure.Figure) + + num_axes = len(fig.axes) + expected_axes = self.labels_categorical.nunique() + self.assertEqual(num_axes, expected_axes) + + def test_labels_not_multiple_of_three(self): + """Test heatmap generation when the number of labels is not a multiple of 3.""" + x = np.random.rand(7) + y = np.random.rand(7) + labels = pd.Series(['A', 'B', 'C', 'D', 'E', 'F', 'G'], dtype="category") # 7 labels + + fig = heatmap_datashader(x, y, labels=labels) + self.assertIsInstance(fig, matplotlib.figure.Figure) + + num_axes = len(fig.axes) + expected_axes = labels.nunique() + self.assertEqual(num_axes, expected_axes) + + for ax in fig.axes: + images = [child for child in ax.get_children() + if isinstance(child, matplotlib.image.AxesImage)] + self.assertGreater(len(images), 0, + "Expected at least one image in each subplot.") + + def test_valid_input_returns_figure(self): + """Test that valid input returns a matplotlib figure with expected subplots and images.""" + fig = heatmap_datashader(self.x, self.y, + labels=self.labels_categorical) + self.assertIsInstance(fig, matplotlib.figure.Figure) + + # Check number of axes matches number of unique labels + num_axes = len(fig.axes) + expected_axes = self.labels_categorical.nunique() + self.assertEqual(num_axes, expected_axes) + + # Check that each axis has an image plotted + for ax in fig.axes: + images = [child for child in ax.get_children() + if isinstance(child, matplotlib.image.AxesImage)] + self.assertGreater(len(images), 0, + "Expected at least one image in each subplot.") + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_visualization/test_plot_spatial_distance_dispatch.py b/tests/test_visualization/test_plot_spatial_distance_dispatch.py index 11587376..bff7dc40 100644 --- a/tests/test_visualization/test_plot_spatial_distance_dispatch.py +++ b/tests/test_visualization/test_plot_spatial_distance_dispatch.py @@ -1,6 +1,3 @@ -import os -import sys -sys.path.append(os.path.dirname(os.path.realpath(__file__)) + "/../../src") import unittest import pandas as pd import matplotlib @@ -8,17 +5,27 @@ import matplotlib.pyplot as plt from pandas.testing import assert_frame_equal from spac.visualization import _plot_spatial_distance_dispatch +from matplotlib.axes import Axes as MatplotlibAxes class TestPlotSpatialDistanceDispatch(unittest.TestCase): def setUp(self): - # Creates a minimal DataFrame for testing + # Creates DataFrames for testing self.df_basic = pd.DataFrame({ 'cellid': ['C1', 'C2'], - 'group': ['g1', 'g1'], + 'group': ['g1', 'g1'], # Single group level 'distance': [0.5, 1.5], - 'phenotype': ['p1', 'p1'] + 'log_distance': [0.405, 0.916], # log1p approx + 'phenotype': ['p1', 'p1'] # Single phenotype level + }) + self.df_strat_and_hue = pd.DataFrame({ + 'cellid': ['C1', 'C2', 'C3', 'C4', 'C5', 'C6'], + 'group': ['g1', 'g1', 'g2', 'g2', 'g1', 'g2'], + 'distance': [0.5, 1.5, 0.7, 1.2, 0.8, 1.0], + 'log_distance': [0.405, 0.916, 0.530, 0.788, 0.587, 0.693], + 'phenotype': ['p1', 'p1', 'p2', 'p2', 'p1', 'p2'], + 'region': ['R1', 'R1', 'R2', 'R2', 'R1', 'R2'] }) def tearDown(self): @@ -27,8 +34,8 @@ def tearDown(self): def test_simple_numeric_scenario(self): """ - Tests a simple scenario with 'numeric' method and no stratify_by. - Verifies output structure and basic figure attributes. + Tests 'numeric' method, no stratify_by. + Verifies output structure and Axes attributes. """ result = _plot_spatial_distance_dispatch( df_long=self.df_basic, @@ -36,101 +43,145 @@ def test_simple_numeric_scenario(self): plot_type='box' ) - # Checks result structure self.assertIn('data', result) - self.assertIn('fig', result) - - # Verifies the data matches the input DataFrame + self.assertIn('ax', result) assert_frame_equal(result['data'], self.df_basic) - # Verifies figure properties - fig = result['fig'] - # Ensures there is at least one axis in the figure - self.assertTrue(len(fig.axes) > 0) + ax_obj = result['ax'] + self.assertIsInstance(ax_obj, MatplotlibAxes) + self.assertEqual(ax_obj.get_xlabel(), "Nearest Neighbor Distance") + self.assertEqual(ax_obj.get_ylabel(), 'group') - # Check axis labels - ax = fig.axes[0] - self.assertEqual(ax.get_xlabel(), "Nearest Neighbor Distance") - self.assertIn('group', ax.get_ylabel()) + def test_simple_numeric_log_distance(self): + """Tests 'numeric' method with 'log_distance'.""" + result = _plot_spatial_distance_dispatch( + df_long=self.df_basic, + method='numeric', + plot_type='box', + distance_col='log_distance' + ) + self.assertIn('ax', result) + ax_obj = result['ax'] + self.assertIsInstance(ax_obj, MatplotlibAxes) + self.assertEqual( + ax_obj.get_xlabel(), + "Log(Nearest Neighbor Distance)" + ) - def test_distribution_scenario_with_hue(self): + def test_distribution_scenario_with_explicit_hue(self): """ - Tests the 'distribution' method with a hue axis. - Verifies output structure and figure attributes. + Tests 'distribution' method with explicit hue_axis. + Verifies output structure and Axes attributes. """ - df_hue = self.df_basic.copy() - df_hue['phenotype'] = ['p1', 'p2'] - result = _plot_spatial_distance_dispatch( - df_long=df_hue, + df_long=self.df_strat_and_hue, method='distribution', plot_type='kde', - hue_axis='phenotype' + hue_axis='phenotype' # 'phenotype' has multiple levels ) - # Verifies the data matches the input DataFrame - assert_frame_equal(result['data'], df_hue) - fig = result['fig'] - self.assertTrue(len(fig.axes) > 0) + assert_frame_equal(result['data'], self.df_strat_and_hue) + ax_obj = result['ax'] + self.assertIsInstance(ax_obj, MatplotlibAxes) + self.assertEqual(ax_obj.get_xlabel(), "Nearest Neighbor Distance") + self.assertTrue(len(ax_obj.get_ylabel()) > 0) # e.g., 'Density' def test_stratify_and_facet_plot(self): """ - Tests the scenario with stratify_by and facet_plot=True. - Verifies the presence of multiple subplots in a single figure. + Tests stratify_by and facet_plot=True. + Verifies list of Axes and their labels. """ - df_strat = pd.DataFrame({ - 'cellid': ['C1', 'C2', 'C3', 'C4'], - 'group': ['g1', 'g1', 'g2', 'g2'], - 'distance': [0.5, 1.5, 0.7, 1.2], - 'phenotype': ['p1', 'p1', 'p2', 'p2'], - 'region': ['R1', 'R1', 'R2', 'R2'] - }) + col_wrap_val = 2 result = _plot_spatial_distance_dispatch( - df_long=df_strat, + df_long=self.df_strat_and_hue, method='numeric', plot_type='violin', - stratify_by='region', + stratify_by='region', # 'region' has R1, R2 (2 unique values) facet_plot=True, - col_wrap=2 + col_wrap=col_wrap_val # Passed to Seaborn ) - assert_frame_equal(result['data'], df_strat) - fig = result['fig'] - # Verifies the expected number of subplots in the figure - self.assertEqual(len(fig.axes), 2) + assert_frame_equal(result['data'], self.df_strat_and_hue) + ax_list = result['ax'] + self.assertIsInstance(ax_list, list) + + num_facets = self.df_strat_and_hue['region'].nunique() + self.assertEqual(len(ax_list), num_facets) + + for i, ax_item in enumerate(ax_list): + self.assertIsInstance(ax_item, MatplotlibAxes) + self.assertEqual(ax_item.get_xlabel(), "Nearest Neighbor Distance") + + # Determine expected y-label based on position in the wrapped grid + # Axes in the first column of a wrapped layout should have the y-label. + # Others (inner columns) should have it cleared by Seaborn. + if i % col_wrap_val == 0: + expected_ylabel = "group" + message = (f"Axes at index {i} (first in a wrapped row) " + "should have y-label 'group'") + else: + expected_ylabel = "" + message = (f"Axes at index {i} (inner in a wrapped row) " + "should have empty y-label") + self.assertEqual(ax_item.get_ylabel(), expected_ylabel, message) def test_stratify_no_facet(self): """ - Tests the scenario with stratify_by and facet_plot=False. - Verifies the presence of multiple figures, one per group. + Tests stratify_by and facet_plot=False. + Verifies list of Axes, one per group. """ - df_strat = pd.DataFrame({ - 'cellid': ['C1', 'C2', 'C3', 'C4'], - 'group': ['g1', 'g1', 'g2', 'g2'], - 'distance': [0.5, 1.5, 0.7, 1.2], - 'phenotype': ['p1', 'p1', 'p2', 'p2'], - 'region': ['R1', 'R1', 'R2', 'R2'] - }) result = _plot_spatial_distance_dispatch( - df_long=df_strat, + df_long=self.df_strat_and_hue, method='distribution', plot_type='hist', - stratify_by='region', + stratify_by='region', # 'region' has R1, R2 facet_plot=False, - bins=5 + bins=5 # Passed to Seaborn ) - assert_frame_equal(result['data'], df_strat) - figs = result['fig'] - - # Verifies the expected number of figures - self.assertIsInstance(figs, list) - self.assertEqual(len(figs), 2) # R1 and R2 + assert_frame_equal(result['data'], self.df_strat_and_hue) + axes_list = result['ax'] + self.assertIsInstance(axes_list, list) + # Expect one Axes per unique value in 'region' + self.assertEqual(len(axes_list), self.df_strat_and_hue['region'].nunique()) - # Verifies each figure has at least one axis - for fig in figs: - self.assertTrue(len(fig.axes) > 0) + for ax_item in axes_list: + self.assertIsInstance(ax_item, MatplotlibAxes) + self.assertEqual(ax_item.get_xlabel(), "Nearest Neighbor Distance") + self.assertTrue(len(ax_item.get_ylabel()) > 0) # e.g., 'Count' + def test_no_stratify_displot_kwargs_facet(self): + """ + Tests no stratify_by, but displot kwargs cause faceting. + Verifies list of Axes. + """ + # 'group' in df_strat_and_hue has 'g1', 'g2' + result = _plot_spatial_distance_dispatch( + df_long=self.df_strat_and_hue, + method='distribution', + plot_type='kde', + # No stratify_by, facet_plot=False (default) + # kwargs will cause faceting within _make_axes_object + col='group' # Facet by 'group' column using displot's 'col' + ) + assert_frame_equal(result['data'], self.df_strat_and_hue) + ax_list = result['ax'] + self.assertIsInstance(ax_list, list, "Expected a list of Axes due to 'col' kwarg in displot.") + # Expect one Axes per unique value in 'group' + self.assertEqual(len(ax_list), self.df_strat_and_hue['group'].nunique()) + for ax_item in ax_list: + self.assertIsInstance(ax_item, MatplotlibAxes) + self.assertEqual(ax_item.get_xlabel(), "Nearest Neighbor Distance") + + + def test_invalid_method_raises_value_error(self): + """Tests that an invalid method raises a ValueError.""" + with self.assertRaisesRegex(ValueError, "`method` must be 'numeric' or 'distribution'."): + _plot_spatial_distance_dispatch( + df_long=self.df_basic, + method='invalid_method', + plot_type='box' + ) if __name__ == '__main__': unittest.main() diff --git a/tests/test_visualization/test_visualize_nearest_neighbor.py b/tests/test_visualization/test_visualize_nearest_neighbor.py index be21d276..31b095fa 100644 --- a/tests/test_visualization/test_visualize_nearest_neighbor.py +++ b/tests/test_visualization/test_visualize_nearest_neighbor.py @@ -1,8 +1,12 @@ +import os +import sys +sys.path.append(os.path.dirname(os.path.realpath(__file__)) + "/../../src") import unittest import pandas as pd import numpy as np import anndata import matplotlib +import matplotlib.collections as mcoll matplotlib.use('Agg') # Uses a non-interactive backend for tests import matplotlib.pyplot as plt from spac.visualization import visualize_nearest_neighbor @@ -10,127 +14,414 @@ class TestVisualizeNearestNeighbor(unittest.TestCase): - def setUp(self): - # Creates a minimal AnnData object with two cells of - # different phenotypes - data = np.array([[1.0], [2.0]]) - obs = pd.DataFrame( - { - 'cell_type': ['type1', 'type2'], - 'imageid': ['img1', 'img1'] - }, - index=['CellA', 'CellB'] - ) - self.adata = anndata.AnnData(X=data, obs=obs) - - # Creates a numeric spatial_distance DataFrame - # Each row corresponds to a cell, columns are phenotypes - # Distances represent the nearest distance from that cell - # to the given phenotype - dist_df = pd.DataFrame( - { - 'type1': [0.0, np.sqrt(2)], - 'type2': [np.sqrt(2), 0.0] - }, - index=['CellA', 'CellB'] - ) - # Ensure numeric dtype + @staticmethod + def _create_test_adata(): + """ + Creates a common AnnData object for testing various scenarios. + """ + data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]) + obs = pd.DataFrame({ + 'cell_type': ['typeA', 'typeB', 'typeA', 'typeC'], + 'image_id': ['img1', 'img1', 'img2', 'img2'] + }, index=['Cell1', 'Cell2', 'Cell3', 'Cell4']) + adata = anndata.AnnData(X=data, obs=obs) + + dist_df = pd.DataFrame({ + 'typeA': [0.0, 1.0, 0.0, 2.0], + 'typeB': [1.0, 0.0, 4.0, 3.0], + 'typeC': [2.0, 3.0, 5.0, 0.0] + }, index=['Cell1', 'Cell2', 'Cell3', 'Cell4']) dist_df = dist_df.astype(float) - self.adata.obsm['spatial_distance'] = dist_df + adata.obsm['spatial_distance'] = dist_df + return adata + + def setUp(self): + """Set up basic AnnData object for tests.""" + self.adata = TestVisualizeNearestNeighbor._create_test_adata() def tearDown(self): - # Closes all figures to prevent memory issues + """Close all Matplotlib figures after each test.""" plt.close('all') - def test_missing_distance_from(self): + def test_output_structure_and_types_single_plot(self): """ - Tests that the function raises a ValueError when 'distance_from' - is not provided, matching the exact error message. + Tests the basic output structure and types for a single plot. """ - expected_msg = ( - "Please specify the 'distance_from' phenotype. It indicates " - "the reference group from which distances are measured." + result = visualize_nearest_neighbor( + adata=self.adata, + annotation='cell_type', + distance_from='typeA', + distance_to='typeB', + method='numeric', + plot_type='box' ) - with self.assertRaisesRegex(ValueError, expected_msg): - visualize_nearest_neighbor( - adata=self.adata, - annotation='cell_type', - distance_from=None, - method='numeric' - ) - def test_invalid_method(self): + self.assertIsInstance(result, dict, "Result should be a dict.") + expected_keys = ['data', 'fig', 'ax', 'palette'] + for key in expected_keys: + self.assertIn(key, result, f"Key '{key}' missing in result.") + + self.assertIsInstance( + result['data'], pd.DataFrame, "'data' should be a DataFrame." + ) + self.assertIsInstance( + result['fig'], matplotlib.figure.Figure, + "'fig' should be a Figure." + ) + self.assertIsInstance( + result['ax'], matplotlib.axes.Axes, + "'ax' should be an Axes object for a single plot." + ) + self.assertIsInstance( + result['palette'], dict, "'palette' should be a dictionary." + ) + + self.assertEqual( + len(result['fig'].axes), 1, + "Single plot figure should contain one axis." + ) + self.assertIs( + result['ax'], result['fig'].axes[0], + "Returned 'ax' should be the axis in 'fig'." + ) + self.assertIs( + result['ax'].figure, result['fig'], + "Returned 'ax' should belong to returned 'fig'." + ) + + def test_minimal_numeric_plot_axis_labels_and_palette(self): """ - Tests that the function raises a ValueError when 'method' is invalid, - matching the exact error message. + Tests a minimal numeric plot, focusing on axis labels and palette. """ - expected_msg = ( - "Invalid 'method'. Please choose 'numeric' or 'distribution'." + result = visualize_nearest_neighbor( + adata=self.adata, + annotation='cell_type', + distance_from='typeA', + distance_to='typeB', + method='numeric', + plot_type='box' ) - with self.assertRaisesRegex(ValueError, expected_msg): - visualize_nearest_neighbor( - adata=self.adata, - annotation='cell_type', - distance_from='type1', - method='invalid_method' + + ax = result['ax'] + self.assertEqual( + ax.get_xlabel(), "Nearest Neighbor Distance", + "X-axis label mismatch." + ) + self.assertEqual( + ax.get_ylabel(), "group", + "Y-axis label mismatch for catplot." + ) + + self.assertIn( + 'typeB', result['palette'], + "'typeB' should be in the generated palette." + ) + self.assertTrue( + result['palette']['typeB'].startswith('#'), + "Palette color should be hex." + ) + + def test_minimal_distribution_plot_axis_labels(self): + """ + Tests a minimal distribution plot, focusing on axis labels. + """ + result = visualize_nearest_neighbor( + adata=self.adata, + annotation='cell_type', + distance_from='typeA', + distance_to='typeB', + method='distribution', + plot_type='kde' + ) + + self.assertTrue( + len(result['fig'].axes) > 0, + "Figure should have axes for displot." + ) + ax = result['fig'].axes[0] + self.assertEqual( + ax.get_xlabel(), "Nearest Neighbor Distance", + "X-axis label mismatch." + ) + self.assertEqual( + ax.get_ylabel(), "Density", + "Y-axis label for KDE plot mismatch." + ) + + def test_stratify_by_facet_plot_true_output_structure(self): + """ + Tests output structure for stratify_by with facet_plot=True. + """ + result = visualize_nearest_neighbor( + adata=self.adata, + annotation='cell_type', + distance_from='typeA', + distance_to='typeB', + method='numeric', + stratify_by='image_id', + facet_plot=True + ) + self.assertIsInstance( + result['fig'], matplotlib.figure.Figure, + "Should be a single Figure for facet plot." + ) + self.assertIsInstance( + result['ax'], list, + "'ax' should be a list of Axes for facet plot." + ) + expected_num_facets = self.adata.obs['image_id'].nunique() + self.assertEqual( + len(result['ax']), expected_num_facets, + "Number of axes should match unique categories in stratify_by." + ) + for ax_item in result['ax']: + self.assertIsInstance( + ax_item, matplotlib.axes.Axes, + "Each item in 'ax' list should be an Axes object." + ) + self.assertIs( + ax_item.figure, result['fig'], + "All facet axes should belong to the same figure." ) - def test_simple_numeric_scenario(self): + def test_stratify_by_facet_plot_false_output_structure(self): """ - Tests a simple numeric scenario without stratification. - Verifies output keys and basic figure attributes. + Tests output structure for stratify_by with facet_plot=False. """ result = visualize_nearest_neighbor( adata=self.adata, annotation='cell_type', - distance_from='type1', - distance_to='type2', + distance_from='typeA', + distance_to='typeB', method='numeric', - plot_type='box' + stratify_by='image_id', + facet_plot=False ) + num_categories = self.adata.obs['image_id'].nunique() + self.assertIsInstance( + result['fig'], list, + "Should be a list of Figures for non-faceted stratified plot." + ) + self.assertEqual( + len(result['fig']), num_categories, + "Number of figures should match unique categories." + ) + for fig_item in result['fig']: + self.assertIsInstance( + fig_item, matplotlib.figure.Figure, + "Each item in 'fig' list should be a Figure." + ) - self.assertIn('data', result) - self.assertIn('fig', result) + self.assertIsInstance( + result['ax'], list, "'ax' should be a list of Axes." + ) + self.assertEqual( + len(result['ax']), num_categories, + "Number of axes should match unique categories." + ) + for i, ax_item in enumerate(result['ax']): + self.assertIsInstance( + ax_item, matplotlib.axes.Axes, + "Each item in 'ax' list should be an Axes object." + ) + self.assertIs( + ax_item.figure, result['fig'][i], + "Each ax should belong to its corresponding fig." + ) - # Verifies the returned DataFrame matches expected structure - df = result['data'] - self.assertIsInstance(df, pd.DataFrame) - self.assertIn('group', df.columns) - self.assertIn('distance', df.columns) + def test_defined_color_map_generates_correct_palette(self): + """ + Tests that a defined_color_map is correctly processed. + """ + self.adata.uns['my_colors'] = { + 'typeA': 'rgb(255,0,0)', + 'typeB': '#00FF00', + 'typeC': 'blue' + } + result = visualize_nearest_neighbor( + adata=self.adata, + annotation='cell_type', + distance_from='typeA', + distance_to=['typeB', 'typeC'], + method='numeric', + defined_color_map='my_colors' + ) + expected_palette = { + 'typeB': '#00ff00', + 'typeC': 'blue' + } + self.assertEqual( + result['palette']['typeB'], expected_palette['typeB'] + ) + self.assertEqual( + result['palette']['typeC'], expected_palette['typeC'] + ) + self.assertNotIn( + 'typeA', result['palette'], + "Source phenotype 'typeA' should not be in palette keys." + ) - fig = result['fig'] - # Ensures there is at least one axis in the figure - self.assertTrue(len(fig.axes) > 0) - # Verifies axis labels - ax = fig.axes[0] - self.assertEqual(ax.get_xlabel(), "Nearest Neighbor Distance") - self.assertIn('group', ax.get_ylabel()) + def test_default_plot_type_selection(self): + """ + Tests that plot_type defaults correctly and axis labels are set. + """ + # Test numeric default ('boxen' plot via catplot) + res_numeric = visualize_nearest_neighbor( + adata=self.adata, + annotation='cell_type', + distance_from='typeA', + distance_to='typeB', + method='numeric' + ) + self.assertIsInstance( + res_numeric['fig'], matplotlib.figure.Figure, + "Numeric default should generate a Figure object." + ) + self.assertIsInstance( + res_numeric['ax'], matplotlib.axes.Axes, + "Numeric default should return an Axes object." + ) + ax_numeric = res_numeric['ax'] + self.assertEqual( + ax_numeric.get_xlabel(), "Nearest Neighbor Distance", + "X-axis label for numeric default mismatch." + ) + self.assertEqual( + ax_numeric.get_ylabel(), "group", + "Y-axis label for numeric default (catplot) mismatch." + ) - def test_visualize_with_log_distance(self): + # Test distribution default ('kde' plot via displot) + res_dist = visualize_nearest_neighbor( + adata=self.adata, + annotation='cell_type', + distance_from='typeA', + distance_to='typeB', + method='distribution' + ) + self.assertIsInstance( + res_dist['fig'], matplotlib.figure.Figure, + "Distribution default should generate a Figure object." + ) + self.assertTrue( + len(res_dist['fig'].axes) > 0, + "Distribution plot figure should have axes." + ) + ax_dist = res_dist['fig'].axes[0] + self.assertEqual( + ax_dist.get_xlabel(), "Nearest Neighbor Distance", + "X-axis label for distribution default mismatch." + ) + self.assertEqual( + ax_dist.get_ylabel(), "Density", + "Y-axis label for distribution default (KDE) mismatch." + ) + + def test_legend_default_is_false_passed_to_dispatch(self): """ - Test that visualize_nearest_neighbor correctly handles log-transformed - distances and uses the 'log_distance' column in the output. + Tests that legend=False is passed to dispatch by default. """ result = visualize_nearest_neighbor( adata=self.adata, annotation='cell_type', - distance_from='type1', - spatial_distance='spatial_distance', - log=True, + distance_from='typeA', + distance_to=['typeB', 'typeC'], method='numeric', plot_type='box' ) + fig = result['fig'] + self.assertEqual( + len(fig.legends), 0, + "Figure should not have a legend by default." + ) + if isinstance(result['ax'], matplotlib.axes.Axes): + self.assertIsNone( + result['ax'].get_legend(), "Axes should not have a legend." + ) + elif isinstance(result['ax'], list): + for ax_item in result['ax']: + self.assertIsNone( + ax_item.get_legend(), + "Each Axes in list should not have a legend." + ) - df_long = result['data'] + def test_legend_can_be_overridden_via_kwargs(self): + """ + Tests that the user can pass legend=True via kwargs. + """ + result = visualize_nearest_neighbor( + adata=self.adata, + annotation='cell_type', + distance_from='typeA', + distance_to=['typeB', 'typeC'], + method='numeric', + plot_type='box', + legend=True + ) fig = result['fig'] + self.assertEqual( + len(fig.legends), 1, + "Figure should have exactly one legend when legend=True." + ) + if fig.legends: # Should be true given the assertion above + legend_texts = [ + text.get_text() for text in fig.legends[0].get_texts() + ] + self.assertIn('typeB', legend_texts) + self.assertIn('typeC', legend_texts) + + def test_error_invalid_method_in_visualize_nearest_neighbor(self): + """ + Tests ValueError if 'method' is invalid. + """ + expected_msg = ("Invalid 'method'. Please choose 'numeric' or " + "'distribution'.") + with self.assertRaisesRegex(ValueError, expected_msg): + visualize_nearest_neighbor( + adata=self.adata, + annotation='cell_type', + distance_from='typeA', + method='invalid_plot_method' + ) + + # log=True → verify x‑axis label and log‑transformed max value + def test_log_distance_scale_and_axis(self): + """With log=True the data and axis must be on log scale (max ≈ 1.60944).""" + res_log = visualize_nearest_neighbor( + adata=self.adata, + annotation='cell_type', + distance_from='typeA', + distance_to='typeB', + method='numeric', + plot_type='box', + log=True, + ) - # Ensure 'log_distance' is used - self.assertIn('log_distance', df_long.columns) - self.assertNotIn('distance', df_long.columns) + df = res_log['data'] + # Hard‑coded expectation: transformed column exists + self.assertIn('log_distance', df.columns) - # Validate the plot uses the correct label for log-transformed distance - ax = fig.axes[0] - self.assertEqual(ax.get_xlabel(), "Log(Nearest Neighbor Distance)") + # Hard‑coded expected max of log1p(4.0) ≈ 1.6094379124341003 + expected_max_log = 1.6094379124341003 + self.assertAlmostEqual( + df['log_distance'].max(), + expected_max_log, + places=6, + msg='log_distance max is not log1p(4.0)' + ) + + ax = res_log['ax'] + # Axis label must reflect log scale + self.assertEqual( + ax.get_xlabel(), + 'Log(Nearest Neighbor Distance)' + ) + # x‑axis upper limit should at least reach the expected max value + self.assertGreaterEqual( + ax.get_xlim()[1], + expected_max_log, + msg='x‑axis upper limit less than expected log max' + ) if __name__ == '__main__':