forked from viksit/differentiable-programming
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvisualize_workflow.py
More file actions
59 lines (51 loc) · 1.54 KB
/
visualize_workflow.py
File metadata and controls
59 lines (51 loc) · 1.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import dspy
import networkx as nx
import matplotlib.pyplot as plt
def _edges(node, parent=None, out=None):
if out is None:
out = []
name = node.__class__.__name__
if name == "Predict":
return out # skip internal nodes
if parent:
out.append((parent, name))
for child in vars(node).values():
if isinstance(child, dspy.Module):
_edges(child, name, out)
return out
# ── prettier palette & styling ──────────────────────────────────────────
def visualize(agent):
G = nx.DiGraph()
G.add_edges_from(_edges(agent))
# layout ‒ spring looks cleaner with k tweak
pos = nx.spring_layout(G, k=0.7, seed=42)
# node styling
node_colors = "#009E73" # a pleasant teal
edge_color = "#444444"
label_color = "#222222"
plt.figure(figsize=(9, 7), facecolor="#fafafa")
nx.draw_networkx_nodes(
G, pos,
node_size=2000,
node_color=node_colors,
edgecolors="#555555",
linewidths=1.5,
alpha=0.9
)
nx.draw_networkx_edges(
G, pos,
width=2.0,
edge_color=edge_color,
arrowsize=20,
arrowstyle="-|>"
)
nx.draw_networkx_labels(
G, pos,
font_size=10,
font_family="DejaVu Sans",
font_color=label_color
)
plt.title("DSPy Workflow Graph", fontsize=14, pad=15, color="#333333")
plt.axis("off")
plt.tight_layout()
plt.show()