Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ repos:
# Ruff version.
rev: v0.13.3
hooks:
- id: ruff
- id: ruff-check
args: [--fix]
- id: ruff-format
- repo: https://github.com/numpy/numpydoc
Expand Down
47 changes: 32 additions & 15 deletions python/lsst/daf/butler_migrate/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
import json
import logging
from collections.abc import Iterable, Iterator, Mapping
from contextlib import contextmanager
from typing import cast
from contextlib import AbstractContextManager, contextmanager
from typing import Any, Literal, Self, cast

import sqlalchemy
from alembic.runtime.migration import MigrationContext
Expand All @@ -43,7 +43,7 @@ class RevisionConsistencyError(Exception):
"""


class Database:
class Database(AbstractContextManager):
"""Class implementing methods for database access needed for migrations.

Parameters
Expand All @@ -69,6 +69,7 @@ class Database:
def __init__(self, db_url: sqlalchemy.engine.url.URL, schema: str | None = None):
self._db_url = db_url
self._schema = schema
self._engine: sqlalchemy.engine.Engine | None = None

@classmethod
def from_repo(cls, repo: str) -> Database:
Expand Down Expand Up @@ -100,13 +101,35 @@ def schema(self) -> str | None:
"""Schema (namespace) name (`str`)."""
return self._schema

@property
def engine(self) -> sqlalchemy.engine.Engine:
"""Cached sqlalchemy Engine."""
if self._engine is None:
self._engine = sqlalchemy.engine.create_engine(self._db_url)
return self._engine

@contextmanager
def connect(self) -> Iterator[sqlalchemy.engine.Connection]:
"""Context manager for database connection."""
engine = sqlalchemy.engine.create_engine(self._db_url)
with engine.connect() as connection:
with self.engine.connect() as connection:
yield connection

def __enter__(self) -> Self:
return self

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]:
try:
self.close()
except Exception:
_LOG.exception("An exception occurred during Database.close()")
return False

def close(self) -> None:
"""Cleanup connection pool."""
if self._engine is not None:
self._engine.dispose()
self._engine = None

def dimensions_namespace(self) -> str | None:
"""Return dimensions namespace from a stored configuration.

Expand All @@ -115,8 +138,6 @@ def dimensions_namespace(self) -> str | None:
namespace: `str` or `None`
Dimensions namespace or `None` if not defined.
"""
engine = sqlalchemy.engine.create_engine(self._db_url)

meta = sqlalchemy.schema.MetaData(schema=self._schema)
table = sqlalchemy.schema.Table(
"butler_attributes",
Expand All @@ -126,7 +147,7 @@ def dimensions_namespace(self) -> str | None:
)

sql = sqlalchemy.sql.select(table.columns.value).where(table.columns.name == self.dimensions_json_key)
with engine.connect() as connection:
with self.engine.connect() as connection:
result = connection.execute(sql)
row = result.fetchone()
if row is None:
Expand All @@ -150,8 +171,6 @@ def manager_versions(self, namespace: str | None = None) -> Mapping[str, tuple[s
tuple consisting of manager class name (including package/module),
version string in X.Y.Z format, and revision ID string/hash.
"""
engine = sqlalchemy.engine.create_engine(self._db_url)

meta = sqlalchemy.schema.MetaData(schema=self._schema)
table = sqlalchemy.schema.Table(
"butler_attributes",
Expand All @@ -164,7 +183,7 @@ def manager_versions(self, namespace: str | None = None) -> Mapping[str, tuple[s
managers: dict[str, str] = {}
versions: dict[str, str] = {}
sql = sqlalchemy.sql.select(table.columns.name, table.columns.value)
with engine.connect() as connection:
with self.engine.connect() as connection:
result = connection.execute(sql)
for name, value in result:
if name.startswith("config:registry.managers."):
Expand Down Expand Up @@ -211,8 +230,7 @@ def alembic_revisions(self) -> list[str]:
Returned list is empty if alembic version table does not exist or
is empty.
"""
engine = sqlalchemy.engine.create_engine(self._db_url)
with engine.connect() as connection:
with self.engine.connect() as connection:
ctx = MigrationContext.configure(
connection=connection, opts={"version_table_schema": self._schema}
)
Expand Down Expand Up @@ -286,8 +304,7 @@ def dump_schema(self, tables: list[str] | None) -> None:
List of the tables, if missing or empty then schema for all tables
is printed.
"""
engine = sqlalchemy.engine.create_engine(self._db_url)
inspector = sqlalchemy.inspect(engine)
inspector = sqlalchemy.inspect(self.engine)
table_names = sorted(inspector.get_table_names(schema=self._schema))
for table in table_names:
if tables and table not in tables:
Expand Down
51 changes: 25 additions & 26 deletions python/lsst/daf/butler_migrate/script/migrate_current.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,32 +51,31 @@ def migrate_current(repo: str, mig_path: str, verbose: bool, butler: bool, names
Dimensions namespace to use when "namespace" key is not present in
``config:dimensions.json``.
"""
db = database.Database.from_repo(repo)
with database.Database.from_repo(repo) as db:
if namespace is None and db.dimensions_namespace() is None:
raise ValueError(
"The `--namespace` option is required when namespace is missing from"
" stored dimensions configuration"
)

if namespace is None and db.dimensions_namespace() is None:
raise ValueError(
"The `--namespace` option is required when namespace is missing from"
" stored dimensions configuration"
)

cfg = config.MigAlembicConfig.from_mig_path(mig_path, repository=repo, db=db)
if butler:
# Print current versions defined in butler.
script_info = scripts.Scripts(cfg)
heads = script_info.head_revisions()
manager_versions = db.manager_versions(namespace)
if manager_versions:
for manager, (klass, version, rev_id) in sorted(manager_versions.items()):
head = " (head)" if rev_id in heads else ""
print(f"{manager}: {klass} {version} -> {rev_id}{head}")
cfg = config.MigAlembicConfig.from_mig_path(mig_path, repository=repo, db=db)
if butler:
# Print current versions defined in butler.
script_info = scripts.Scripts(cfg)
heads = script_info.head_revisions()
manager_versions = db.manager_versions(namespace)
if manager_versions:
for manager, (klass, version, rev_id) in sorted(manager_versions.items()):
head = " (head)" if rev_id in heads else ""
print(f"{manager}: {klass} {version} -> {rev_id}{head}")
else:
print("No manager versions defined in butler_attributes table.")
else:
print("No manager versions defined in butler_attributes table.")
else:
# Revisions from alembic.
command.current(cfg, verbose=verbose)
# Revisions from alembic.
command.current(cfg, verbose=verbose)

# Complain if alembic_version table is there but does not match manager
# versions.
if db.alembic_revisions():
script_info = scripts.Scripts(cfg)
db.validate_revisions(namespace, script_info.base_revisions())
# Complain if alembic_version table is there but does not match manager
# versions.
if db.alembic_revisions():
script_info = scripts.Scripts(cfg)
db.validate_revisions(namespace, script_info.base_revisions())
41 changes: 21 additions & 20 deletions python/lsst/daf/butler_migrate/script/migrate_downgrade.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,28 +54,29 @@ def migrate_downgrade(
Dimensions namespace to use when "namespace" key is not present in
``config:dimensions.json``.
"""
db = database.Database.from_repo(repo)
with database.Database.from_repo(repo) as db:
if namespace is None and db.dimensions_namespace() is None:
raise ValueError(
"The `--namespace` option is required when namespace is missing from"
" stored dimensions configuration"
)

if namespace is None and db.dimensions_namespace() is None:
raise ValueError(
"The `--namespace` option is required when namespace is missing from"
" stored dimensions configuration"
)
# Check that alembic versions exist in database, we do not support
# migrations from empty state.
if not db.alembic_revisions():
raise ValueError(
"Alembic version table does not exist, you may need to run `butler migrate stamp` first."
)

# Check that alembic versions exist in database, we do not support
# migrations from empty state.
if not db.alembic_revisions():
raise ValueError(
"Alembic version table does not exist, you may need to run `butler migrate stamp` first."
one_shot_arg: str | None = None
if one_shot_tree:
one_shot_arg = one_shot_tree
cfg = config.MigAlembicConfig.from_mig_path(
mig_path, repository=repo, db=db, one_shot_tree=one_shot_arg
)

one_shot_arg: str | None = None
if one_shot_tree:
one_shot_arg = one_shot_tree
cfg = config.MigAlembicConfig.from_mig_path(mig_path, repository=repo, db=db, one_shot_tree=one_shot_arg)

# check that alembic versions are consistent with butler
script_info = scripts.Scripts(cfg)
db.validate_revisions(namespace, script_info.base_revisions())
# check that alembic versions are consistent with butler
script_info = scripts.Scripts(cfg)
db.validate_revisions(namespace, script_info.base_revisions())

command.downgrade(cfg, revision, sql=sql)
command.downgrade(cfg, revision, sql=sql)
4 changes: 2 additions & 2 deletions python/lsst/daf/butler_migrate/script/migrate_dump_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,5 @@ def migrate_dump_schema(repo: str, table: list[str]) -> None:
table : `list`
List of the tables, if empty then schema for all tables is printed.
"""
db = database.Database.from_repo(repo)
db.dump_schema(table)
with database.Database.from_repo(repo) as db:
db.dump_schema(table)
40 changes: 20 additions & 20 deletions python/lsst/daf/butler_migrate/script/migrate_set_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,27 +44,27 @@ def migrate_set_namespace(repo: str, namespace: str | None, update: bool) -> Non
update : `bool`
Allows update of the existing namespace.
"""
db = database.Database.from_repo(repo)
db_namespace = db.dimensions_namespace()
with database.Database.from_repo(repo) as db:
db_namespace = db.dimensions_namespace()

if not namespace:
# Print current value
if not db_namespace:
print("No namespace defined in dimensions configuration.")
else:
print("Current dimensions namespace:", db_namespace)
if not namespace:
# Print current value
if not db_namespace:
print("No namespace defined in dimensions configuration.")
else:
print("Current dimensions namespace:", db_namespace)

else:
if db_namespace and not update:
raise ValueError(
f"Namespace is already defined ({db_namespace}), use --update option to replace it."
)
else:
if db_namespace and not update:
raise ValueError(
f"Namespace is already defined ({db_namespace}), use --update option to replace it."
)

def update_namespace(config: dict) -> dict:
"""Update namespace attribute"""
config["namespace"] = namespace
return config
def update_namespace(config: dict) -> dict:
"""Update namespace attribute"""
config["namespace"] = namespace
return config

with db.connect() as connection:
attributes = butler_attributes.ButlerAttributes(connection, db.schema)
attributes.update_dimensions_json(update_namespace)
with db.connect() as connection:
attributes = butler_attributes.ButlerAttributes(connection, db.schema)
attributes.update_dimensions_json(update_namespace)
67 changes: 33 additions & 34 deletions python/lsst/daf/butler_migrate/script/migrate_stamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,42 +54,41 @@ def migrate_stamp(
manager : `str`, Optional
Name of the manager to stamp, if `None` then all managers are stamped.
"""
db = database.Database.from_repo(repo)
with database.Database.from_repo(repo) as db:
if namespace is None and db.dimensions_namespace() is None:
raise ValueError(
"The `--namespace` option is required when namespace is missing from"
" stored dimensions configuration"
)

if namespace is None and db.dimensions_namespace() is None:
raise ValueError(
"The `--namespace` option is required when namespace is missing from"
" stored dimensions configuration"
)
manager_versions = db.manager_versions(namespace)

manager_versions = db.manager_versions(namespace)
revisions: dict[str, str] = {}
for mgr_name, (klass, version, rev_id) in manager_versions.items():
_LOG.debug("found revision (%s, %s, %s) -> %s", mgr_name, klass, version, rev_id)
revisions[mgr_name] = rev_id

revisions: dict[str, str] = {}
for mgr_name, (klass, version, rev_id) in manager_versions.items():
_LOG.debug("found revision (%s, %s, %s) -> %s", mgr_name, klass, version, rev_id)
revisions[mgr_name] = rev_id
cfg: config.MigAlembicConfig | None = None
if manager:
if manager in revisions:
revisions = {manager: revisions[manager]}
else:
# If specified manager not in the database, it may mean that an
# initial "tree-root" revision needs to be added to alembic
# table, if that manager is defined in the migration trees.
cfg = config.MigAlembicConfig.from_mig_path(mig_path, repository=repo, db=db)
script_info = scripts.Scripts(cfg)
base_revision = revision.rev_id(manager)
if base_revision not in script_info.base_revisions():
raise ValueError(f"Unknown manager name {manager} (not in the database or migrations)")
revisions = {manager: base_revision}

cfg: config.MigAlembicConfig | None = None
if manager:
if manager in revisions:
revisions = {manager: revisions[manager]}
if dry_run:
print("Will store these revisions in alembic version table:")
for manager, rev_id in revisions.items():
print(f" {manager}: {rev_id}")
else:
# If specified manager not in the database, it may mean that an
# initial "tree-root" revision needs to be added to alembic
# table, if that manager is defined in the migration trees.
cfg = config.MigAlembicConfig.from_mig_path(mig_path, repository=repo, db=db)
script_info = scripts.Scripts(cfg)
base_revision = revision.rev_id(manager)
if base_revision not in script_info.base_revisions():
raise ValueError(f"Unknown manager name {manager} (not in the database or migrations)")
revisions = {manager: base_revision}

if dry_run:
print("Will store these revisions in alembic version table:")
for manager, rev_id in revisions.items():
print(f" {manager}: {rev_id}")
else:
if cfg is None:
cfg = config.MigAlembicConfig.from_mig_path(mig_path, repository=repo, db=db)
for rev in revisions.values():
command.stamp(cfg, rev, purge=purge)
if cfg is None:
cfg = config.MigAlembicConfig.from_mig_path(mig_path, repository=repo, db=db)
for rev in revisions.values():
command.stamp(cfg, rev, purge=purge)
Loading
Loading