Skip to content

Commit 546ba39

Browse files
committed
close with final assistant message
2 parents a80c6e9 + b5d242c commit 546ba39

File tree

6 files changed

+63
-52
lines changed

6 files changed

+63
-52
lines changed

eval_protocol/mcp/client/connection.py

Lines changed: 11 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from mcp.client.streamable_http import streamablehttp_client
1717

1818
from ...types import MCPSession
19+
from mcp.types import Implementation
1920

2021
logger = logging.getLogger(__name__)
2122

@@ -50,19 +51,16 @@ async def initialize_session(self, session: MCPSession) -> None:
5051

5152
exit_stack = AsyncExitStack()
5253

53-
client_info = None
54-
if session.seed is not None or (session.dataset_row and session.dataset_row.environment_context):
55-
from mcp.types import Implementation
56-
57-
client_info = Implementation(name="reward-kit", version="1.0.0", _extra={})
58-
if session.seed is not None:
59-
client_info._extra["seed"] = session.seed
60-
if session.dataset_row and session.dataset_row.environment_context:
61-
client_info._extra["config"] = session.dataset_row.environment_context
62-
if session.dataset_row and session.dataset_row.id:
63-
client_info._extra["dataset_row_id"] = session.dataset_row.id
64-
if session.model_id:
65-
client_info._extra["model_id"] = session.model_id
54+
client_info = Implementation(name="reward-kit", version="1.0.0", _extra={})
55+
client_info._extra["session_id"] = session.session_id
56+
if session.seed is not None:
57+
client_info._extra["seed"] = session.seed
58+
if session.dataset_row and session.dataset_row.environment_context:
59+
client_info._extra["config"] = session.dataset_row.environment_context
60+
if session.dataset_row and session.dataset_row.id:
61+
client_info._extra["dataset_row_id"] = session.dataset_row.id
62+
if session.model_id:
63+
client_info._extra["model_id"] = session.model_id
6664

6765
read_stream, write_stream, _ = await exit_stack.enter_async_context(
6866
streamablehttp_client(session.base_url, terminate_on_close=True)
@@ -77,32 +75,6 @@ async def initialize_session(self, session: MCPSession) -> None:
7775
session._mcp_session = mcp_session
7876
session._exit_stack = exit_stack
7977

80-
# Update session ID to match server's calculation (for control plane sync)
81-
if client_info and hasattr(client_info, "_extra"):
82-
extra_data = client_info._extra
83-
if extra_data and isinstance(extra_data, dict):
84-
85-
seed_value = extra_data.get("seed")
86-
config_value = extra_data.get("config", {})
87-
dataset_row_id_value = extra_data.get("dataset_row_id")
88-
model_id_value = extra_data.get("model_id")
89-
90-
stable_data = {
91-
"seed": seed_value,
92-
"config": config_value,
93-
"dataset_row_id": dataset_row_id_value,
94-
"model_id": model_id_value,
95-
"name": client_info.name,
96-
"version": client_info.version,
97-
}
98-
99-
stable_str = json.dumps(stable_data, sort_keys=True)
100-
server_session_id = hashlib.md5(stable_str.encode()).hexdigest()
101-
102-
# Update the session ID to match what the server generated
103-
session.session_id = server_session_id
104-
logger.info(f"Updated session ID to match server: {server_session_id}")
105-
10678
# PRE-WARM: Discover and cache tools immediately after session initialization
10779
# This prevents concurrent list_tools() calls later
10880
await self._prewarm_tools_cache(session)

eval_protocol/mcp/execution/manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ async def _execute_with_semaphore(idx):
104104

105105
tasks = [_execute_with_semaphore(i) for i in range(envs.n)]
106106
# exceptions should be try catched inside single _execute_rollout
107+
# exceptions should be try catched inside single _execute_rollout
107108
trajectories = await asyncio.gather(*tasks)
108109

109110
# Calculate durations
@@ -386,6 +387,9 @@ async def _execute_rollout(
386387
# Log conversation state for playback if in recording mode
387388
if recording_mode:
388389
policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history)
390+
# Log conversation state for playback if in recording mode
391+
if recording_mode:
392+
policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history)
389393

390394
# Use control plane information for termination decision
391395
if rollout_end:

eval_protocol/mcp/mcpgym.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,12 @@ def _get_session_id(self, ctx: Context) -> str:
146146
print(f"🔍 _get_session_id: extra_data type: {type(extra_data)}")
147147

148148
if extra_data and isinstance(extra_data, dict):
149-
# Create a stable session ID based on seed and other config
149+
# use the client generated session id
150+
if "session_id" in extra_data:
151+
print(f"🔍 _get_session_id: using client generated session_id: {extra_data['session_id']}")
152+
return extra_data["session_id"]
153+
154+
# fallback to create a stable session ID based on seed and other config
150155
seed_value = extra_data.get("seed")
151156
config_value = extra_data.get("config", {})
152157
dataset_row_id_value = extra_data.get("dataset_row_id")

eval_protocol/mcp/session/manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,9 @@ def _default_formatter(self, template: str, obs: Any, context: Dict) -> Union[st
219219

220220
async def close(self):
221221
"""Closes all MCP sessions."""
222+
print(f"🧹 Resetting {self.n} MCP sessions in MCP server...")
223+
cleanup_tasks = [self.connection_manager.reset_session(session) for session in self.sessions]
224+
await asyncio.gather(*cleanup_tasks)
222225
print(f"🧹 Closing {self.n} MCP sessions...")
223226
tasks = [self.connection_manager.close_session(session) for session in self.sessions]
224227
await asyncio.gather(*tasks)

eval_protocol/mcp_env.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,34 @@
5353
from .mcp.session.manager import GeneralMCPVectorEnv
5454
from .models import EvaluationRow
5555
from .types import DatasetRow, MCPSession, MCPToolCall
56+
import asyncio
57+
import hashlib
58+
import json
5659

5760
logger = logging.getLogger(__name__)
5861

5962

63+
def gen_session_id(dataset_row: DatasetRow, model_id: str) -> str:
64+
"""
65+
Generate a session ID for a dataset row
66+
"""
67+
seed_value = dataset_row.seed
68+
config_value = dataset_row.environment_context
69+
dataset_row_id_value = dataset_row.id
70+
model_id_value = model_id
71+
72+
stable_data = {
73+
"seed": seed_value,
74+
"config": config_value,
75+
"dataset_row_id": dataset_row_id_value,
76+
"model_id": model_id_value,
77+
}
78+
79+
stable_str = json.dumps(stable_data, sort_keys=True)
80+
81+
return hashlib.md5(stable_str.encode()).hexdigest()
82+
83+
6084
async def reset_mcp_sessions(envs: GeneralMCPVectorEnv):
6185
"""
6286
Reset mcp server sessions
@@ -162,9 +186,10 @@ async def make(
162186

163187
dataset_rows.append(dataset_row)
164188

189+
session_id = gen_session_id(dataset_row, model_id)
165190
# Create MCP session
166191
session = MCPSession(
167-
session_id=dataset_row.id,
192+
session_id=session_id,
168193
base_url=base_url,
169194
seed=dataset_row.seed,
170195
model_id=model_id,
@@ -198,9 +223,11 @@ async def make(
198223
)
199224
dataset_rows.append(dataset_row)
200225

226+
session_id = gen_session_id(dataset_row, model_id)
227+
201228
# Create MCP session
202229
session = MCPSession(
203-
session_id=f"session_{i}",
230+
session_id=session_id,
204231
base_url=base_url,
205232
seed=seeds[i],
206233
model_id=model_id,

tests/pytest/test_frozen_lake.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
similar to the test_frozen_lake_e2e test but integrated with the pytest evaluation system.
66
"""
77

8-
98
from typing import Any, Dict, List
109

1110
from eval_protocol.models import EvaluateResult, EvaluationRow, Message, InputMetadata, CompletionParams, MetricResult
@@ -18,7 +17,7 @@ def frozen_lake_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluation
1817
Convert entries from frozen lake dataset to EvaluationRow objects.
1918
"""
2019
rows = []
21-
20+
2221
for row in data:
2322
eval_row = EvaluationRow(
2423
messages=[Message(role="system", content=row["system_prompt"])],
@@ -27,14 +26,15 @@ def frozen_lake_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluation
2726
dataset_info={
2827
"environment_context": row["environment_context"],
2928
"user_prompt_template": row["user_prompt_template"],
30-
}
31-
)
29+
},
30+
),
3231
)
33-
32+
3433
rows.append(eval_row)
35-
34+
3635
return rows
3736

37+
3838
@evaluation_test(
3939
input_dataset=["tests/pytest/data/frozen_lake_dataset.jsonl"],
4040
dataset_adapter=frozen_lake_to_evaluation_row,
@@ -50,13 +50,13 @@ def frozen_lake_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluation
5050
def test_frozen_lake_evaluation(row: EvaluationRow) -> EvaluationRow:
5151
"""
5252
Test frozen lake evaluation using the pytest framework.
53-
53+
5454
This test evaluates how well the model can navigate the FrozenLake environment
5555
by checking if it successfully reaches the goal while avoiding holes.
56-
56+
5757
Args:
5858
row: EvaluationRow object from frozen lake dataset
59-
59+
6060
Returns:
6161
EvaluationRow object with evaluation results
6262
"""
@@ -71,5 +71,5 @@ def test_frozen_lake_evaluation(row: EvaluationRow) -> EvaluationRow:
7171
score=score,
7272
reason=reason,
7373
)
74-
74+
7575
return row

0 commit comments

Comments
 (0)