Skip to content

Commit ea2c7c8

Browse files
authored
Retail Suite (#66)
1 parent 6b018d4 commit ea2c7c8

File tree

5 files changed

+406
-68
lines changed

5 files changed

+406
-68
lines changed

eval_protocol/benchmarks/suites/aime25.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
from typing import Any, Dict, List, Optional
22

3+
from eval_protocol.benchmarks.registry import export_benchmark
34
from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult
45
from eval_protocol.pytest.default_single_turn_rollout_process import (
56
default_single_turn_rollout_processor,
67
)
78
from eval_protocol.pytest.evaluation_test import evaluation_test
8-
from eval_protocol.benchmarks.registry import export_benchmark
9-
109

1110
SYSTEM_PROMPT = (
12-
"You are a helpful math assistant. Please reason step by step, and put your "
13-
"final answer within \\boxed{...}."
11+
"You are a helpful math assistant. Please reason step by step, and put your " "final answer within \\boxed{...}."
1412
)
1513

1614

@@ -56,9 +54,7 @@ def aime2025_dataset_adapter(rows: List[Dict[str, Any]]) -> List[EvaluationRow]:
5654
Message(role="system", content=SYSTEM_PROMPT),
5755
Message(role="user", content=str(question)),
5856
]
59-
converted.append(
60-
EvaluationRow(messages=messages, ground_truth=str(answer) if answer is not None else None)
61-
)
57+
converted.append(EvaluationRow(messages=messages, ground_truth=str(answer) if answer is not None else None))
6258
return converted
6359

6460

@@ -73,7 +69,6 @@ def aime2025_dataset_adapter(rows: List[Dict[str, Any]]) -> List[EvaluationRow]:
7369
rollout_input_params=[{"max_tokens": 131000, "extra_body": {"reasoning_effort": "low"}}],
7470
rollout_processor=default_single_turn_rollout_processor,
7571
aggregation_method="mean",
76-
threshold_of_success=None,
7772
num_runs=8,
7873
max_dataset_rows=2,
7974
max_concurrent_rollouts=4,
@@ -114,5 +109,3 @@ def test_aime25_pointwise(row: EvaluationRow) -> EvaluationRow:
114109
metrics=metrics,
115110
)
116111
return row
117-
118-

eval_protocol/benchmarks/suites/gpqa.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
1-
from typing import List
2-
31
import csv
42
import io
53
import re
4+
from typing import List
5+
66
import requests
77

8+
from eval_protocol.benchmarks.registry import export_benchmark
89
from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult
9-
from eval_protocol.pytest.evaluation_test import evaluation_test
1010
from eval_protocol.pytest.default_single_turn_rollout_process import (
1111
default_single_turn_rollout_processor,
1212
)
13-
from eval_protocol.benchmarks.registry import export_benchmark
14-
13+
from eval_protocol.pytest.evaluation_test import evaluation_test
1514

1615
SYSTEM_PROMPT = (
1716
"You are a helpful assistant. Read the question and options carefully. "
@@ -66,7 +65,6 @@ def _extract_abcd_letter(text: str) -> str | None:
6665
rollout_input_params=[{"extra_body": {"reasoning_effort": "low"}}],
6766
rollout_processor=default_single_turn_rollout_processor,
6867
aggregation_method="mean",
69-
threshold_of_success=None,
7068
num_runs=8,
7169
mode="pointwise",
7270
)
@@ -96,5 +94,3 @@ def gpqa_pointwise(row: EvaluationRow) -> EvaluationRow:
9694
},
9795
)
9896
return row
99-
100-
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
"""
2+
Pytest test for tau bench retail evaluation using the evaluation_test decorator.
3+
4+
This test demonstrates how to use tau bench environments within the pytest framework,
5+
similar to the test_entire_retail_dataset test but integrated with the pytest evaluation system.
6+
"""
7+
8+
import json
9+
from datetime import datetime
10+
from pathlib import Path
11+
from typing import Any, Dict, List
12+
13+
from eval_protocol.benchmarks.registry import export_benchmark
14+
from eval_protocol.models import CompletionParams, EvaluateResult, EvaluationRow, InputMetadata, Message
15+
from eval_protocol.pytest import evaluation_test
16+
from eval_protocol.pytest.default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor
17+
from vendor.tau2.data_model.message import (
18+
AssistantMessage,
19+
SystemMessage,
20+
ToolCall,
21+
ToolMessage,
22+
UserMessage,
23+
)
24+
from vendor.tau2.data_model.tasks import Action, EvaluationCriteria, RewardType, Task, UserScenario
25+
from vendor.tau2.evaluator.evaluator import EnvironmentEvaluator
26+
from vendor.tau2.evaluator.evaluator_action import ActionEvaluator
27+
from vendor.tau2.evaluator.evaluator_communicate import CommunicateEvaluator
28+
from vendor.tau2.evaluator.evaluator_nl_assertions import NLAssertionsEvaluator
29+
from vendor.tau2.registry import registry
30+
31+
32+
def tau_bench_retail_to_evaluation_row(data: List[Dict[str, Any]]) -> List[EvaluationRow]:
33+
"""
34+
Convert entries from retail dataset to EvaluationRow objects.
35+
"""
36+
rows = []
37+
test_dir = Path(__file__).parent.parent.parent.parent / "examples" / "tau2_mcp" / "tests"
38+
39+
# Load system prompt from file so we can change it in one place
40+
domain = data[0]["environment_context"]["domain"]
41+
prompt_file = test_dir / f"system_prompts/{domain}_agent_system_prompt.md"
42+
43+
with open(prompt_file, "r") as f:
44+
system_prompt = f.read().strip()
45+
46+
for row in data:
47+
eval_row = EvaluationRow(
48+
messages=[Message(role="system", content=system_prompt)],
49+
input_metadata=InputMetadata(
50+
row_id=row["id"],
51+
dataset_info={
52+
"environment_context": row["environment_context"],
53+
"user_simulation": row["user_simulation"],
54+
"evaluation_criteria": row["evaluation_criteria"],
55+
"user_prompt_template": row["user_prompt_template"],
56+
},
57+
),
58+
)
59+
60+
rows.append(eval_row)
61+
62+
return rows
63+
64+
65+
@export_benchmark("tau_bench_retail")
66+
@evaluation_test(
67+
input_dataset=["tests/pytest/data/retail_dataset.jsonl"],
68+
dataset_adapter=tau_bench_retail_to_evaluation_row,
69+
model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"],
70+
rollout_input_params=[{"temperature": 0.8, "extra_body": {"reasoning_effort": "medium"}}],
71+
rollout_processor=default_mcp_gym_rollout_processor,
72+
num_runs=8,
73+
mode="pointwise",
74+
max_concurrent_rollouts=50,
75+
server_script_path="examples/tau2_mcp/server.py",
76+
)
77+
def test_tau_bench_retail_evaluation(row: EvaluationRow) -> EvaluationRow:
78+
"""
79+
Test tau bench retail evaluation using the pytest framework.
80+
81+
This test now uses the tau_bench_retail_reward function which automatically
82+
extracts evaluation criteria from dataset entries. No wrapper needed!
83+
84+
Args:
85+
row: EvaluationRow object from tau bench retail dataset after rollout
86+
87+
Returns:
88+
EvaluationRow with tau2 evaluation results
89+
"""
90+
messages = row.messages
91+
92+
# Get evaluation criteria and user_simulation from input_metadata.dataset_info
93+
dataset_info = row.input_metadata.dataset_info if row.input_metadata else {}
94+
evaluation_criteria = dataset_info.get("evaluation_criteria", {})
95+
96+
nl_assertions = evaluation_criteria.get("nl_assertions", [])
97+
communicate_info = evaluation_criteria.get("communicate_info", [])
98+
actions = evaluation_criteria.get("actions", [])
99+
100+
# Convert Message objects directly to tau2-bench message objects
101+
trajectory_objects = []
102+
for msg in messages:
103+
role = msg.role
104+
content = msg.content
105+
106+
if role == "system":
107+
trajectory_objects.append(SystemMessage(role=role, content=content))
108+
elif role == "assistant":
109+
tau2_tool_calls = []
110+
if msg.tool_calls:
111+
for tool_call in msg.tool_calls:
112+
arguments = json.loads(tool_call.function.arguments)
113+
tau2_tool_call = ToolCall(
114+
id=tool_call.id,
115+
name=tool_call.function.name,
116+
arguments=arguments,
117+
)
118+
tau2_tool_calls.append(tau2_tool_call)
119+
120+
trajectory_objects.append(AssistantMessage(role=role, content=content, tool_calls=tau2_tool_calls))
121+
elif role == "user":
122+
trajectory_objects.append(UserMessage(role=role, content=content))
123+
elif role == "tool":
124+
tool_id = msg.tool_call_id
125+
trajectory_objects.append(ToolMessage(id=tool_id, role=role, content=content))
126+
127+
reward = 1.0
128+
129+
evaluation_criteria = EvaluationCriteria(
130+
nl_assertions=nl_assertions,
131+
communicate_info=communicate_info,
132+
actions=actions,
133+
reward_basis=[ # Use this to adjust how to calculate reward. Tau2-bench uses DB and COMMUNICATE by default for retail tasks.
134+
RewardType.DB,
135+
RewardType.COMMUNICATE,
136+
],
137+
)
138+
139+
task = Task(
140+
id="Filler", evaluation_criteria=evaluation_criteria, user_scenario=UserScenario(instructions="Filler")
141+
) # id and user_scenario are required for the Task type but not used in calculating reward
142+
143+
if RewardType.DB in task.evaluation_criteria.reward_basis:
144+
env_reward_info = EnvironmentEvaluator.calculate_reward(
145+
environment_constructor=registry.get_env_constructor("retail"),
146+
task=task,
147+
full_trajectory=trajectory_objects,
148+
)
149+
if RewardType.ACTION in task.evaluation_criteria.reward_basis:
150+
action_reward_info = ActionEvaluator.calculate_reward(
151+
task=task,
152+
full_trajectory=trajectory_objects,
153+
)
154+
if RewardType.COMMUNICATE in task.evaluation_criteria.reward_basis:
155+
communicate_reward_info = CommunicateEvaluator.calculate_reward(
156+
task=task,
157+
full_trajectory=trajectory_objects,
158+
)
159+
if RewardType.NL_ASSERTION in task.evaluation_criteria.reward_basis:
160+
nl_reward_info = NLAssertionsEvaluator.calculate_reward(
161+
task=task,
162+
full_trajectory=trajectory_objects,
163+
)
164+
165+
reward = 1.0
166+
env_bases = {RewardType.DB, RewardType.ENV_ASSERTION}
167+
action_bases = {RewardType.ACTION}
168+
nl_bases = {RewardType.NL_ASSERTION}
169+
comm_bases = {RewardType.COMMUNICATE}
170+
task_reward_basis = set(task.evaluation_criteria.reward_basis)
171+
172+
reward_breakdown = {}
173+
if task_reward_basis & env_bases:
174+
if env_reward_info.reward_breakdown is not None:
175+
reward_breakdown.update(env_reward_info.reward_breakdown)
176+
reward *= env_reward_info.reward
177+
if task_reward_basis & action_bases:
178+
if action_reward_info.reward_breakdown is not None:
179+
reward_breakdown.update(action_reward_info.reward_breakdown)
180+
reward *= action_reward_info.reward
181+
if task_reward_basis & nl_bases:
182+
if nl_reward_info.reward_breakdown is not None:
183+
reward_breakdown.update(nl_reward_info.reward_breakdown)
184+
reward *= nl_reward_info.reward
185+
if task_reward_basis & comm_bases:
186+
if communicate_reward_info.reward_breakdown is not None:
187+
reward_breakdown.update(communicate_reward_info.reward_breakdown)
188+
reward *= communicate_reward_info.reward
189+
190+
# Generate reason showing only failed components
191+
failed_reasons = []
192+
193+
if task_reward_basis & env_bases and env_reward_info.reward == 0:
194+
failed_reasons.append("❌ Environment/DB check failed")
195+
196+
if task_reward_basis & action_bases and action_reward_info.reward == 0:
197+
failed_actions = []
198+
if hasattr(action_reward_info, "action_checks") and action_reward_info.action_checks:
199+
failed_actions = [
200+
f"{ac.action.name}({ac.action.arguments})"
201+
for ac in action_reward_info.action_checks
202+
if not ac.action_match
203+
]
204+
if failed_actions:
205+
failed_reasons.append(f"❌ Failed actions: {failed_actions}")
206+
else:
207+
failed_reasons.append("❌ Actions failed")
208+
209+
if task_reward_basis & nl_bases and nl_reward_info.reward == 0:
210+
failed_nl = []
211+
if hasattr(nl_reward_info, "nl_assertions") and nl_reward_info.nl_assertions:
212+
failed_nl = [nla.nl_assertion for nla in nl_reward_info.nl_assertions if not nla.met]
213+
if failed_nl:
214+
failed_reasons.append(f"❌ Failed NL assertions: {failed_nl}")
215+
else:
216+
failed_reasons.append("❌ NL Assertions failed")
217+
218+
if task_reward_basis & comm_bases and communicate_reward_info.reward == 0:
219+
failed_comm = []
220+
if hasattr(communicate_reward_info, "communicate_checks") and communicate_reward_info.communicate_checks:
221+
failed_comm = [cc.info for cc in communicate_reward_info.communicate_checks if not cc.met]
222+
if failed_comm:
223+
failed_reasons.append(f"❌ Failed communication: {failed_comm}")
224+
else:
225+
failed_reasons.append("❌ Communication failed")
226+
227+
# If everything passed, show success
228+
reason = "\n".join(failed_reasons) if failed_reasons else "✅ All checks passed"
229+
230+
row.evaluation_result = EvaluateResult(
231+
score=reward,
232+
reason=reason,
233+
metrics={},
234+
)
235+
return row

0 commit comments

Comments
 (0)