1717 policy = ep.FireworksPolicy(model_id="accounts/fireworks/models/qwen3-235b-a22b")
1818
1919 # Create environments with evaluation_rows configuration
20- envs = ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows)
20+ envs = await ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows)
2121
2222 # Execute tool-calling rollouts
2323 evaluation_rows = await ep.rollout(envs, policy=policy, steps=512)
5151from .mcp .session .manager import GeneralMCPVectorEnv
5252from .models import EvaluationRow
5353from .types import DatasetRow , MCPSession , MCPToolCall
54+ import asyncio
5455
5556logger = logging .getLogger (__name__ )
5657
5758
58- def make (
59+ async def reset_mcp_sessions (envs : GeneralMCPVectorEnv ):
60+ """
61+ Reset mcp server sessions
62+ """
63+ tasks = [envs .connection_manager .reset_session (session ) for session in envs .sessions ]
64+ await asyncio .gather (* tasks )
65+
66+
67+ async def make (
5968 env_spec : str ,
6069 evaluation_rows : Optional [List [EvaluationRow ]] = None ,
6170 dataset : Optional [List [Dict ]] = None ,
6271 n : Optional [int ] = None ,
6372 seeds : Optional [List [int ]] = None ,
6473 model_id : str = "unknown" ,
6574 user_prompt_formatter : Optional [Callable ] = None ,
75+ reset_sessions : bool = False ,
6676) -> GeneralMCPVectorEnv :
6777 """
6878 Create general MCP environments driven by evaluation_rows configuration.
@@ -75,19 +85,20 @@ def make(
7585 seeds: List of seeds (for backward compatibility)
7686 model_id: Model identifier
7787 user_prompt_formatter: Optional callback for formatting user prompts
88+ reset_sessions: Whether to reset sessions before returning the environment
7889
7990 Returns:
8091 General MCP environment that works with any MCP server
8192
8293 Example:
8394 # EvaluationRow approach (preferred)
84- envs = ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows)
95+ envs = await ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows)
8596
8697 # Dataset approach (backward compatibility)
87- envs = ep.make("http://localhost:8000/mcp", dataset=dataset)
98+ envs = await ep.make("http://localhost:8000/mcp", dataset=dataset)
8899
89100 # Legacy approach (backward compatibility)
90- envs = ep.make("http://localhost:8000/mcp", n=10, seeds=seeds)
101+ envs = await ep.make("http://localhost:8000/mcp", n=10, seeds=seeds)
91102 """
92103 # Parse environment specification - make sure URL format is correct
93104 base_url = env_spec
@@ -160,8 +171,6 @@ def make(
160171 )
161172 sessions .append (session )
162173
163- return GeneralMCPVectorEnv (sessions , dataset_rows , user_prompt_formatter )
164-
165174 else :
166175 # Legacy approach for backward compatibility
167176 if n is None :
@@ -198,7 +207,14 @@ def make(
198207 )
199208 sessions .append (session )
200209
201- return GeneralMCPVectorEnv (sessions , dataset_rows , user_prompt_formatter )
210+ mcp_envs = GeneralMCPVectorEnv (sessions , dataset_rows , user_prompt_formatter )
211+ tasks = [mcp_envs .connection_manager .initialize_session (session ) for session in sessions ]
212+ await asyncio .gather (* tasks )
213+
214+ if reset_sessions :
215+ await reset_mcp_sessions (mcp_envs )
216+
217+ return mcp_envs
202218
203219
204220async def rollout (
@@ -266,7 +282,7 @@ async def rollout(
266282 raise ValueError ("Either 'evaluation_rows' or 'dataset' must be provided when envs is a URL" )
267283
268284 auto_model_id = model_id or getattr (policy , "model_id" , "unknown" )
269- envs = make (envs , evaluation_rows = evaluation_rows , dataset = dataset , model_id = auto_model_id )
285+ envs = await make (envs , evaluation_rows = evaluation_rows , dataset = dataset , model_id = auto_model_id )
270286
271287 # Use the new ExecutionManager for execution
272288 execution_manager = ExecutionManager ()
0 commit comments