Skip to content

Commit 986452f

Browse files
authored
Catch rollout exception + end multi turn with assistant messsage (#40)
* gen session_id from client * catch error catch error add * catch error catch error add * add final assistant response * update * add * record rollout status * fix ut
1 parent c35d7f0 commit 986452f

File tree

9 files changed

+330
-273
lines changed

9 files changed

+330
-273
lines changed

eval_protocol/mcp/client/connection.py

Lines changed: 11 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from mcp.client.streamable_http import streamablehttp_client
1717

1818
from ...types import MCPSession
19+
from mcp.types import Implementation
1920

2021
logger = logging.getLogger(__name__)
2122

@@ -50,19 +51,16 @@ async def initialize_session(self, session: MCPSession) -> None:
5051

5152
exit_stack = AsyncExitStack()
5253

53-
client_info = None
54-
if session.seed is not None or (session.dataset_row and session.dataset_row.environment_context):
55-
from mcp.types import Implementation
56-
57-
client_info = Implementation(name="reward-kit", version="1.0.0", _extra={})
58-
if session.seed is not None:
59-
client_info._extra["seed"] = session.seed
60-
if session.dataset_row and session.dataset_row.environment_context:
61-
client_info._extra["config"] = session.dataset_row.environment_context
62-
if session.dataset_row and session.dataset_row.id:
63-
client_info._extra["dataset_row_id"] = session.dataset_row.id
64-
if session.model_id:
65-
client_info._extra["model_id"] = session.model_id
54+
client_info = Implementation(name="reward-kit", version="1.0.0", _extra={})
55+
client_info._extra["session_id"] = session.session_id
56+
if session.seed is not None:
57+
client_info._extra["seed"] = session.seed
58+
if session.dataset_row and session.dataset_row.environment_context:
59+
client_info._extra["config"] = session.dataset_row.environment_context
60+
if session.dataset_row and session.dataset_row.id:
61+
client_info._extra["dataset_row_id"] = session.dataset_row.id
62+
if session.model_id:
63+
client_info._extra["model_id"] = session.model_id
6664

6765
read_stream, write_stream, _ = await exit_stack.enter_async_context(
6866
streamablehttp_client(session.base_url, terminate_on_close=True)
@@ -77,32 +75,6 @@ async def initialize_session(self, session: MCPSession) -> None:
7775
session._mcp_session = mcp_session
7876
session._exit_stack = exit_stack
7977

80-
# Update session ID to match server's calculation (for control plane sync)
81-
if client_info and hasattr(client_info, "_extra"):
82-
extra_data = client_info._extra
83-
if extra_data and isinstance(extra_data, dict):
84-
85-
seed_value = extra_data.get("seed")
86-
config_value = extra_data.get("config", {})
87-
dataset_row_id_value = extra_data.get("dataset_row_id")
88-
model_id_value = extra_data.get("model_id")
89-
90-
stable_data = {
91-
"seed": seed_value,
92-
"config": config_value,
93-
"dataset_row_id": dataset_row_id_value,
94-
"model_id": model_id_value,
95-
"name": client_info.name,
96-
"version": client_info.version,
97-
}
98-
99-
stable_str = json.dumps(stable_data, sort_keys=True)
100-
server_session_id = hashlib.md5(stable_str.encode()).hexdigest()
101-
102-
# Update the session ID to match what the server generated
103-
session.session_id = server_session_id
104-
logger.info(f"Updated session ID to match server: {server_session_id}")
105-
10678
# PRE-WARM: Discover and cache tools immediately after session initialization
10779
# This prevents concurrent list_tools() calls later
10880
await self._prewarm_tools_cache(session)

eval_protocol/mcp/execution/manager.py

Lines changed: 249 additions & 220 deletions
Large diffs are not rendered by default.

eval_protocol/mcp/mcpgym.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,12 @@ def _get_session_id(self, ctx: Context) -> str:
146146
print(f"🔍 _get_session_id: extra_data type: {type(extra_data)}")
147147

148148
if extra_data and isinstance(extra_data, dict):
149-
# Create a stable session ID based on seed and other config
149+
# use the client generated session id
150+
if "session_id" in extra_data:
151+
print(f"🔍 _get_session_id: using client generated session_id: {extra_data['session_id']}")
152+
return extra_data["session_id"]
153+
154+
# fallback to create a stable session ID based on seed and other config
150155
seed_value = extra_data.get("seed")
151156
config_value = extra_data.get("config", {})
152157
dataset_row_id_value = extra_data.get("dataset_row_id")

eval_protocol/mcp/session/manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,9 @@ def _default_formatter(self, template: str, obs: Any, context: Dict) -> Union[st
219219

220220
async def close(self):
221221
"""Closes all MCP sessions."""
222+
print(f"🧹 Resetting {self.n} MCP sessions in MCP server...")
223+
cleanup_tasks = [self.connection_manager.reset_session(session) for session in self.sessions]
224+
await asyncio.gather(*cleanup_tasks)
222225
print(f"🧹 Closing {self.n} MCP sessions...")
223226
tasks = [self.connection_manager.close_session(session) for session in self.sessions]
224227
await asyncio.gather(*tasks)

eval_protocol/mcp_env.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,34 @@
5353
from .mcp.session.manager import GeneralMCPVectorEnv
5454
from .models import EvaluationRow
5555
from .types import DatasetRow, MCPSession, MCPToolCall
56+
import asyncio
57+
import hashlib
58+
import json
5659

5760
logger = logging.getLogger(__name__)
5861

5962

63+
def gen_session_id(dataset_row: DatasetRow, model_id: str) -> str:
64+
"""
65+
Generate a session ID for a dataset row
66+
"""
67+
seed_value = dataset_row.seed
68+
config_value = dataset_row.environment_context
69+
dataset_row_id_value = dataset_row.id
70+
model_id_value = model_id
71+
72+
stable_data = {
73+
"seed": seed_value,
74+
"config": config_value,
75+
"dataset_row_id": dataset_row_id_value,
76+
"model_id": model_id_value,
77+
}
78+
79+
stable_str = json.dumps(stable_data, sort_keys=True)
80+
81+
return hashlib.md5(stable_str.encode()).hexdigest()
82+
83+
6084
async def reset_mcp_sessions(envs: GeneralMCPVectorEnv):
6185
"""
6286
Reset mcp server sessions
@@ -162,9 +186,10 @@ async def make(
162186

163187
dataset_rows.append(dataset_row)
164188

189+
session_id = gen_session_id(dataset_row, model_id)
165190
# Create MCP session
166191
session = MCPSession(
167-
session_id=dataset_row.id,
192+
session_id=session_id,
168193
base_url=base_url,
169194
seed=dataset_row.seed,
170195
model_id=model_id,
@@ -198,9 +223,11 @@ async def make(
198223
)
199224
dataset_rows.append(dataset_row)
200225

226+
session_id = gen_session_id(dataset_row, model_id)
227+
201228
# Create MCP session
202229
session = MCPSession(
203-
session_id=f"session_{i}",
230+
session_id=session_id,
204231
base_url=base_url,
205232
seed=seeds[i],
206233
model_id=model_id,

eval_protocol/models.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,21 @@ class EvalMetadata(BaseModel):
220220
passed: Optional[bool] = Field(None, description="Whether the evaluation passed based on the threshold")
221221

222222

223+
class RolloutStatus(BaseModel):
224+
"""Status of the rollout."""
225+
226+
"""
227+
running: Unfinished rollout which is still in progress.
228+
finished: Rollout finished successfully.
229+
error: Rollout failed.
230+
stopped: Rollout terminated unexpectedly (e.g. max step, control plane signal, user stop).
231+
"""
232+
status: Literal["running", "finished", "error", "stopped"] = Field(
233+
"finished", description="Status of the rollout."
234+
)
235+
error_message: Optional[str] = Field(None, description="Error message if the rollout failed.")
236+
237+
223238
class EvaluationRow(BaseModel):
224239
"""
225240
Unified data structure for a single evaluation unit that contains messages,
@@ -244,6 +259,11 @@ class EvaluationRow(BaseModel):
244259
description="Metadata related to the input (dataset info, model config, session data, etc.).",
245260
)
246261

262+
rollout_status: RolloutStatus = Field(
263+
default_factory=RolloutStatus,
264+
description="The status of the rollout.",
265+
)
266+
247267
# Ground truth reference (moved from EvaluateResult to top level)
248268
ground_truth: Optional[str] = Field(
249269
default=None, description="Optional ground truth reference for this evaluation."

eval_protocol/types/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class TerminationReason(str, Enum):
1616
MAX_STEPS = "max_steps"
1717
CONTROL_PLANE_SIGNAL = "control_plane_signal"
1818
USER_STOP = "user_stop"
19+
ERROR = "error"
1920

2021

2122
@dataclass

tests/pytest/test_frozen_lake.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
similar to the test_frozen_lake_e2e test but integrated with the pytest evaluation system.
66
"""
77

8-
98
from typing import Any, Dict, List
109

1110
from eval_protocol.models import EvaluateResult, EvaluationRow, Message, InputMetadata, CompletionParams, MetricResult
@@ -18,7 +17,7 @@ def frozen_lake_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluation
1817
Convert entries from frozen lake dataset to EvaluationRow objects.
1918
"""
2019
rows = []
21-
20+
2221
for row in data:
2322
eval_row = EvaluationRow(
2423
messages=[Message(role="system", content=row["system_prompt"])],
@@ -27,14 +26,15 @@ def frozen_lake_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluation
2726
dataset_info={
2827
"environment_context": row["environment_context"],
2928
"user_prompt_template": row["user_prompt_template"],
30-
}
31-
)
29+
},
30+
),
3231
)
33-
32+
3433
rows.append(eval_row)
35-
34+
3635
return rows
3736

37+
3838
@evaluation_test(
3939
input_dataset=["tests/pytest/data/frozen_lake_dataset.jsonl"],
4040
dataset_adapter=frozen_lake_to_evaluation_row,
@@ -50,13 +50,13 @@ def frozen_lake_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluation
5050
def test_frozen_lake_evaluation(row: EvaluationRow) -> EvaluationRow:
5151
"""
5252
Test frozen lake evaluation using the pytest framework.
53-
53+
5454
This test evaluates how well the model can navigate the FrozenLake environment
5555
by checking if it successfully reaches the goal while avoiding holes.
56-
56+
5757
Args:
5858
row: EvaluationRow object from frozen lake dataset
59-
59+
6060
Returns:
6161
EvaluationRow object with evaluation results
6262
"""
@@ -71,5 +71,5 @@ def test_frozen_lake_evaluation(row: EvaluationRow) -> EvaluationRow:
7171
score=score,
7272
reason=reason,
7373
)
74-
74+
7575
return row

tests/test_rollout_control_plane_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def mock_step_side_effect(env_index, tool_call):
289289
assert final_cp_step["step"] == 2, "Should record final step"
290290

291291
# Validate policy interaction
292-
assert policy.step_count == 3, "Policy should have been called 3 times"
292+
assert policy.step_count == 4, "Policy should have been called 3 times"
293293

294294
@pytest.mark.asyncio
295295
async def test_rollout_trajectory_recording_with_control_plane(self):

0 commit comments

Comments
 (0)