@@ -163,7 +163,7 @@ async def _execute_with_semaphore(idx):
163163 evaluation_rows [idx ].input_metadata .row_id = envs .dataset_rows [idx ].id
164164 evaluation_rows [idx ].input_metadata .dataset_info = asdict (envs .dataset_rows [idx ])
165165 evaluation_rows [idx ].tools = shared_tool_schema
166- evaluation_rows [idx ].usage = trajectory .usage
166+ evaluation_rows [idx ].usage = CompletionUsage ( ** trajectory .usage )
167167 evaluation_rows [idx ].input_metadata .completion_params = CompletionParams (
168168 model = policy .model_id ,
169169 temperature = getattr (policy , "temperature" , None ),
@@ -260,8 +260,6 @@ async def _execute_rollout(
260260 {"role" : "user" , "content" : user_prompt },
261261 ]
262262
263- usage_stats_list : List [CompletionUsage ] = []
264-
265263 logger .info (f"🎯 Starting rollout { rollout_idx } in thread { threading .current_thread ().name } " )
266264
267265 # Run rollout loop for this specific environment
@@ -299,6 +297,12 @@ async def _execute_rollout(
299297 while not turn_completed and not trajectory .terminated :
300298 tool_calls , usage_stats = await policy (tool_schema , rollout_idx , conversation_history )
301299
300+ # calc llm usage stats happened in this turn if there is aany
301+ if usage_stats :
302+ trajectory .usage ["prompt_tokens" ] += usage_stats .prompt_tokens
303+ trajectory .usage ["completion_tokens" ] += usage_stats .completion_tokens
304+ trajectory .usage ["total_tokens" ] += usage_stats .total_tokens
305+
302306 # If no tool call is generated, turn is finished
303307 if len (tool_calls ) == 1 :
304308 # If there's a user simulator, no tool call means the policy is ready to provide final response on this turn
@@ -308,6 +312,8 @@ async def _execute_rollout(
308312 # If there's no user simulator, no tool call means policy failed and we should terminate the rollout
309313 elif tool_calls [0 ].tool_name in ["_playback_terminate" , "_no_tool_call" ]:
310314 trajectory .terminated = True
315+ trajectory .termination_reason = TerminationReason .ERROR
316+ trajectory .control_plane_summary .update ({"error_message" : "No expected tool call" })
311317 break
312318
313319 # Execute each tool call sequentially
@@ -373,10 +379,6 @@ async def _execute_rollout(
373379 if observation is not None :
374380 current_observation = observation
375381
376- # calc llm usage stats happened in this turn if there is aany
377- if usage_stats :
378- usage_stats_list .append (usage_stats )
379-
380382 # With user simulator, increment step after an entire conversation step
381383 if user_simulator is not None :
382384 step += 1
@@ -409,7 +411,9 @@ async def _execute_rollout(
409411 # tool indicates rollout should be terminated, call policy one last time to get the final response
410412 _ , usage_stats = await policy (tool_schema , rollout_idx , conversation_history )
411413 if usage_stats :
412- usage_stats_list .append (usage_stats )
414+ trajectory .usage ["prompt_tokens" ] += usage_stats .prompt_tokens
415+ trajectory .usage ["completion_tokens" ] += usage_stats .completion_tokens
416+ trajectory .usage ["total_tokens" ] += usage_stats .total_tokens
413417
414418 # Add final control plane summary
415419 trajectory .control_plane_summary .update (
@@ -460,11 +464,6 @@ async def _execute_rollout(
460464 msg ["control_plane_step" ]["termination_reason" ] = trajectory .termination_reason
461465 break
462466
463- for usage_stats in usage_stats_list :
464- trajectory .usage ["prompt_tokens" ] += usage_stats .prompt_tokens
465- trajectory .usage ["completion_tokens" ] += usage_stats .completion_tokens
466- trajectory .usage ["total_tokens" ] += usage_stats .total_tokens
467-
468467 logger .info (
469468 f"✅ Rollout { rollout_idx } completed: { trajectory .steps } steps, reward: { trajectory .total_reward :.2f} , termination: { trajectory .termination_reason } , in thread { threading .current_thread ().name } "
470469 )
0 commit comments