Skip to content

Commit c726a57

Browse files
committed
fix
1 parent 692f3ad commit c726a57

File tree

2 files changed

+47
-43
lines changed

2 files changed

+47
-43
lines changed

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,10 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
9797
chunks.append(chunk)
9898
response = litellm.stream_chunk_builder(chunks, messages_payload)
9999
else:
100+
tc = time.perf_counter()
101+
# print(f"run_id {row.execution_metadata.run_id} request_params: {json.dumps(request_params)}")
100102
response = await acompletion(**request_params)
103+
print(f"run_id {row.execution_metadata.run_id} time taken: {time.perf_counter() - tc} speculation_enabled: {request_params.get('extra_body', {}).get('prediction', None) is not None}")
101104

102105
assert response is not None, "Response is None"
103106
assert isinstance(response, ModelResponse), "Response should be ModelResponse"

eval_protocol/pytest/priority_scheduler.py

Lines changed: 44 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -318,52 +318,53 @@ async def _run_eval(rows_to_eval: Union[EvaluationRow, List[EvaluationRow]]):
318318
del self.active_rollouts[row_index]
319319
await self._update_rollout_pbar_postfix()
320320

321-
# 4. Update sample state and schedule next run (streaming)
322-
async with sample_state.lock:
323-
sample_state.active_runs -= 1
324-
sample_state.completed_runs += 1
325-
326-
# Extract history from this run's result
327-
if result_row:
328-
last_msg = result_row.last_assistant_message()
329-
if last_msg and last_msg.content:
330-
sample_state.history.append(str(last_msg.content))
331-
else:
332-
sample_state.history.append("")
333-
334-
# In groupwise mode, buffer results
335-
if self.mode == "groupwise":
336-
if result_row:
337-
self.groups_buffer[row_index].append(result_row)
338-
# Check if all runs for this sample are complete
339-
if sample_state.completed_runs >= self.rollout_n:
340-
full_group = self.groups_buffer.pop(row_index, [])
341-
if full_group:
342-
t = asyncio.create_task(_run_eval(full_group))
343-
self.background_tasks.add(t)
344-
t.add_done_callback(self.background_tasks.discard)
345-
346-
# Schedule next run if:
347-
# 1. There are more runs to do
348-
# 2. We haven't hit in_group_minibatch_size concurrent runs for this sample
349-
if (sample_state.next_run_idx < self.rollout_n and
350-
sample_state.active_runs < self.in_group_minibatch_size):
321+
# 4. Update sample state and schedule next run (streaming)
322+
# Must be in finally to ensure state is updated even on exception
323+
async with sample_state.lock:
324+
sample_state.active_runs -= 1
325+
sample_state.completed_runs += 1
351326

352-
next_run_idx = sample_state.next_run_idx
353-
sample_state.next_run_idx += 1
354-
sample_state.active_runs += 1
327+
# Extract history from this run's result
328+
if result_row:
329+
last_msg = result_row.last_assistant_message()
330+
if last_msg and last_msg.content:
331+
sample_state.history.append(str(last_msg.content))
332+
else:
333+
sample_state.history.append("")
355334

356-
# High priority (0) to finish this sample ASAP
357-
# Use current accumulated history for speculation
358-
priority = (0, row_index, next_run_idx)
335+
# In groupwise mode, buffer results
336+
if self.mode == "groupwise":
337+
if result_row:
338+
self.groups_buffer[row_index].append(result_row)
339+
# Check if all runs for this sample are complete
340+
if sample_state.completed_runs >= self.rollout_n:
341+
full_group = self.groups_buffer.pop(row_index, [])
342+
if full_group:
343+
t = asyncio.create_task(_run_eval(full_group))
344+
self.background_tasks.add(t)
345+
t.add_done_callback(self.background_tasks.discard)
359346

360-
new_task = RolloutTask(
361-
priority=priority,
362-
sample_state=sample_state,
363-
run_idx=next_run_idx,
364-
history_snapshot=list(sample_state.history), # Snapshot current history
365-
)
366-
self.queue.put_nowait(new_task)
347+
# Schedule next run if:
348+
# 1. There are more runs to do
349+
# 2. We haven't hit in_group_minibatch_size concurrent runs for this sample
350+
if (sample_state.next_run_idx < self.rollout_n and
351+
sample_state.active_runs < self.in_group_minibatch_size):
352+
353+
next_run_idx = sample_state.next_run_idx
354+
sample_state.next_run_idx += 1
355+
sample_state.active_runs += 1
356+
357+
# High priority (0) to finish this sample ASAP
358+
# Use current accumulated history for speculation
359+
priority = (0, row_index, next_run_idx)
360+
361+
new_task = RolloutTask(
362+
priority=priority,
363+
sample_state=sample_state,
364+
run_idx=next_run_idx,
365+
history_snapshot=list(sample_state.history), # Snapshot current history
366+
)
367+
self.queue.put_nowait(new_task)
367368

368369
def _format_active_rollouts(self) -> str:
369370
"""Format active rollouts for display in progress bar."""

0 commit comments

Comments
 (0)