@@ -318,52 +318,53 @@ async def _run_eval(rows_to_eval: Union[EvaluationRow, List[EvaluationRow]]):
318318 del self .active_rollouts [row_index ]
319319 await self ._update_rollout_pbar_postfix ()
320320
321- # 4. Update sample state and schedule next run (streaming)
322- async with sample_state .lock :
323- sample_state .active_runs -= 1
324- sample_state .completed_runs += 1
325-
326- # Extract history from this run's result
327- if result_row :
328- last_msg = result_row .last_assistant_message ()
329- if last_msg and last_msg .content :
330- sample_state .history .append (str (last_msg .content ))
331- else :
332- sample_state .history .append ("" )
333-
334- # In groupwise mode, buffer results
335- if self .mode == "groupwise" :
336- if result_row :
337- self .groups_buffer [row_index ].append (result_row )
338- # Check if all runs for this sample are complete
339- if sample_state .completed_runs >= self .rollout_n :
340- full_group = self .groups_buffer .pop (row_index , [])
341- if full_group :
342- t = asyncio .create_task (_run_eval (full_group ))
343- self .background_tasks .add (t )
344- t .add_done_callback (self .background_tasks .discard )
345-
346- # Schedule next run if:
347- # 1. There are more runs to do
348- # 2. We haven't hit in_group_minibatch_size concurrent runs for this sample
349- if (sample_state .next_run_idx < self .rollout_n and
350- sample_state .active_runs < self .in_group_minibatch_size ):
321+ # 4. Update sample state and schedule next run (streaming)
322+ # Must be in finally to ensure state is updated even on exception
323+ async with sample_state .lock :
324+ sample_state .active_runs -= 1
325+ sample_state .completed_runs += 1
351326
352- next_run_idx = sample_state .next_run_idx
353- sample_state .next_run_idx += 1
354- sample_state .active_runs += 1
327+ # Extract history from this run's result
328+ if result_row :
329+ last_msg = result_row .last_assistant_message ()
330+ if last_msg and last_msg .content :
331+ sample_state .history .append (str (last_msg .content ))
332+ else :
333+ sample_state .history .append ("" )
355334
356- # High priority (0) to finish this sample ASAP
357- # Use current accumulated history for speculation
358- priority = (0 , row_index , next_run_idx )
335+ # In groupwise mode, buffer results
336+ if self .mode == "groupwise" :
337+ if result_row :
338+ self .groups_buffer [row_index ].append (result_row )
339+ # Check if all runs for this sample are complete
340+ if sample_state .completed_runs >= self .rollout_n :
341+ full_group = self .groups_buffer .pop (row_index , [])
342+ if full_group :
343+ t = asyncio .create_task (_run_eval (full_group ))
344+ self .background_tasks .add (t )
345+ t .add_done_callback (self .background_tasks .discard )
359346
360- new_task = RolloutTask (
361- priority = priority ,
362- sample_state = sample_state ,
363- run_idx = next_run_idx ,
364- history_snapshot = list (sample_state .history ), # Snapshot current history
365- )
366- self .queue .put_nowait (new_task )
347+ # Schedule next run if:
348+ # 1. There are more runs to do
349+ # 2. We haven't hit in_group_minibatch_size concurrent runs for this sample
350+ if (sample_state .next_run_idx < self .rollout_n and
351+ sample_state .active_runs < self .in_group_minibatch_size ):
352+
353+ next_run_idx = sample_state .next_run_idx
354+ sample_state .next_run_idx += 1
355+ sample_state .active_runs += 1
356+
357+ # High priority (0) to finish this sample ASAP
358+ # Use current accumulated history for speculation
359+ priority = (0 , row_index , next_run_idx )
360+
361+ new_task = RolloutTask (
362+ priority = priority ,
363+ sample_state = sample_state ,
364+ run_idx = next_run_idx ,
365+ history_snapshot = list (sample_state .history ), # Snapshot current history
366+ )
367+ self .queue .put_nowait (new_task )
367368
368369 def _format_active_rollouts (self ) -> str :
369370 """Format active rollouts for display in progress bar."""
0 commit comments