diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7a60b85 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +__pycache__/ +*.pyc diff --git a/environment.yml b/environment.yml index 518cf63..71fc642 100644 --- a/environment.yml +++ b/environment.yml @@ -20,6 +20,8 @@ dependencies: - scipy=1.10.1 - pip - pip: + - colorcet + - git+https://github.com/holoviz/datashader.git - ipykernel==6.29.5 - ipython==8.18.0 - ipython-genutils==0.2.0 diff --git a/server/scatterplot_server.py b/server/scatterplot_server.py index cf2d9c7..479775e 100644 --- a/server/scatterplot_server.py +++ b/server/scatterplot_server.py @@ -2,6 +2,8 @@ import anndata as ad import pandas as pd import spac.visualization +from utils.datashader_utils import scatter_heatmap +import matplotlib.pyplot as plt def scatterplot_server(input, output, session, shared): @@ -120,22 +122,33 @@ def spac_Scatter(): x = get_scatterplot_coordinates_x() y = get_scatterplot_coordinates_y() color_enabled = input.scatter_color_check() + heatmap_mode = input.scatter_heatmap_mode() x_label = input.scatter_x() y_label = input.scatter_y() title = f"Scatterplot: {x_label} vs {y_label}" - if color_enabled: - fig, ax = spac.visualization.visualize_2D_scatter( - x, y, labels=get_color_values() - ) - for a in fig.axes: - if hasattr(a, "get_ylabel") and a != ax: - a.set_ylabel(f"Colored by: {input.scatter_color()}") + if heatmap_mode: + color = get_color_values() if color_enabled else None + img = scatter_heatmap(x, y, color) + fig, ax = plt.subplots(figsize=(8, 6)) + ax.imshow(img, aspect='auto') + ax.set_title(title, fontsize=14) + ax.set_xlabel(x_label) + ax.set_ylabel(y_label) + ax.axis('on') # Show axes + return fig else: - fig, ax = spac.visualization.visualize_2D_scatter(x, y) - - ax.set_title(title, fontsize=14) - ax.set_xlabel(x_label) - ax.set_ylabel(y_label) + if color_enabled: + fig, ax = spac.visualization.visualize_2D_scatter( + x, y, labels=get_color_values() + ) + for a in fig.axes: + if hasattr(a, "get_ylabel") and a != ax: + a.set_ylabel(f"Colored by: {input.scatter_color()}") + else: + fig, ax = spac.visualization.visualize_2D_scatter(x, y) - return ax + ax.set_title(title, fontsize=14) + ax.set_xlabel(x_label) + ax.set_ylabel(y_label) + return ax diff --git a/ui/scatterplot_ui.py b/ui/scatterplot_ui.py index a552a63..6c97823 100644 --- a/ui/scatterplot_ui.py +++ b/ui/scatterplot_ui.py @@ -32,6 +32,11 @@ def scatterplot_ui(): "Color by Feature", value=False ), + ui.input_checkbox( + "scatter_heatmap_mode", + "Show as Heatmap", + value=False + ), ui.div(id="main-scatter_dropdown"), ui.input_action_button( "go_scatter", diff --git a/utils/datashader_utils.py b/utils/datashader_utils.py new file mode 100644 index 0000000..f232b18 --- /dev/null +++ b/utils/datashader_utils.py @@ -0,0 +1,16 @@ +import pandas as pd +import datashader as ds +import datashader.transfer_functions as tf +from colorcet import fire + +def scatter_heatmap(x, y, color=None, width=800, height=600): + df = pd.DataFrame({'x': x, 'y': y}) + cvs = ds.Canvas(plot_width=width, plot_height=height) + if color is not None: + df['color'] = color + agg = cvs.points(df, 'x', 'y', ds.mean('color')) + img = tf.shade(agg, cmap=fire, how='eq_hist') + else: + agg = cvs.points(df, 'x', 'y', ds.count()) + img = tf.shade(agg, cmap=fire) + return img.to_pil() \ No newline at end of file