forked from viksit/differentiable-programming
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdiff_prog_learnable_graph.py
More file actions
170 lines (139 loc) · 6.74 KB
/
diff_prog_learnable_graph.py
File metadata and controls
170 lines (139 loc) · 6.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# %% Differentiable Programming for Learnable Graphs: Optimizing Agent Workflows with DSPy
import random
import datetime as dt
from typing import List
import dspy
from dspy import Example
from dotenv import load_dotenv
from visualize_workflow import visualize
load_dotenv()
# %% Configure your dspy LM appropriately
class ServiceReply(dspy.Signature):
tag: str = dspy.OutputField() # eta | missing | driver | fallback
body: str = dspy.OutputField() # free-form message
# %% Router with closed label set
# constrained‐label LLM node that inspects the incoming support ticket and returns one of four route tags: eta, missing, driver, or fallback.
class RouterSignature(dspy.Signature):
"""Choose the single best route for this customer ticket."""
ticket: str = dspy.InputField()
route: str = dspy.OutputField(desc='Choose the single best route for this customer ticket from: eta | missing | driver | fallback.')
class Router(dspy.Module):
def __init__(self):
self.step = dspy.Predict(RouterSignature) # Use the signature class directly
def forward(self, ticket:str):
return self.step(ticket=ticket).route.lower().strip()
# %% Branch modules (use unified ServiceReply)
# synthesises a plausible delivery window + turns that timestamp into a customer-friendly sentence
# returns ServiceReply
class LatePath(dspy.Module):
def forward(self, ticket: str):
mins = random.randint(10, 20)
eta = (dt.datetime.now() + dt.timedelta(minutes=mins)).strftime("%I:%M %p")
msg = f"Courier is about {mins} min away — arriving ≈ {eta}."
return dspy.Prediction(tag="eta", body=msg, _sig=ServiceReply)
# confirms the missing-item claim and issues a refund message
class MissingPath(dspy.Module):
def forward(self, ticket: str):
msg = "Item verified missing via photo. Refund has been issued."
return dspy.Prediction(tag="missing", body=msg, _sig=ServiceReply)
# a stub that would locate the courier and respond with tag="driver"
class DriverPath(dspy.Module):
def forward(self, ticket:str):
return dspy.Prediction(tag="driver",
body="Driver located, contact info sent.",
_sig=ServiceReply)
# a final catch-all that directs the user to FAQ or human support, returning tag="fallback"
class FallbackPath(dspy.Module):
def forward(self, ticket:str):
return dspy.Prediction(tag="fallback",
body="Please see our FAQ or reach live support.",
_sig=ServiceReply)
# %% Top-level agent
class SupportAgent(dspy.Module):
def __init__(self):
self.router = Router()
self.eta = LatePath()
self.missing = MissingPath()
self.driver = DriverPath()
self.fallback= FallbackPath()
def forward(self, ticket:str):
route = self.router(ticket=ticket)
if route == "eta":
r = self.eta(ticket=ticket)
elif route == "missing":
r = self.missing(ticket=ticket)
elif route == "driver":
r = self.driver(ticket=ticket)
else:
r = self.fallback(ticket=ticket)
return r
# %% Tiny dataset
train = [
Example(ticket="Order #8123 is 20 minutes late. Any ETA?", tag="eta").with_inputs("ticket"),
Example(ticket="Driver stuck in traffic for order #9041.", tag="driver").with_inputs("ticket"),
Example(ticket="Burger arrived but fries missing in order #6677.", tag="missing").with_inputs("ticket"),
Example(ticket="App shows delivered yet nothing here (#7001).", tag="missing").with_inputs("ticket"),
Example(ticket="Order #5502 delayed, need update please.", tag="eta").with_inputs("ticket"),
Example(ticket="I didn't get the soda in combo meal #5502.", tag="missing").with_inputs("ticket"),
Example(ticket="Driver phone off — can you locate? Order #4321.", tag="driver").with_inputs("ticket"),
Example(ticket="How do I cancel my late order #8890?", tag="fallback").with_inputs("ticket"),
Example(ticket="Missing dipping sauce order #9988.", tag="missing").with_inputs("ticket"),
Example(ticket="Order #7002 already 25 min late. Where is it?", tag="eta").with_inputs("ticket"),
]
dev = [
Example(ticket="Order #1234 is late — where is it?", tag="eta").with_inputs("ticket"),
Example(ticket="Missing fries in order #5678, need refund.", tag="missing").with_inputs("ticket"),
Example(ticket="Driver got flat tire. What's new ETA for #2020?", tag="eta").with_inputs("ticket"),
Example(ticket="Never got my drink with order #3003.", tag="missing").with_inputs("ticket"),
Example(ticket="Order #4040 taking forever. Any update?", tag="eta").with_inputs("ticket"),
Example(ticket="Is there a way to track courier? Order #5050.", tag="driver").with_inputs("ticket"),
Example(ticket="Half my toppings missing on pizza #6060.", tag="missing").with_inputs("ticket"),
Example(ticket="App shows delivered but nothing arrived (#7070).", tag="missing").with_inputs("ticket"),
Example(ticket="FAQ didn't help. Cancel late order #8080.", tag="fallback").with_inputs("ticket"),
Example(ticket="What's status of order #9090? It's 30 min late.", tag="eta").with_inputs("ticket"),
]
# %% Simple metric function that compares predicted tags only
def tag_match(ex, pred, trace=None):
return float(ex.tag == pred.tag)
# %% training mipro agent
# simply create a DSPy optimizer that uses the training set and our LM
# and iterates through to find better prompts at each step
def train_support_agent(train: List[Example]):
opt = dspy.MIPROv2(metric=tag_match, auto="light", num_threads=1)
return opt.compile(
SupportAgent(),
trainset=train,
requires_permission_to_run=False,
max_bootstrapped_demos=2,
max_labeled_demos=4,
)
# %% function to visualize agent
def viz(agent: SupportAgent):
visualize(agent)
# %% main function
def main():
# Load LLM configurations
lm = dspy.LM(model="openai/gpt-4o")
dspy.configure(lm=lm)
# Example execution
support_bot = SupportAgent()
r = support_bot(ticket="Order #8123 is 20 minutes late. Any ETA?")
print(r)
# create evaluation harness
eval_dev = dspy.Evaluate(
devset=dev,
metric=tag_match,
num_threads=1,
display_progress=True,
display_table=5,
)
# Baseline eval
print("Zero-shot dev score:", eval_dev(support_bot))
# optimized eval
agent_optim = train_support_agent(train)
print("Post-opt dev score:", eval_dev(agent_optim))
# visualize agent
visualize(support_bot)
# %% main entrypoint
if __name__ == "__main__":
main()