Skip to content

Commit 1a1046e

Browse files
authored
llm usage stats collect fix + set trajectory failure reason (#42)
* keep intermediate llm usage stats even for failure trajectories * set termination reason and error message * add
1 parent 69ba4ac commit 1a1046e

File tree

1 file changed

+12
-13
lines changed

1 file changed

+12
-13
lines changed

eval_protocol/mcp/execution/manager.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ async def _execute_with_semaphore(idx):
163163
evaluation_rows[idx].input_metadata.row_id = envs.dataset_rows[idx].id
164164
evaluation_rows[idx].input_metadata.dataset_info = asdict(envs.dataset_rows[idx])
165165
evaluation_rows[idx].tools = shared_tool_schema
166-
evaluation_rows[idx].usage = trajectory.usage
166+
evaluation_rows[idx].usage = CompletionUsage(**trajectory.usage)
167167
evaluation_rows[idx].input_metadata.completion_params = CompletionParams(
168168
model=policy.model_id,
169169
temperature=getattr(policy, "temperature", None),
@@ -260,8 +260,6 @@ async def _execute_rollout(
260260
{"role": "user", "content": user_prompt},
261261
]
262262

263-
usage_stats_list: List[CompletionUsage] = []
264-
265263
logger.info(f"🎯 Starting rollout {rollout_idx} in thread {threading.current_thread().name}")
266264

267265
# Run rollout loop for this specific environment
@@ -299,6 +297,12 @@ async def _execute_rollout(
299297
while not turn_completed and not trajectory.terminated:
300298
tool_calls, usage_stats = await policy(tool_schema, rollout_idx, conversation_history)
301299

300+
# calc llm usage stats happened in this turn if there is aany
301+
if usage_stats:
302+
trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens
303+
trajectory.usage["completion_tokens"] += usage_stats.completion_tokens
304+
trajectory.usage["total_tokens"] += usage_stats.total_tokens
305+
302306
# If no tool call is generated, turn is finished
303307
if len(tool_calls) == 1:
304308
# If there's a user simulator, no tool call means the policy is ready to provide final response on this turn
@@ -308,6 +312,8 @@ async def _execute_rollout(
308312
# If there's no user simulator, no tool call means policy failed and we should terminate the rollout
309313
elif tool_calls[0].tool_name in ["_playback_terminate", "_no_tool_call"]:
310314
trajectory.terminated = True
315+
trajectory.termination_reason = TerminationReason.ERROR
316+
trajectory.control_plane_summary.update({"error_message": "No expected tool call"})
311317
break
312318

313319
# Execute each tool call sequentially
@@ -373,10 +379,6 @@ async def _execute_rollout(
373379
if observation is not None:
374380
current_observation = observation
375381

376-
# calc llm usage stats happened in this turn if there is aany
377-
if usage_stats:
378-
usage_stats_list.append(usage_stats)
379-
380382
# With user simulator, increment step after an entire conversation step
381383
if user_simulator is not None:
382384
step += 1
@@ -409,7 +411,9 @@ async def _execute_rollout(
409411
# tool indicates rollout should be terminated, call policy one last time to get the final response
410412
_, usage_stats = await policy(tool_schema, rollout_idx, conversation_history)
411413
if usage_stats:
412-
usage_stats_list.append(usage_stats)
414+
trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens
415+
trajectory.usage["completion_tokens"] += usage_stats.completion_tokens
416+
trajectory.usage["total_tokens"] += usage_stats.total_tokens
413417

414418
# Add final control plane summary
415419
trajectory.control_plane_summary.update(
@@ -460,11 +464,6 @@ async def _execute_rollout(
460464
msg["control_plane_step"]["termination_reason"] = trajectory.termination_reason
461465
break
462466

463-
for usage_stats in usage_stats_list:
464-
trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens
465-
trajectory.usage["completion_tokens"] += usage_stats.completion_tokens
466-
trajectory.usage["total_tokens"] += usage_stats.total_tokens
467-
468467
logger.info(
469468
f"✅ Rollout {rollout_idx} completed: {trajectory.steps} steps, reward: {trajectory.total_reward:.2f}, termination: {trajectory.termination_reason}, in thread {threading.current_thread().name}"
470469
)

0 commit comments

Comments
 (0)