forked from viksit/differentiable-programming
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmulti_step_agent_optimization.py
More file actions
128 lines (102 loc) · 3.45 KB
/
multi_step_agent_optimization.py
File metadata and controls
128 lines (102 loc) · 3.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# %% Using metrics and rewards to optimize multi step agents in DSPy
## Why functional workflows still fail users, and how to make behavior learnable using differentiable programming.
import dspy
import re
from diff_prog_learnable_graph import (
ServiceReply,
MissingPath,
DriverPath,
FallbackPath,
train,
dev,
)
# %% define metric function
def friendly_eta_metric(ex, pred, trace=None):
FRIENDLY_WORDS = {"hi", "hey", "please", "thanks", "thank you", "sure", "happy"}
# 0. Tag must match
if ex.tag != pred.tag:
return 0.0
# For non-ETA tickets, any correct tag scores 1.0
if ex.tag != "eta":
return 1.0
body = pred.body.lower()
rules = [
"eta" in body, # mentions ETA
any(w in body for w in FRIENDLY_WORDS), # at least one polite word
5 <= len(body.split()) <= 40, # not too short / long
re.search(r"\b\d{1,2}\s?min", body) is not None # gives a minutes estimate
]
return sum(rules) / len(rules) # 0.0-1.0
# %% Signatures
class RouterSignature(dspy.Signature):
ticket: str = dspy.InputField()
route: str = dspy.OutputField(desc='eta | missing | driver | fallback')
class LatePathSignature(dspy.Signature):
ticket: str = dspy.InputField()
body: str = dspy.OutputField(desc='Friendly ETA sentence that includes the word "eta".')
# %% Modules
class Router(dspy.Module):
def __init__(self):
self.step = dspy.Predict(RouterSignature)
def forward(self, ticket: str):
return self.step(ticket=ticket).route.lower().strip()
class LatePath(dspy.Module):
def __init__(self):
self.step = dspy.Predict(LatePathSignature)
def forward(self, ticket: str):
body = self.step(ticket=ticket).body.strip()
return dspy.Prediction(tag="eta", body=body, _sig=ServiceReply)
# %% Top-level agent
class SupportAgent(dspy.Module):
def __init__(self):
self.router = Router()
self.eta = LatePath()
self.missing = MissingPath()
self.driver = DriverPath()
self.fallback = FallbackPath()
def forward(self, ticket: str):
route = self.router(ticket=ticket)
if route == "eta":
r = self.eta(ticket=ticket)
elif route == "missing":
r = self.missing(ticket=ticket)
elif route == "driver":
r = self.driver(ticket=ticket)
else:
r = self.fallback(ticket=ticket)
return r
# %% main function
def main():
# Load LLM configurations
lm = dspy.LM('openai/gpt-4o')
dspy.configure(lm=lm)
# create evaluation harness
THREADS = 1
evaluate = dspy.Evaluate(
devset=dev,
metric=friendly_eta_metric,
num_threads=THREADS,
display_progress=True,
display_table=5,
)
# Baseline eval
support_bot = SupportAgent()
evaluate(support_bot)
# optimized eval
opt = dspy.MIPROv2(
metric=friendly_eta_metric,
auto="light", # minimal search space
num_threads=THREADS,
teacher_settings=dict(lm=lm),
)
optimized_support_bot = opt.compile(
SupportAgent(), # program to optimize
trainset=train[:100],
requires_permission_to_run=False,
max_bootstrapped_demos=4,
max_labeled_demos=4,
)
evaluate(optimized_support_bot)
# %% main entrypoint
if __name__ == "__main__":
main()