Skip to content

Commit 8219c44

Browse files
authored
Pipeline training support + priority based rollout scheduler (#358)
* add * add priority rolluot scheduler * groupwise * add * add * fix * put it back * add * add postprocess * resolve comments and fix bugs * fix
1 parent de3aba0 commit 8219c44

File tree

6 files changed

+1085
-228
lines changed

6 files changed

+1085
-228
lines changed

eval_protocol/pytest/buffer.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import asyncio
2+
import os
3+
from collections import defaultdict
4+
from typing import List, Dict
5+
6+
from eval_protocol.models import EvaluationRow
7+
8+
class MicroBatchDataBuffer:
9+
"""
10+
Buffers evaluation results and writes them to disk in minibatches.
11+
Waits for all runs of a sample to complete before considering it ready and flush to disk.
12+
"""
13+
def __init__(self, num_runs: int, batch_size: int, output_path_template: str):
14+
self.num_runs = num_runs
15+
self.batch_size = batch_size
16+
self.output_path_template = output_path_template
17+
self.pending_samples: Dict[str, List[EvaluationRow]] = defaultdict(list) # row_id -> list[EvaluationRow]
18+
self.completed_samples_buffer: List[List[EvaluationRow]] = [] # List[List[EvaluationRow]]
19+
self.batch_index = 0
20+
self.lock = asyncio.Lock()
21+
22+
async def add_result(self, row: EvaluationRow):
23+
"""
24+
Add a single evaluation result.
25+
Thread-safe/Coroutine-safe.
26+
"""
27+
async with self.lock:
28+
row_id = row.input_metadata.row_id
29+
if not row_id:
30+
# Should not happen in valid EP workflow, unique row_id is required to group things together properly
31+
return
32+
33+
self.pending_samples[row_id].append(row)
34+
35+
if len(self.pending_samples[row_id]) >= self.num_runs:
36+
# Sample completed (all runs finished)
37+
completed_rows = self.pending_samples.pop(row_id)
38+
self.completed_samples_buffer.append(completed_rows)
39+
40+
if len(self.completed_samples_buffer) >= self.batch_size:
41+
await self._flush_unsafe()
42+
43+
async def _flush_unsafe(self):
44+
"""
45+
not thread safe, assumes lock is held by called
46+
"""
47+
if not self.completed_samples_buffer:
48+
return
49+
50+
if "{index}" in self.output_path_template:
51+
output_path = self.output_path_template.format(index=self.batch_index)
52+
mode = "w"
53+
else:
54+
output_path = self.output_path_template
55+
mode = "a" # Append if no index placeholder
56+
57+
# Ensure directory exists
58+
os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
59+
60+
# Write flattened rows
61+
with open(output_path, mode) as f:
62+
for sample_rows in self.completed_samples_buffer:
63+
for row in sample_rows:
64+
f.write(row.model_dump_json() + "\n")
65+
66+
self.completed_samples_buffer = []
67+
self.batch_index += 1
68+
69+
async def close(self):
70+
"""
71+
Flush any remaining samples in the buffer.
72+
"""
73+
async with self.lock:
74+
# Also flush pending (incomplete) samples to avoid data loss
75+
if self.pending_samples:
76+
for rows in self.pending_samples.values():
77+
self.completed_samples_buffer.append(rows)
78+
self.pending_samples.clear()
79+
80+
if self.completed_samples_buffer:
81+
await self._flush_unsafe()
82+

0 commit comments

Comments
 (0)