|
| 1 | +import asyncio |
| 2 | +import os |
| 3 | +from collections import defaultdict |
| 4 | +from typing import List, Dict |
| 5 | + |
| 6 | +from eval_protocol.models import EvaluationRow |
| 7 | + |
| 8 | +class MicroBatchDataBuffer: |
| 9 | + """ |
| 10 | + Buffers evaluation results and writes them to disk in minibatches. |
| 11 | + Waits for all runs of a sample to complete before considering it ready and flush to disk. |
| 12 | + """ |
| 13 | + def __init__(self, num_runs: int, batch_size: int, output_path_template: str): |
| 14 | + self.num_runs = num_runs |
| 15 | + self.batch_size = batch_size |
| 16 | + self.output_path_template = output_path_template |
| 17 | + self.pending_samples: Dict[str, List[EvaluationRow]] = defaultdict(list) # row_id -> list[EvaluationRow] |
| 18 | + self.completed_samples_buffer: List[List[EvaluationRow]] = [] # List[List[EvaluationRow]] |
| 19 | + self.batch_index = 0 |
| 20 | + self.lock = asyncio.Lock() |
| 21 | + |
| 22 | + async def add_result(self, row: EvaluationRow): |
| 23 | + """ |
| 24 | + Add a single evaluation result. |
| 25 | + Thread-safe/Coroutine-safe. |
| 26 | + """ |
| 27 | + async with self.lock: |
| 28 | + row_id = row.input_metadata.row_id |
| 29 | + if not row_id: |
| 30 | + # Should not happen in valid EP workflow, unique row_id is required to group things together properly |
| 31 | + return |
| 32 | + |
| 33 | + self.pending_samples[row_id].append(row) |
| 34 | + |
| 35 | + if len(self.pending_samples[row_id]) >= self.num_runs: |
| 36 | + # Sample completed (all runs finished) |
| 37 | + completed_rows = self.pending_samples.pop(row_id) |
| 38 | + self.completed_samples_buffer.append(completed_rows) |
| 39 | + |
| 40 | + if len(self.completed_samples_buffer) >= self.batch_size: |
| 41 | + await self._flush_unsafe() |
| 42 | + |
| 43 | + async def _flush_unsafe(self): |
| 44 | + """ |
| 45 | + not thread safe, assumes lock is held by called |
| 46 | + """ |
| 47 | + if not self.completed_samples_buffer: |
| 48 | + return |
| 49 | + |
| 50 | + if "{index}" in self.output_path_template: |
| 51 | + output_path = self.output_path_template.format(index=self.batch_index) |
| 52 | + mode = "w" |
| 53 | + else: |
| 54 | + output_path = self.output_path_template |
| 55 | + mode = "a" # Append if no index placeholder |
| 56 | + |
| 57 | + # Ensure directory exists |
| 58 | + os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True) |
| 59 | + |
| 60 | + # Write flattened rows |
| 61 | + with open(output_path, mode) as f: |
| 62 | + for sample_rows in self.completed_samples_buffer: |
| 63 | + for row in sample_rows: |
| 64 | + f.write(row.model_dump_json() + "\n") |
| 65 | + |
| 66 | + self.completed_samples_buffer = [] |
| 67 | + self.batch_index += 1 |
| 68 | + |
| 69 | + async def close(self): |
| 70 | + """ |
| 71 | + Flush any remaining samples in the buffer. |
| 72 | + """ |
| 73 | + async with self.lock: |
| 74 | + # Also flush pending (incomplete) samples to avoid data loss |
| 75 | + if self.pending_samples: |
| 76 | + for rows in self.pending_samples.values(): |
| 77 | + self.completed_samples_buffer.append(rows) |
| 78 | + self.pending_samples.clear() |
| 79 | + |
| 80 | + if self.completed_samples_buffer: |
| 81 | + await self._flush_unsafe() |
| 82 | + |
0 commit comments