Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
__pycache__/
*.pyc
2 changes: 2 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ dependencies:
- scipy=1.10.1
- pip
- pip:
- colorcet
- git+https://github.com/holoviz/datashader.git
- ipykernel==6.29.5
- ipython==8.18.0
- ipython-genutils==0.2.0
Expand Down
39 changes: 26 additions & 13 deletions server/scatterplot_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import anndata as ad
import pandas as pd
import spac.visualization
from utils.datashader_utils import scatter_heatmap
import matplotlib.pyplot as plt


def scatterplot_server(input, output, session, shared):
Expand Down Expand Up @@ -120,22 +122,33 @@ def spac_Scatter():
x = get_scatterplot_coordinates_x()
y = get_scatterplot_coordinates_y()
color_enabled = input.scatter_color_check()
heatmap_mode = input.scatter_heatmap_mode()
x_label = input.scatter_x()
y_label = input.scatter_y()
title = f"Scatterplot: {x_label} vs {y_label}"

if color_enabled:
fig, ax = spac.visualization.visualize_2D_scatter(
x, y, labels=get_color_values()
)
for a in fig.axes:
if hasattr(a, "get_ylabel") and a != ax:
a.set_ylabel(f"Colored by: {input.scatter_color()}")
if heatmap_mode:
color = get_color_values() if color_enabled else None
img = scatter_heatmap(x, y, color)
fig, ax = plt.subplots(figsize=(8, 6))
ax.imshow(img, aspect='auto')
ax.set_title(title, fontsize=14)
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
ax.axis('on') # Show axes
return fig
else:
fig, ax = spac.visualization.visualize_2D_scatter(x, y)

ax.set_title(title, fontsize=14)
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
if color_enabled:
fig, ax = spac.visualization.visualize_2D_scatter(
x, y, labels=get_color_values()
)
for a in fig.axes:
if hasattr(a, "get_ylabel") and a != ax:
a.set_ylabel(f"Colored by: {input.scatter_color()}")
else:
fig, ax = spac.visualization.visualize_2D_scatter(x, y)

return ax
ax.set_title(title, fontsize=14)
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
return ax
5 changes: 5 additions & 0 deletions ui/scatterplot_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ def scatterplot_ui():
"Color by Feature",
value=False
),
ui.input_checkbox(
"scatter_heatmap_mode",
"Show as Heatmap",
value=False
),
ui.div(id="main-scatter_dropdown"),
ui.input_action_button(
"go_scatter",
Expand Down
16 changes: 16 additions & 0 deletions utils/datashader_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pandas as pd
import datashader as ds
import datashader.transfer_functions as tf
from colorcet import fire

def scatter_heatmap(x, y, color=None, width=800, height=600):
df = pd.DataFrame({'x': x, 'y': y})
cvs = ds.Canvas(plot_width=width, plot_height=height)
if color is not None:
df['color'] = color
agg = cvs.points(df, 'x', 'y', ds.mean('color'))
img = tf.shade(agg, cmap=fire, how='eq_hist')
else:
agg = cvs.points(df, 'x', 'y', ds.count())
img = tf.shade(agg, cmap=fire)
return img.to_pil()