Skip to content

Commit 8dad6c3

Browse files
committed
fix trajectory collection
1 parent 2bbd53a commit 8dad6c3

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

eval_protocol/mcp/execution/manager.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ async def _execute_rollout(
252252
current_observation = user_message.content if user_message.content else ""
253253

254254
user_prompt = envs.format_user_prompt(rollout_idx, current_observation)
255-
conversation_history = [
255+
trajectory.conversation_history = [
256256
{"role": "system", "content": system_prompt},
257257
{"role": "user", "content": user_prompt},
258258
]
@@ -272,7 +272,7 @@ async def _execute_rollout(
272272

273273
if user_simulator and user_simulator_state:
274274
# Get user simulator messages and find the last assistant message
275-
user_simulator_messages = self._get_user_simulator_messages(conversation_history)
275+
user_simulator_messages = self._get_user_simulator_messages(trajectory.conversation_history)
276276

277277
# Last message was agent, simulated user response
278278
if user_simulator_messages and isinstance(user_simulator_messages[-1], AssistantMessage):
@@ -283,7 +283,7 @@ async def _execute_rollout(
283283
user_content = user_message.content if user_message.content else ""
284284

285285
user_prompt = envs.format_user_prompt(rollout_idx, user_content)
286-
conversation_history.append({"role": "user", "content": user_prompt})
286+
trajectory.conversation_history.append({"role": "user", "content": user_prompt})
287287

288288
# Check if user simulator signaled termination
289289
if UserSimulator.is_stop(user_message):
@@ -293,7 +293,7 @@ async def _execute_rollout(
293293
# In each turn: keep looping until assistant is ready to provide final response
294294
while not turn_completed and not trajectory.terminated:
295295
tool_calls, usage_stats, finish_reason = await policy(
296-
tool_schema, rollout_idx, conversation_history
296+
tool_schema, rollout_idx, trajectory.conversation_history
297297
)
298298

299299
# calc llm usage stats happened in this turn if there is aany
@@ -326,7 +326,7 @@ async def _execute_rollout(
326326
rollout_idx,
327327
tool_call,
328328
tool_response,
329-
conversation_history,
329+
trajectory.conversation_history,
330330
reward,
331331
env_end,
332332
info,
@@ -357,12 +357,14 @@ async def _execute_rollout(
357357
"num_tool_calls": 1,
358358
}
359359
print(f"🔍 control_plane_step: {control_plane_step}")
360-
conversation_history[-1]["control_plane_step"] = control_plane_step
360+
trajectory.conversation_history[-1]["control_plane_step"] = control_plane_step
361361
trajectory.control_plane_steps.append(control_plane_step)
362362

363363
# Log conversation state for playback if in recording mode
364364
if recording_mode:
365-
policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history)
365+
policy.log_conversation_state_for_playback(
366+
rollout_idx, step - 1, trajectory.conversation_history
367+
)
366368

367369
if env_end:
368370
# if the env marks the end of the rollout, break the tool call loop
@@ -396,17 +398,21 @@ async def _execute_rollout(
396398
"tool_calls": tool_calls_summary,
397399
"num_tool_calls": len(tool_calls),
398400
}
399-
conversation_history[-1]["control_plane_step"] = control_plane_step
401+
trajectory.conversation_history[-1]["control_plane_step"] = control_plane_step
400402
trajectory.control_plane_steps.append(control_plane_step)
401403

402404
# Log conversation state for playback if in recording mode
403405
if recording_mode:
404-
policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history)
406+
policy.log_conversation_state_for_playback(
407+
rollout_idx, step - 1, trajectory.conversation_history
408+
)
405409

406410
# if the env marks end, update control plane summary and do one last policy call, then break the agent loop
407411
# this is to ensure each turn ends with an assistant message, which will align with the actual agentic llm behavior
408412
if env_end:
409-
_, usage_stats, finish_reason = await policy(tool_schema, rollout_idx, conversation_history)
413+
_, usage_stats, finish_reason = await policy(
414+
tool_schema, rollout_idx, trajectory.conversation_history
415+
)
410416
if usage_stats:
411417
trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens
412418
trajectory.usage["completion_tokens"] += usage_stats.completion_tokens
@@ -424,10 +430,10 @@ async def _execute_rollout(
424430

425431
# Log final OpenAI conversation for terminated trajectories only
426432
if openai_logger:
427-
if conversation_history and len(conversation_history) > 0:
433+
if trajectory.conversation_history and len(trajectory.conversation_history) > 0:
428434
openai_logger(
429435
{
430-
"messages": conversation_history,
436+
"messages": trajectory.conversation_history,
431437
"metadata": {
432438
"session_id": session.session_id,
433439
"seed": session.seed,
@@ -453,8 +459,6 @@ async def _execute_rollout(
453459
if not trajectory.termination_reason and step >= steps:
454460
trajectory.termination_reason = TerminationReason.MAX_STEPS
455461

456-
trajectory.conversation_history = conversation_history
457-
458462
# Add termination_reason to the final control_plane_step
459463
for msg in reversed(trajectory.conversation_history):
460464
if msg.get("control_plane_step"):

0 commit comments

Comments
 (0)