Skip to content

Commit a43c0ae

Browse files
committed
Add Database.close() and allow use in context manager
1 parent fad9519 commit a43c0ae

File tree

9 files changed

+177
-168
lines changed

9 files changed

+177
-168
lines changed

python/lsst/daf/butler_migrate/database.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
import json
2525
import logging
2626
from collections.abc import Iterable, Iterator, Mapping
27-
from contextlib import contextmanager
28-
from typing import cast
27+
from contextlib import AbstractContextManager, contextmanager
28+
from typing import Any, Literal, Self, cast
2929

3030
import sqlalchemy
3131
from alembic.runtime.migration import MigrationContext
@@ -43,7 +43,7 @@ class RevisionConsistencyError(Exception):
4343
"""
4444

4545

46-
class Database:
46+
class Database(AbstractContextManager):
4747
"""Class implementing methods for database access needed for migrations.
4848
4949
Parameters
@@ -69,6 +69,7 @@ class Database:
6969
def __init__(self, db_url: sqlalchemy.engine.url.URL, schema: str | None = None):
7070
self._db_url = db_url
7171
self._schema = schema
72+
self._engine: sqlalchemy.engine.Engine | None = None
7273

7374
@classmethod
7475
def from_repo(cls, repo: str) -> Database:
@@ -100,13 +101,34 @@ def schema(self) -> str | None:
100101
"""Schema (namespace) name (`str`)."""
101102
return self._schema
102103

104+
@property
105+
def engine(self) -> sqlalchemy.engine.Engine:
106+
"""Cached sqlalchemy Engine."""
107+
if self._engine is None:
108+
self._engine = sqlalchemy.engine.create_engine(self._db_url)
109+
return self._engine
110+
103111
@contextmanager
104112
def connect(self) -> Iterator[sqlalchemy.engine.Connection]:
105113
"""Context manager for database connection."""
106-
engine = sqlalchemy.engine.create_engine(self._db_url)
107-
with engine.connect() as connection:
114+
with self.engine.connect() as connection:
108115
yield connection
109116

117+
def __enter__(self) -> Self:
118+
return self
119+
120+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]:
121+
try:
122+
self.close()
123+
except Exception:
124+
_LOG.exception("An exception occurred during Database.close()")
125+
return False
126+
127+
def close(self) -> None:
128+
"""Cleanup connection pool."""
129+
self.engine.dispose()
130+
self._engine = None
131+
110132
def dimensions_namespace(self) -> str | None:
111133
"""Return dimensions namespace from a stored configuration.
112134
@@ -115,8 +137,6 @@ def dimensions_namespace(self) -> str | None:
115137
namespace: `str` or `None`
116138
Dimensions namespace or `None` if not defined.
117139
"""
118-
engine = sqlalchemy.engine.create_engine(self._db_url)
119-
120140
meta = sqlalchemy.schema.MetaData(schema=self._schema)
121141
table = sqlalchemy.schema.Table(
122142
"butler_attributes",
@@ -126,7 +146,7 @@ def dimensions_namespace(self) -> str | None:
126146
)
127147

128148
sql = sqlalchemy.sql.select(table.columns.value).where(table.columns.name == self.dimensions_json_key)
129-
with engine.connect() as connection:
149+
with self.engine.connect() as connection:
130150
result = connection.execute(sql)
131151
row = result.fetchone()
132152
if row is None:
@@ -150,8 +170,6 @@ def manager_versions(self, namespace: str | None = None) -> Mapping[str, tuple[s
150170
tuple consisting of manager class name (including package/module),
151171
version string in X.Y.Z format, and revision ID string/hash.
152172
"""
153-
engine = sqlalchemy.engine.create_engine(self._db_url)
154-
155173
meta = sqlalchemy.schema.MetaData(schema=self._schema)
156174
table = sqlalchemy.schema.Table(
157175
"butler_attributes",
@@ -164,7 +182,7 @@ def manager_versions(self, namespace: str | None = None) -> Mapping[str, tuple[s
164182
managers: dict[str, str] = {}
165183
versions: dict[str, str] = {}
166184
sql = sqlalchemy.sql.select(table.columns.name, table.columns.value)
167-
with engine.connect() as connection:
185+
with self.engine.connect() as connection:
168186
result = connection.execute(sql)
169187
for name, value in result:
170188
if name.startswith("config:registry.managers."):
@@ -211,8 +229,7 @@ def alembic_revisions(self) -> list[str]:
211229
Returned list is empty if alembic version table does not exist or
212230
is empty.
213231
"""
214-
engine = sqlalchemy.engine.create_engine(self._db_url)
215-
with engine.connect() as connection:
232+
with self.engine.connect() as connection:
216233
ctx = MigrationContext.configure(
217234
connection=connection, opts={"version_table_schema": self._schema}
218235
)
@@ -286,8 +303,7 @@ def dump_schema(self, tables: list[str] | None) -> None:
286303
List of the tables, if missing or empty then schema for all tables
287304
is printed.
288305
"""
289-
engine = sqlalchemy.engine.create_engine(self._db_url)
290-
inspector = sqlalchemy.inspect(engine)
306+
inspector = sqlalchemy.inspect(self.engine)
291307
table_names = sorted(inspector.get_table_names(schema=self._schema))
292308
for table in table_names:
293309
if tables and table not in tables:

python/lsst/daf/butler_migrate/script/migrate_current.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -51,32 +51,31 @@ def migrate_current(repo: str, mig_path: str, verbose: bool, butler: bool, names
5151
Dimensions namespace to use when "namespace" key is not present in
5252
``config:dimensions.json``.
5353
"""
54-
db = database.Database.from_repo(repo)
54+
with database.Database.from_repo(repo) as db:
55+
if namespace is None and db.dimensions_namespace() is None:
56+
raise ValueError(
57+
"The `--namespace` option is required when namespace is missing from"
58+
" stored dimensions configuration"
59+
)
5560

56-
if namespace is None and db.dimensions_namespace() is None:
57-
raise ValueError(
58-
"The `--namespace` option is required when namespace is missing from"
59-
" stored dimensions configuration"
60-
)
61-
62-
cfg = config.MigAlembicConfig.from_mig_path(mig_path, repository=repo, db=db)
63-
if butler:
64-
# Print current versions defined in butler.
65-
script_info = scripts.Scripts(cfg)
66-
heads = script_info.head_revisions()
67-
manager_versions = db.manager_versions(namespace)
68-
if manager_versions:
69-
for manager, (klass, version, rev_id) in sorted(manager_versions.items()):
70-
head = " (head)" if rev_id in heads else ""
71-
print(f"{manager}: {klass} {version} -> {rev_id}{head}")
61+
cfg = config.MigAlembicConfig.from_mig_path(mig_path, repository=repo, db=db)
62+
if butler:
63+
# Print current versions defined in butler.
64+
script_info = scripts.Scripts(cfg)
65+
heads = script_info.head_revisions()
66+
manager_versions = db.manager_versions(namespace)
67+
if manager_versions:
68+
for manager, (klass, version, rev_id) in sorted(manager_versions.items()):
69+
head = " (head)" if rev_id in heads else ""
70+
print(f"{manager}: {klass} {version} -> {rev_id}{head}")
71+
else:
72+
print("No manager versions defined in butler_attributes table.")
7273
else:
73-
print("No manager versions defined in butler_attributes table.")
74-
else:
75-
# Revisions from alembic.
76-
command.current(cfg, verbose=verbose)
74+
# Revisions from alembic.
75+
command.current(cfg, verbose=verbose)
7776

78-
# Complain if alembic_version table is there but does not match manager
79-
# versions.
80-
if db.alembic_revisions():
81-
script_info = scripts.Scripts(cfg)
82-
db.validate_revisions(namespace, script_info.base_revisions())
77+
# Complain if alembic_version table is there but does not match manager
78+
# versions.
79+
if db.alembic_revisions():
80+
script_info = scripts.Scripts(cfg)
81+
db.validate_revisions(namespace, script_info.base_revisions())

python/lsst/daf/butler_migrate/script/migrate_downgrade.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -54,28 +54,29 @@ def migrate_downgrade(
5454
Dimensions namespace to use when "namespace" key is not present in
5555
``config:dimensions.json``.
5656
"""
57-
db = database.Database.from_repo(repo)
57+
with database.Database.from_repo(repo) as db:
58+
if namespace is None and db.dimensions_namespace() is None:
59+
raise ValueError(
60+
"The `--namespace` option is required when namespace is missing from"
61+
" stored dimensions configuration"
62+
)
5863

59-
if namespace is None and db.dimensions_namespace() is None:
60-
raise ValueError(
61-
"The `--namespace` option is required when namespace is missing from"
62-
" stored dimensions configuration"
63-
)
64+
# Check that alembic versions exist in database, we do not support
65+
# migrations from empty state.
66+
if not db.alembic_revisions():
67+
raise ValueError(
68+
"Alembic version table does not exist, you may need to run `butler migrate stamp` first."
69+
)
6470

65-
# Check that alembic versions exist in database, we do not support
66-
# migrations from empty state.
67-
if not db.alembic_revisions():
68-
raise ValueError(
69-
"Alembic version table does not exist, you may need to run `butler migrate stamp` first."
71+
one_shot_arg: str | None = None
72+
if one_shot_tree:
73+
one_shot_arg = one_shot_tree
74+
cfg = config.MigAlembicConfig.from_mig_path(
75+
mig_path, repository=repo, db=db, one_shot_tree=one_shot_arg
7076
)
7177

72-
one_shot_arg: str | None = None
73-
if one_shot_tree:
74-
one_shot_arg = one_shot_tree
75-
cfg = config.MigAlembicConfig.from_mig_path(mig_path, repository=repo, db=db, one_shot_tree=one_shot_arg)
76-
77-
# check that alembic versions are consistent with butler
78-
script_info = scripts.Scripts(cfg)
79-
db.validate_revisions(namespace, script_info.base_revisions())
78+
# check that alembic versions are consistent with butler
79+
script_info = scripts.Scripts(cfg)
80+
db.validate_revisions(namespace, script_info.base_revisions())
8081

81-
command.downgrade(cfg, revision, sql=sql)
82+
command.downgrade(cfg, revision, sql=sql)

python/lsst/daf/butler_migrate/script/migrate_dump_schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,5 @@ def migrate_dump_schema(repo: str, table: list[str]) -> None:
3939
table : `list`
4040
List of the tables, if empty then schema for all tables is printed.
4141
"""
42-
db = database.Database.from_repo(repo)
43-
db.dump_schema(table)
42+
with database.Database.from_repo(repo) as db:
43+
db.dump_schema(table)

python/lsst/daf/butler_migrate/script/migrate_set_namespace.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -44,27 +44,27 @@ def migrate_set_namespace(repo: str, namespace: str | None, update: bool) -> Non
4444
update : `bool`
4545
Allows update of the existing namespace.
4646
"""
47-
db = database.Database.from_repo(repo)
48-
db_namespace = db.dimensions_namespace()
47+
with database.Database.from_repo(repo) as db:
48+
db_namespace = db.dimensions_namespace()
4949

50-
if not namespace:
51-
# Print current value
52-
if not db_namespace:
53-
print("No namespace defined in dimensions configuration.")
54-
else:
55-
print("Current dimensions namespace:", db_namespace)
50+
if not namespace:
51+
# Print current value
52+
if not db_namespace:
53+
print("No namespace defined in dimensions configuration.")
54+
else:
55+
print("Current dimensions namespace:", db_namespace)
5656

57-
else:
58-
if db_namespace and not update:
59-
raise ValueError(
60-
f"Namespace is already defined ({db_namespace}), use --update option to replace it."
61-
)
57+
else:
58+
if db_namespace and not update:
59+
raise ValueError(
60+
f"Namespace is already defined ({db_namespace}), use --update option to replace it."
61+
)
6262

63-
def update_namespace(config: dict) -> dict:
64-
"""Update namespace attribute"""
65-
config["namespace"] = namespace
66-
return config
63+
def update_namespace(config: dict) -> dict:
64+
"""Update namespace attribute"""
65+
config["namespace"] = namespace
66+
return config
6767

68-
with db.connect() as connection:
69-
attributes = butler_attributes.ButlerAttributes(connection, db.schema)
70-
attributes.update_dimensions_json(update_namespace)
68+
with db.connect() as connection:
69+
attributes = butler_attributes.ButlerAttributes(connection, db.schema)
70+
attributes.update_dimensions_json(update_namespace)

python/lsst/daf/butler_migrate/script/migrate_stamp.py

Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -54,42 +54,41 @@ def migrate_stamp(
5454
manager : `str`, Optional
5555
Name of the manager to stamp, if `None` then all managers are stamped.
5656
"""
57-
db = database.Database.from_repo(repo)
57+
with database.Database.from_repo(repo) as db:
58+
if namespace is None and db.dimensions_namespace() is None:
59+
raise ValueError(
60+
"The `--namespace` option is required when namespace is missing from"
61+
" stored dimensions configuration"
62+
)
5863

59-
if namespace is None and db.dimensions_namespace() is None:
60-
raise ValueError(
61-
"The `--namespace` option is required when namespace is missing from"
62-
" stored dimensions configuration"
63-
)
64+
manager_versions = db.manager_versions(namespace)
6465

65-
manager_versions = db.manager_versions(namespace)
66+
revisions: dict[str, str] = {}
67+
for mgr_name, (klass, version, rev_id) in manager_versions.items():
68+
_LOG.debug("found revision (%s, %s, %s) -> %s", mgr_name, klass, version, rev_id)
69+
revisions[mgr_name] = rev_id
6670

67-
revisions: dict[str, str] = {}
68-
for mgr_name, (klass, version, rev_id) in manager_versions.items():
69-
_LOG.debug("found revision (%s, %s, %s) -> %s", mgr_name, klass, version, rev_id)
70-
revisions[mgr_name] = rev_id
71+
cfg: config.MigAlembicConfig | None = None
72+
if manager:
73+
if manager in revisions:
74+
revisions = {manager: revisions[manager]}
75+
else:
76+
# If specified manager not in the database, it may mean that an
77+
# initial "tree-root" revision needs to be added to alembic
78+
# table, if that manager is defined in the migration trees.
79+
cfg = config.MigAlembicConfig.from_mig_path(mig_path, repository=repo, db=db)
80+
script_info = scripts.Scripts(cfg)
81+
base_revision = revision.rev_id(manager)
82+
if base_revision not in script_info.base_revisions():
83+
raise ValueError(f"Unknown manager name {manager} (not in the database or migrations)")
84+
revisions = {manager: base_revision}
7185

72-
cfg: config.MigAlembicConfig | None = None
73-
if manager:
74-
if manager in revisions:
75-
revisions = {manager: revisions[manager]}
86+
if dry_run:
87+
print("Will store these revisions in alembic version table:")
88+
for manager, rev_id in revisions.items():
89+
print(f" {manager}: {rev_id}")
7690
else:
77-
# If specified manager not in the database, it may mean that an
78-
# initial "tree-root" revision needs to be added to alembic
79-
# table, if that manager is defined in the migration trees.
80-
cfg = config.MigAlembicConfig.from_mig_path(mig_path, repository=repo, db=db)
81-
script_info = scripts.Scripts(cfg)
82-
base_revision = revision.rev_id(manager)
83-
if base_revision not in script_info.base_revisions():
84-
raise ValueError(f"Unknown manager name {manager} (not in the database or migrations)")
85-
revisions = {manager: base_revision}
86-
87-
if dry_run:
88-
print("Will store these revisions in alembic version table:")
89-
for manager, rev_id in revisions.items():
90-
print(f" {manager}: {rev_id}")
91-
else:
92-
if cfg is None:
93-
cfg = config.MigAlembicConfig.from_mig_path(mig_path, repository=repo, db=db)
94-
for rev in revisions.values():
95-
command.stamp(cfg, rev, purge=purge)
91+
if cfg is None:
92+
cfg = config.MigAlembicConfig.from_mig_path(mig_path, repository=repo, db=db)
93+
for rev in revisions.values():
94+
command.stamp(cfg, rev, purge=purge)

0 commit comments

Comments
 (0)