Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions ehrapy/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import TYPE_CHECKING, ParamSpec, TypeVar, cast

import numpy as np
import scipy.sparse as sp

P = ParamSpec("P")
R = TypeVar("R")
Expand Down Expand Up @@ -248,3 +249,25 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return wrapper

return decorator


def asarray(a):
import numpy as np

return np.asarray(a)


def as_dense_dask_array(a, chunk_size=1000):
import dask.array as da

return da.from_array(a, chunks=chunk_size)


ARRAY_TYPES_NUMERIC = (
asarray,
as_dense_dask_array,
sp.csr_array,
sp.csc_array,
) # add coo_array once supported in AnnData
ARRAY_TYPES_NUMERIC_3D_ABLE = (asarray, as_dense_dask_array) # add coo_array once supported in AnnData
ARRAY_TYPES_NONNUMERIC = (asarray, as_dense_dask_array)
1 change: 1 addition & 0 deletions ehrapy/plot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ehrapy.plot._catplot import catplot
from ehrapy.plot._colormaps import * # noqa: F403
from ehrapy.plot._missingno import * # noqa: F403
from ehrapy.plot._sankey import plot_sankey, plot_sankey_time
from ehrapy.plot._scanpy_pl_api import * # noqa: F403
from ehrapy.plot._survival_analysis import cox_ph_forestplot, kaplan_meier, ols
from ehrapy.plot.causal_inference._dowhy import causal_effect
Expand Down
182 changes: 182 additions & 0 deletions ehrapy/plot/_sankey.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Dict, List

import ehrdata as ed
import holoviews as hv
import numpy as np
import pandas as pd
from holoviews import opts

hv.extension("bokeh")

if TYPE_CHECKING:
from anndata import AnnData
from ehrdata import EHRData


def plot_sankey(
edata: EHRData | AnnData,
*,
columns: list[str],
show: bool = False,
**kwargs,
) -> hv.Sankey:
"""Create a Sankey diagram showing relationships across observation columns.

Args:
edata : Central data object containing observation data
columns : Column names from edata.obs to visualize
show: If True, display the plot immediately. If False, only return the plot object without displaying.
**kwargs: Additional styling options passed to `holoviews.opts.Sankey`. See HoloViews Sankey documentation for full list of options.

Returns:
holoviews.Sankey

Examples:
>>> import ehrdata as ed
>>> edata = ed.dt.diabetes_130_fairlearn(columns_obs_only=["gender", "race"])
>>> ep.pl.plot_sankey(edata, columns=["gender", "race"])

"""
df = edata.obs[columns]

labels = []
for col in columns:
labels.extend([f"{col}: {val}" for val in df[col].unique()])
labels = list(dict.fromkeys(labels)) # keep order & unique

# Build links between consecutive columns
sources, targets, values = [], [], []
source_levels, target_levels = [], []
for i in range(len(columns) - 1):
col_from, col_to = columns[i], columns[i + 1]
flows = df.groupby([col_from, col_to]).size().reset_index(name="count")
for _, row in flows.iterrows():
source = f"{col_from}: {row[col_from]}"
target = f"{col_to}: {row[col_to]}"
sources.append(source)
targets.append(target)
values.append(row["count"])
source_levels.append(col_from)
target_levels.append(col_to)

sankey_df = pd.DataFrame(
{
"source": sources,
"target": targets,
"value": values,
"source_level": source_levels,
"target_level": target_levels,
}
)

sankey = hv.Sankey(sankey_df, kdims=["source", "target"], vdims=["value"])
default_opts = {"label_position": "right", "show_values": True, "title": f"Patient flows: {columns[0]} over time"}

default_opts.update(kwargs)

sankey = sankey.opts(opts.Sankey(**default_opts))

if show:
from IPython.display import display

display(sankey)

return sankey


def plot_sankey_time(
edata: EHRData | AnnData,
*,
columns: list[str],
layer: str,
state_labels: dict[int, str] | None = None,
show: bool = False,
**kwargs,
) -> hv.Sankey:
"""Create a Sankey diagram showing patient state transitions over time.

This function visualizes how patients transition between different states
(e.g., disease severity, treatment status) across consecutive time points.
Each node represents a state at a specific time point, and flows show the
number of patients transitioning between states.

Args:
edata: Central data object containing observation data
columns: Column names from edata.obs to visualize
layer: Name of the layer in `edata.layers` containing the feature data to visualize.
state_labels: Mapping from numeric state values to readable labels. If None, state values
will be displayed as strings of their numeric codes (e.g., "0", "1", "2"). Default: "None"
show: If True, display the plot immediately. If False, only return the plot object without displaying.
**kwargs: Additional styling options passed to `holoviews.opts.Sankey`. See HoloViews Sankey documentation for full list of options.

Returns:
holoviews.Sankey

Examples:
>>> import numpy as np
>>> import pandas as pd
>>> import ehrdata as ed
>>>
>>> layer = np.array(
... [
... [[1, 0, 1], [0, 1, 0]], # patient 1: treatment, disease_flare
... [[0, 1, 1], [1, 0, 0]], # patient 2: treatment, disease_flare
... [[1, 1, 0], [0, 0, 1]], # patient 3: treatment, disease_flare
... ]
... )
>>>
>>> edata = ed.EHRData(
... layers={"layer_1": layer},
... obs=pd.DataFrame(index=["patient_1", "patient_2", "patient_3"]),
... var=pd.DataFrame(index=["treatment", "disease_flare"]),
... tem=pd.DataFrame(index=["visit_0", "visit_1", "visit_2"]),
... )
>>>
>>> plot_sankey_time(edata, columns=["disease_flare"], layer="layer_1", state_labels={0: "no flare", 1: "flare"})


"""
flare_data = edata[:, edata.var_names.isin(columns), :].layers[layer][:, 0, :]

time_steps = edata.tem.index.tolist()
# states = edata.var.loc[columns[0]].values

if state_labels is None:
unique_states = np.unique(flare_data)
unique_states = unique_states[~np.isnan(unique_states)]
state_labels = {int(state): str(state) for state in unique_states}
# state_labels = {int(state): states[int(state)] for state in unique_states} if the categorical variables values are also in layer

state_values = sorted(state_labels.keys())
state_names = [state_labels[val] for val in state_values]

sources, targets, values = [], [], []
for t in range(len(time_steps) - 1):
for s_from_idx, s_from_val in enumerate(state_values):
for s_to_idx, s_to_val in enumerate(state_values):
count = np.sum((flare_data[:, t] == s_from_val) & (flare_data[:, t + 1] == s_to_val))
if count > 0:
source_label = f"{state_names[s_from_idx]} ({time_steps[t]})"
target_label = f"{state_names[s_to_idx]} ({time_steps[t + 1]})"
sources.append(source_label)
targets.append(target_label)
values.append(int(count))

sankey_df = pd.DataFrame({"source": sources, "target": targets, "value": values})

sankey = hv.Sankey(sankey_df, kdims=["source", "target"], vdims=["value"])

default_opts = {"label_position": "right", "show_values": True, "title": f"Patient flows: {columns[0]} over time"}

default_opts.update(kwargs)

sankey = sankey.opts(opts.Sankey(**default_opts))

if show:
from IPython.display import display

display(sankey)

return sankey
Loading
Loading