Skip to content

Commit fffd75c

Browse files
authored
Modified Tau (#25)
* changed tests * more change * gpt-oss example e2e * uv lock * fix test
1 parent ae6117c commit fffd75c

File tree

7 files changed

+89
-63
lines changed

7 files changed

+89
-63
lines changed

eval_protocol/mcp/execution/manager.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ async def execute_rollouts(
4242
steps: int = 512,
4343
openai_format_log_file: Optional[str] = None,
4444
max_concurrent_rollouts: int = 8,
45+
evaluation_rows: Optional[List[EvaluationRow]] = None,
4546
) -> List[EvaluationRow]:
4647
"""
4748
Execute general rollouts using tool calling interface with automatic record/playback.
@@ -135,9 +136,11 @@ async def _execute_with_semaphore(idx):
135136
# Add note about control plane separation
136137
logger.info(f"🎛️ Trajectories include control plane separation")
137138

138-
# Convert trajectories to unified EvaluationRow format
139-
evaluation_rows = []
140-
for trajectory in trajectories:
139+
# Convert trajectories to unified EvaluationRow format. If no evaluation_rows are provided, create empty ones for backwards compatibility.
140+
if evaluation_rows is None:
141+
evaluation_rows = [EvaluationRow(messages=[], input_metadata=InputMetadata()) for _ in trajectories]
142+
143+
for idx, trajectory in enumerate(trajectories):
141144
# Handle multimodal content by extracting text from complex content structures
142145
messages = []
143146
for msg in trajectory.conversation_history:
@@ -155,26 +158,15 @@ async def _execute_with_semaphore(idx):
155158

156159
messages.append(Message.model_validate(msg_dict))
157160

158-
input_metadata = InputMetadata(
159-
row_id=trajectory.session.dataset_row.id if trajectory.session.dataset_row else None,
160-
dataset_info=asdict(trajectory.session.dataset_row) if trajectory.session.dataset_row else {},
161-
completion_params=CompletionParams(
162-
model=policy.model_id,
163-
temperature=getattr(policy, "temperature", None),
164-
max_tokens=getattr(policy, "max_tokens", None),
165-
max_tool_calls=getattr(policy, "max_tools_per_turn", None),
166-
),
167-
session_data={
168-
"timestamp": time.time(),
169-
},
170-
)
171-
evaluation_row = EvaluationRow(
172-
messages=messages,
173-
tools=shared_tool_schema,
174-
input_metadata=input_metadata,
175-
usage=trajectory.usage,
161+
evaluation_rows[idx].messages = messages
162+
evaluation_rows[idx].tools = shared_tool_schema
163+
evaluation_rows[idx].usage = trajectory.usage
164+
evaluation_rows[idx].input_metadata.completion_params = CompletionParams(
165+
model=policy.model_id,
166+
temperature=getattr(policy, "temperature", None),
167+
max_tokens=getattr(policy, "max_tokens", None),
168+
max_tool_calls=getattr(policy, "max_tools_per_turn", None),
176169
)
177-
evaluation_rows.append(evaluation_row)
178170

179171
return evaluation_rows
180172

eval_protocol/mcp/execution/policy.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ def __init__(
6464
self.num_retries = num_retries
6565
self.retry_strategy = retry_strategy
6666

67+
# Store additional API parameters from kwargs
68+
self.additional_params = kwargs
69+
6770
# Only initialize LiteLLM in live mode (not in playback mode)
6871
if not self._is_playback:
6972
self._setup_litellm_caching(use_caching, cache_type, redis_url)
@@ -166,6 +169,14 @@ async def _make_llm_call(self, messages: List[Dict], tools: List[Dict]) -> Dict:
166169
"base_url": self.base_url,
167170
}
168171

172+
# Add additional parameters from kwargs (like reasoning_effort)
173+
if self.additional_params:
174+
request_params.update(self.additional_params)
175+
176+
# Tell LiteLLM to allow reasoning_effort if it's present
177+
if "reasoning_effort" in self.additional_params:
178+
request_params["allowed_openai_params"] = ["reasoning_effort"]
179+
169180
# Add tools if provided
170181
if tools:
171182
request_params["tools"] = tools

eval_protocol/mcp_env.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,19 @@
4040
- Resources provide static/configuration data, tools provide dynamic actions
4141
"""
4242

43+
import asyncio
44+
4345
# For legacy compatibility - import the facade functions
4446
import logging
4547
import random
4648
from typing import Any, Callable, Dict, List, Optional, Union
4749

4850
# Import all functionality from the new modular components
4951
from .mcp.execution.manager import ExecutionManager
50-
from .mcp.execution.policy import AnthropicPolicy, FireworksPolicy, LLMBasePolicy, OpenAIPolicy, LiteLLMPolicy
52+
from .mcp.execution.policy import AnthropicPolicy, FireworksPolicy, LiteLLMPolicy, LLMBasePolicy, OpenAIPolicy
5153
from .mcp.session.manager import GeneralMCPVectorEnv
5254
from .models import EvaluationRow
5355
from .types import DatasetRow, MCPSession, MCPToolCall
54-
import asyncio
5556

5657
logger = logging.getLogger(__name__)
5758

@@ -288,7 +289,7 @@ async def rollout(
288289
execution_manager = ExecutionManager()
289290

290291
return await execution_manager.execute_rollouts(
291-
envs, policy, steps, openai_format_log_file, max_concurrent_rollouts
292+
envs, policy, steps, openai_format_log_file, max_concurrent_rollouts, evaluation_rows
292293
)
293294

294295

eval_protocol/pytest/default_mcp_gym_rollout_processor.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
import asyncio
2+
import atexit
23
import os
4+
import signal
5+
import socket
36
import subprocess
47
import time
5-
import socket
68
from pathlib import Path
79
from typing import List, Optional
810

911
import eval_protocol as ep
1012
from eval_protocol.models import EvaluationRow, Message
1113
from eval_protocol.pytest.types import RolloutProcessorConfig
1214

13-
import atexit
14-
import signal
15-
1615

1716
class MCPServerManager:
1817
"""Manages MCP server lifecycle for testing."""
@@ -188,13 +187,16 @@ async def default_mcp_gym_rollout_processor(
188187
"""
189188
Rollout processor for tau bench environments.
190189
190+
191191
This processor starts an MCP server, creates tau bench environments, and runs rollouts
192192
using the eval_protocol framework, following the pattern from test_tau2_e2e.py.
193193
194+
194195
Args:
195196
rows: List of EvaluationRow objects containing messages and dataset info in input_metadata
196197
config: RolloutProcessorConfig with model and other parameters
197198
199+
198200
Returns:
199201
List of EvaluationRow objects with completed conversations
200202
"""
@@ -207,6 +209,7 @@ async def default_mcp_gym_rollout_processor(
207209
model_id=config.model,
208210
temperature=config.input_params.get("temperature", 0.0),
209211
max_tokens=config.input_params.get("max_tokens", 4096),
212+
reasoning_effort=config.input_params.get("reasoning_effort", None),
210213
)
211214

212215
# Create MCP environments directly from evaluation_rows
@@ -218,7 +221,11 @@ async def default_mcp_gym_rollout_processor(
218221

219222
# Run rollout with environments and policy
220223
evaluation_rows = await ep.rollout(
221-
envs, policy=policy, steps=config.steps, max_concurrent_rollouts=config.max_concurrent_rollouts
224+
envs,
225+
policy=policy,
226+
evaluation_rows=rows,
227+
steps=config.steps,
228+
max_concurrent_rollouts=config.max_concurrent_rollouts,
222229
)
223230

224231
return evaluation_rows

tests/pytest/test_tau_bench_airline.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@
1010
from pathlib import Path
1111
from typing import Any, Dict, List
1212

13-
from eval_protocol.models import EvaluateResult, EvaluationRow, Message, InputMetadata, CompletionParams
13+
from eval_protocol.models import CompletionParams, EvaluateResult, EvaluationRow, InputMetadata, Message
1414
from eval_protocol.pytest import evaluation_test
1515
from eval_protocol.pytest.default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor
16-
1716
from vendor.tau2.data_model.message import (
1817
AssistantMessage,
1918
SystemMessage,
@@ -28,20 +27,21 @@
2827
from vendor.tau2.evaluator.evaluator_nl_assertions import NLAssertionsEvaluator
2928
from vendor.tau2.registry import registry
3029

30+
3131
def tau_bench_airline_to_evaluation_row(data: List[Dict[str, Any]]) -> List[EvaluationRow]:
3232
"""
3333
Convert entries from airline dataset to EvaluationRow objects.
3434
"""
3535
rows = []
3636
test_dir = Path(__file__).parent.parent.parent / "examples" / "tau2_mcp" / "tests"
37-
37+
3838
# Load system prompt from file so we can change it in one place
3939
domain = data[0]["environment_context"]["domain"]
4040
prompt_file = test_dir / f"system_prompts/{domain}_agent_system_prompt.md"
41-
41+
4242
with open(prompt_file, "r") as f:
4343
system_prompt = f.read().strip()
44-
44+
4545
for row in data:
4646
eval_row = EvaluationRow(
4747
messages=[Message(role="system", content=system_prompt)],
@@ -52,47 +52,46 @@ def tau_bench_airline_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Eval
5252
"user_simulation": row["user_simulation"],
5353
"evaluation_criteria": row["evaluation_criteria"],
5454
"user_prompt_template": row["user_prompt_template"],
55-
}
55+
},
5656
),
5757
)
58-
58+
5959
rows.append(eval_row)
60-
60+
6161
return rows
6262

63+
6364
@evaluation_test(
6465
input_dataset=["tests/pytest/data/airline_dataset.jsonl"],
6566
dataset_adapter=tau_bench_airline_to_evaluation_row,
66-
model=["fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"],
67-
rollout_input_params=[{"temperature": 0.0, "max_tokens": 4096}],
67+
model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"],
68+
rollout_input_params=[{"temperature": 0.8, "max_tokens": 4096, "reasoning_effort": "high"}],
6869
rollout_processor=default_mcp_gym_rollout_processor,
6970
threshold_of_success=0.4,
7071
num_runs=1,
7172
mode="pointwise",
72-
max_concurrent_rollouts=32,
73+
max_concurrent_rollouts=16,
7374
server_script_path="examples/tau2_mcp/server.py",
7475
)
7576
def test_tau_bench_airline_evaluation(row: EvaluationRow) -> EvaluationRow:
7677
"""
7778
Test tau bench airline evaluation using the pytest framework.
78-
79+
7980
This test now uses the tau_bench_airline_reward function which automatically
8081
extracts evaluation criteria from dataset entries. No wrapper needed!
81-
82+
8283
Args:
83-
input_dataset: List of EvaluationRow objects from tau bench airline dataset
84-
input_params: Model parameters (temperature, max_tokens, etc.)
85-
model: Model identifier
86-
84+
row: EvaluationRow object from tau bench airline dataset after rollout
85+
8786
Returns:
88-
List of evaluated EvaluationRow objects with scores and feedback
87+
EvaluationRow with tau2 evaluation results
8988
"""
9089
messages = row.messages
91-
90+
9291
# Get evaluation criteria and user_simulation from input_metadata.dataset_info
9392
dataset_info = row.input_metadata.dataset_info if row.input_metadata else {}
9493
evaluation_criteria = dataset_info.get("evaluation_criteria", {})
95-
94+
9695
nl_assertions = evaluation_criteria.get("nl_assertions", [])
9796
communicate_info = evaluation_criteria.get("communicate_info", [])
9897
actions = evaluation_criteria.get("actions", [])
@@ -131,10 +130,8 @@ def test_tau_bench_airline_evaluation(row: EvaluationRow) -> EvaluationRow:
131130
communicate_info=communicate_info,
132131
actions=actions,
133132
reward_basis=[
134-
RewardType.NL_ASSERTION,
135133
RewardType.DB,
136134
RewardType.COMMUNICATE,
137-
RewardType.ACTION,
138135
],
139136
)
140137

@@ -230,4 +227,4 @@ def test_tau_bench_airline_evaluation(row: EvaluationRow) -> EvaluationRow:
230227
reason=reason,
231228
metrics={},
232229
)
233-
return row
230+
return row

tests/test_rollout_control_plane_integration.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,7 @@ async def test_rollout_creates_envs_from_url(self):
519519
5,
520520
None,
521521
8,
522+
None,
522523
)
523524

524525
assert result == ["ok"]

vendor/tau2/evaluator/evaluator_nl_assertions.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import json
2+
from typing import List
3+
4+
from pydantic import BaseModel
25

36
from vendor.tau2.config import DEFAULT_LLM_NL_ASSERTIONS, DEFAULT_LLM_NL_ASSERTIONS_ARGS
47
from vendor.tau2.data_model.message import Message, SystemMessage, UserMessage
@@ -7,6 +10,20 @@
710
from vendor.tau2.utils.llm_utils import generate
811

912

13+
class NLAssertionResult(BaseModel):
14+
"""Individual NL assertion evaluation result."""
15+
16+
expectedOutcome: str
17+
reasoning: str
18+
metExpectation: bool
19+
20+
21+
class NLAssertionsResponse(BaseModel):
22+
"""Complete NL assertions evaluation response."""
23+
24+
results: List[NLAssertionResult]
25+
26+
1027
class NLAssertionsEvaluator:
1128
"""
1229
Judge that evaluates whether a trajectory adheres to all the natural-language assertions.
@@ -37,9 +54,7 @@ def calculate_reward(
3754
reward_breakdown={RewardType.NL_ASSERTION: 1.0},
3855
)
3956

40-
nl_assertions_checks = cls.evaluate_nl_assertions(
41-
full_trajectory, nl_assertions
42-
)
57+
nl_assertions_checks = cls.evaluate_nl_assertions(full_trajectory, nl_assertions)
4358

4459
# Calculate reward: 1 if all expectations are met, 0 otherwise
4560
all_expectations_met = all(result.met for result in nl_assertions_checks)
@@ -70,9 +85,7 @@ def evaluate_nl_assertions(
7085
- metExpectation: Boolean indicating if the assertion was met
7186
- reasoning: Explanation for the evaluation
7287
"""
73-
trajectory_str = "\n".join(
74-
[f"{message.role}: {message.content}" for message in trajectory]
75-
)
88+
trajectory_str = "\n".join([f"{message.role}: {message.content}" for message in trajectory])
7689
# System prompt similar to the TypeScript implementation
7790
system_prompt = """
7891
TASK
@@ -86,7 +99,7 @@ def evaluate_nl_assertions(
8699
- `reasoning`: a short explanation for your classification
87100
- `metExpectation`: `true` if the agent satisfies the expected outcomes, `false` otherwise
88101
- `expectedOutcome`: repeat the expectation from the input that you are grading
89-
102+
90103
Example response structure:
91104
{
92105
"results": [
@@ -102,7 +115,7 @@ def evaluate_nl_assertions(
102115
user_prompt = f"""
103116
conversation:
104117
{trajectory_str}
105-
118+
106119
expectedOutcomes:
107120
{nl_assertions}
108121
"""
@@ -115,8 +128,12 @@ def evaluate_nl_assertions(
115128
assistant_message = generate(
116129
model=DEFAULT_LLM_NL_ASSERTIONS,
117130
messages=messages,
118-
**DEFAULT_LLM_NL_ASSERTIONS_ARGS,
119-
)
131+
temperature=0.0,
132+
response_format={
133+
"type": "json_schema",
134+
"json_schema": {"name": "NLAssertionsResponse", "schema": NLAssertionsResponse.model_json_schema()},
135+
},
136+
) # Adding constrained generation to ensure the response is a valid JSON object
120137
result_data = json.loads(assistant_message.content)
121138
return [
122139
NLAssertionCheck(

0 commit comments

Comments
 (0)