Skip to content

Commit df2d034

Browse files
committed
make completionparams optional and clean up tests
1 parent 1ceb651 commit df2d034

File tree

4 files changed

+32
-41
lines changed

4 files changed

+32
-41
lines changed

eval_protocol/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ class InputMetadata(BaseModel):
181181
model_config = ConfigDict(extra="allow")
182182

183183
row_id: Optional[str] = Field(None, description="Unique string to ID the row")
184-
completion_params: CompletionParams = Field(..., description="Completion endpoint parameters used")
184+
completion_params: Optional[CompletionParams] = Field(None, description="Completion endpoint parameters used")
185185
dataset_info: Optional[Dict[str, Any]] = Field(
186186
None, description="Dataset row details: seed, system_prompt, environment_context, etc"
187187
)

tests/pytest/test_frozen_lake.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,19 @@ def frozen_lake_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluation
1919
"""
2020
rows = []
2121

22-
for entry in data:
23-
row = EvaluationRow(
24-
messages=[Message(role="system", content=entry.get("system_prompt", ""))],
22+
for row in data:
23+
eval_row = EvaluationRow(
24+
messages=[Message(role="system", content=row["system_prompt"])],
2525
input_metadata=InputMetadata(
26-
row_id=entry.get("id"),
27-
completion_params=CompletionParams(model="placeholder"), # This gets populated by the rollout processor
26+
row_id=row["id"],
2827
dataset_info={
29-
"environment_context": entry.get("environment_context", {}),
30-
"user_prompt_template": entry.get("user_prompt_template", ""),
28+
"environment_context": row["environment_context"],
29+
"user_prompt_template": row["user_prompt_template"],
3130
}
3231
)
3332
)
3433

35-
rows.append(row)
34+
rows.append(eval_row)
3635

3736
return rows
3837

tests/pytest/test_lunar_lander.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,19 @@ def lunar_lander_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluatio
1818
"""
1919
rows = []
2020

21-
for entry in data:
22-
row = EvaluationRow(
23-
messages=[Message(role="system", content=entry.get("system_prompt", ""))],
21+
for row in data:
22+
eval_row = EvaluationRow(
23+
messages=[Message(role="system", content=row["system_prompt"])],
2424
input_metadata=InputMetadata(
25-
row_id=entry.get("id"),
26-
completion_params=CompletionParams(model="placeholder"), # This gets populated by the rollout processor
25+
row_id=row["id"],
2726
dataset_info={
28-
"environment_context": entry.get("environment_context", {}),
29-
"user_prompt_template": entry.get("user_prompt_template", ""),
27+
"environment_context": row["environment_context"],
28+
"user_prompt_template": row["user_prompt_template"],
3029
}
3130
)
3231
)
3332

34-
rows.append(row)
33+
rows.append(eval_row)
3534

3635
return rows
3736

tests/pytest/test_tau_bench_airline.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -35,35 +35,28 @@ def tau_bench_airline_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Eval
3535
rows = []
3636
test_dir = Path(__file__).parent.parent.parent / "examples" / "tau2_mcp" / "tests"
3737

38-
for entry in data:
39-
# Load system prompt from file so we can change it in one place
40-
domain = entry["environment_context"]["domain"]
41-
prompt_file = test_dir / f"system_prompts/{domain}_agent_system_prompt.md"
42-
43-
with open(prompt_file, "r") as f:
44-
system_prompt = f.read().strip()
45-
46-
messages = [Message(role="system", content=system_prompt)]
47-
48-
evaluation_criteria = entry.get("evaluation_criteria", {})
49-
user_simulation = entry.get("user_simulation", {})
50-
user_prompt_template = entry.get("user_prompt_template", "")
51-
52-
row = EvaluationRow(
53-
messages=messages,
38+
# Load system prompt from file so we can change it in one place
39+
domain = data[0]["environment_context"]["domain"]
40+
prompt_file = test_dir / f"system_prompts/{domain}_agent_system_prompt.md"
41+
42+
with open(prompt_file, "r") as f:
43+
system_prompt = f.read().strip()
44+
45+
for row in data:
46+
eval_row = EvaluationRow(
47+
messages=[Message(role="system", content=system_prompt)],
5448
input_metadata=InputMetadata(
55-
row_id=entry.get("id"),
56-
completion_params=CompletionParams(model="placeholder"), # This gets populated by the rollout processor
49+
row_id=row["id"],
5750
dataset_info={
58-
"environment_context": entry.get("environment_context"),
59-
"user_simulation": user_simulation,
60-
"evaluation_criteria": evaluation_criteria,
61-
"user_prompt_template": user_prompt_template,
51+
"environment_context": row["environment_context"],
52+
"user_simulation": row["user_simulation"],
53+
"evaluation_criteria": row["evaluation_criteria"],
54+
"user_prompt_template": row["user_prompt_template"],
6255
}
6356
),
6457
)
6558

66-
rows.append(row)
59+
rows.append(eval_row)
6760

6861
return rows
6962

@@ -94,7 +87,7 @@ def save_single_trajectory(trajectory_record: Dict, row_id: str, output_dir: str
9487
rollout_input_params=[{"temperature": 0.0, "max_tokens": 4096}],
9588
rollout_processor=default_mcp_gym_rollout_processor,
9689
threshold_of_success=0.4,
97-
num_runs=4,
90+
num_runs=1,
9891
mode="pointwise",
9992
max_concurrent_rollouts=32,
10093
server_script_path="examples/tau2_mcp/server.py",

0 commit comments

Comments
 (0)