Skip to content

Commit c0bece6

Browse files
authored
Update Eval Row Messages Mid Rollout (#125)
1 parent fd2bec1 commit c0bece6

File tree

1 file changed

+28
-4
lines changed

1 file changed

+28
-4
lines changed

eval_protocol/mcp/execution/manager.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,12 @@ def execute_rollouts(
101101

102102
async def _execute_with_semaphore(idx):
103103
async with semaphore:
104+
evaluation_row: EvaluationRow = evaluation_rows[idx]
105+
104106
trajectory = await self._execute_rollout(
105-
envs, policy, idx, steps, openai_logger, recording_mode, playback_mode, start_time
107+
envs, policy, idx, steps, openai_logger, recording_mode, playback_mode, start_time, evaluation_row
106108
)
107109

108-
# Convert trajectory to EvaluationRow immediately
109-
evaluation_row: EvaluationRow = evaluation_rows[idx]
110-
111110
# Handle multimodal content by extracting text from complex content structures
112111
messages = []
113112
for msg in trajectory.conversation_history:
@@ -161,6 +160,7 @@ async def _execute_rollout(
161160
recording_mode: bool,
162161
playback_mode: bool,
163162
start_time: float,
163+
evaluation_row: Optional[EvaluationRow] = None,
164164
) -> Trajectory:
165165
"""
166166
Execute a single rollout for one environment (async version for thread execution).
@@ -170,6 +170,25 @@ async def _execute_rollout(
170170
session = envs.sessions[rollout_idx]
171171
dataset_row = envs.dataset_rows[rollout_idx]
172172

173+
# Helper function to sync conversation history to evaluation_row.messages
174+
def update_evaluation_row_messages():
175+
if evaluation_row:
176+
177+
def extract_text_content(msg_dict):
178+
msg_copy = dict(msg_dict)
179+
if isinstance(msg_copy.get("content"), list):
180+
for content_block in msg_copy["content"]:
181+
if isinstance(content_block, dict) and content_block.get("type") == "text":
182+
msg_copy["content"] = content_block.get("text", "")
183+
break
184+
else:
185+
msg_copy["content"] = ""
186+
return msg_copy
187+
188+
evaluation_row.messages = [
189+
Message.model_validate(extract_text_content(msg)) for msg in trajectory.conversation_history
190+
]
191+
173192
# Initialize trajectory
174193
trajectory = Trajectory(
175194
session=session,
@@ -223,6 +242,7 @@ async def _execute_rollout(
223242
{"role": "system", "content": system_prompt},
224243
{"role": "user", "content": user_prompt},
225244
]
245+
update_evaluation_row_messages()
226246

227247
logger.info(f"🎯 Starting rollout {rollout_idx} in thread {threading.current_thread().name}")
228248

@@ -251,6 +271,7 @@ async def _execute_rollout(
251271

252272
user_prompt = envs.format_user_prompt(rollout_idx, user_content)
253273
trajectory.conversation_history.append({"role": "user", "content": user_prompt})
274+
update_evaluation_row_messages()
254275

255276
# Check if user simulator signaled termination
256277
if UserSimulator.is_stop(user_message):
@@ -262,6 +283,7 @@ async def _execute_rollout(
262283
tool_calls, usage_stats, finish_reason = await policy(
263284
tool_schema, rollout_idx, trajectory.conversation_history
264285
)
286+
update_evaluation_row_messages()
265287

266288
# calc llm usage stats happened in this turn if there is aany
267289
if usage_stats:
@@ -297,6 +319,7 @@ async def _execute_rollout(
297319
env_end,
298320
info,
299321
)
322+
update_evaluation_row_messages()
300323

301324
# Update trajectory with both data and control plane information
302325
trajectory.observations.append(observation)
@@ -379,6 +402,7 @@ async def _execute_rollout(
379402
_, usage_stats, finish_reason = await policy(
380403
tool_schema, rollout_idx, trajectory.conversation_history
381404
)
405+
update_evaluation_row_messages()
382406
if usage_stats:
383407
trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens
384408
trajectory.usage["completion_tokens"] += usage_stats.completion_tokens

0 commit comments

Comments
 (0)