Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion alphaquant/plotting/fcviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ def __init__(self, label_rotation = 90, add_stripplot = False, narrowing_factor_
shortened_xticklabels = False,
remove_leaf_labels_in_tree = False,
hide_root_in_tree = False,
exclude_outlier_fragments = True):
exclude_outlier_fragments = True,
highlight_excluded_nodes = True,
show_excluded_node_counts = True,
show_exclusion_legend = True):
"""
Configuration class for plotting.

Expand All @@ -139,6 +142,10 @@ def __init__(self, label_rotation = 90, add_stripplot = False, narrowing_factor_
exclude_outlier_fragments (bool): Whether to exclude fragment ions marked as outliers from plots.
When True (default), only fragments used in statistical aggregation are displayed.
Mirrors the fragment_outlier_filtering behavior from the analysis pipeline.
highlight_excluded_nodes (bool): Whether tree plots should highlight nodes excluded from aggregation.
show_excluded_node_counts (bool): Whether visible parents should annotate how many hidden descendants
were excluded from aggregation.
show_exclusion_legend (bool): Whether tree plots should include a legend for exclusion highlighting.
"""
self.label_rotation = label_rotation
self.add_stripplot = add_stripplot
Expand All @@ -158,6 +165,9 @@ def __init__(self, label_rotation = 90, add_stripplot = False, narrowing_factor_
self.remove_leaf_labels_in_tree = remove_leaf_labels_in_tree
self.hide_root_in_tree = hide_root_in_tree
self.exclude_outlier_fragments = exclude_outlier_fragments
self.highlight_excluded_nodes = highlight_excluded_nodes
self.show_excluded_node_counts = show_excluded_node_counts
self.show_exclusion_legend = show_exclusion_legend

# Node annotation configuration
self.show_node_annotations = show_node_annotations
Expand All @@ -180,6 +190,9 @@ def __init__(self, label_rotation = 90, add_stripplot = False, narrowing_factor_
'min_reps': 'reps={}',
'fraction_consistent': 'cons={:.2f}',
'is_included': 'incl={}',
'exclude_residual_decorrelation': 'decorr excl={}',
'is_outlier_fragment': 'frag excl={}',
'is_outlier_peptide': 'pep excl={}',
'missingval': 'miss={}'
}
else:
Expand Down
165 changes: 132 additions & 33 deletions alphaquant/plotting/treeviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import networkx as nx
import anytree
import re
import shlex
from matplotlib import gridspec
from matplotlib.lines import Line2D
import matplotlib.pyplot as plt
import alphaquant.cluster.cluster_utils as aqcluster_utils
import alphaquant.plotting.base_functions as aqviz
Expand Down Expand Up @@ -77,7 +79,7 @@ def _define_colorlist(self):
self._colorlist_hex = [aqviz.rgb_to_hex(x) for x in self._plotconfig.colorlist]

def _format_graph(self):
pos = nx.drawing.nx_agraph.graphviz_layout(self.graph, **self._graph_parameters.layout_params)
pos = _graphviz_layout(self.graph, self._graph_parameters.layout_params)

root_id = id(self._protein)
hide_root = getattr(self._plotconfig, 'hide_root_in_tree', False)
Expand All @@ -87,14 +89,8 @@ def _format_graph(self):

for node in nodes_to_draw:
matching_anynode = self._id2anytree_node[node]
is_included = matching_anynode.is_included
if not is_included:
self._graph_parameters.node_options["alpha"] = self._graph_parameters.alpha_excluded
self._graph_parameters.node_options["node_color"] = self._determine_cluster_color(matching_anynode)
# Allow overriding node size from plotconfig
if hasattr(self._plotconfig, 'node_size') and self._plotconfig.node_size is not None:
self._graph_parameters.node_options["node_size"] = self._plotconfig.node_size
nx.draw_networkx_nodes(self.graph, pos, nodelist=[node], ax=self._ax, **self._graph_parameters.node_options)
node_options = self._get_node_options(matching_anynode)
nx.draw_networkx_nodes(self.graph, pos, nodelist=[node], ax=self._ax, **node_options)

label_dict = nx.get_node_attributes(self.graph, 'label')

Expand All @@ -112,9 +108,10 @@ def _format_graph(self):
rotation = self._plotconfig.label_rotation if len(matching_anynode.children) == 0 else 0

self._ax.text(x, y, labelstring, verticalalignment='center', horizontalalignment='center', fontsize=self._plotconfig.node_fontsize, family='monospace',
weight = "bold", rotation = rotation)
weight = "bold", rotation = rotation, color=self._get_label_color(matching_anynode))

nx.draw_networkx_edges(self.graph, pos, edgelist=edges_to_draw, ax=self._ax, **self._graph_parameters.edge_options)
self._add_exclusion_legend(nodes_to_draw)

# Add vertical padding to avoid cutting labels at top/bottom and hide axis frame
try:
Expand All @@ -133,25 +130,97 @@ def _format_graph(self):
def _determine_cluster_color(self, anynode):
return self._colorlist_hex[anynode.cluster]

def _get_node_options(self, anynode):
node_options = dict(self._graph_parameters.node_options)

if hasattr(self._plotconfig, 'node_size') and self._plotconfig.node_size is not None:
node_options["node_size"] = self._plotconfig.node_size

node_options["node_color"] = self._determine_cluster_color(anynode)
node_options["alpha"] = self._graph_parameters.alpha_included
node_options["edgecolors"] = self._graph_parameters.included_edge_color
node_options["linewidths"] = self._graph_parameters.included_linewidth

if self._should_highlight_exclusions() and self._is_directly_excluded(anynode):
node_options["node_color"] = self._graph_parameters.excluded_color
node_options["alpha"] = self._graph_parameters.alpha_excluded
node_options["edgecolors"] = self._graph_parameters.excluded_edge_color
node_options["linewidths"] = self._graph_parameters.excluded_linewidth

return node_options

def _get_label_color(self, anynode):
if self._should_highlight_exclusions() and self._is_directly_excluded(anynode):
return self._graph_parameters.excluded_label_color
return self._graph_parameters.included_label_color

def _should_highlight_exclusions(self):
return getattr(self._plotconfig, "highlight_excluded_nodes", True)

@classmethod
def _is_directly_excluded(cls, anynode):
return (
not getattr(anynode, "is_included", True)
or getattr(anynode, "exclude_residual_decorrelation", False)
or getattr(anynode, "is_outlier_fragment", False)
or getattr(anynode, "is_outlier_peptide", False)
)

def _add_exclusion_legend(self, nodes_to_draw):
if not (
self._should_highlight_exclusions()
and getattr(self._plotconfig, "show_exclusion_legend", True)
):
return

drawn_anynodes = [self._id2anytree_node[node] for node in nodes_to_draw]
has_direct = any(self._is_directly_excluded(node) for node in drawn_anynodes)
if not has_direct:
return

legend_handles = []
if has_direct:
legend_handles.append(
Line2D(
[0],
[0],
marker="o",
linestyle="",
markerfacecolor=self._graph_parameters.excluded_color,
markeredgecolor=self._graph_parameters.excluded_edge_color,
markeredgewidth=self._graph_parameters.excluded_linewidth,
markersize=8,
label="excluded from aggregation",
)
)

self._ax.legend(
handles=legend_handles,
loc="upper right",
frameon=False,
fontsize=max(7, self._plotconfig.node_fontsize - 3),
)

@staticmethod
def render_tree(root):
for pre, _, node in anytree.RenderTree(root):
print("%s%s" % (pre, node.name))


class GraphParameters():
def __init__(self):
self.included_color = "skyblue"
self.excluded_color = "lightgrey"
self.excluded_color = "#D9D9D9"
self.included_edge_color = "#404040"
self.excluded_edge_color = "#808080"
self.included_label_color = "#202020"
self.excluded_label_color = "#333333"
self.alpha_included = 0.6 # More transparent nodes
self.alpha_excluded = 0.3 # More transparent excluded nodes
self.alpha_excluded = 0.8
self.included_linewidth = 1
self.excluded_linewidth = 2.0

self.node_options = {
"node_color": self.included_color,
"node_size": 1500,
"linewidths": 1,
"linewidths": self.included_linewidth,
"edgecolors": self.included_edge_color,
"alpha": self.alpha_included, # default alpha
}

Expand All @@ -171,6 +240,38 @@ def __init__(self):
}


def _graphviz_layout(graph, layout_params):
try:
return nx.drawing.nx_agraph.graphviz_layout(graph, **layout_params)
except ImportError:
return _graphviz_plain_layout(graph, layout_params)


def _graphviz_plain_layout(graph, layout_params):
import graphviz

prog = layout_params.get("prog", "dot")
graph_attr = {}
for token in shlex.split(layout_params.get("args", "")):
if token.startswith("-G") and "=" in token:
key, value = token[2:].split("=", 1)
graph_attr[key] = value

dot = graphviz.Digraph(engine=prog, graph_attr=graph_attr)
for node in graph.nodes:
dot.node(str(node))
for parent, child in graph.edges:
dot.edge(str(parent), str(child))

plain = dot.pipe(format="plain").decode("utf-8")
pos = {}
for line in plain.splitlines():
parts = line.split()
if len(parts) >= 4 and parts[0] == "node":
pos[int(parts[1])] = (float(parts[2]) * 72.0, float(parts[3]) * 72.0)
return pos



class TreeLabelFormatter:
@classmethod
Expand Down Expand Up @@ -252,8 +353,17 @@ def get_annotation_lines(cls, node, plotconfig):
formatted = f"{attr}={value}"

annotations.append(formatted)

return annotations

@classmethod
def get_exclusion_annotation_lines(cls, node):
"""Return compact annotations for aggregation-excluded nodes."""
if GraphCreator._is_directly_excluded(node):
return ["excluded"]

return []


class AnnotatedGraphCreator(GraphCreator):
"""Enhanced GraphCreator that supports configurable node annotations."""
Expand All @@ -264,7 +374,7 @@ def __init__(self, protein, ax, plotconfig):

def _format_graph(self):
"""Override _format_graph to use the enhanced label formatter."""
pos = nx.drawing.nx_agraph.graphviz_layout(self.graph, **self._graph_parameters.layout_params)
pos = _graphviz_layout(self.graph, self._graph_parameters.layout_params)

root_id = id(self._protein)
hide_root = getattr(self._plotconfig, 'hide_root_in_tree', False)
Expand All @@ -274,17 +384,8 @@ def _format_graph(self):

for node in nodes_to_draw:
matching_anynode = self._id2anytree_node[node]
is_included = matching_anynode.is_included
if not is_included:
self._graph_parameters.node_options["alpha"] = self._graph_parameters.alpha_excluded
else:
self._graph_parameters.node_options["alpha"] = self._graph_parameters.alpha_included

self._graph_parameters.node_options["node_color"] = self._determine_cluster_color(matching_anynode)
# Allow overriding node size from plotconfig
if hasattr(self._plotconfig, 'node_size') and self._plotconfig.node_size is not None:
self._graph_parameters.node_options["node_size"] = self._plotconfig.node_size
nx.draw_networkx_nodes(self.graph, pos, nodelist=[node], ax=self._ax, **self._graph_parameters.node_options)
node_options = self._get_node_options(matching_anynode)
nx.draw_networkx_nodes(self.graph, pos, nodelist=[node], ax=self._ax, **node_options)

label_dict = nx.get_node_attributes(self.graph, 'label')

Expand Down Expand Up @@ -316,9 +417,10 @@ def _format_graph(self):

self._ax.text(x, y, labelstring, verticalalignment='center', horizontalalignment='center',
fontsize=fontsize, family='monospace', weight="bold",
rotation=rotation)
rotation=rotation, color=self._get_label_color(matching_anynode))

nx.draw_networkx_edges(self.graph, pos, edgelist=edges_to_draw, ax=self._ax, **self._graph_parameters.edge_options)
self._add_exclusion_legend(nodes_to_draw)

# Add vertical padding to avoid cutting labels at top/bottom and hide axis frame
try:
Expand Down Expand Up @@ -389,6 +491,3 @@ def define_tree_fig_and_ax(self):
fig_width = min(max(8, num_leaves * 1.3),100)
fig_height = max(8, max_depth * 2)
self.fig, self.ax_tree = plt.subplots(figsize=(fig_width, fig_height))



Loading