Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 48 additions & 17 deletions database/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import AsyncGenerator

from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine

from .models import Base

Expand All @@ -15,7 +15,7 @@ def __init__(self, database_url: str):
# Convert postgresql:// to postgresql+asyncpg:// for async support
if database_url.startswith("postgresql://"):
database_url = database_url.replace("postgresql://", "postgresql+asyncpg://")

self.engine = create_async_engine(
database_url,
# Connection pool settings for async
Expand All @@ -34,54 +34,86 @@ def __init__(self, database_url: str):
}
)
self.async_session = async_sessionmaker(
self.engine,
class_=AsyncSession,
self.engine,
class_=AsyncSession,
expire_on_commit=False
)

async def create_tables(self):
"""Create all tables defined in the models."""
try:
async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)


# Run lightweight migrations for existing tables
await self._run_migrations(conn)

# Drop Hummingbot's native tables since we use our custom orders/trades tables
await self._drop_hummingbot_tables(conn)

logger.info("Database tables created successfully")
except Exception as e:
logger.error(f"Failed to create database tables: {e}")
raise


async def _run_migrations(self, conn):
"""Run lightweight schema migrations for existing tables."""
migrations = [
# Add controller_id to executors table (default "main" for existing rows)
(
"executors", "controller_id",
"ALTER TABLE executors ADD COLUMN controller_id TEXT NOT NULL DEFAULT 'main'"
),
]
for table, column, sql in migrations:
try:
# Check if column already exists
result = await conn.execute(
text(
"SELECT column_name FROM information_schema.columns "
"WHERE table_name = :table AND column_name = :column"
),
{"table": table, "column": column}
)
if result.fetchone() is None:
await conn.execute(text(sql))
logger.info(f"Migration: added {column} to {table}")
except Exception as e:
# Column-already-exists is expected on repeat startups
err_msg = str(e).lower()
if "already exists" in err_msg or "duplicate column" in err_msg:
logger.debug(f"Migration check for {table}.{column}: {e}")
else:
logger.warning(f"Unexpected migration error for {table}.{column}: {e}")

async def _drop_hummingbot_tables(self, conn):
"""Drop Hummingbot's native database tables since we use custom ones."""
hummingbot_tables = [
"hummingbot_orders",
"hummingbot_trade_fills",
"hummingbot_trade_fills",
"hummingbot_order_status"
]

for table_name in hummingbot_tables:
try:
await conn.execute(text(f"DROP TABLE IF EXISTS {table_name}"))
logger.info(f"Dropped Hummingbot table: {table_name}")
except Exception as e:
logger.debug(f"Could not drop table {table_name}: {e}") # Use debug since table might not exist

async def close(self):
"""Close all database connections."""
await self.engine.dispose()
logger.info("Database connections closed")

def get_session(self) -> AsyncSession:
"""Get a new database session."""
return self.async_session()

@asynccontextmanager
async def get_session_context(self) -> AsyncGenerator[AsyncSession, None]:
"""
Get a database session with automatic error handling and cleanup.

Usage:
async with db_manager.get_session_context() as session:
# Use session here
Expand All @@ -95,11 +127,10 @@ async def get_session_context(self) -> AsyncGenerator[AsyncSession, None]:
raise
finally:
await session.close()

async def health_check(self) -> bool:
"""
Check if the database connection is healthy.

Returns:
bool: True if connection is healthy, False otherwise.
"""
Expand All @@ -109,4 +140,4 @@ async def health_check(self) -> bool:
return True
except Exception as e:
logger.error(f"Database health check failed: {e}")
return False
return False
1 change: 1 addition & 0 deletions database/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ class ExecutorRecord(Base):
account_name = Column(String, nullable=False, index=True)
connector_name = Column(String, nullable=False, index=True)
trading_pair = Column(String, nullable=False, index=True)
controller_id = Column(String, nullable=False, default="main", index=True)

# Timestamps
created_at = Column(TIMESTAMP(timezone=True), server_default=func.now(), nullable=False, index=True)
Expand Down
Loading
Loading