Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions eval_protocol/dataset_logger/sqlite_evaluation_row_store.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import os
from typing import List, Optional

from peewee import CharField, DatabaseError, Model, SqliteDatabase
from peewee import CharField, Model, SqliteDatabase
from playhouse.sqlite_ext import JSONField

from eval_protocol.event_bus.sqlite_event_bus_database import (
SQLITE_HARDENED_PRAGMAS,
DatabaseCorruptedError,
check_and_repair_database,
execute_with_sqlite_retry,
)
from eval_protocol.models import EvaluationRow

Expand Down Expand Up @@ -55,7 +55,13 @@ def upsert_row(self, data: dict) -> None:
if rollout_id is None:
raise ValueError("execution_metadata.rollout_id is required to upsert a row")

with self._db.atomic("EXCLUSIVE"):
execute_with_sqlite_retry(lambda: self._do_upsert(rollout_id, data))

def _do_upsert(self, rollout_id: str, data: dict) -> None:
"""Internal method to perform the actual upsert within a transaction."""
# Use IMMEDIATE instead of EXCLUSIVE for better concurrency
# IMMEDIATE acquires a reserved lock immediately but allows concurrent reads
with self._db.atomic("IMMEDIATE"):
if self._EvaluationRow.select().where(self._EvaluationRow.rollout_id == rollout_id).exists():
self._EvaluationRow.update(data=data).where(self._EvaluationRow.rollout_id == rollout_id).execute()
else:
Expand Down
1 change: 1 addition & 0 deletions eval_protocol/event_bus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from eval_protocol.event_bus.sqlite_event_bus_database import (
DatabaseCorruptedError,
check_and_repair_database,
execute_with_sqlite_retry,
SQLITE_HARDENED_PRAGMAS,
)

Expand Down
76 changes: 65 additions & 11 deletions eval_protocol/event_bus/sqlite_event_bus_database.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,60 @@
import os
import time
from typing import Any, List
from typing import Any, Callable, List, TypeVar
from uuid import uuid4

from peewee import BooleanField, CharField, DatabaseError, DateTimeField, Model, SqliteDatabase
import backoff
from peewee import BooleanField, CharField, DatabaseError, DateTimeField, Model, OperationalError, SqliteDatabase
from playhouse.sqlite_ext import JSONField

from eval_protocol.event_bus.logger import logger


# Retry configuration for database operations
SQLITE_RETRY_MAX_TRIES = 5
SQLITE_RETRY_MAX_TIME = 30 # seconds


def _is_database_locked_error(e: Exception) -> bool:
"""Check if an exception is a database locked error."""
error_str = str(e).lower()
return "database is locked" in error_str or "locked" in error_str


T = TypeVar("T")


def execute_with_sqlite_retry(operation: Callable[[], T]) -> T:
"""
Execute a database operation with exponential backoff retry on lock errors.

Uses the backoff library for consistent retry behavior across the codebase.
Retries only on OperationalError with "database is locked" message.

Args:
operation: A callable that performs the database operation

Returns:
The result of the operation

Raises:
OperationalError: If the operation fails after all retries
"""

@backoff.on_exception(
backoff.expo,
OperationalError,
max_tries=SQLITE_RETRY_MAX_TRIES,
max_time=SQLITE_RETRY_MAX_TIME,
giveup=lambda e: not _is_database_locked_error(e),
jitter=backoff.full_jitter,
)
def _execute() -> T:
return operation()

return _execute()


# SQLite pragmas for hardened concurrency safety
SQLITE_HARDENED_PRAGMAS = {
"journal_mode": "wal", # Write-Ahead Logging for concurrent reads/writes
Expand Down Expand Up @@ -148,13 +194,15 @@ def publish_event(self, event_type: str, data: Any, process_id: str) -> None:
else:
serialized_data = data

self._Event.create(
event_id=str(uuid4()),
event_type=event_type,
data=serialized_data,
timestamp=time.time(),
process_id=process_id,
processed=False,
execute_with_sqlite_retry(
lambda: self._Event.create(
event_id=str(uuid4()),
event_type=event_type,
data=serialized_data,
timestamp=time.time(),
process_id=process_id,
processed=False,
)
)
except Exception as e:
logger.warning(f"Failed to publish event to database: {e}")
Expand Down Expand Up @@ -188,14 +236,20 @@ def get_unprocessed_events(self, process_id: str) -> List[dict]:
def mark_event_processed(self, event_id: str) -> None:
"""Mark an event as processed."""
try:
self._Event.update(processed=True).where(self._Event.event_id == event_id).execute()
execute_with_sqlite_retry(
lambda: self._Event.update(processed=True).where(self._Event.event_id == event_id).execute()
)
except Exception as e:
logger.debug(f"Failed to mark event as processed: {e}")

def cleanup_old_events(self, max_age_hours: int = 24) -> None:
"""Clean up old processed events."""
try:
cutoff_time = time.time() - (max_age_hours * 3600)
self._Event.delete().where((self._Event.processed) & (self._Event.timestamp < cutoff_time)).execute()
execute_with_sqlite_retry(
lambda: self._Event.delete()
.where((self._Event.processed) & (self._Event.timestamp < cutoff_time))
.execute()
)
except Exception as e:
logger.debug(f"Failed to cleanup old events: {e}")
Loading