@@ -35,35 +35,28 @@ def tau_bench_airline_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Eval
3535 rows = []
3636 test_dir = Path (__file__ ).parent .parent .parent / "examples" / "tau2_mcp" / "tests"
3737
38- for entry in data :
39- # Load system prompt from file so we can change it in one place
40- domain = entry ["environment_context" ]["domain" ]
41- prompt_file = test_dir / f"system_prompts/{ domain } _agent_system_prompt.md"
42-
43- with open (prompt_file , "r" ) as f :
44- system_prompt = f .read ().strip ()
45-
46- messages = [Message (role = "system" , content = system_prompt )]
47-
48- evaluation_criteria = entry .get ("evaluation_criteria" , {})
49- user_simulation = entry .get ("user_simulation" , {})
50- user_prompt_template = entry .get ("user_prompt_template" , "" )
51-
52- row = EvaluationRow (
53- messages = messages ,
38+ # Load system prompt from file so we can change it in one place
39+ domain = data [0 ]["environment_context" ]["domain" ]
40+ prompt_file = test_dir / f"system_prompts/{ domain } _agent_system_prompt.md"
41+
42+ with open (prompt_file , "r" ) as f :
43+ system_prompt = f .read ().strip ()
44+
45+ for row in data :
46+ eval_row = EvaluationRow (
47+ messages = [Message (role = "system" , content = system_prompt )],
5448 input_metadata = InputMetadata (
55- row_id = entry .get ("id" ),
56- completion_params = CompletionParams (model = "placeholder" ), # This gets populated by the rollout processor
49+ row_id = row ["id" ],
5750 dataset_info = {
58- "environment_context" : entry . get ( "environment_context" ) ,
59- "user_simulation" : user_simulation ,
60- "evaluation_criteria" : evaluation_criteria ,
61- "user_prompt_template" : user_prompt_template ,
51+ "environment_context" : row [ "environment_context" ] ,
52+ "user_simulation" : row [ " user_simulation" ] ,
53+ "evaluation_criteria" : row [ " evaluation_criteria" ] ,
54+ "user_prompt_template" : row [ " user_prompt_template" ] ,
6255 }
6356 ),
6457 )
6558
66- rows .append (row )
59+ rows .append (eval_row )
6760
6861 return rows
6962
@@ -94,7 +87,7 @@ def save_single_trajectory(trajectory_record: Dict, row_id: str, output_dir: str
9487 rollout_input_params = [{"temperature" : 0.0 , "max_tokens" : 4096 }],
9588 rollout_processor = default_mcp_gym_rollout_processor ,
9689 threshold_of_success = 0.4 ,
97- num_runs = 4 ,
90+ num_runs = 1 ,
9891 mode = "pointwise" ,
9992 max_concurrent_rollouts = 32 ,
10093 server_script_path = "examples/tau2_mcp/server.py" ,
0 commit comments