Skip to content

Commit d1bc6d0

Browse files
authored
[CVAT][Exchange Oracle] Use Alembic in tests (#2620)
* Use Alembic in tests * Use RuntimeError instead of assert statement
1 parent b71589e commit d1bc6d0

File tree

4 files changed

+134
-25
lines changed

4 files changed

+134
-25
lines changed

packages/examples/cvat/exchange-oracle/alembic/env.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def run_migrations_offline() -> None:
4949
target_metadata=target_metadata,
5050
literal_binds=True,
5151
dialect_opts={"paramstyle": "named"},
52+
output_buffer=config.output_buffer,
5253
)
5354

5455
with context.begin_transaction():

packages/examples/cvat/exchange-oracle/alembic/versions/1707173682_non_unique_escrows_c1e74c227cfe.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
77
"""
88

9+
import os
10+
11+
from sqlalchemy import Column, String, delete, func, select
12+
from sqlalchemy.orm import declarative_base
13+
914
from alembic import op
1015

1116
# revision identifiers, used by Alembic.
@@ -14,6 +19,14 @@
1419
branch_labels = None
1520
depends_on = None
1621

22+
Base = declarative_base()
23+
24+
25+
class Project(Base):
26+
__tablename__ = "projects"
27+
id = Column(String, primary_key=True, index=True)
28+
escrow_address = Column(String(42), unique=False, nullable=False)
29+
1730

1831
def upgrade() -> None:
1932
# ### commands auto generated by Alembic - please adjust! ###
@@ -22,6 +35,23 @@ def upgrade() -> None:
2235

2336

2437
def downgrade() -> None:
38+
offline_mode = op.get_context().environment_context.is_offline_mode()
39+
if not (offline_mode or "TESTING" in os.environ or "test" in op.get_bind().engine.url):
40+
raise RuntimeError(
41+
"This downgrade deletes data and should only run in a test environment."
42+
"If you are sure you want to run it, set the TESTING environment variable."
43+
)
44+
45+
op.execute(
46+
delete(Project).where(
47+
Project.escrow_address.in_(
48+
select(Project.escrow_address)
49+
.group_by(Project.escrow_address)
50+
.having(func.count(Project.escrow_address) > 1)
51+
)
52+
)
53+
)
54+
2555
# ### commands auto generated by Alembic - please adjust! ###
2656
op.create_unique_constraint("projects_escrow_address_key", "projects", ["escrow_address"])
2757
# ### end Alembic commands ###

packages/examples/cvat/exchange-oracle/alembic/versions/1727205503_add_assignment_updated_at_2cbc85686054.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
import enum
1010

1111
import sqlalchemy as sa
12-
from sqlalchemy import Column, DateTime, Enum, String
13-
from sqlalchemy.orm import Session, declarative_base
12+
from sqlalchemy import Column, DateTime, Enum, String, update
13+
from sqlalchemy.orm import declarative_base
1414
from sqlalchemy.sql import func
1515

1616
from alembic import op
@@ -51,25 +51,36 @@ class Assignment(Base):
5151

5252

5353
def define_initial_updated_at():
54-
bind = op.get_bind()
55-
session = Session(bind=bind)
56-
57-
session.query(Assignment).filter(
58-
Assignment.updated_at == None,
59-
Assignment.status.in_(
60-
[AssignmentStatuses.expired, AssignmentStatuses.rejected, AssignmentStatuses.canceled]
61-
),
62-
).update({Assignment.updated_at: Assignment.expires_at})
63-
64-
session.query(Assignment).filter(
65-
Assignment.updated_at == None,
66-
Assignment.status == AssignmentStatuses.completed,
67-
).update({Assignment.updated_at: Assignment.completed_at})
68-
69-
session.query(Assignment).filter(
70-
Assignment.updated_at == None,
71-
# fallback for invalid entries above + handling of status == "created"
72-
).update({Assignment.updated_at: Assignment.created_at})
54+
# First update: expired, rejected, and canceled assignments
55+
# using op.execute instead of session.execute to support offline migrations
56+
op.execute(
57+
update(Assignment)
58+
.where(
59+
Assignment.updated_at == None,
60+
Assignment.status.in_(
61+
[
62+
AssignmentStatuses.expired,
63+
AssignmentStatuses.rejected,
64+
AssignmentStatuses.canceled,
65+
]
66+
),
67+
)
68+
.values(updated_at=Assignment.expires_at)
69+
)
70+
71+
# Second update: completed assignments
72+
op.execute(
73+
update(Assignment)
74+
.where(Assignment.updated_at == None, Assignment.status == AssignmentStatuses.completed)
75+
.values(updated_at=Assignment.completed_at)
76+
)
77+
78+
# Third update: fallback for invalid entries and handling status == "created"
79+
op.execute(
80+
update(Assignment)
81+
.where(Assignment.updated_at == None)
82+
.values(updated_at=Assignment.created_at)
83+
)
7384

7485

7586
def upgrade() -> None:

packages/examples/cvat/exchange-oracle/tests/conftest.py

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,87 @@
11
import os
22
from collections.abc import Generator
3+
from dataclasses import dataclass
4+
from io import StringIO
5+
from pathlib import Path
36

47
os.environ["DEBUG"] = "1"
58

69
import pytest
710
from fastapi.testclient import TestClient
11+
from sqlalchemy import TextClause, text
812
from sqlalchemy.orm import Session
13+
from sqlalchemy_utils import create_database, database_exists, drop_database
914

15+
from alembic import command as alembic_command
16+
from alembic.config import Config
1017
from src import app
11-
from src.db import Base, SessionLocal, engine
18+
from src.db import SessionLocal, engine
19+
20+
alembic_config = Config(Path(__file__).parent.parent / "alembic.ini")
21+
22+
23+
@dataclass
24+
class AlembicSQL:
25+
upgrade: TextClause
26+
downgrade: TextClause
27+
28+
29+
@pytest.fixture(scope="session")
30+
def alembic() -> AlembicSQL:
31+
"""
32+
Captures the SQL generated by Alembic for upgrade/downgrade operations.
33+
Doesn't actually run migrations.
34+
"""
35+
alembic_config.output_buffer = StringIO()
36+
alembic_command.upgrade(alembic_config, "head", sql=True)
37+
upgrade_sql = alembic_config.output_buffer.getvalue()
38+
39+
alembic_config.output_buffer = StringIO()
40+
alembic_command.downgrade(alembic_config, "head:base", sql=True)
41+
downgrade_sql = alembic_config.output_buffer.getvalue()
42+
43+
return AlembicSQL(text(upgrade_sql), text(downgrade_sql))
44+
45+
46+
@pytest.fixture(scope="session", autouse=True)
47+
def setup_db(alembic) -> None:
48+
assert "test" in engine.url.database, "The test database must be used for testing."
49+
if database_exists(engine.url):
50+
drop_database(engine.url)
51+
create_database(engine.url)
52+
yield # Run the test cases
53+
54+
# Upgrade to the latest version after all tests are done,
55+
# this helps with inspection of the latest schema.
56+
with engine.connect() as connection:
57+
connection.execute(alembic.upgrade)
1258

1359

1460
@pytest.fixture(autouse=True)
15-
def db():
16-
Base.metadata.drop_all(bind=engine)
17-
Base.metadata.create_all(bind=engine)
61+
def init_db(alembic) -> None:
62+
"""
63+
Runs the recorded Alembic upgrade and downgrade SQL for each test.
64+
This ensures correctness of alembic migrations.
65+
"""
66+
try:
67+
with engine.connect() as connection:
68+
connection.execute(alembic.upgrade)
69+
except Exception as e:
70+
raise RuntimeError(
71+
"Failed to upgrade migrations, `alembic upgrade head` would fail."
72+
" inspect the cause error and change migrations accordingly."
73+
) from e
74+
75+
yield # Run the test case
76+
77+
try:
78+
with engine.connect() as connection:
79+
connection.execute(alembic.downgrade)
80+
except Exception as e:
81+
raise RuntimeError(
82+
"Failed to downgrade migrations, `alembic downgrade head:base` would fail."
83+
" inspect the cause error and change migrations accordingly."
84+
) from e
1885

1986

2087
@pytest.fixture(scope="module")

0 commit comments

Comments
 (0)