Skip to content

Commit e68248b

Browse files
committed
hotfix for mcp gym rollout server
1 parent e8d83cd commit e68248b

File tree

2 files changed

+8
-18
lines changed

2 files changed

+8
-18
lines changed

eval_protocol/mcp/client/connection.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -53,26 +53,16 @@ async def initialize_session(self, session: MCPSession) -> None:
5353

5454
exit_stack = AsyncExitStack()
5555

56-
# Attach client metadata for the server to consume (session_id, seed, config, etc.).
57-
# The server inspects a private `_extra` dict on client_info, so we populate it here.
58-
client_info = Implementation(name="reward-kit", version="1.0.0")
59-
extra_data: Dict[str, Any] = {"session_id": session.session_id}
56+
client_info = Implementation(name="reward-kit", version="1.0.0", _extra={}) # pyright: ignore[reportCallIssue]
57+
client_info._extra["session_id"] = session.session_id # pyright: ignore[reportAttributeAccessIssue]
6058
if session.seed is not None:
61-
extra_data["seed"] = session.seed
59+
client_info._extra["seed"] = session.seed # pyright: ignore[reportAttributeAccessIssue]
6260
if session.dataset_row and session.dataset_row.environment_context:
63-
extra_data["config"] = session.dataset_row.environment_context
61+
client_info._extra["config"] = session.dataset_row.environment_context # pyright: ignore[reportAttributeAccessIssue]
6462
if session.dataset_row and session.dataset_row.id:
65-
extra_data["dataset_row_id"] = session.dataset_row.id
63+
client_info._extra["dataset_row_id"] = session.dataset_row.id # pyright: ignore[reportAttributeAccessIssue]
6664
if session.model_id:
67-
extra_data["model_id"] = session.model_id
68-
69-
# Merge with any existing _extra dict instead of overwriting
70-
existing_extra = getattr(client_info, "_extra", None)
71-
merged_extra: Dict[str, Any] = {}
72-
if isinstance(existing_extra, dict):
73-
merged_extra.update(existing_extra)
74-
merged_extra.update(extra_data)
75-
setattr(client_info, "_extra", merged_extra)
65+
client_info._extra["model_id"] = session.model_id # pyright: ignore[reportAttributeAccessIssue]
7666

7767
read_stream, write_stream, _ = await exit_stack.enter_async_context(
7868
streamablehttp_client(session.base_url, terminate_on_close=True)

eval_protocol/pytest/default_mcp_gym_rollout_processor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,9 @@ def __exit__(self, exc_type, exc_val, exc_tb):
199199

200200
class MCPGymRolloutProcessor(RolloutProcessor):
201201
"""
202-
Rollout processor for tau bench environments.
202+
Rollout processor for MCP gym environments.
203203
204-
This processor starts an MCP server, creates tau bench environments, and returns rollout tasks
204+
This processor starts an MCP server, creates an environment, and returns rollout tasks
205205
using the eval_protocol framework with proper cleanup handling.
206206
"""
207207

0 commit comments

Comments
 (0)