diff --git a/src/dsa/nodes.py b/src/dsa/nodes.py index 521d24e..d7899e5 100644 --- a/src/dsa/nodes.py +++ b/src/dsa/nodes.py @@ -47,14 +47,59 @@ def _get_all_json_refs(item: Any) -> set[JsonRef]: # derived from https://github.com/langflow-ai/langflow/pull/5262 def find_cycle_vertices(edges): - # Create a directed graph from the edges - graph = nx.DiGraph(edges) - - # Find all simple cycles in the graph - cycles = list(nx.simple_cycles(graph)) - - # Flatten the list of cycles and remove duplicates - cycle_vertices = {vertex for cycle in cycles for vertex in cycle} + # Build adjacency list (directed) + adj = {} + for u, v in edges: + adj.setdefault(u, set()).add(v) + # Ensure all nodes appear in adj, even if they have only in-edges + adj.setdefault(v, set()) + + def strongconnect(v, index, stack, index_map, lowlink_map, on_stack, components): + index_map[v] = index[0] + lowlink_map[v] = index[0] + index[0] += 1 + stack.append(v) + on_stack.add(v) + + for w in adj[v]: + if w not in index_map: + strongconnect( + w, index, stack, index_map, lowlink_map, on_stack, components + ) + lowlink_map[v] = min(lowlink_map[v], lowlink_map[w]) + elif w in on_stack: + lowlink_map[v] = min(lowlink_map[v], index_map[w]) + + if lowlink_map[v] == index_map[v]: + component = set() + while True: + w = stack.pop() + on_stack.remove(w) + component.add(w) + if w == v: + break + if len(component) > 1: + components.append(component) + # Single node cycle (self-loop) + elif list(adj[v]).count(v): + components.append(component) + + # Find all strongly connected components (SCCs) + index = [0] + stack = [] + index_map = {} + lowlink_map = {} + on_stack = set() + components = [] + + for v in adj: + if v not in index_map: + strongconnect(v, index, stack, index_map, lowlink_map, on_stack, components) + + # Union all cyclic SCCs + cycle_vertices = set() + for comp in components: + cycle_vertices.update(comp) return sorted(cycle_vertices)