diff --git a/alphaquant/plotting/fcviz.py b/alphaquant/plotting/fcviz.py index 1851d857..97759f0d 100644 --- a/alphaquant/plotting/fcviz.py +++ b/alphaquant/plotting/fcviz.py @@ -118,7 +118,10 @@ def __init__(self, label_rotation = 90, add_stripplot = False, narrowing_factor_ shortened_xticklabels = False, remove_leaf_labels_in_tree = False, hide_root_in_tree = False, - exclude_outlier_fragments = True): + exclude_outlier_fragments = True, + highlight_excluded_nodes = True, + show_excluded_node_counts = True, + show_exclusion_legend = True): """ Configuration class for plotting. @@ -139,6 +142,10 @@ def __init__(self, label_rotation = 90, add_stripplot = False, narrowing_factor_ exclude_outlier_fragments (bool): Whether to exclude fragment ions marked as outliers from plots. When True (default), only fragments used in statistical aggregation are displayed. Mirrors the fragment_outlier_filtering behavior from the analysis pipeline. + highlight_excluded_nodes (bool): Whether tree plots should highlight nodes excluded from aggregation. + show_excluded_node_counts (bool): Whether visible parents should annotate how many hidden descendants + were excluded from aggregation. + show_exclusion_legend (bool): Whether tree plots should include a legend for exclusion highlighting. """ self.label_rotation = label_rotation self.add_stripplot = add_stripplot @@ -158,6 +165,9 @@ def __init__(self, label_rotation = 90, add_stripplot = False, narrowing_factor_ self.remove_leaf_labels_in_tree = remove_leaf_labels_in_tree self.hide_root_in_tree = hide_root_in_tree self.exclude_outlier_fragments = exclude_outlier_fragments + self.highlight_excluded_nodes = highlight_excluded_nodes + self.show_excluded_node_counts = show_excluded_node_counts + self.show_exclusion_legend = show_exclusion_legend # Node annotation configuration self.show_node_annotations = show_node_annotations @@ -180,6 +190,9 @@ def __init__(self, label_rotation = 90, add_stripplot = False, narrowing_factor_ 'min_reps': 'reps={}', 'fraction_consistent': 'cons={:.2f}', 'is_included': 'incl={}', + 'exclude_residual_decorrelation': 'decorr excl={}', + 'is_outlier_fragment': 'frag excl={}', + 'is_outlier_peptide': 'pep excl={}', 'missingval': 'miss={}' } else: diff --git a/alphaquant/plotting/treeviz.py b/alphaquant/plotting/treeviz.py index 63d2515e..7642ec74 100644 --- a/alphaquant/plotting/treeviz.py +++ b/alphaquant/plotting/treeviz.py @@ -4,7 +4,9 @@ import networkx as nx import anytree import re +import shlex from matplotlib import gridspec +from matplotlib.lines import Line2D import matplotlib.pyplot as plt import alphaquant.cluster.cluster_utils as aqcluster_utils import alphaquant.plotting.base_functions as aqviz @@ -77,7 +79,7 @@ def _define_colorlist(self): self._colorlist_hex = [aqviz.rgb_to_hex(x) for x in self._plotconfig.colorlist] def _format_graph(self): - pos = nx.drawing.nx_agraph.graphviz_layout(self.graph, **self._graph_parameters.layout_params) + pos = _graphviz_layout(self.graph, self._graph_parameters.layout_params) root_id = id(self._protein) hide_root = getattr(self._plotconfig, 'hide_root_in_tree', False) @@ -87,14 +89,8 @@ def _format_graph(self): for node in nodes_to_draw: matching_anynode = self._id2anytree_node[node] - is_included = matching_anynode.is_included - if not is_included: - self._graph_parameters.node_options["alpha"] = self._graph_parameters.alpha_excluded - self._graph_parameters.node_options["node_color"] = self._determine_cluster_color(matching_anynode) - # Allow overriding node size from plotconfig - if hasattr(self._plotconfig, 'node_size') and self._plotconfig.node_size is not None: - self._graph_parameters.node_options["node_size"] = self._plotconfig.node_size - nx.draw_networkx_nodes(self.graph, pos, nodelist=[node], ax=self._ax, **self._graph_parameters.node_options) + node_options = self._get_node_options(matching_anynode) + nx.draw_networkx_nodes(self.graph, pos, nodelist=[node], ax=self._ax, **node_options) label_dict = nx.get_node_attributes(self.graph, 'label') @@ -112,9 +108,10 @@ def _format_graph(self): rotation = self._plotconfig.label_rotation if len(matching_anynode.children) == 0 else 0 self._ax.text(x, y, labelstring, verticalalignment='center', horizontalalignment='center', fontsize=self._plotconfig.node_fontsize, family='monospace', - weight = "bold", rotation = rotation) + weight = "bold", rotation = rotation, color=self._get_label_color(matching_anynode)) nx.draw_networkx_edges(self.graph, pos, edgelist=edges_to_draw, ax=self._ax, **self._graph_parameters.edge_options) + self._add_exclusion_legend(nodes_to_draw) # Add vertical padding to avoid cutting labels at top/bottom and hide axis frame try: @@ -133,25 +130,97 @@ def _format_graph(self): def _determine_cluster_color(self, anynode): return self._colorlist_hex[anynode.cluster] + def _get_node_options(self, anynode): + node_options = dict(self._graph_parameters.node_options) + if hasattr(self._plotconfig, 'node_size') and self._plotconfig.node_size is not None: + node_options["node_size"] = self._plotconfig.node_size + + node_options["node_color"] = self._determine_cluster_color(anynode) + node_options["alpha"] = self._graph_parameters.alpha_included + node_options["edgecolors"] = self._graph_parameters.included_edge_color + node_options["linewidths"] = self._graph_parameters.included_linewidth + + if self._should_highlight_exclusions() and self._is_directly_excluded(anynode): + node_options["node_color"] = self._graph_parameters.excluded_color + node_options["alpha"] = self._graph_parameters.alpha_excluded + node_options["edgecolors"] = self._graph_parameters.excluded_edge_color + node_options["linewidths"] = self._graph_parameters.excluded_linewidth + + return node_options + + def _get_label_color(self, anynode): + if self._should_highlight_exclusions() and self._is_directly_excluded(anynode): + return self._graph_parameters.excluded_label_color + return self._graph_parameters.included_label_color + + def _should_highlight_exclusions(self): + return getattr(self._plotconfig, "highlight_excluded_nodes", True) + + @classmethod + def _is_directly_excluded(cls, anynode): + return ( + not getattr(anynode, "is_included", True) + or getattr(anynode, "exclude_residual_decorrelation", False) + or getattr(anynode, "is_outlier_fragment", False) + or getattr(anynode, "is_outlier_peptide", False) + ) + + def _add_exclusion_legend(self, nodes_to_draw): + if not ( + self._should_highlight_exclusions() + and getattr(self._plotconfig, "show_exclusion_legend", True) + ): + return + + drawn_anynodes = [self._id2anytree_node[node] for node in nodes_to_draw] + has_direct = any(self._is_directly_excluded(node) for node in drawn_anynodes) + if not has_direct: + return + + legend_handles = [] + if has_direct: + legend_handles.append( + Line2D( + [0], + [0], + marker="o", + linestyle="", + markerfacecolor=self._graph_parameters.excluded_color, + markeredgecolor=self._graph_parameters.excluded_edge_color, + markeredgewidth=self._graph_parameters.excluded_linewidth, + markersize=8, + label="excluded from aggregation", + ) + ) + + self._ax.legend( + handles=legend_handles, + loc="upper right", + frameon=False, + fontsize=max(7, self._plotconfig.node_fontsize - 3), + ) - @staticmethod - def render_tree(root): - for pre, _, node in anytree.RenderTree(root): - print("%s%s" % (pre, node.name)) class GraphParameters(): def __init__(self): self.included_color = "skyblue" - self.excluded_color = "lightgrey" + self.excluded_color = "#D9D9D9" + self.included_edge_color = "#404040" + self.excluded_edge_color = "#808080" + self.included_label_color = "#202020" + self.excluded_label_color = "#333333" self.alpha_included = 0.6 # More transparent nodes - self.alpha_excluded = 0.3 # More transparent excluded nodes + self.alpha_excluded = 0.8 + self.included_linewidth = 1 + self.excluded_linewidth = 2.0 self.node_options = { "node_color": self.included_color, "node_size": 1500, - "linewidths": 1, + "linewidths": self.included_linewidth, + "edgecolors": self.included_edge_color, "alpha": self.alpha_included, # default alpha } @@ -171,6 +240,38 @@ def __init__(self): } +def _graphviz_layout(graph, layout_params): + try: + return nx.drawing.nx_agraph.graphviz_layout(graph, **layout_params) + except ImportError: + return _graphviz_plain_layout(graph, layout_params) + + +def _graphviz_plain_layout(graph, layout_params): + import graphviz + + prog = layout_params.get("prog", "dot") + graph_attr = {} + for token in shlex.split(layout_params.get("args", "")): + if token.startswith("-G") and "=" in token: + key, value = token[2:].split("=", 1) + graph_attr[key] = value + + dot = graphviz.Digraph(engine=prog, graph_attr=graph_attr) + for node in graph.nodes: + dot.node(str(node)) + for parent, child in graph.edges: + dot.edge(str(parent), str(child)) + + plain = dot.pipe(format="plain").decode("utf-8") + pos = {} + for line in plain.splitlines(): + parts = line.split() + if len(parts) >= 4 and parts[0] == "node": + pos[int(parts[1])] = (float(parts[2]) * 72.0, float(parts[3]) * 72.0) + return pos + + class TreeLabelFormatter: @classmethod @@ -252,8 +353,17 @@ def get_annotation_lines(cls, node, plotconfig): formatted = f"{attr}={value}" annotations.append(formatted) + return annotations + @classmethod + def get_exclusion_annotation_lines(cls, node): + """Return compact annotations for aggregation-excluded nodes.""" + if GraphCreator._is_directly_excluded(node): + return ["excluded"] + + return [] + class AnnotatedGraphCreator(GraphCreator): """Enhanced GraphCreator that supports configurable node annotations.""" @@ -264,7 +374,7 @@ def __init__(self, protein, ax, plotconfig): def _format_graph(self): """Override _format_graph to use the enhanced label formatter.""" - pos = nx.drawing.nx_agraph.graphviz_layout(self.graph, **self._graph_parameters.layout_params) + pos = _graphviz_layout(self.graph, self._graph_parameters.layout_params) root_id = id(self._protein) hide_root = getattr(self._plotconfig, 'hide_root_in_tree', False) @@ -274,17 +384,8 @@ def _format_graph(self): for node in nodes_to_draw: matching_anynode = self._id2anytree_node[node] - is_included = matching_anynode.is_included - if not is_included: - self._graph_parameters.node_options["alpha"] = self._graph_parameters.alpha_excluded - else: - self._graph_parameters.node_options["alpha"] = self._graph_parameters.alpha_included - - self._graph_parameters.node_options["node_color"] = self._determine_cluster_color(matching_anynode) - # Allow overriding node size from plotconfig - if hasattr(self._plotconfig, 'node_size') and self._plotconfig.node_size is not None: - self._graph_parameters.node_options["node_size"] = self._plotconfig.node_size - nx.draw_networkx_nodes(self.graph, pos, nodelist=[node], ax=self._ax, **self._graph_parameters.node_options) + node_options = self._get_node_options(matching_anynode) + nx.draw_networkx_nodes(self.graph, pos, nodelist=[node], ax=self._ax, **node_options) label_dict = nx.get_node_attributes(self.graph, 'label') @@ -316,9 +417,10 @@ def _format_graph(self): self._ax.text(x, y, labelstring, verticalalignment='center', horizontalalignment='center', fontsize=fontsize, family='monospace', weight="bold", - rotation=rotation) + rotation=rotation, color=self._get_label_color(matching_anynode)) nx.draw_networkx_edges(self.graph, pos, edgelist=edges_to_draw, ax=self._ax, **self._graph_parameters.edge_options) + self._add_exclusion_legend(nodes_to_draw) # Add vertical padding to avoid cutting labels at top/bottom and hide axis frame try: @@ -389,6 +491,3 @@ def define_tree_fig_and_ax(self): fig_width = min(max(8, num_leaves * 1.3),100) fig_height = max(8, max_depth * 2) self.fig, self.ax_tree = plt.subplots(figsize=(fig_width, fig_height)) - - -