Skip to content

Commit a50c3f6

Browse files
authored
handle streaming httpclient error and closure in same asyncio task an… (#49)
* handle streaming httpclient error and closure in same asyncio task and context * revert ep.make back * fix ut: * add interrupt termination reason * remove comment
1 parent 16149d2 commit a50c3f6

File tree

20 files changed

+101
-106
lines changed

20 files changed

+101
-106
lines changed

eval_protocol/mcp/client/connection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -539,10 +539,10 @@ async def close_session(self, session: MCPSession) -> None:
539539
await session._exit_stack.aclose()
540540
except asyncio.CancelledError:
541541
# Handle cancellation gracefully (especially important for Python 3.12)
542-
logger.debug(f"Session {session.session_id} close was cancelled")
542+
logger.error(f"Session {session.session_id} close was cancelled")
543543
except Exception as e:
544544
# Hitting this error, probably because of use of threads: "Attempted to exit cancel scope in a different task than it was entered in"
545-
logger.debug(f"Error closing session {session.session_id}: {e}")
545+
logger.error(f"Error closing session {session.session_id}: {e}")
546546
finally:
547547
session._exit_stack = None
548548
session._mcp_session = None

eval_protocol/mcp/execution/base_policy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ async def _generate_live_tool_calls(
220220
return mcp_tool_calls, usage_stats
221221
else:
222222
# No tool calls in response - this is normal when episode ends or LLM provides only text
223-
logger.info(f"No tool calls in response for env {env_index}, message content: {message.get('content')}")
223+
logger.debug(f"No tool calls in response for env {env_index}, message content: {message.get('content')}")
224224
return [
225225
MCPToolCall(
226226
tool_name="_no_tool_call",

eval_protocol/mcp/execution/manager.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,12 @@ async def execute_rollouts(
9797

9898
async def _execute_with_semaphore(idx):
9999
async with semaphore:
100-
return await self._execute_rollout(
100+
result = await self._execute_rollout(
101101
envs, policy, idx, steps, openai_logger, recording_mode, playback_mode, start_time
102102
)
103103

104+
return result
105+
104106
tasks = [_execute_with_semaphore(i) for i in range(envs.n)]
105107
# exceptions will be try catched inside single _execute_rollout
106108
trajectories = await asyncio.gather(*tasks)
@@ -112,9 +114,6 @@ async def _execute_with_semaphore(idx):
112114

113115
shared_tool_schema = envs.tool_schemas
114116

115-
# Clean up
116-
await envs.close()
117-
118117
# Enhanced reporting with control plane info
119118
successful = sum(1 for traj in trajectories if traj.total_reward > 0)
120119
terminated_by_control_plane = sum(
@@ -175,8 +174,11 @@ async def _execute_with_semaphore(idx):
175174
TerminationReason.USER_STOP,
176175
}:
177176
evaluation_rows[idx].rollout_status.status = "finished"
178-
elif trajectory.termination_reason == TerminationReason.MAX_STEPS:
177+
elif trajectory.termination_reason in {TerminationReason.MAX_STEPS, TerminationReason.INTERRUPTED}:
179178
evaluation_rows[idx].rollout_status.status = "stopped"
179+
evaluation_rows[idx].rollout_status.error_message = trajectory.control_plane_summary.get(
180+
"termination_reason", trajectory.termination_reason
181+
)
180182
else:
181183
evaluation_rows[idx].rollout_status.status = "error"
182184
evaluation_rows[idx].rollout_status.error_message = trajectory.control_plane_summary.get(
@@ -226,6 +228,7 @@ async def _execute_rollout(
226228
"total_tokens": 0,
227229
},
228230
)
231+
failure_reason = None
229232
try:
230233
current_observation, tool_schema = await envs.reset(session)
231234
system_prompt = dataset_row.system_prompt
@@ -311,8 +314,7 @@ async def _execute_rollout(
311314
# If there's no user simulator, no tool call means policy failed and we should terminate the rollout
312315
elif tool_calls[0].tool_name in ["_playback_terminate", "_no_tool_call"]:
313316
trajectory.terminated = True
314-
trajectory.termination_reason = TerminationReason.ERROR
315-
trajectory.control_plane_summary.update({"error_message": "No expected tool call"})
317+
trajectory.termination_reason = TerminationReason.INTERRUPTED
316318
break
317319

318320
# Execute each tool call sequentially
@@ -466,11 +468,26 @@ async def _execute_rollout(
466468
logger.info(
467469
f"✅ Rollout {rollout_idx} completed: {trajectory.steps} steps, reward: {trajectory.total_reward:.2f}, termination: {trajectory.termination_reason}, in thread {threading.current_thread().name}"
468470
)
471+
472+
except asyncio.CancelledError:
473+
logger.error(f"🚨 AsyncIO Cancel Error in roll out {rollout_idx}", exc_info=True)
474+
failure_reason = "asyncio context cancelled"
469475
except Exception as e:
470476
logger.error(f"🚨 Error in rollout {rollout_idx}: {e}", exc_info=True)
471-
trajectory.terminated = True
472-
trajectory.termination_reason = TerminationReason.ERROR
473-
trajectory.control_plane_summary.update({"error_message": str(e)})
477+
failure_reason = str(e)
478+
finally:
479+
if failure_reason:
480+
trajectory.terminated = True
481+
trajectory.termination_reason = TerminationReason.ERROR
482+
trajectory.control_plane_summary.update({"error_message": f"{failure_reason}"})
483+
try:
484+
await envs.connection_manager.reset_session(session)
485+
except:
486+
logger.error(f"Error resetting session {session.session_id}")
487+
try:
488+
await envs.connection_manager.close_session(session)
489+
except:
490+
logger.error(f"Error closing session {session.session_id}")
474491
return trajectory
475492

476493
async def _get_control_plane_status(self, session) -> Optional[Dict[str, Any]]:

eval_protocol/mcp/session/manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ async def reset(self, session: MCPSession) -> Tuple[Any, List[Dict]]:
5858
5959
This is thread-safe and can be called from worker threads.
6060
"""
61+
await self.connection_manager.initialize_session(session)
6162
# Get available tools from MCP server
6263
tool_schemas = await self.connection_manager.discover_tools(session)
6364

eval_protocol/mcp_env.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
policy = ep.FireworksPolicy(model_id="accounts/fireworks/models/qwen3-235b-a22b")
1818
1919
# Create environments with evaluation_rows configuration
20-
envs = await ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows)
20+
envs = ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows)
2121
2222
# Execute tool-calling rollouts
2323
evaluation_rows = await ep.rollout(envs, policy=policy, steps=512)
@@ -86,18 +86,17 @@ async def reset_mcp_sessions(envs: GeneralMCPVectorEnv):
8686
Reset mcp server sessions
8787
"""
8888
tasks = [envs.connection_manager.reset_session(session) for session in envs.sessions]
89-
await asyncio.gather(*tasks)
89+
await asyncio.gather(*tasks, return_exceptions=True)
9090

9191

92-
async def make(
92+
def make(
9393
env_spec: str,
9494
evaluation_rows: Optional[List[EvaluationRow]] = None,
9595
dataset: Optional[List[Dict]] = None,
9696
n: Optional[int] = None,
9797
seeds: Optional[List[int]] = None,
9898
model_id: str = "unknown",
9999
user_prompt_formatter: Optional[Callable] = None,
100-
reset_sessions: bool = False,
101100
) -> GeneralMCPVectorEnv:
102101
"""
103102
Create general MCP environments driven by evaluation_rows configuration.
@@ -110,20 +109,19 @@ async def make(
110109
seeds: List of seeds (for backward compatibility)
111110
model_id: Model identifier
112111
user_prompt_formatter: Optional callback for formatting user prompts
113-
reset_sessions: Whether to reset sessions before returning the environment
114112
115113
Returns:
116114
General MCP environment that works with any MCP server
117115
118116
Example:
119117
# EvaluationRow approach (preferred)
120-
envs = await ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows)
118+
envs = ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows)
121119
122120
# Dataset approach (backward compatibility)
123-
envs = await ep.make("http://localhost:8000/mcp", dataset=dataset)
121+
envs = ep.make("http://localhost:8000/mcp", dataset=dataset)
124122
125123
# Legacy approach (backward compatibility)
126-
envs = await ep.make("http://localhost:8000/mcp", n=10, seeds=seeds)
124+
envs = ep.make("http://localhost:8000/mcp", n=10, seeds=seeds)
127125
"""
128126
# Parse environment specification - make sure URL format is correct
129127
base_url = env_spec
@@ -236,12 +234,6 @@ async def make(
236234
sessions.append(session)
237235

238236
mcp_envs = GeneralMCPVectorEnv(sessions, dataset_rows, user_prompt_formatter)
239-
tasks = [mcp_envs.connection_manager.initialize_session(session) for session in sessions]
240-
await asyncio.gather(*tasks)
241-
242-
if reset_sessions:
243-
await reset_mcp_sessions(mcp_envs)
244-
245237
return mcp_envs
246238

247239

eval_protocol/pytest/default_mcp_gym_rollout_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ async def default_mcp_gym_rollout_processor(
226226
)
227227

228228
# Create MCP environments directly from evaluation_rows
229-
envs = await ep.make(
229+
envs = ep.make(
230230
"http://localhost:9700/mcp/",
231231
evaluation_rows=rows,
232232
model_id=policy.model_id,

eval_protocol/types/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,14 @@ class TerminationReason(str, Enum):
1111
MAX_STEPS: Trajectory ends because we hit the step limit
1212
CONTROL_PLANE_SIGNAL: Trajectory ends because the control plane signals termination (e.g. env goal reached or failure condition)
1313
USER_STOP: Trajectory ends because the simulated user signals to stop
14+
INTERRUPTED: Trajectory ends unexpectedly, for example, expecting tool call but there is no tool call
15+
ERROR: Trajectory ends because of an error
1416
"""
1517

1618
MAX_STEPS = "max_steps"
1719
CONTROL_PLANE_SIGNAL = "control_plane_signal"
1820
USER_STOP = "user_stop"
21+
INTERRUPTED = "interrupted"
1922
ERROR = "error"
2023

2124

examples/blackjack_mcp/tests/test_record_and_replay_e2e.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ async def test_production_server_record_and_replay(production_server, blackjack_
215215
assert playback_policy.is_playback_mode(), "Should be in playback mode in CI"
216216

217217
# Create environments for playback
218-
playback_envs = await ep.make(
218+
playback_envs = ep.make(
219219
"http://localhost:9500/mcp/",
220220
dataset=blackjack_dataset,
221221
model_id=playback_policy.model_id,
@@ -250,7 +250,7 @@ async def test_production_server_record_and_replay(production_server, blackjack_
250250
assert not policy.is_playback_mode(), "Should be in recording mode initially"
251251

252252
# Create environments
253-
envs = await ep.make(
253+
envs = ep.make(
254254
"http://localhost:9500/mcp/",
255255
dataset=blackjack_dataset,
256256
model_id=policy.model_id,
@@ -310,7 +310,7 @@ async def test_production_server_record_and_replay(production_server, blackjack_
310310
assert playback_policy.is_playback_mode(), "Should be in playback mode"
311311

312312
# Create new environments for playback
313-
playback_envs = await ep.make(
313+
playback_envs = ep.make(
314314
"http://localhost:9500/mcp/",
315315
dataset=blackjack_dataset,
316316
model_id=playback_policy.model_id,
@@ -462,7 +462,7 @@ async def test_blackjack_step_by_step(conda_isolation_recording_file):
462462
]
463463

464464
# Create environment pointing to conda-isolated server
465-
envs = await ep.make(
465+
envs = ep.make(
466466
f"http://localhost:{port}/mcp/",
467467
dataset=test_dataset,
468468
model_id=policy.model_id,
@@ -570,7 +570,7 @@ async def test_multi_environment_sessions(multi_env_dataset, multi_env_recording
570570
policy = create_blackjack_static_policy(action_sequence=["HIT", "HIT", "STICK"])
571571

572572
# Create multiple environments
573-
envs = await ep.make(
573+
envs = ep.make(
574574
f"http://localhost:{server.port}/mcp/",
575575
dataset=multi_env_dataset,
576576
model_id=policy.model_id,
@@ -992,7 +992,7 @@ async def test_fireworks_multi_environment_sessions(multi_env_dataset, fireworks
992992
assert playback_policy.is_playback_mode(), "Should be in playback mode in CI"
993993

994994
# Create environments for playback
995-
playback_envs = await ep.make(
995+
playback_envs = ep.make(
996996
"http://localhost:9500/mcp/",
997997
dataset=multi_env_dataset,
998998
model_id=playback_policy.model_id,
@@ -1033,7 +1033,7 @@ async def test_fireworks_multi_environment_sessions(multi_env_dataset, fireworks
10331033
assert not policy.is_playback_mode(), "Should be in recording mode initially"
10341034

10351035
# Create multiple environments
1036-
envs = await ep.make(
1036+
envs = ep.make(
10371037
f"http://localhost:{server.port}/mcp/",
10381038
dataset=multi_env_dataset,
10391039
model_id=policy.model_id,
@@ -1149,7 +1149,7 @@ async def test_control_plane_state_querying(multi_env_dataset):
11491149
policy = create_blackjack_static_policy(action_sequence=["HIT", "STAND"])
11501150

11511151
# Create environments
1152-
envs = await ep.make(
1152+
envs = ep.make(
11531153
f"http://localhost:{server.port}/mcp/",
11541154
dataset=multi_env_dataset[:2], # Use only 2 environments for faster testing
11551155
model_id=policy.model_id,

examples/cliff_walking_mcp/tests/test_cliff_walking_e2e.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ async def test_production_server_record_and_replay(
224224
assert playback_policy.is_playback_mode(), "Should be in playback mode in CI"
225225

226226
# Create environments for playback
227-
playback_envs = await ep.make(
227+
playback_envs = ep.make(
228228
"http://localhost:9500/mcp/",
229229
dataset=cliff_walking_dataset,
230230
model_id=playback_policy.model_id,
@@ -259,7 +259,7 @@ async def test_production_server_record_and_replay(
259259
assert not policy.is_playback_mode(), "Should be in recording mode initially"
260260

261261
# Create environments
262-
envs = await ep.make(
262+
envs = ep.make(
263263
"http://localhost:9500/mcp/",
264264
dataset=cliff_walking_dataset,
265265
model_id=policy.model_id,
@@ -318,7 +318,7 @@ async def test_production_server_record_and_replay(
318318
assert playback_policy.is_playback_mode(), "Should be in playback mode"
319319

320320
# Create new environments for playback
321-
playback_envs = await ep.make(
321+
playback_envs = ep.make(
322322
"http://localhost:9500/mcp/",
323323
dataset=cliff_walking_dataset,
324324
model_id=playback_policy.model_id,
@@ -471,7 +471,7 @@ async def test_cliff_walking_step_by_step(conda_isolation_recording_file):
471471
]
472472

473473
# Create environment pointing to conda-isolated server
474-
envs = await ep.make(
474+
envs = ep.make(
475475
f"http://localhost:{port}/mcp/",
476476
dataset=test_dataset,
477477
model_id=policy.model_id,
@@ -589,7 +589,7 @@ async def test_multi_environment_sessions(multi_env_dataset, multi_env_recording
589589
)
590590

591591
# Create multiple environments
592-
envs = await ep.make(
592+
envs = ep.make(
593593
f"http://localhost:{server.port}/mcp/",
594594
dataset=multi_env_dataset,
595595
model_id=policy.model_id,
@@ -1018,7 +1018,7 @@ async def test_fireworks_multi_environment_sessions(multi_env_dataset, fireworks
10181018
assert playback_policy.is_playback_mode(), "Should be in playback mode in CI"
10191019

10201020
# Create environments for playback
1021-
playback_envs = await ep.make(
1021+
playback_envs = ep.make(
10221022
"http://localhost:9500/mcp/",
10231023
dataset=multi_env_dataset,
10241024
model_id=playback_policy.model_id,
@@ -1059,7 +1059,7 @@ async def test_fireworks_multi_environment_sessions(multi_env_dataset, fireworks
10591059
assert not policy.is_playback_mode(), "Should be in recording mode initially"
10601060

10611061
# Create multiple environments
1062-
envs = await ep.make(
1062+
envs = ep.make(
10631063
f"http://localhost:{server.port}/mcp/",
10641064
dataset=multi_env_dataset,
10651065
model_id=policy.model_id,
@@ -1178,7 +1178,7 @@ async def test_control_plane_state_querying(multi_env_dataset):
11781178
policy = create_cliff_walking_static_policy(action_sequence=["UP", "UP"])
11791179

11801180
# Create environments
1181-
envs = await ep.make(
1181+
envs = ep.make(
11821182
f"http://localhost:{server.port}/mcp/",
11831183
dataset=multi_env_dataset[:2], # Use only 2 environments for faster testing
11841184
model_id=policy.model_id,

examples/frozen_lake_mcp/test_basic_functionality.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ async def test_basic_server_functionality():
4646
policy = ep.FireworksPolicy(model_id="accounts/fireworks/models/qwen3-235b-a22b", temperature=0.2)
4747

4848
# Create environment pointing to local server
49-
envs = await ep.make("http://localhost:8000/mcp/", dataset=test_dataset, model_id=policy.model_id)
49+
envs = ep.make("http://localhost:8000/mcp/", dataset=test_dataset, model_id=policy.model_id)
5050
print("✅ Successfully connected to MCP server")
5151

5252
# Test 2: Try to make tool calls (we'll simulate this for now)

0 commit comments

Comments
 (0)