Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 67 additions & 6 deletions eval_protocol/dataset_logger/sqlite_evaluation_row_store.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import logging
import time
from typing import List, Optional

from peewee import CharField, Model, SqliteDatabase
from peewee import CharField, Model, SqliteDatabase, FloatField
from playhouse.sqlite_ext import JSONField

from eval_protocol.models import EvaluationRow
Expand All @@ -18,6 +20,7 @@ def __init__(self, db_path: str):
os.makedirs(os.path.dirname(db_path), exist_ok=True)
self._db_path = db_path
self._db = SqliteDatabase(self._db_path, pragmas={"journal_mode": "wal"})
self._logger = logging.getLogger(__name__)

class BaseModel(Model):
class Meta:
Expand All @@ -26,12 +29,26 @@ class Meta:
class EvaluationRow(BaseModel): # type: ignore
rollout_id = CharField(unique=True)
data = JSONField()
updated_at = FloatField(default=lambda: time.time())

self._EvaluationRow = EvaluationRow

self._db.connect()
# Use safe=True to avoid errors when tables/indexes already exist
self._db.create_tables([EvaluationRow], safe=True)
# Attempt to add updated_at column for existing installations
try:
columns = {c.name for c in self._db.get_columns(self._EvaluationRow._meta.table_name)}
if "updated_at" not in columns:
self._db.execute_sql(
f'ALTER TABLE "{self._EvaluationRow._meta.table_name}" ADD COLUMN "updated_at" REAL'
)
# Backfill with current time
now_ts = time.time()
self._EvaluationRow.update(updated_at=now_ts).execute()
except Exception:
# Best-effort; ignore if migration not needed or fails
pass

@property
def db_path(self) -> str:
Expand All @@ -44,16 +61,60 @@ def upsert_row(self, data: dict) -> None:

with self._db.atomic("EXCLUSIVE"):
if self._EvaluationRow.select().where(self._EvaluationRow.rollout_id == rollout_id).exists():
self._EvaluationRow.update(data=data).where(self._EvaluationRow.rollout_id == rollout_id).execute()
self._EvaluationRow.update(data=data, updated_at=time.time()).where(
self._EvaluationRow.rollout_id == rollout_id
).execute()
else:
self._EvaluationRow.create(rollout_id=rollout_id, data=data)
self._EvaluationRow.create(rollout_id=rollout_id, data=data, updated_at=time.time())

def read_rows(self, rollout_id: Optional[str] = None) -> List[dict]:
# Build base query
if rollout_id is None:
query = self._EvaluationRow.select().dicts()
model_query = self._EvaluationRow.select().order_by(self._EvaluationRow.updated_at.desc())
else:
query = self._EvaluationRow.select().dicts().where(self._EvaluationRow.rollout_id == rollout_id)
results = list(query)
model_query = self._EvaluationRow.select().where(self._EvaluationRow.rollout_id == rollout_id)

# Log SQL for debugging
try:
sql_text, sql_params = model_query.sql()
self._logger.debug(
"[SQLITE_READ_ROWS] db=%s sql=%s params=%s", self._db_path, sql_text, sql_params
)
except Exception as e:
self._logger.debug("[SQLITE_READ_ROWS] Failed to render SQL for debug: %s", e)

# Execute and collect results
results = list(model_query.dicts())

# Debug: summarize results
try:
count = len(results)
sample = results[:3]
sample_rollout_ids = []
sample_updated = []
for r in sample:
# r is a row dict with keys: rollout_id, data, updated_at
rid = r.get("rollout_id")
# updated_at may be missing on very old rows; guard accordingly
up_at = r.get("updated_at", None)
# Prefer rollout_id from nested data if available
try:
rid_nested = r.get("data", {}).get("execution_metadata", {}).get("rollout_id")
if rid_nested:
rid = rid_nested
except Exception:
pass
sample_rollout_ids.append(str(rid))
sample_updated.append(up_at)
self._logger.debug(
"[SQLITE_READ_ROWS] fetched_rows=%d sample_rollout_ids=%s sample_updated_at=%s",
count,
sample_rollout_ids,
sample_updated,
)
except Exception as e:
self._logger.debug("[SQLITE_READ_ROWS] Failed to summarize results for debug: %s", e)

return [result["data"] for result in results]

def delete_row(self, rollout_id: str) -> int:
Expand Down
Empty file.
Loading
Loading