1010from pathlib import Path
1111from typing import Any , Dict , List
1212
13- from eval_protocol .models import EvaluateResult , EvaluationRow , Message , InputMetadata , CompletionParams
13+ from eval_protocol .models import CompletionParams , EvaluateResult , EvaluationRow , InputMetadata , Message
1414from eval_protocol .pytest import evaluation_test
1515from eval_protocol .pytest .default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor
16-
1716from vendor .tau2 .data_model .message import (
1817 AssistantMessage ,
1918 SystemMessage ,
2827from vendor .tau2 .evaluator .evaluator_nl_assertions import NLAssertionsEvaluator
2928from vendor .tau2 .registry import registry
3029
30+
3131def tau_bench_airline_to_evaluation_row (data : List [Dict [str , Any ]]) -> List [EvaluationRow ]:
3232 """
3333 Convert entries from airline dataset to EvaluationRow objects.
3434 """
3535 rows = []
3636 test_dir = Path (__file__ ).parent .parent .parent / "examples" / "tau2_mcp" / "tests"
37-
37+
3838 # Load system prompt from file so we can change it in one place
3939 domain = data [0 ]["environment_context" ]["domain" ]
4040 prompt_file = test_dir / f"system_prompts/{ domain } _agent_system_prompt.md"
41-
41+
4242 with open (prompt_file , "r" ) as f :
4343 system_prompt = f .read ().strip ()
44-
44+
4545 for row in data :
4646 eval_row = EvaluationRow (
4747 messages = [Message (role = "system" , content = system_prompt )],
@@ -52,47 +52,46 @@ def tau_bench_airline_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Eval
5252 "user_simulation" : row ["user_simulation" ],
5353 "evaluation_criteria" : row ["evaluation_criteria" ],
5454 "user_prompt_template" : row ["user_prompt_template" ],
55- }
55+ },
5656 ),
5757 )
58-
58+
5959 rows .append (eval_row )
60-
60+
6161 return rows
6262
63+
6364@evaluation_test (
6465 input_dataset = ["tests/pytest/data/airline_dataset.jsonl" ],
6566 dataset_adapter = tau_bench_airline_to_evaluation_row ,
66- model = ["fireworks_ai/accounts/fireworks/models/kimi-k2-instruct " ],
67- rollout_input_params = [{"temperature" : 0.0 , "max_tokens" : 4096 }],
67+ model = ["fireworks_ai/accounts/fireworks/models/gpt-oss-120b " ],
68+ rollout_input_params = [{"temperature" : 0.8 , "max_tokens" : 4096 , "reasoning_effort" : "high" }],
6869 rollout_processor = default_mcp_gym_rollout_processor ,
6970 threshold_of_success = 0.4 ,
7071 num_runs = 1 ,
7172 mode = "pointwise" ,
72- max_concurrent_rollouts = 32 ,
73+ max_concurrent_rollouts = 16 ,
7374 server_script_path = "examples/tau2_mcp/server.py" ,
7475)
7576def test_tau_bench_airline_evaluation (row : EvaluationRow ) -> EvaluationRow :
7677 """
7778 Test tau bench airline evaluation using the pytest framework.
78-
79+
7980 This test now uses the tau_bench_airline_reward function which automatically
8081 extracts evaluation criteria from dataset entries. No wrapper needed!
81-
82+
8283 Args:
83- input_dataset: List of EvaluationRow objects from tau bench airline dataset
84- input_params: Model parameters (temperature, max_tokens, etc.)
85- model: Model identifier
86-
84+ row: EvaluationRow object from tau bench airline dataset after rollout
85+
8786 Returns:
88- List of evaluated EvaluationRow objects with scores and feedback
87+ EvaluationRow with tau2 evaluation results
8988 """
9089 messages = row .messages
91-
90+
9291 # Get evaluation criteria and user_simulation from input_metadata.dataset_info
9392 dataset_info = row .input_metadata .dataset_info if row .input_metadata else {}
9493 evaluation_criteria = dataset_info .get ("evaluation_criteria" , {})
95-
94+
9695 nl_assertions = evaluation_criteria .get ("nl_assertions" , [])
9796 communicate_info = evaluation_criteria .get ("communicate_info" , [])
9897 actions = evaluation_criteria .get ("actions" , [])
@@ -131,10 +130,8 @@ def test_tau_bench_airline_evaluation(row: EvaluationRow) -> EvaluationRow:
131130 communicate_info = communicate_info ,
132131 actions = actions ,
133132 reward_basis = [
134- RewardType .NL_ASSERTION ,
135133 RewardType .DB ,
136134 RewardType .COMMUNICATE ,
137- RewardType .ACTION ,
138135 ],
139136 )
140137
@@ -230,4 +227,4 @@ def test_tau_bench_airline_evaluation(row: EvaluationRow) -> EvaluationRow:
230227 reason = reason ,
231228 metrics = {},
232229 )
233- return row
230+ return row
0 commit comments