Skip to content

Commit 368c44b

Browse files
authored
Tau2, Frozen Lake, and Lunar Lander pytest (#9)
* working pytest * comment * fix test * adding tests * full dataset
1 parent 7b3defb commit 368c44b

18 files changed

+893
-32
lines changed

.github/workflows/ci.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,13 @@ jobs:
8585
FIREWORKS_ACCOUNT_ID: ${{ secrets.FIREWORKS_ACCOUNT_ID }}
8686
PYTHONWARNINGS: "ignore::DeprecationWarning,ignore::RuntimeWarning"
8787
run: |
88-
# Run most tests in parallel, but explicitly ignore tests that manage their own servers
88+
# Run most tests in parallel, but explicitly ignore tests that manage their own servers or are slow
8989
uv run pytest \
9090
-n auto \
9191
--ignore=tests/test_batch_evaluation.py \
92+
--ignore=tests/pytest/test_frozen_lake.py \
93+
--ignore=tests/pytest/test_lunar_lander.py \
94+
--ignore=tests/pytest/test_tau_bench_airline.py \
9295
--cov=eval_protocol --cov-append --cov-report=xml --cov-report=term-missing -v --durations=10
9396
9497
- name: Store coverage file

eval_protocol/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from .mcp_env import (
1818
AnthropicPolicy,
1919
OpenAIPolicy,
20+
LiteLLMPolicy,
21+
FireworksPolicy,
2022
make,
2123
rollout,
2224
test_mcp,
@@ -60,6 +62,7 @@
6062
# MCP Environment API
6163
"make",
6264
"rollout",
65+
"LiteLLMPolicy",
6366
"AnthropicPolicy",
6467
"FireworksPolicy",
6568
"OpenAIPolicy",
@@ -73,10 +76,6 @@
7376
"mcp",
7477
]
7578

76-
# Add FireworksPolicy to exports if available
77-
if _FIREWORKS_AVAILABLE:
78-
__all__.insert(__all__.index("OpenAIPolicy") + 1, "FireworksPolicy")
79-
8079
from . import _version
8180

8281
__version__ = _version.get_versions()["version"]

eval_protocol/mcp/execution/manager.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,25 @@ async def _execute_with_semaphore(idx):
173173
# Convert trajectories to unified EvaluationRow format
174174
evaluation_rows = []
175175
for trajectory in trajectories:
176-
messages = [Message.model_validate(msg) for msg in trajectory.conversation_history]
176+
# Handle multimodal content by extracting text from complex content structures
177+
messages = []
178+
for msg in trajectory.conversation_history:
179+
# Create a copy to avoid modifying the original
180+
msg_dict = dict(msg)
181+
182+
# Handle multimodal content (list of content blocks) by extracting text
183+
if isinstance(msg_dict.get("content"), list):
184+
text_content = None
185+
for content_block in msg_dict["content"]:
186+
if isinstance(content_block, dict) and content_block.get("type") == "text":
187+
text_content = content_block.get("text")
188+
break
189+
msg_dict["content"] = text_content or ""
190+
191+
messages.append(Message.model_validate(msg_dict))
177192

178193
input_metadata = InputMetadata(
179-
row_id=trajectory.session.session_id,
194+
row_id=trajectory.session.dataset_row.id if trajectory.session.dataset_row else None,
180195
dataset_info=asdict(trajectory.session.dataset_row) if trajectory.session.dataset_row else {},
181196
completion_params=CompletionParams(
182197
model=policy.model_id,

eval_protocol/mcp_env.py

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,18 @@
1313
Usage remains the same:
1414
import eval_protocol as ep
1515
16-
# Load dataset with environment configuration and prompts
17-
dataset = load_jsonl("dataset.jsonl")
18-
1916
# Create general policy (environment-agnostic)
2017
policy = ep.FireworksPolicy(model_id="accounts/fireworks/models/qwen3-235b-a22b")
2118
22-
# Create environments with dataset-driven configuration
23-
envs = ep.make("http://localhost:8000/mcp", dataset=dataset)
19+
# Create environments with evaluation_rows configuration
20+
envs = ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows)
2421
2522
# Execute tool-calling rollouts
2623
evaluation_rows = await ep.rollout(envs, policy=policy, steps=512)
2724
2825
Key Features:
2926
- General tool-calling interface that works with any MCP environment
30-
- Dataset-driven configuration with system prompts and user prompt templates
27+
- EvaluationRow-driven configuration with system prompts and user prompt templates
3128
- Automatic MCP tool discovery from servers
3229
- **PROPER MCP PATTERN**: Initial state obtained from MCP resources during session establishment
3330
- Tools used only for actions/interactions, not for getting initial state
@@ -50,7 +47,7 @@
5047

5148
# Import all functionality from the new modular components
5249
from .mcp.execution.manager import ExecutionManager
53-
from .mcp.execution.policy import AnthropicPolicy, FireworksPolicy, LLMBasePolicy, OpenAIPolicy
50+
from .mcp.execution.policy import AnthropicPolicy, FireworksPolicy, LLMBasePolicy, OpenAIPolicy, LiteLLMPolicy
5451
from .mcp.session.manager import GeneralMCPVectorEnv
5552
from .models import EvaluationRow
5653
from .types import DatasetRow, MCPSession, MCPToolCall
@@ -60,18 +57,20 @@
6057

6158
def make(
6259
env_spec: str,
60+
evaluation_rows: Optional[List[EvaluationRow]] = None,
6361
dataset: Optional[List[Dict]] = None,
6462
n: Optional[int] = None,
6563
seeds: Optional[List[int]] = None,
6664
model_id: str = "unknown",
6765
user_prompt_formatter: Optional[Callable] = None,
6866
) -> GeneralMCPVectorEnv:
6967
"""
70-
Create general MCP environments driven by dataset configuration.
68+
Create general MCP environments driven by evaluation_rows configuration.
7169
7270
Args:
7371
env_spec: MCP server URL
74-
dataset: List of dataset rows with prompts and context (preferred)
72+
evaluation_rows: List of EvaluationRow objects containing messages and metadata (preferred)
73+
dataset: List of dataset entries (for backward compatibility)
7574
n: Number of environments (for backward compatibility)
7675
seeds: List of seeds (for backward compatibility)
7776
model_id: Model identifier
@@ -81,8 +80,10 @@ def make(
8180
General MCP environment that works with any MCP server
8281
8382
Example:
84-
# New dataset-driven approach (preferred)
85-
dataset = load_jsonl("dataset.jsonl")
83+
# EvaluationRow approach (preferred)
84+
envs = ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows)
85+
86+
# Dataset approach (backward compatibility)
8687
envs = ep.make("http://localhost:8000/mcp", dataset=dataset)
8788
8889
# Legacy approach (backward compatibility)
@@ -97,13 +98,39 @@ def make(
9798
if not base_url.endswith("/"):
9899
base_url += "/"
99100

100-
# Handle dataset-driven vs legacy approaches
101-
if dataset is not None:
102-
# New dataset-driven approach
101+
# Convert evaluation_rows to dataset format if provided
102+
internal_dataset = []
103+
104+
if evaluation_rows:
105+
for i, row in enumerate(evaluation_rows):
106+
dataset_info = row.input_metadata.dataset_info if row.input_metadata else {}
107+
108+
system_message = row.get_system_message()
109+
system_prompt = system_message.content or ""
110+
111+
dataset_entry = {
112+
"id": row.input_metadata.row_id if row.input_metadata and row.input_metadata.row_id else f"task_{i}",
113+
"system_prompt": system_prompt,
114+
"user_prompt_template": dataset_info.get("user_prompt_template", ""),
115+
"environment_context": dataset_info.get("environment_context", {}),
116+
"user_simulation": dataset_info.get("user_simulation", {}),
117+
"evaluation_criteria": dataset_info.get("evaluation_criteria", {})
118+
}
119+
internal_dataset.append(dataset_entry)
120+
elif dataset:
121+
# Use provided dataset directly for backward compatibility
122+
internal_dataset = dataset
123+
124+
dataset_rows = []
125+
sessions = []
126+
127+
# Handle evaluation_rows vs legacy approaches
128+
if internal_dataset:
129+
# New evaluation_rows approach
103130
dataset_rows = []
104131
sessions = []
105132

106-
for row in dataset:
133+
for row in internal_dataset:
107134
# Parse dataset row
108135
if isinstance(row, dict):
109136
# Handle seed from both old location (backward compatibility) and new location
@@ -138,7 +165,7 @@ def make(
138165
else:
139166
# Legacy approach for backward compatibility
140167
if n is None:
141-
raise ValueError("Either 'dataset' or 'n' must be provided")
168+
raise ValueError("Either 'evaluation_rows' or 'n' must be provided")
142169

143170
# Generate seeds if not provided
144171
if seeds is None:
@@ -178,6 +205,7 @@ async def rollout(
178205
envs: GeneralMCPVectorEnv,
179206
policy: Union[FireworksPolicy, LLMBasePolicy, Callable],
180207
*,
208+
evaluation_rows: Optional[List[EvaluationRow]] = None,
181209
dataset: Optional[List[Dict]] = None,
182210
model_id: Optional[str] = None,
183211
steps: int = 512,
@@ -191,13 +219,14 @@ async def rollout(
191219
192220
This works with ANY MCP environment because:
193221
1. Policy receives tool schemas and makes tool calls
194-
2. Environment prompts come from dataset
222+
2. Environment prompts come from evaluation_rows
195223
3. No hardcoded environment logic
196224
197225
Args:
198226
envs: Either a GeneralMCPVectorEnv instance or the MCP server URL
199227
policy: Policy that takes tool schemas, observations, prompts and returns tool calls
200-
dataset: Dataset used when envs is a URL (required for automatic env creation)
228+
evaluation_rows: EvaluationRow list used when envs is a URL (for automatic env creation)
229+
dataset: Dataset list used for backward compatibility when envs is a URL
201230
model_id: Model identifier used when creating environments. Defaults to ``policy.model_id`` when available.
202231
steps: Maximum steps per rollout
203232
openai_format_log_file: Optional file to log clean OpenAI format for terminated trajectories only
@@ -220,7 +249,7 @@ async def rollout(
220249
trajectories = await ep.rollout(
221250
"http://localhost:8000/mcp/",
222251
policy,
223-
dataset=my_dataset,
252+
evaluation_rows=my_evaluation_rows,
224253
model_id=policy.model_id,
225254
)
226255
@@ -233,11 +262,11 @@ async def rollout(
233262
"""
234263
# Automatically create environments if a base URL is provided
235264
if isinstance(envs, str):
236-
if dataset is None:
237-
raise ValueError("'dataset' must be provided when envs is a URL")
265+
if evaluation_rows is None and dataset is None:
266+
raise ValueError("Either 'evaluation_rows' or 'dataset' must be provided when envs is a URL")
238267

239268
auto_model_id = model_id or getattr(policy, "model_id", "unknown")
240-
envs = make(envs, dataset=dataset, model_id=auto_model_id)
269+
envs = make(envs, evaluation_rows=evaluation_rows, dataset=dataset, model_id=auto_model_id)
241270

242271
# Use the new ExecutionManager for execution
243272
execution_manager = ExecutionManager()
@@ -304,6 +333,7 @@ async def test_mcp(base_url: str, seeds: List[int]) -> Dict[str, Any]:
304333
"AnthropicPolicy",
305334
"FireworksPolicy",
306335
"OpenAIPolicy",
336+
"LiteLLMPolicy",
307337
"LLMBasePolicy", # New base class for OpenAI integration
308338
"GeneralMCPVectorEnv",
309339
"MCPToolCall",

eval_protocol/models.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,13 @@ def get_conversation_length(self) -> int:
243243
"""Returns the number of messages in the conversation."""
244244
return len(self.messages)
245245

246+
def get_system_message(self) -> Message:
247+
"""Returns the system message from the conversation. Returns empty Message if none found."""
248+
system_messages = [msg for msg in self.messages if msg.role == "system"]
249+
if not system_messages:
250+
return Message(role="system", content="")
251+
return system_messages[0]
252+
246253
def get_assistant_messages(self) -> List[Message]:
247254
"""Returns only the assistant messages from the conversation."""
248255
return [msg for msg in self.messages if msg.role == "assistant"]

0 commit comments

Comments
 (0)