Skip to content

Commit d79b0cf

Browse files
committed
support reset mcp session
1 parent a908847 commit d79b0cf

File tree

19 files changed

+142
-99
lines changed

19 files changed

+142
-99
lines changed

eval_protocol/mcp/client/connection.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
from contextlib import AsyncExitStack
1313
from typing import Any, Dict, List, Optional, Tuple
1414

15+
from mcp.types import EmptyResult
1516
from mcp.client.session import ClientSession
1617
from mcp.client.streamable_http import streamablehttp_client
18+
from pydantic import BaseModel
1719

1820
from ...types import MCPSession
1921

@@ -101,7 +103,7 @@ async def initialize_session(self, session: MCPSession) -> None:
101103

102104
# Update the session ID to match what the server generated
103105
session.session_id = server_session_id
104-
logger.debug(f"Updated session ID to match server: {server_session_id}")
106+
logger.info(f"Updated session ID to match server: {server_session_id}")
105107

106108
# PRE-WARM: Discover and cache tools immediately after session initialization
107109
# This prevents concurrent list_tools() calls later
@@ -133,6 +135,24 @@ async def _prewarm_tools_cache(self, session: MCPSession) -> None:
133135
self._tools_cache[cache_key] = tool_schemas
134136
logger.debug(f"✅ PRE-WARMED {len(tool_schemas)} tools for{cache_key}")
135137

138+
async def reset_session(self, session: MCPSession) -> None:
139+
"""
140+
Clean session data in remote mcp server for the given session
141+
"""
142+
import httpx
143+
144+
base_url = session.base_url.rstrip("/").removesuffix("/mcp")
145+
url = f"{base_url}/control/reset_session"
146+
147+
headers = {"mcp-session-id": session.session_id}
148+
body = {"seed": session.seed}
149+
150+
timeout = httpx.Timeout(3.0)
151+
async with httpx.AsyncClient(timeout=timeout) as client:
152+
resp = await client.post(url, headers=headers, json=body)
153+
resp.raise_for_status()
154+
logger.debug(f"Session {session.session_id}: reset_session -> {resp.json()}")
155+
136156
async def discover_tools(self, session: MCPSession) -> List[Dict]:
137157
"""
138158
Discover available tools from an MCP session.
@@ -422,6 +442,8 @@ async def call_tool(self, session: MCPSession, tool_name: str, arguments: Dict)
422442
mcp_session = session._mcp_session
423443

424444
# 1. Execute the tool call via MCP protocol (DATA PLANE)
445+
print("session.session_id", session._mcp_session)
446+
print(session._mcp_session._write_stream._closed)
425447
tool_result = await mcp_session.call_tool(tool_name, arguments)
426448

427449
# Extract data plane results (observation only)

eval_protocol/mcp/mcpgym.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -230,14 +230,23 @@ def _get_or_create_session(self, ctx: Context) -> Dict[str, Any]:
230230
def _register_session_reset_endpoint(self):
231231

232232
@self.mcp.custom_route("/control/reset_session", methods=["POST"])
233-
async def reset_session_endpoint(request: Request, ctx: Context) -> JSONResponse:
233+
async def reset_session_endpoint(request: Request) -> JSONResponse:
234234
session_id = request.headers.get("mcp-session-id")
235+
body = await request.json()
236+
seed = body.get("seed", None)
237+
print(f"🔍 _register_session_reset_endpoint: Resetting session, session_id: {session_id}, seed: {seed}")
235238
if not session_id:
236239
return JSONResponse({"error": "Missing mcp-session-id header"}, status_code=400)
237240
with self.session_lock:
238241
if session_id in self.sessions:
239-
del self.sessions[session_id]
240-
self.sessions[session_id] = self._get_or_create_session(ctx)
242+
env, obs, _ = self._new_env(seed=seed)
243+
self.sessions[session_id] = {
244+
"env": env,
245+
"obs": obs,
246+
"session_data": {},
247+
"session_id": session_id,
248+
}
249+
print(f"🔍 _register_session_reset_endpoint: Finished reset session, session_id: {session_id}")
241250
return JSONResponse({"message": "Session reset successfully"})
242251

243252
def _discover_and_register_control_plane_endpoints(self):
@@ -336,7 +345,7 @@ def _update_control_plane(self, reward: float, terminated: bool, truncated: bool
336345

337346
# Log control plane update (for debugging)
338347
print(
339-
f"🎛️ Control plane updated: reward={reward}, terminated={terminated}, step={self.control_plane_state['step_count']}"
348+
f"🎛️ Control plane updated: reward={reward}, terminated={terminated}, step={self.control_plane_state['step_count']}, total_reward={self.control_plane_state['total_reward']}"
340349
)
341350

342351
def _get_or_create_session_control_plane(self, session_id: str) -> Dict[str, Any]:
@@ -378,7 +387,7 @@ def _update_session_control_plane(
378387

379388
# Log control plane update
380389
print(
381-
f"🎛️ Session {session_id[:16]}... control plane: reward={reward}, terminated={terminated}, step={control_plane['step_count']}"
390+
f"🎛️ Session {session_id[:16]}... control plane: reward={reward}, terminated={terminated}, step={control_plane['step_count']}, total_reward={control_plane['total_reward']}"
382391
)
383392

384393
def get_control_plane_state(self, session_id: str) -> Optional[Dict[str, Any]]:

eval_protocol/mcp_env.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
policy = ep.FireworksPolicy(model_id="accounts/fireworks/models/qwen3-235b-a22b")
1818
1919
# Create environments with evaluation_rows configuration
20-
envs = ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows)
20+
envs = await await ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows)
2121
2222
# Execute tool-calling rollouts
2323
evaluation_rows = await ep.rollout(envs, policy=policy, steps=512)
@@ -56,14 +56,23 @@
5656
logger = logging.getLogger(__name__)
5757

5858

59-
def make(
59+
async def reset_mcp_sessions(envs: GeneralMCPVectorEnv):
60+
"""
61+
Reset mcp server sessions
62+
"""
63+
tasks = [envs.connection_manager.reset_session(session) for session in envs.sessions]
64+
await asyncio.gather(*tasks)
65+
66+
67+
async def make(
6068
env_spec: str,
6169
evaluation_rows: Optional[List[EvaluationRow]] = None,
6270
dataset: Optional[List[Dict]] = None,
6371
n: Optional[int] = None,
6472
seeds: Optional[List[int]] = None,
6573
model_id: str = "unknown",
6674
user_prompt_formatter: Optional[Callable] = None,
75+
reset_sessions: bool = False,
6776
) -> GeneralMCPVectorEnv:
6877
"""
6978
Create general MCP environments driven by evaluation_rows configuration.
@@ -76,19 +85,20 @@ def make(
7685
seeds: List of seeds (for backward compatibility)
7786
model_id: Model identifier
7887
user_prompt_formatter: Optional callback for formatting user prompts
88+
reset_sessions: Whether to reset sessions before returning the environment
7989
8090
Returns:
8191
General MCP environment that works with any MCP server
8292
8393
Example:
8494
# EvaluationRow approach (preferred)
85-
envs = ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows)
95+
envs = await await ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows)
8696
8797
# Dataset approach (backward compatibility)
88-
envs = ep.make("http://localhost:8000/mcp", dataset=dataset)
98+
envs = await await ep.make("http://localhost:8000/mcp", dataset=dataset)
8999
90100
# Legacy approach (backward compatibility)
91-
envs = ep.make("http://localhost:8000/mcp", n=10, seeds=seeds)
101+
envs = await await ep.make("http://localhost:8000/mcp", n=10, seeds=seeds)
92102
"""
93103
# Parse environment specification - make sure URL format is correct
94104
base_url = env_spec
@@ -161,8 +171,6 @@ def make(
161171
)
162172
sessions.append(session)
163173

164-
return GeneralMCPVectorEnv(sessions, dataset_rows, user_prompt_formatter)
165-
166174
else:
167175
# Legacy approach for backward compatibility
168176
if n is None:
@@ -199,10 +207,14 @@ def make(
199207
)
200208
sessions.append(session)
201209

202-
mcp_envs = GeneralMCPVectorEnv(sessions, dataset_rows, user_prompt_formatter)
203-
tasks = [mcp_envs.connection_manager.initialize_session(session) for session in sessions]
204-
asyncio.run(asyncio.gather(*tasks))
205-
return mcp_envs
210+
mcp_envs = GeneralMCPVectorEnv(sessions, dataset_rows, user_prompt_formatter)
211+
tasks = [mcp_envs.connection_manager.initialize_session(session) for session in sessions]
212+
await asyncio.gather(*tasks)
213+
214+
if reset_sessions:
215+
await reset_mcp_sessions(mcp_envs)
216+
217+
return mcp_envs
206218

207219

208220
async def rollout(

eval_protocol/pytest/default_mcp_gym_rollout_processor.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -182,49 +182,47 @@ def __exit__(self, exc_type, exc_val, exc_tb):
182182
return False # Don't suppress exceptions
183183

184184

185-
186-
async def default_mcp_gym_rollout_processor(rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[EvaluationRow]:
185+
async def default_mcp_gym_rollout_processor(
186+
rows: List[EvaluationRow], config: RolloutProcessorConfig
187+
) -> List[EvaluationRow]:
187188
"""
188189
Rollout processor for tau bench environments.
189-
190+
190191
This processor starts an MCP server, creates tau bench environments, and runs rollouts
191192
using the eval_protocol framework, following the pattern from test_tau2_e2e.py.
192-
193+
193194
Args:
194195
rows: List of EvaluationRow objects containing messages and dataset info in input_metadata
195196
config: RolloutProcessorConfig with model and other parameters
196-
197+
197198
Returns:
198199
List of EvaluationRow objects with completed conversations
199200
"""
200201
server = MCPServerManager(config.server_script_path, port=9700)
201-
202+
202203
try:
203204
server.start()
204-
205+
205206
policy = ep.LiteLLMPolicy(
206207
model_id=config.model,
207-
temperature=config.input_params.get('temperature', 0.0),
208-
max_tokens=config.input_params.get('max_tokens', 4096),
208+
temperature=config.input_params.get("temperature", 0.0),
209+
max_tokens=config.input_params.get("max_tokens", 4096),
209210
)
210-
211+
211212
# Create MCP environments directly from evaluation_rows
212-
envs = ep.make(
213-
'http://localhost:9700/mcp/',
213+
envs = await ep.make(
214+
"http://localhost:9700/mcp/",
214215
evaluation_rows=rows,
215216
model_id=policy.model_id,
216217
)
217-
218+
218219
# Run rollout with environments and policy
219220
evaluation_rows = await ep.rollout(
220-
envs,
221-
policy=policy,
222-
steps=config.steps,
223-
max_concurrent_rollouts=config.max_concurrent_rollouts
221+
envs, policy=policy, steps=config.steps, max_concurrent_rollouts=config.max_concurrent_rollouts
224222
)
225-
223+
226224
return evaluation_rows
227-
225+
228226
finally:
229227
# Always clean up the server
230228
server.stop()

eval_protocol/types/types.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from dataclasses import dataclass, field
22
from enum import Enum
33
from typing import Any, Dict, List, Optional
4+
from mcp.client.session import ClientSession
5+
from contextlib import AsyncExitStack
46

57

68
class TerminationReason(str, Enum):
@@ -50,8 +52,8 @@ class MCPSession:
5052
last_observation: Any = None
5153

5254
# Persistent MCP connection components
53-
_exit_stack: Optional[Any] = None
54-
_mcp_session: Optional[Any] = None
55+
_exit_stack: Optional[AsyncExitStack] = None
56+
_mcp_session: Optional[ClientSession] = None
5557

5658

5759
@dataclass

examples/blackjack_mcp/tests/test_record_and_replay_e2e.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ async def test_production_server_record_and_replay(production_server, blackjack_
215215
assert playback_policy.is_playback_mode(), "Should be in playback mode in CI"
216216

217217
# Create environments for playback
218-
playback_envs = ep.make(
218+
playback_envs = await ep.make(
219219
"http://localhost:9500/mcp/",
220220
dataset=blackjack_dataset,
221221
model_id=playback_policy.model_id,
@@ -250,7 +250,7 @@ async def test_production_server_record_and_replay(production_server, blackjack_
250250
assert not policy.is_playback_mode(), "Should be in recording mode initially"
251251

252252
# Create environments
253-
envs = ep.make(
253+
envs = await ep.make(
254254
"http://localhost:9500/mcp/",
255255
dataset=blackjack_dataset,
256256
model_id=policy.model_id,
@@ -310,7 +310,7 @@ async def test_production_server_record_and_replay(production_server, blackjack_
310310
assert playback_policy.is_playback_mode(), "Should be in playback mode"
311311

312312
# Create new environments for playback
313-
playback_envs = ep.make(
313+
playback_envs = await ep.make(
314314
"http://localhost:9500/mcp/",
315315
dataset=blackjack_dataset,
316316
model_id=playback_policy.model_id,
@@ -462,7 +462,7 @@ async def test_blackjack_step_by_step(conda_isolation_recording_file):
462462
]
463463

464464
# Create environment pointing to conda-isolated server
465-
envs = ep.make(
465+
envs = await ep.make(
466466
f"http://localhost:{port}/mcp/",
467467
dataset=test_dataset,
468468
model_id=policy.model_id,
@@ -570,7 +570,7 @@ async def test_multi_environment_sessions(multi_env_dataset, multi_env_recording
570570
policy = create_blackjack_static_policy(action_sequence=["HIT", "HIT", "STICK"])
571571

572572
# Create multiple environments
573-
envs = ep.make(
573+
envs = await ep.make(
574574
f"http://localhost:{server.port}/mcp/",
575575
dataset=multi_env_dataset,
576576
model_id=policy.model_id,
@@ -992,7 +992,7 @@ async def test_fireworks_multi_environment_sessions(multi_env_dataset, fireworks
992992
assert playback_policy.is_playback_mode(), "Should be in playback mode in CI"
993993

994994
# Create environments for playback
995-
playback_envs = ep.make(
995+
playback_envs = await ep.make(
996996
"http://localhost:9500/mcp/",
997997
dataset=multi_env_dataset,
998998
model_id=playback_policy.model_id,
@@ -1033,7 +1033,7 @@ async def test_fireworks_multi_environment_sessions(multi_env_dataset, fireworks
10331033
assert not policy.is_playback_mode(), "Should be in recording mode initially"
10341034

10351035
# Create multiple environments
1036-
envs = ep.make(
1036+
envs = await ep.make(
10371037
f"http://localhost:{server.port}/mcp/",
10381038
dataset=multi_env_dataset,
10391039
model_id=policy.model_id,
@@ -1149,7 +1149,7 @@ async def test_control_plane_state_querying(multi_env_dataset):
11491149
policy = create_blackjack_static_policy(action_sequence=["HIT", "STAND"])
11501150

11511151
# Create environments
1152-
envs = ep.make(
1152+
envs = await ep.make(
11531153
f"http://localhost:{server.port}/mcp/",
11541154
dataset=multi_env_dataset[:2], # Use only 2 environments for faster testing
11551155
model_id=policy.model_id,

0 commit comments

Comments
 (0)