Skip to content

Commit 1a37ee1

Browse files
authored
multi rollout support (#24)
* specialize duty for session and execution manager for further customization * fix ut * support reset mcp session * fix ut * lower halluciantion threshold * fix ut * fix ut * fix ut * remove useless import * clean * remove print
1 parent 4dbac4d commit 1a37ee1

File tree

22 files changed

+192
-150
lines changed

22 files changed

+192
-150
lines changed

eval_protocol/mcp/client/connection.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ async def initialize_session(self, session: MCPSession) -> None:
101101

102102
# Update the session ID to match what the server generated
103103
session.session_id = server_session_id
104-
logger.debug(f"Updated session ID to match server: {server_session_id}")
104+
logger.info(f"Updated session ID to match server: {server_session_id}")
105105

106106
# PRE-WARM: Discover and cache tools immediately after session initialization
107107
# This prevents concurrent list_tools() calls later
@@ -133,6 +133,24 @@ async def _prewarm_tools_cache(self, session: MCPSession) -> None:
133133
self._tools_cache[cache_key] = tool_schemas
134134
logger.debug(f"✅ PRE-WARMED {len(tool_schemas)} tools for{cache_key}")
135135

136+
async def reset_session(self, session: MCPSession) -> None:
137+
"""
138+
Clean session data in remote mcp server for the given session
139+
"""
140+
import httpx
141+
142+
base_url = session.base_url.rstrip("/").removesuffix("/mcp")
143+
url = f"{base_url}/control/reset_session"
144+
145+
headers = {"mcp-session-id": session.session_id}
146+
body = {"seed": session.seed}
147+
148+
timeout = httpx.Timeout(3.0)
149+
async with httpx.AsyncClient(timeout=timeout) as client:
150+
resp = await client.post(url, headers=headers, json=body)
151+
resp.raise_for_status()
152+
logger.debug(f"Session {session.session_id}: reset_session -> {resp.json()}")
153+
136154
async def discover_tools(self, session: MCPSession) -> List[Dict]:
137155
"""
138156
Discover available tools from an MCP session.

eval_protocol/mcp/execution/manager.py

Lines changed: 3 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
from ...models import CompletionParams, EvaluationRow, InputMetadata, Message
2424
from ...types import MCPSession, MCPToolCall, TerminationReason, Trajectory
25-
from ..client.connection import MCPConnectionManager
2625

2726
if TYPE_CHECKING:
2827
from ..session.manager import GeneralMCPVectorEnv
@@ -33,43 +32,9 @@
3332

3433
class ExecutionManager:
3534
"""
36-
Unified manager that handles both MCP session lifecycle and rollout execution.
37-
38-
Combines the functionality of SessionManager and RolloutManager for better
39-
organization and reduced complexity.
35+
Manage rollout for MCP environments.
4036
"""
4137

42-
def __init__(self):
43-
"""Initialize the execution manager."""
44-
self.connection_manager = MCPConnectionManager()
45-
46-
async def initialize_sessions(self, sessions: List[MCPSession]) -> None:
47-
"""
48-
Initialize multiple MCP sessions in parallel.
49-
50-
Args:
51-
sessions: List of MCPSessions to initialize
52-
"""
53-
tasks = [self.connection_manager.initialize_session(session) for session in sessions]
54-
await asyncio.gather(*tasks)
55-
56-
async def close_sessions(self, sessions: List[MCPSession]) -> None:
57-
"""
58-
Close multiple MCP sessions in parallel.
59-
60-
Args:
61-
sessions: List of MCPSessions to close
62-
"""
63-
tasks = [asyncio.create_task(self.connection_manager.close_session(session)) for session in sessions]
64-
65-
if tasks:
66-
try:
67-
# Wait for all close operations to complete
68-
await asyncio.gather(*tasks, return_exceptions=True)
69-
except asyncio.CancelledError:
70-
# Handle cancellation gracefully (especially important for Python 3.12)
71-
logger.debug("Close operation was cancelled, but sessions are marked as closed")
72-
7338
async def execute_rollouts(
7439
self,
7540
envs: "GeneralMCPVectorEnv",
@@ -178,7 +143,7 @@ async def _execute_with_semaphore(idx):
178143
for msg in trajectory.conversation_history:
179144
# Create a copy to avoid modifying the original
180145
msg_dict = dict(msg)
181-
146+
182147
# Handle multimodal content (list of content blocks) by extracting text
183148
if isinstance(msg_dict.get("content"), list):
184149
text_content = None
@@ -187,7 +152,7 @@ async def _execute_with_semaphore(idx):
187152
text_content = content_block.get("text")
188153
break
189154
msg_dict["content"] = text_content or ""
190-
155+
191156
messages.append(Message.model_validate(msg_dict))
192157

193158
input_metadata = InputMetadata(

eval_protocol/mcp/mcpgym.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def __init__(self, server_name: str, adapter: EnvironmentAdapter, seed: Optional
116116
# Register tools and control plane endpoints
117117
self._register_tools()
118118
self._discover_and_register_control_plane_endpoints()
119+
self._register_session_reset_endpoint()
119120

120121
def _get_session_id(self, ctx: Context) -> str:
121122
"""
@@ -227,6 +228,28 @@ def _get_or_create_session(self, ctx: Context) -> Dict[str, Any]:
227228

228229
return self.sessions[session_id]
229230

231+
def _register_session_reset_endpoint(self):
232+
233+
@self.mcp.custom_route("/control/reset_session", methods=["POST"])
234+
async def reset_session_endpoint(request: Request) -> JSONResponse:
235+
session_id = request.headers.get("mcp-session-id")
236+
body = await request.json()
237+
seed = body.get("seed", None)
238+
print(f"🔍 _register_session_reset_endpoint: Resetting session, session_id: {session_id}, seed: {seed}")
239+
if not session_id:
240+
return JSONResponse({"error": "Missing mcp-session-id header"}, status_code=400)
241+
with self.session_lock:
242+
if session_id in self.sessions:
243+
env, obs, _ = self._new_env(seed=seed)
244+
self.sessions[session_id] = {
245+
"env": env,
246+
"obs": obs,
247+
"session_data": {},
248+
"session_id": session_id,
249+
}
250+
print(f"🔍 _register_session_reset_endpoint: Finished reset session, session_id: {session_id}")
251+
return JSONResponse({"message": "Session reset successfully"})
252+
230253
def _discover_and_register_control_plane_endpoints(self):
231254
"""
232255
Discover and register control plane endpoints on the subclass instance.
@@ -323,7 +346,7 @@ def _update_control_plane(self, reward: float, terminated: bool, truncated: bool
323346

324347
# Log control plane update (for debugging)
325348
print(
326-
f"🎛️ Control plane updated: reward={reward}, terminated={terminated}, step={self.control_plane_state['step_count']}"
349+
f"🎛️ Control plane updated: reward={reward}, terminated={terminated}, step={self.control_plane_state['step_count']}, total_reward={self.control_plane_state['total_reward']}"
327350
)
328351

329352
def _get_or_create_session_control_plane(self, session_id: str) -> Dict[str, Any]:
@@ -365,7 +388,7 @@ def _update_session_control_plane(
365388

366389
# Log control plane update
367390
print(
368-
f"🎛️ Session {session_id[:16]}... control plane: reward={reward}, terminated={terminated}, step={control_plane['step_count']}"
391+
f"🎛️ Session {session_id[:16]}... control plane: reward={reward}, terminated={terminated}, step={control_plane['step_count']}, total_reward={control_plane['total_reward']}"
369392
)
370393

371394
def get_control_plane_state(self, session_id: str) -> Optional[Dict[str, Any]]:

eval_protocol/mcp/session/manager.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1212

1313
from ...types import DatasetRow, MCPSession, MCPToolCall
14-
from ..execution.manager import ExecutionManager
14+
from ..client.connection import MCPConnectionManager
1515

1616
logger = logging.getLogger(__name__)
1717

@@ -44,7 +44,7 @@ def __init__(
4444
self.user_prompt_formatter = user_prompt_formatter or self._default_formatter
4545
self.n = len(sessions)
4646
self.tool_schemas = [] # Discovered from MCP servers
47-
self.execution_manager = ExecutionManager()
47+
self.connection_manager = MCPConnectionManager()
4848
self.usage_stats = {} # llm usage stats for monitoring
4949

5050
if len(sessions) != len(dataset_rows):
@@ -58,17 +58,14 @@ async def reset(self, session: MCPSession) -> Tuple[Any, List[Dict]]:
5858
5959
This is thread-safe and can be called from worker threads.
6060
"""
61-
# Establish a persistent session for each environment.
62-
await self.execution_manager.connection_manager.initialize_session(session)
63-
6461
# Get available tools from MCP server
65-
tool_schemas = await self.execution_manager.connection_manager.discover_tools(session)
62+
tool_schemas = await self.connection_manager.discover_tools(session)
6663

6764
if not self.tool_schemas:
6865
self.tool_schemas = tool_schemas
6966

7067
# PROPER MCP PATTERN: Get initial state from resources during session establishment
71-
initial_observation = await self.execution_manager.connection_manager.get_initial_state(session)
68+
initial_observation = await self.connection_manager.get_initial_state(session)
7269

7370
# Update session state
7471
session.terminated = False
@@ -119,7 +116,7 @@ async def step(self, env_index: int, tool_call: MCPToolCall) -> Tuple[Any, float
119116
)
120117

121118
# Execute the tool call via MCP protocol
122-
observation, reward, done, info = await self.execution_manager.connection_manager.call_tool(
119+
observation, reward, done, info = await self.connection_manager.call_tool(
123120
session, tool_call.tool_name, tool_call.arguments
124121
)
125122

@@ -223,5 +220,6 @@ def _default_formatter(self, template: str, obs: Any, context: Dict) -> Union[st
223220
async def close(self):
224221
"""Closes all MCP sessions."""
225222
print(f"🧹 Closing {self.n} MCP sessions...")
226-
await self.execution_manager.close_sessions(self.sessions)
223+
tasks = [self.connection_manager.close_session(session) for session in self.sessions]
224+
await asyncio.gather(*tasks)
227225
print(f"✅ All MCP sessions closed.")

eval_protocol/mcp_env.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
policy = ep.FireworksPolicy(model_id="accounts/fireworks/models/qwen3-235b-a22b")
1818
1919
# Create environments with evaluation_rows configuration
20-
envs = ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows)
20+
envs = await ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows)
2121
2222
# Execute tool-calling rollouts
2323
evaluation_rows = await ep.rollout(envs, policy=policy, steps=512)
@@ -51,18 +51,28 @@
5151
from .mcp.session.manager import GeneralMCPVectorEnv
5252
from .models import EvaluationRow
5353
from .types import DatasetRow, MCPSession, MCPToolCall
54+
import asyncio
5455

5556
logger = logging.getLogger(__name__)
5657

5758

58-
def make(
59+
async def reset_mcp_sessions(envs: GeneralMCPVectorEnv):
60+
"""
61+
Reset mcp server sessions
62+
"""
63+
tasks = [envs.connection_manager.reset_session(session) for session in envs.sessions]
64+
await asyncio.gather(*tasks)
65+
66+
67+
async def make(
5968
env_spec: str,
6069
evaluation_rows: Optional[List[EvaluationRow]] = None,
6170
dataset: Optional[List[Dict]] = None,
6271
n: Optional[int] = None,
6372
seeds: Optional[List[int]] = None,
6473
model_id: str = "unknown",
6574
user_prompt_formatter: Optional[Callable] = None,
75+
reset_sessions: bool = False,
6676
) -> GeneralMCPVectorEnv:
6777
"""
6878
Create general MCP environments driven by evaluation_rows configuration.
@@ -75,19 +85,20 @@ def make(
7585
seeds: List of seeds (for backward compatibility)
7686
model_id: Model identifier
7787
user_prompt_formatter: Optional callback for formatting user prompts
88+
reset_sessions: Whether to reset sessions before returning the environment
7889
7990
Returns:
8091
General MCP environment that works with any MCP server
8192
8293
Example:
8394
# EvaluationRow approach (preferred)
84-
envs = ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows)
95+
envs = await ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows)
8596
8697
# Dataset approach (backward compatibility)
87-
envs = ep.make("http://localhost:8000/mcp", dataset=dataset)
98+
envs = await ep.make("http://localhost:8000/mcp", dataset=dataset)
8899
89100
# Legacy approach (backward compatibility)
90-
envs = ep.make("http://localhost:8000/mcp", n=10, seeds=seeds)
101+
envs = await ep.make("http://localhost:8000/mcp", n=10, seeds=seeds)
91102
"""
92103
# Parse environment specification - make sure URL format is correct
93104
base_url = env_spec
@@ -160,8 +171,6 @@ def make(
160171
)
161172
sessions.append(session)
162173

163-
return GeneralMCPVectorEnv(sessions, dataset_rows, user_prompt_formatter)
164-
165174
else:
166175
# Legacy approach for backward compatibility
167176
if n is None:
@@ -198,7 +207,14 @@ def make(
198207
)
199208
sessions.append(session)
200209

201-
return GeneralMCPVectorEnv(sessions, dataset_rows, user_prompt_formatter)
210+
mcp_envs = GeneralMCPVectorEnv(sessions, dataset_rows, user_prompt_formatter)
211+
tasks = [mcp_envs.connection_manager.initialize_session(session) for session in sessions]
212+
await asyncio.gather(*tasks)
213+
214+
if reset_sessions:
215+
await reset_mcp_sessions(mcp_envs)
216+
217+
return mcp_envs
202218

203219

204220
async def rollout(
@@ -266,7 +282,7 @@ async def rollout(
266282
raise ValueError("Either 'evaluation_rows' or 'dataset' must be provided when envs is a URL")
267283

268284
auto_model_id = model_id or getattr(policy, "model_id", "unknown")
269-
envs = make(envs, evaluation_rows=evaluation_rows, dataset=dataset, model_id=auto_model_id)
285+
envs = await make(envs, evaluation_rows=evaluation_rows, dataset=dataset, model_id=auto_model_id)
270286

271287
# Use the new ExecutionManager for execution
272288
execution_manager = ExecutionManager()

eval_protocol/pytest/default_mcp_gym_rollout_processor.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -182,49 +182,47 @@ def __exit__(self, exc_type, exc_val, exc_tb):
182182
return False # Don't suppress exceptions
183183

184184

185-
186-
async def default_mcp_gym_rollout_processor(rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[EvaluationRow]:
185+
async def default_mcp_gym_rollout_processor(
186+
rows: List[EvaluationRow], config: RolloutProcessorConfig
187+
) -> List[EvaluationRow]:
187188
"""
188189
Rollout processor for tau bench environments.
189-
190+
190191
This processor starts an MCP server, creates tau bench environments, and runs rollouts
191192
using the eval_protocol framework, following the pattern from test_tau2_e2e.py.
192-
193+
193194
Args:
194195
rows: List of EvaluationRow objects containing messages and dataset info in input_metadata
195196
config: RolloutProcessorConfig with model and other parameters
196-
197+
197198
Returns:
198199
List of EvaluationRow objects with completed conversations
199200
"""
200201
server = MCPServerManager(config.server_script_path, port=9700)
201-
202+
202203
try:
203204
server.start()
204-
205+
205206
policy = ep.LiteLLMPolicy(
206207
model_id=config.model,
207-
temperature=config.input_params.get('temperature', 0.0),
208-
max_tokens=config.input_params.get('max_tokens', 4096),
208+
temperature=config.input_params.get("temperature", 0.0),
209+
max_tokens=config.input_params.get("max_tokens", 4096),
209210
)
210-
211+
211212
# Create MCP environments directly from evaluation_rows
212-
envs = ep.make(
213-
'http://localhost:9700/mcp/',
213+
envs = await ep.make(
214+
"http://localhost:9700/mcp/",
214215
evaluation_rows=rows,
215216
model_id=policy.model_id,
216217
)
217-
218+
218219
# Run rollout with environments and policy
219220
evaluation_rows = await ep.rollout(
220-
envs,
221-
policy=policy,
222-
steps=config.steps,
223-
max_concurrent_rollouts=config.max_concurrent_rollouts
221+
envs, policy=policy, steps=config.steps, max_concurrent_rollouts=config.max_concurrent_rollouts
224222
)
225-
223+
226224
return evaluation_rows
227-
225+
228226
finally:
229227
# Always clean up the server
230228
server.stop()

eval_protocol/types/types.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from dataclasses import dataclass, field
22
from enum import Enum
33
from typing import Any, Dict, List, Optional
4+
from mcp.client.session import ClientSession
5+
from contextlib import AsyncExitStack
46

57

68
class TerminationReason(str, Enum):
@@ -50,8 +52,8 @@ class MCPSession:
5052
last_observation: Any = None
5153

5254
# Persistent MCP connection components
53-
_exit_stack: Optional[Any] = None
54-
_mcp_session: Optional[Any] = None
55+
_exit_stack: Optional[AsyncExitStack] = None
56+
_mcp_session: Optional[ClientSession] = None
5557

5658

5759
@dataclass

0 commit comments

Comments
 (0)