diff --git a/app.py b/app.py index a2635c2..3f6a4f5 100644 --- a/app.py +++ b/app.py @@ -116,6 +116,7 @@ def server(input, output, session): shared = { "preloaded_data": preloaded_data, # Preloaded data for initial load + "preloaded_file_path": file_path, # Path to preloaded file for filename extraction "data_loaded": data_loaded, # Reactive to track if data is loaded "adata_main": adata_main, # Main anndata object } @@ -124,6 +125,11 @@ def server(input, output, session): # and add them to the shared dictionary for key in data_keys: shared[key] = reactive.Value(None) + + # Add reactive value for input filename + shared['input_filename'] = reactive.Value(None) + shared['boxplot_fig'] = reactive.Value(None) + shared['sankey_fig'] = reactive.Value(None) # Individual server components getting_started_server(input, output, session, shared) diff --git a/server/anno_vs_anno_server.py b/server/anno_vs_anno_server.py index 90a96a1..d13782b 100644 --- a/server/anno_vs_anno_server.py +++ b/server/anno_vs_anno_server.py @@ -11,50 +11,94 @@ def anno_vs_anno_server(input, output, session, shared): @reactive.event(input.go_sk1, ignore_none=True) def spac_Sankey(): adata = ad.AnnData( - X=shared['X_data'].get(), - obs=pd.DataFrame(shared['obs_data'].get()), - layers=shared['layers_data'].get(), + X=shared['X_data'].get(), + obs=pd.DataFrame(shared['obs_data'].get()), + layers=shared['layers_data'].get(), dtype=shared['X_data'].get().dtype ) - if adata is not None: - fig = spac.visualization.sankey_plot( - adata, - source_annotation=input.sk1_anno1(), - target_annotation=input.sk1_anno2() - ) - return fig - return None + if adata is None: + return None + fig = spac.visualization.sankey_plot( + adata, + source_annotation=input.sk1_anno1(), + target_annotation=input.sk1_anno2() + ) + shared['sankey_fig'].set(fig) # Store figure for HTML download + return fig @output @render_widget @reactive.event(input.go_rhm1, ignore_none=True) def spac_Relational(): adata = ad.AnnData( - X=shared['X_data'].get(), + X=shared['X_data'].get(), obs=pd.DataFrame(shared['obs_data'].get()) ) - if adata is not None: - result = spac.visualization.relational_heatmap( - adata, - source_annotation=input.rhm_anno1(), - target_annotation=input.rhm_anno2() - ) - shared['df_relational'].set(result['data']) - return result['figure'] - return None + if adata is None: + return None + result = spac.visualization.relational_heatmap( + adata, + source_annotation=input.rhm_anno1(), + target_annotation=input.rhm_anno2() + ) + shared['df_relational'].set(result['data']) + return result['figure'] + @render.download(filename="relational_data.csv") def download_df_1(): df = shared['df_relational'].get() - if df is not None: - csv_string = df.to_csv(index=False) - csv_bytes = csv_string.encode("utf-8") - return csv_bytes, "text/csv" - return None + if df is None: + return None + csv_string = df.to_csv(index=False) + csv_bytes = csv_string.encode("utf-8") + return csv_bytes, "text/csv" + @render.ui @reactive.event(input.go_rhm1, ignore_none=True) def download_button_ui_1(): - if shared['df_relational'].get() is not None: - return ui.download_button("download_df_1", "Download Data", class_="btn-warning") - return None + if shared['df_relational'].get() is None: + return None + return ui.download_button("download_df_1", "Download Data", class_="btn-warning") + + def get_sankey_html_filename(): + """Generate HTML download filename for sankey.""" + input_filename = shared['input_filename'].get() + if input_filename: + return f"{input_filename}_sankey.html" + return "sankey.html" + + @render.download(filename=get_sankey_html_filename) + def download_sankey_html(): + fig = shared['sankey_fig'].get() + if fig is None: + return None + html_string = fig.to_html(include_plotlyjs='cdn') + html_bytes = html_string.encode("utf-8") + return html_bytes, "text/html" + + @render.ui + @reactive.event(input.go_sk1, ignore_none=True) + def download_button_ui_sankey(): + if shared['sankey_fig'].get() is None: + return None + return ui.input_action_button( + "show_download_modal_sankey", + "Download Data", + class_="btn-warning" + ) + + @reactive.Effect + @reactive.event(input.show_download_modal_sankey) + def show_download_modal_sankey(): + m = ui.modal( + ui.div( + ui.download_button("download_sankey_html", "HTML", class_="btn-primary"), + style="display: flex; gap: 10px; justify-content: center;" + ), + title="Select a Format:", + easy_close=True, + footer=None + ) + ui.modal_show(m) diff --git a/server/annotations_server.py b/server/annotations_server.py index 04242a0..9a2734f 100644 --- a/server/annotations_server.py +++ b/server/annotations_server.py @@ -17,24 +17,24 @@ def spac_Histogram_2(): adata, annotation=input.h2_anno() ).values() - shared['df_histogram2'].set(df) + shared['df_histogram2'].set(df) ax.tick_params(axis='x', rotation=input.anno_slider(), labelsize=10) return fig - # 2) If "Group By" is CHECKED, we must always supply a + # 2) If "Group By" is CHECKED, we must always supply a # valid multiple parameter else: - # If user also checked "Plot Together", use their selected + # If user also checked "Plot Together", use their selected # stack type if input.h2_together_check(): # e.g. 'stack', 'dodge', etc. - multiple_param = input.h2_together_drop() + multiple_param = input.h2_together_drop() together_flag = True else: # If grouping by but not "plot together", pick a default layout # or 'dodge' or any valid string - multiple_param = "layer" + multiple_param = "layer" together_flag = False fig, ax, df = spac.visualization.histogram( @@ -44,12 +44,12 @@ def spac_Histogram_2(): together=together_flag, multiple=multiple_param ).values() - shared['df_histogram2'].set(df) + shared['df_histogram2'].set(df) axes = ax if isinstance(ax, (list, np.ndarray)) else [ax] for ax in axes: ax.tick_params( - axis='x', - rotation=input.anno_slider(), + axis='x', + rotation=input.anno_slider(), labelsize=10 ) return fig @@ -59,23 +59,24 @@ def spac_Histogram_2(): @render.ui @reactive.event(input.go_h2, ignore_none=True) def download_histogram_button_ui(): - if shared['df_histogram2'].get() is not None: - return ui.download_button( - "download_histogram2_df", - "Download Data", - class_="btn-warning" - ) - return None + if shared['df_histogram2'].get() is None: + return None + return ui.download_button( + "download_histogram2_df", + "Download Data", + class_="btn-warning" + ) + @render.download(filename="annotation_histogram_data.csv") def download_histogram2_df(): df = shared['df_histogram2'].get() - if df is not None: - csv_string = df.to_csv(index=False) - csv_bytes = csv_string.encode("utf-8") - return csv_bytes, "text/csv" - return None + if df is None: + return None + csv_string = df.to_csv(index=False) + csv_bytes = csv_string.encode("utf-8") + return csv_bytes, "text/csv" histogram2_ui_initialized = reactive.Value(False) @@ -86,8 +87,8 @@ def histogram_reactivity_2(): if btn and not ui_initialized: dropdown = ui.input_select( - "h2_anno_1", - "Select an Annotation", + "h2_anno_1", + "Select an Annotation", choices=shared['obs_names'].get() ) ui.insert_ui( @@ -97,8 +98,8 @@ def histogram_reactivity_2(): ) together_check = ui.input_checkbox( - "h2_together_check", - "Plot Together", + "h2_together_check", + "Plot Together", value=True ) ui.insert_ui( @@ -120,18 +121,18 @@ def histogram_reactivity_2(): def update_stack_type_dropdown(): if input.h2_together_check(): dropdown_together = ui.input_select( - "h2_together_drop", - "Select Stack Type", - choices=['stack', 'layer', 'dodge', 'fill'], + "h2_together_drop", + "Select Stack Type", + choices=['stack', 'layer', 'dodge', 'fill'], selected='stack' ) ui.insert_ui( ui.div({ - "id": "inserted-dropdown_together-1"}, + "id": "inserted-dropdown_together-1"}, dropdown_together ), selector="#main-h2_together_drop", where="beforeEnd" - ) + ) else: ui.remove_ui("#inserted-dropdown_together-1") diff --git a/server/boxplot_server.py b/server/boxplot_server.py index c381492..e89148d 100644 --- a/server/boxplot_server.py +++ b/server/boxplot_server.py @@ -6,7 +6,7 @@ def boxplot_server(input, output, session, shared): - # Helper functions for reusability +# Helper functions for reusability def on_outlier_check(): selected_choice = input.bp_outlier_check() return None if selected_choice == "none" else selected_choice @@ -32,60 +32,97 @@ def spac_Boxplot(): if not input.bp_output_type(): return None - else: - + else: adata = ad.AnnData( - X=shared['X_data'].get(), - obs=pd.DataFrame(shared['obs_data'].get()), - var=pd.DataFrame(shared['var_data'].get()), - layers=shared['layers_data'].get(), + X=shared['X_data'].get(), + obs=pd.DataFrame(shared['obs_data'].get()), + var=pd.DataFrame(shared['var_data'].get()), + layers=shared['layers_data'].get(), dtype=shared['X_data'].get().dtype ) # Proceed only if adata is valid - if adata is not None and adata.var is not None: - - fig, df = spac.visualization.boxplot_interactive( - adata, - annotation=on_anno_check(), - layer=on_layer_check(), - features=list(input.bp_features()), - showfliers=on_outlier_check(), - log_scale=input.bp_log_scale(), - orient=on_orient_check(), - figure_height=3, - figure_width=4.8, - figure_type="interactive" - ).values() - - # Return the interactive Plotly figure object - shared['df_boxplot'].set(df) - print(type(fig)) - return fig - - return None - - - @render.download(filename="boxplot_data.csv") + if adata is None and adata.var is None: + return None + + fig, df = spac.visualization.boxplot_interactive( + adata, + annotation=on_anno_check(), + layer=on_layer_check(), + features=list(input.bp_features()), + showfliers=on_outlier_check(), + log_scale=input.bp_log_scale(), + orient=on_orient_check(), + figure_height=3, + figure_width=4.8, + figure_type="interactive" + ).values() + + # Return the interactive Plotly figure object + shared['df_boxplot'].set(df) + shared['boxplot_fig'].set(fig) # Store figure for HTML download + print(type(fig)) + return fig + + + + def get_boxplot_csv_filename(): + """Generate CSV download filename.""" + input_filename = shared['input_filename'].get() + if input_filename: + return f"{input_filename}_boxplot.csv" + return "boxplot.csv" + + def get_boxplot_html_filename(): + """Generate HTML download filename.""" + input_filename = shared['input_filename'].get() + if input_filename: + return f"{input_filename}_boxplot.html" + return "boxplot.html" + + @render.download(filename=get_boxplot_csv_filename) def download_boxplot(): df = shared['df_boxplot'].get() - if df is not None: - csv_string = df.to_csv(index=False) - csv_bytes = csv_string.encode("utf-8") - return csv_bytes, "text/csv" - return None - + if df is None: + return None + csv_string = df.to_csv(index=False) + csv_bytes = csv_string.encode("utf-8") + return csv_bytes, "text/csv" + + @render.download(filename=get_boxplot_html_filename) + def download_boxplot_html(): + fig = shared['boxplot_fig'].get() + if fig is None: + return None + html_string = fig.to_html(include_plotlyjs='cdn') + html_bytes = html_string.encode("utf-8") + return html_bytes, "text/html" @render.ui @reactive.event(input.go_bp, ignore_none=True) def download_button_ui1(): - if shared['df_boxplot'].get() is not None: - return ui.download_button( - "download_boxplot", - "Download Data", - class_="btn-warning" - ) - return None + if shared['df_boxplot'].get() is None: + return None + return ui.input_action_button( + "show_download_modal_bp", + "Download Data", + class_="btn-warning" + ) + + @reactive.Effect + @reactive.event(input.show_download_modal_bp) + def show_download_modal(): + m = ui.modal( + ui.div( + ui.download_button("download_boxplot", "CSV", class_="btn-primary me-2"), + ui.download_button("download_boxplot_html", "HTML", class_="btn-primary"), + style="display: flex; gap: 10px; justify-content: center;" + ), + title="Select a Format:", + easy_close=True, + footer=None + ) + ui.modal_show(m) @output @@ -96,38 +133,37 @@ def boxplot_static(): This function produces a static (Plotly) boxplot image. """ - # Only run this function if both conditions are met + # Only run this function if both conditions are met if input.bp_output_type(): return None - else: - + else: adata = ad.AnnData( - X=shared['X_data'].get(), - obs=pd.DataFrame(shared['obs_data'].get()), - var=pd.DataFrame(shared['var_data'].get()), - layers=shared['layers_data'].get(), + X=shared['X_data'].get(), + obs=pd.DataFrame(shared['obs_data'].get()), + var=pd.DataFrame(shared['var_data'].get()), + layers=shared['layers_data'].get(), dtype=shared['X_data'].get().dtype ) # Proceed only if adata is valid - if adata is not None and adata.var is not None: - - fig, df = spac.visualization.boxplot_interactive( - adata, - annotation=on_anno_check(), - layer=on_layer_check(), - features=list(input.bp_features()), - showfliers=on_outlier_check(), - log_scale=input.bp_log_scale(), - orient=on_orient_check(), - figure_height=3, - figure_width=4.8, - figure_type="static" - ).values() - - return fig - - return None + if adata is None and adata.var is None: + return None + + fig, df = spac.visualization.boxplot_interactive( + adata, + annotation=on_anno_check(), + layer=on_layer_check(), + features=list(input.bp_features()), + showfliers=on_outlier_check(), + log_scale=input.bp_log_scale(), + orient=on_orient_check(), + figure_height=3, + figure_width=4.8, + figure_type="static" + ).values() + + return fig + diff --git a/server/data_input_server.py b/server/data_input_server.py index ba7cbc0..f83ce23 100644 --- a/server/data_input_server.py +++ b/server/data_input_server.py @@ -1,38 +1,93 @@ +""" +Data input server module for SPAC Shiny application. + +This module handles file uploads and data loading with caching support +for improved performance across all analysis modules. +""" + +import os +import re + from shiny import render, reactive -import pickle -import anndata as ad +from utils.data_processing import cached_load_data + + +def sanitize_filename(filename): + """Extract base filename and sanitize it for use in download names.""" + # Remove path and extension + base_name = os.path.splitext(os.path.basename(filename))[0] + # Replace spaces and special characters with underscores + sanitized = re.sub(r'[^a-zA-Z0-9_-]', '_', base_name) + # Remove multiple consecutive underscores + sanitized = re.sub(r'_+', '_', sanitized) + # Remove leading/trailing underscores + sanitized = sanitized.strip('_') + return sanitized def data_input_server(input, output, session, shared): + """ + Server logic for data input and loading. + + Parameters + ---------- + input : shiny.session.Inputs + Shiny input object + output : shiny.session.Outputs + Shiny output object + session : shiny.session.Session + Shiny session object + shared : dict + Shared reactive values across modules + """ + @reactive.Effect def adata_filter(): + """ + Load and cache AnnData from file upload or preloaded data. + + Uses cached loading to avoid repeated file I/O, providing + performance benefits across all analysis modules. + """ print("Calling Data") file_info = input.input_file() + if not file_info: # Only set preloaded data if it exists if shared['preloaded_data'] is not None: shared['adata_main'].set(shared['preloaded_data']) shared['data_loaded'].set(True) + # Extract filename from preloaded file path + preloaded_path = shared.get('preloaded_file_path', 'dev_example.pickle') + filename = sanitize_filename(preloaded_path) + shared['input_filename'].set(filename) else: shared['data_loaded'].set(False) + shared['input_filename'].set(None) else: file_path = file_info[0]['datapath'] - with open(file_path, 'rb') as file: - if file_path.endswith('.pickle'): - shared['adata_main'].set(pickle.load(file)) - elif file_path.endswith('.h5ad'): - shared['adata_main'].set(ad.read_h5ad(file_path)) - else: - shared['adata_main'].set(ad.read(file_path)) + # Extract and store filename + filename = sanitize_filename(file_path) + shared['input_filename'].set(filename) + + # Use cached loader for performance - data shared across all modules + shared['adata_main'].set(cached_load_data(file_path)) # Set to True if a file is successfully uploaded shared['data_loaded'].set(True) @reactive.Effect def update_parts(): + """ + Extract and update shared data components from loaded AnnData. + + Parses the main AnnData object and populates shared reactive + values for use across all modules (spatial, violin, etc.). + """ print("Updating Parts") adata = shared['adata_main'].get() + if adata is not None: - + # Extract all AnnData components if hasattr(adata, 'X'): shared['X_data'].set(adata.X) else: @@ -90,12 +145,12 @@ def update_parts(): else: shared['uns_names'].set(None) - # Extract spatial_distance column names if available via helper + # Extract spatial distance columns from utils.data_processing import get_spatial_distance_columns - spatial_cols = get_spatial_distance_columns(adata) shared['spatial_distance_columns'].set(spatial_cols) else: + # Clear all shared data if no AnnData loaded shared['obs_data'].set(None) shared['obsm_data'].set(None) shared['layers_data'].set(None) @@ -109,7 +164,9 @@ def update_parts(): shared['uns_names'].set(None) shared['spatial_distance_columns'].set(None) - + # ...existing render functions (print_obs_names, formatted_obs_names, etc.)... + # Keep all your existing @reactive.Calc and @render.text/ui functions + @reactive.Calc @render.text def print_obs_names(): @@ -124,7 +181,6 @@ def print_obs_names(): return "Annotations: " + obs_str else: return "Empty" - return @reactive.Calc @render.text @@ -140,19 +196,15 @@ def print_obsm_names(): return "Associated Tables: " + obsm_str else: return "Empty" - return @reactive.Calc @render.text def print_layers_names(): layers = shared['layers_names'].get() - # If there are no layers at all, just say "None" if not layers: return "Tables: None" - # If there's more than one layer if len(layers) > 1: layers_str = ", ".join(layers) - # If there's exactly one layer else: layers_str = layers[0] return "Tables: " + layers_str @@ -169,7 +221,6 @@ def print_uns_names(): else: uns_str = uns[0] if uns else "" return "Unstructured Data: " + uns_str - return @reactive.Calc @render.text @@ -181,7 +232,6 @@ def print_rows(): return str(shape[0]) else: return "Empty" - return @reactive.Calc @render.text @@ -190,12 +240,10 @@ def print_columns(): if not shape: return "None" if shape is not None: - return str(shape[1]) + return str(shape[1]) else: return "Empty" - return - # Formatted UI outputs for better display @reactive.Calc @render.ui def formatted_obs_names(): diff --git a/server/effect_update_server.py b/server/effect_update_server.py index 7a1ac9e..3c64ba5 100644 --- a/server/effect_update_server.py +++ b/server/effect_update_server.py @@ -16,7 +16,7 @@ def update_select_input_feat(): choices = shared['var_names'].get() ui.update_select("h1_feat", choices=choices) ui.update_select("umap_feat", choices=choices) - if choices is not None: + if choices: ui.update_select("bp_features", choices=choices) ui.update_selectize("bp_features", selected=choices[:2]) @@ -24,14 +24,13 @@ def update_select_input_feat(): def update_select_input_anno(): choices = shared['obs_names'].get() ui.update_select("bp_anno", choices=choices) - if choices is not None: + if choices: new_choices = choices + ["No Annotation"] ui.update_select("bp_anno", choices=new_choices) ui.update_select("h2_anno", choices=choices) ui.update_select("hm1_anno", choices=choices) - if choices is not None and len(choices) > 1: - + if choices and len(choices) > 1: ui.update_select("sk1_anno1", choices=choices) ui.update_selectize("sk1_anno1", selected=choices[0]) ui.update_select("sk1_anno2", choices=choices) @@ -52,7 +51,7 @@ def update_nearest_neighbor_choices(): # Get spatial_distance columns from shared state phenotype_choices = shared['spatial_distance_columns'].get() - if phenotype_choices is not None and len(phenotype_choices) > 0: + if phenotype_choices and len(phenotype_choices) > 0: # Update source label dropdown ui.update_select( "nn_source_label", @@ -78,7 +77,7 @@ def update_nearest_neighbor_choices(): @reactive.Effect def update_select_input_layer(): - if shared['layers_names'].get() is not None: + if shared['layers_names'].get(): new_choices = shared['layers_names'].get() + ["Original"] ui.update_select("h1_layer", choices=new_choices) ui.update_select("bp_layer", choices=new_choices) @@ -122,7 +121,7 @@ def update_rl_region_and_slide_labels(): if label_counts is None: return - if region_name is not None and region_name in label_counts: + if region_name and region_name in label_counts: # Use helper to fetch sorted labels for the region annotation from utils.data_processing import get_annotation_top_labels @@ -161,7 +160,7 @@ def print_obsm_names(): obsm = shared['obsm_names'].get() if not obsm: return "Associated Tables: None" - if obsm is not None: + if obsm: if len(obsm) > 1: obsm_str = ", ".join(obsm) else: @@ -234,7 +233,7 @@ def update_subset_labels(): adata = shared['adata_main'].get() selected_anno = input.subset_anno_select() - if adata is not None and selected_anno: + if adata and selected_anno: labels = adata.obs[selected_anno].unique().tolist() print(f"Updating labels for {selected_anno}: {labels}") ui.update_selectize("subset_label_select", choices=labels) @@ -243,7 +242,7 @@ def update_subset_labels(): @reactive.event(input.go_subset, ignore_none=True) def subset_stratification(): adata = shared['adata_main'].get() - if adata is not None: + if adata: annotation = input.subset_anno_select() labels = list(input.subset_label_select()) @@ -297,7 +296,7 @@ def print_subset_history(): def store_master_copy(): """Store a master copy of the adata object when first loaded.""" adata = shared['adata_main'].get() - if adata is not None and adata_master.get() is None: + if adata and adata_master.get() is None: # Make a copy of the adata object and store it as the master copy adata_master.set(adata.copy()) @@ -315,7 +314,7 @@ def restore_to_master(): Restore adata_main to the master copy stored in adata_master. """ master_data = adata_master.get() - if master_data is not None: + if master_data: # Set adata_main to a copy of the master data to ensure # independence shared['adata_main'].set(master_data.copy()) diff --git a/server/feat_vs_anno_server.py b/server/feat_vs_anno_server.py index b559023..f12e0fe 100644 --- a/server/feat_vs_anno_server.py +++ b/server/feat_vs_anno_server.py @@ -12,14 +12,14 @@ def on_layer_check(): def on_dendro_check(): ''' Check if dendrogram is enabled and return the appropriate values. - If dendrogram is enabled, + If dendrogram is enabled, return a tuple (annotation dendrogram, feature dendrogram). - If dendrogram is disabled, + If dendrogram is disabled, return (None, None) to indicate that no dendrogram is available. ''' return ( - (input.h2_anno_dendro(), input.h2_feat_dendro()) - if input.dendogram() + (input.h2_anno_dendro(), input.h2_feat_dendro()) + if input.dendogram() else (None, None) ) @@ -28,35 +28,34 @@ def on_dendro_check(): @reactive.event(input.go_hm1, ignore_none=True) def spac_Heatmap(): adata = ad.AnnData( - X=shared['X_data'].get(), - obs=pd.DataFrame(shared['obs_data'].get()), - var=pd.DataFrame(shared['var_data'].get()), - layers=shared['layers_data'].get(), + X=shared['X_data'].get(), + obs=pd.DataFrame(shared['obs_data'].get()), + layers=shared['layers_data'].get(), dtype=shared['X_data'].get().dtype ) - if adata is not None: + if adata: vmin = input.min_select() - vmax = input.max_select() - cmap = input.hm1_cmap() # Get the selected color map from the dropdown - kwargs = {"vmin": vmin,"vmax": vmax,} + vmax = input.max_select() + cmap = input.hm1_cmap() # Get the selected color map from the dropdown + kwargs = {"vmin": vmin,"vmax": vmax,} cluster_annotations, cluster_features = on_dendro_check() df, fig, ax = spac.visualization.hierarchical_heatmap( - adata, - annotation=input.hm1_anno(), - layer=on_layer_check(), - z_score=None, + adata, + annotation=input.hm1_anno(), + layer=on_layer_check(), + z_score=None, cluster_annotations=cluster_annotations, - cluster_feature=cluster_features, + cluster_feature=cluster_features, **kwargs ) # Only update if a non-default color map is selected - if cmap != "viridis": + if cmap != "viridis": fig.ax_heatmap.collections[0].set_cmap(cmap) shared['df_heatmap'].set(df) - + # Rotate x-axis labels fig.ax_heatmap.set_xticklabels( fig.ax_heatmap.get_xticklabels(), @@ -69,7 +68,7 @@ def spac_Heatmap(): return fig return None - + heatmap_ui_initialized = reactive.Value(False) @reactive.effect @@ -102,7 +101,7 @@ def heatmap_reactivity(): @render.download(filename="heatmap_data.csv") def download_df(): df = shared['df_heatmap'].get() - if df is not None: + if df: csv_string = df.to_csv(index=False) csv_bytes = csv_string.encode("utf-8") return csv_bytes, "text/csv" @@ -121,9 +120,9 @@ def download_button_ui(): @reactive.event(input.hm1_layer) def update_min_max(): adata = ad.AnnData( - X=shared['X_data'].get(), - obs=pd.DataFrame(shared['obs_data'].get()), - var=pd.DataFrame(shared['var_data'].get()), + X=shared['X_data'].get(), + obs=pd.DataFrame(shared['obs_data'].get()), + var=pd.DataFrame(shared['var_data'].get()), layers=shared['layers_data'].get() ) if input.hm1_layer() == "Original": @@ -139,10 +138,10 @@ def update_min_max(): ui.remove_ui("#inserted-max_num") min_num = ui.input_numeric( - "min_select", - "Minimum", - min_val, - min=min_val, + "min_select", + "Minimum", + min_val, + min=min_val, max=max_val ) ui.insert_ui( @@ -150,12 +149,12 @@ def update_min_max(): selector="#main-min_num", where="beforeEnd", ) - + max_num = ui.input_numeric( - "max_select", - "Maximum", - max_val, - min=min_val, + "max_select", + "Maximum", + max_val, + min=min_val, max=max_val ) ui.insert_ui( diff --git a/server/features_server.py b/server/features_server.py index b2207c3..eaf2171 100644 --- a/server/features_server.py +++ b/server/features_server.py @@ -15,10 +15,10 @@ def on_layer_check(): @reactive.event(input.go_h1, ignore_none=True) def spac_Histogram_1(): adata = ad.AnnData( - X=shared['X_data'].get(), - obs=pd.DataFrame(shared['obs_data'].get()), - var=pd.DataFrame(shared['var_data'].get()), - layers=shared['layers_data'].get(), + X=shared['X_data'].get(), + obs=pd.DataFrame(shared['obs_data'].get()), + var=pd.DataFrame(shared['var_data'].get()), + layers=shared['layers_data'].get(), dtype=shared['X_data'].get().dtype ) @@ -44,7 +44,7 @@ def spac_Histogram_1(): kwargs["together"] = input.h1_together_check() if input.h1_together_check(): kwargs["multiple"] = input.h1_together_drop() - + fig1, ax, df = spac.visualization.histogram(**kwargs).values() axes = ax if isinstance(ax, (list, np.ndarray)) else [ax] @@ -59,8 +59,17 @@ def spac_Histogram_1(): @render.download(filename="features_histogram_data.csv") def download_histogram1_df(): + """ + Download the histogram data as a CSV file. + + Returns + ------- + tuple or None + A tuple containing the CSV bytes and the MIME type if data is available, + otherwise None. + """ df = shared['df_histogram1'].get() - if df is not None: + if df is not None and not df.empty: csv_string = df.to_csv(index=False) csv_bytes = csv_string.encode("utf-8") return csv_bytes, "text/csv" @@ -70,10 +79,19 @@ def download_histogram1_df(): @render.ui @reactive.event(input.go_h1, ignore_none=True) def download_histogram1_button_ui(): - if shared['df_histogram1'].get() is not None: + """ + Render the download button for the histogram data. + + Returns + ------- + shiny.ui.Tag or None + The download button UI element if data is available, otherwise None. + """ + df = shared['df_histogram1'].get() + if df is not None and not df.empty: return ui.download_button( - "download_histogram1_df", - "Download Data", + "download_histogram1_df", + "Download Data", class_="btn-warning" ) return None @@ -86,8 +104,8 @@ def histogram_reactivity(): if btn and not ui_initialized: dropdown = ui.input_select( - "h1_anno", - "Select an Annotation", + "h1_anno", + "Select an Annotation", choices=shared['obs_names'].get() ) ui.insert_ui( @@ -97,8 +115,8 @@ def histogram_reactivity(): ) together_check = ui.input_checkbox( - "h1_together_check", - "Plot Together", + "h1_together_check", + "Plot Together", value=True ) ui.insert_ui( @@ -121,17 +139,17 @@ def histogram_reactivity(): def update_stack_type_dropdown(): if input.h1_together_check(): dropdown_together = ui.input_select( - "h1_together_drop", - "Select Stack Type", - choices=['stack', 'layer', 'dodge', 'fill'], + "h1_together_drop", + "Select Stack Type", + choices=['stack', 'layer', 'dodge', 'fill'], selected='stack' ) ui.insert_ui( ui.div( - {"id": "inserted-dropdown_together"}, + {"id": "inserted-dropdown_together"}, dropdown_together ), selector="#main-h1_together_drop", - where="beforeEnd",) + where="beforeEnd",) else: ui.remove_ui("#inserted-dropdown_together") diff --git a/server/nearest_neighbor_server.py b/server/nearest_neighbor_server.py index fe4cbb0..bd83cb3 100644 --- a/server/nearest_neighbor_server.py +++ b/server/nearest_neighbor_server.py @@ -88,14 +88,14 @@ def nn_color_mapping_ui(): """ adata = get_adata() choices = {"None": "None (Auto)"} - - if adata is not None: + + if adata: # Extract available color mappings from uns if hasattr(adata, 'uns') and adata.uns is not None: for key in adata.uns.keys(): if key.endswith('_color_map') or 'color' in key.lower(): choices[key] = key - + return ui.input_select( "nn_color_mapping", ui.tags.span( @@ -144,21 +144,21 @@ def nn_visualization_plot(): # Auto-detect annotation column matching spatial_distance phenotypes annotation = None spatial_distance_key = "spatial_distance" # Use hardcoded default - + # Check if spatial distance data is in obsm or uns distance_df = None if spatial_distance_key in adata.obsm: distance_df = adata.obsm[spatial_distance_key] elif spatial_distance_key in adata.uns: distance_df = adata.uns[spatial_distance_key] - - if distance_df is not None and hasattr(distance_df, 'columns'): + + if distance_df and hasattr(distance_df, 'columns'): spatial_phenotypes = set(distance_df.columns) - + # Find annotation column that contains matching phenotypes for col in adata.obs.columns: is_categorical = (adata.obs[col].dtype == 'object' or - adata.obs[col].dtype.name == 'category') + adata.obs[col].dtype.name == 'category') if is_categorical: obs_phenotypes = set(adata.obs[col].unique()) # Check if there's significant overlap (80%+) @@ -167,7 +167,7 @@ def nn_visualization_plot(): if len(overlap) >= len(spatial_phenotypes) * 0.8: annotation = col break - + if annotation is None: # Fallback: use the first categorical column for col in adata.obs.columns: @@ -199,7 +199,7 @@ def nn_visualization_plot(): "Annotation": annotation, "Source_Anchor_Cell_Label": source_label, "Target_Cell_Label": (",".join(target_labels) - if target_labels else "All"), + if target_labels else "All"), "ImageID": image_id or "None", "Plot_Method": input.nn_plot_method(), "Plot_Type": get_plot_type(), @@ -209,8 +209,8 @@ def nn_visualization_plot(): "X_Axis_Label_Rotation": input.nn_x_axis_rotation(), "Shared_X_Axis_Title_": input.nn_shared_x_title(), "X_Axis_Title_Font_Size": (input.nn_x_title_fontsize() - if input.nn_x_title_fontsize() - else "None"), + if input.nn_x_title_fontsize() + else "None"), "Defined_Color_Mapping": get_color_mapping() or "None", "Figure_Width": input.nn_figure_width(), "Figure_Height": input.nn_figure_height(), @@ -263,7 +263,7 @@ def download_df_nn(): CSV bytes and content type """ df = shared['df_nn'].get() - if df is not None: + if df: csv_string = df.to_csv(index=False) csv_bytes = csv_string.encode("utf-8") return csv_bytes, "text/csv" @@ -280,7 +280,7 @@ def download_button_ui_nn(): shiny.ui element or None Download button UI or None if no data """ - if shared['df_nn'].get() is not None: + if shared['df_nn'].get(): return ui.download_button( "download_df_nn", "Download Data", diff --git a/server/ripleyL_server.py b/server/ripleyL_server.py index 823d005..535e46a 100644 --- a/server/ripleyL_server.py +++ b/server/ripleyL_server.py @@ -107,7 +107,7 @@ def spac_ripley_l_plot(): @render.download(filename="ripley_plot_data.csv") def download_df_rl(): df = shared['df_ripley'].get() - if df is not None: + if df: csv_string = df.to_csv(index=False) csv_bytes = csv_string.encode("utf-8") return csv_bytes, "text/csv" @@ -116,7 +116,7 @@ def download_df_rl(): @render.ui @reactive.event(input.go_rl, ignore_none=True) def download_button_ui_rl(): - if shared['df_ripley'].get() is not None: + if shared['df_ripley'].get(): return ui.download_button( "download_df_rl", "Download Data", diff --git a/server/scatterplot_server.py b/server/scatterplot_server.py index cf2d9c7..64228e1 100644 --- a/server/scatterplot_server.py +++ b/server/scatterplot_server.py @@ -83,8 +83,8 @@ def scatter_reactivity(): if btn and not scatter_ui_initialized.get(): # Insert the color selection dropdown if not already initialized dropdown = ui.input_select( - "scatter_color", - "Select Feature", + "scatter_color", + "Select Feature", choices=shared['var_names'].get() ) ui.insert_ui( @@ -104,14 +104,14 @@ def get_color_values(): if selected_feature is None: return None adata = ad.AnnData( - X=shared['X_data'].get(), + X=shared['X_data'].get(), var=pd.DataFrame(shared['var_data'].get()) ) if selected_feature in adata.var_names: column_index = adata.var_names.get_loc(selected_feature) color_values = adata.X[:, column_index] return color_values - return None + return None @output @render.plot diff --git a/server/spatial_server.py b/server/spatial_server.py index 7ab1731..eaf1748 100644 --- a/server/spatial_server.py +++ b/server/spatial_server.py @@ -15,8 +15,8 @@ def slide_reactivity(): if btn and not ui_initialized: dropdown_slide = ui.input_select( - "slide_select_drop", - "Select the Slide Annotation", + "slide_select_drop", + "Select the Slide Annotation", choices=shared['obs_names'].get()) ui.insert_ui( ui.div({"id": "inserted-slide_dropdown"}, dropdown_slide), @@ -25,8 +25,8 @@ def slide_reactivity(): ) dropdown_label = ui.input_select( - "slide_select_label", - "Select a Slide", + "slide_select_label", + "Select a Slide", choices=[] ) ui.insert_ui( @@ -58,8 +58,8 @@ def region_reactivity(): if btn and not ui_initialized: dropdown_region = ui.input_select( - "region_select_drop", - "Select the Region Annotation", + "region_select_drop", + "Select the Region Annotation", choices=shared['obs_names'].get()) ui.insert_ui( ui.div({"id": "inserted-region_dropdown"}, dropdown_region), @@ -68,13 +68,13 @@ def region_reactivity(): ) dropdown_label = ui.input_select( - "region_label_select", + "region_label_select", "Select a Region", choices=[] ) ui.insert_ui( ui.div( - {"id": "inserted-region_label_select_dropdown"}, + {"id": "inserted-region_label_select_dropdown"}, dropdown_label ), selector="#main-region_label_select_dropdown", @@ -100,16 +100,16 @@ def update_region_select_drop(): @reactive.event(input.go_sp1, ignore_none=True) def spac_Spatial(): adata = ad.AnnData( - X=shared['X_data'].get(), - var=pd.DataFrame(shared['var_data'].get()), - obsm=shared['obsm_data'].get(), - obs=shared['obs_data'].get(), - dtype=shared['X_data'].get().dtype, + X=shared['X_data'].get(), + var=pd.DataFrame(shared['var_data'].get()), + obsm=shared['obsm_data'].get(), + obs=shared['obs_data'].get(), + dtype=shared['X_data'].get().dtype, layers=shared['layers_data'].get() ) slide_check = input.slide_select_check() region_check = input.region_select_check() - if adata is not None: + if adata: if slide_check is False and region_check is False: adata_subset = adata elif slide_check is True and region_check is False: @@ -139,7 +139,7 @@ def spac_Spatial(): if "spatial_feat" not in input or input.spatial_feat() is None: return None layer = ( - None if input.spatial_layer() == "Original" + None if input.spatial_layer() == "Original" else input.spatial_layer() ) out = spac.visualization.interactive_spatial_plot( @@ -163,29 +163,29 @@ def spac_Spatial(): else: return None out[0]['image_object'].update_xaxes( - showticklabels=True, - ticks="outside", - tickwidth=2, + showticklabels=True, + ticks="outside", + tickwidth=2, ticklen=10 ) out[0]['image_object'].update_yaxes( - showticklabels=True, - ticks="outside", - tickwidth=2, + showticklabels=True, + ticks="outside", + tickwidth=2, ticklen=10 ) return out[0]['image_object'] return None - #Track UI State + #Track UI State spatial_annotation_initialized = reactive.Value(False) spatial_feature_initialized = reactive.Value(False) - + @reactive.effect def spatial_reactivity(): flipper = shared['data_loaded'].get() - if flipper is not False: + if flipper: btn = input.spatial_rb() if btn == "Annotation": @@ -211,8 +211,8 @@ def spatial_reactivity(): elif btn == "Feature": if not spatial_feature_initialized.get(): dropdown = ui.input_select( - "spatial_feat", - "Select a Feature", + "spatial_feat", + "Select a Feature", choices=shared['var_names'].get() ) ui.insert_ui( @@ -223,8 +223,8 @@ def spatial_reactivity(): where="beforeEnd" ) table_select = ui.input_select( - "spatial_layer", - "Select a Table", + "spatial_layer", + "Select a Table", choices=shared['layers_names'].get() + ["Original"], selected="Original" ) diff --git a/server/umap_server.py b/server/umap_server.py index f0b1ff4..696b6bf 100644 --- a/server/umap_server.py +++ b/server/umap_server.py @@ -61,15 +61,14 @@ def spac_UMAP(): @reactive.effect def umap_reactivity(): flipper = shared['data_loaded'].get() - if flipper is not False: + if flipper: btn = input.umap_rb() - if btn == "Annotation": if not umap_annotation_initialized.get(): # Create the Annotation dropdown dropdown = ui.input_select( - "umap_rb_anno", - "Select an Annotation", + "umap_rb_anno", + "Select an Annotation", choices=shared['obs_names'].get(), ) ui.insert_ui( @@ -89,8 +88,8 @@ def umap_reactivity(): if not umap_feature_initialized.get(): # Create the Feature dropdown dropdown1 = ui.input_select( - "umap_rb_feat", - "Select a Feature", + "umap_rb_feat", + "Select a Feature", choices=shared['var_names'].get() ) ui.insert_ui( @@ -102,9 +101,9 @@ def umap_reactivity(): # Create the Table dropdown new_choices = shared['layers_names'].get() + ["Original"] table_umap = ui.input_select( - "umap_layer", - "Select a Table", - choices=new_choices, + "umap_layer", + "Select a Table", + choices=new_choices, selected=["Original"] ) ui.insert_ui( @@ -186,14 +185,14 @@ def spac_UMAP2(): @reactive.effect def umap_reactivity2(): flipper = shared['data_loaded'].get() - if flipper is not False: + if flipper: btn = input.umap_rb2() if btn == "Annotation": if not umap2_annotation_initialized.get(): dropdown = ui.input_select( - "umap_rb_anno2", - "Select an Annotation", + "umap_rb_anno2", + "Select an Annotation", choices=shared['obs_names'].get() ) ui.insert_ui( @@ -210,8 +209,8 @@ def umap_reactivity2(): elif btn == "Feature": if not umap2_feature_initialized.get(): dropdown1 = ui.input_select( - "umap_rb_feat2", - "Select a Feature", + "umap_rb_feat2", + "Select a Feature", choices=shared['var_names'].get() ) ui.insert_ui( diff --git a/ui/anno_vs_anno_ui.py b/ui/anno_vs_anno_ui.py index fbbf96e..46a7146 100644 --- a/ui/anno_vs_anno_ui.py +++ b/ui/anno_vs_anno_ui.py @@ -96,6 +96,10 @@ def anno_vs_anno_ui(): "go_sk1", "Generate Sankey Plot", class_="btn-success" + ), + ui.div( + {"style": "padding-top: 20px;"}, + ui.output_ui("download_button_ui_sankey") ) ), ui.column( diff --git a/utils/data_processing.py b/utils/data_processing.py index ccf0ac3..74efc73 100644 --- a/utils/data_processing.py +++ b/utils/data_processing.py @@ -3,9 +3,125 @@ This module contains functions for loading and processing data for the SPAC Shiny app. """ +from pathlib import Path +import logging +import pickle + import anndata as ad import pandas as pd +logger = logging.getLogger(__name__) + +# Simple in-memory cache for loaded datasets +_data_cache = {} + + +def cached_load_data(file_path): + """ + Load AnnData with caching to avoid repeated file I/O. + + This function caches loaded datasets in system RAM, providing + significant performance improvements when switching between + analysis modules or reloading the same dataset. + + Parameters + ---------- + file_path : str or Path + Path to the data file (.h5ad, .pickle, or other AnnData format) + + Returns + ------- + anndata.AnnData + Loaded AnnData object (cached if previously loaded) + + Notes + ----- + Cache is stored in system RAM and persists until app restart. + Subsequent loads of the same file are nearly instant. + + Examples + -------- + >>> adata = cached_load_data("data/sample.h5ad") + Loading data from data/sample.h5ad + >>> # Second call with same file - instant from cache + >>> adata = cached_load_data("data/sample.h5ad") + Retrieved data from cache: data/sample.h5ad + """ + # Convert to absolute path for consistent cache keys + cache_key = str(Path(file_path).resolve()) + + # Check cache first + if cache_key in _data_cache: + logger.info(f"Retrieved data from cache: {file_path}") + return _data_cache[cache_key] + + # Load data from file + logger.info(f"Loading data from {file_path}") + + with open(file_path, 'rb') as file: + if file_path.endswith('.pickle'): + adata = pickle.load(file) + elif file_path.endswith('.h5ad'): + adata = ad.read_h5ad(file_path) + else: + adata = ad.read(file_path) + + # Store in cache + _data_cache[cache_key] = adata + logger.info( + f"Cached data: {adata.n_obs} cells, " + f"{adata.n_vars} genes" + ) + + return adata + + +def clear_data_cache(): + """ + Clear the data cache to free system RAM. + + This removes all cached datasets from memory. Useful if you need + to free up RAM or reload data from disk. + + Examples + -------- + >>> clear_data_cache() + Cleared data cache + """ + global _data_cache + _data_cache.clear() + logger.info("Cleared data cache") + + +def get_cache_info(): + """ + Get information about currently cached datasets. + + Returns + ------- + dict + Dictionary with cache statistics including number of cached + files and total memory usage + + Examples + -------- + >>> info = get_cache_info() + >>> print(f"Cached files: {info['num_files']}") + Cached files: 2 + """ + import sys + + total_size = sum( + sys.getsizeof(adata) + for adata in _data_cache.values() + ) + + return { + 'num_files': len(_data_cache), + 'cached_files': list(_data_cache.keys()), + 'total_size_mb': total_size / (1024 * 1024) + } + def read_html_file(filepath): """Read HTML file content"""