diff --git a/.github/workflows/dynamic_annotation_db _ci.yml b/.github/workflows/dynamic_annotation_db _ci.yml index b68def9..7943da8 100644 --- a/.github/workflows/dynamic_annotation_db _ci.yml +++ b/.github/workflows/dynamic_annotation_db _ci.yml @@ -19,7 +19,7 @@ jobs: services: postgres: - image: postgis/postgis:9.6-2.5 + image: postgis/postgis:13-3.3 env: POSTGRES_USER: postgres POSTGRES_PASSWORD: postgres @@ -29,20 +29,20 @@ jobs: options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 steps: - - uses: actions/checkout@v2 - - name: Set up Python 3.7 - uses: actions/setup-python@v2 + - uses: actions/checkout@v4 + - name: Set up Python 3.11 + uses: actions/setup-python@v5 with: - python-version: 3.7 + python-version: 3.11 - - uses: actions/cache@v2 + - uses: actions/cache@v4 with: path: ~/.cache/pip key: ${{ runner.os }}-pip-${{ hashFiles('**/test_requirements.txt') }} restore-keys: | ${{ runner.os }}-pip- - - uses: actions/cache@v2 + - uses: actions/cache@v4 with: path: ~/.cache/pip key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} diff --git a/dynamicannotationdb/annotation.py b/dynamicannotationdb/annotation.py index 8f451e1..35a4d07 100644 --- a/dynamicannotationdb/annotation.py +++ b/dynamicannotationdb/annotation.py @@ -1,43 +1,35 @@ import datetime import logging -from typing import List +from typing import List, Dict, Any from marshmallow import INCLUDE -from sqlalchemy import DDL, event +from sqlalchemy import select, update -from .database import DynamicAnnotationDB -from .errors import ( +from dynamicannotationdb.database import DynamicAnnotationDB +from dynamicannotationdb.errors import ( AnnotationInsertLimitExceeded, NoAnnotationsFoundWithID, UpdateAnnotationError, TableNameNotFound, ) -from .models import AnnoMetadata -from .schema import DynamicSchemaClient +from dynamicannotationdb.models import AnnoMetadata +from dynamicannotationdb.schema import DynamicSchemaClient class DynamicAnnotationClient: def __init__(self, sql_url: str) -> None: self.db = DynamicAnnotationDB(sql_url) self.schema = DynamicSchemaClient() + self._table = None @property def table(self): + if self._table is None: + raise ValueError("No table loaded. Use load_table() first.") return self._table def load_table(self, table_name: str): - """Load a table - - Parameters - ---------- - table_name : str - name of table - - Returns - ------- - DeclarativeMeta - the sqlalchemy table of that name - """ + """Load a table""" self._table = self.db.cached_table(table_name) return self._table @@ -95,18 +87,17 @@ def create_table( """ existing_tables = self.db._check_table_is_unique(table_name) + reference_table = None if table_metadata: reference_table, _ = self.schema._parse_schema_metadata_params( schema_type, table_name, table_metadata, existing_tables ) - else: - reference_table = None AnnotationModel = self.schema.create_annotation_model( table_name, schema_type, table_metadata=table_metadata, - with_crud_columns=with_crud_columns, + with_crud_columns=True, ) self.db.base.metadata.tables[AnnotationModel.__name__].create( @@ -134,23 +125,16 @@ def create_table( logging.info(f"Metadata for table: {table_name} is {metadata_dict}") anno_metadata = AnnoMetadata(**metadata_dict) - self.db.cached_session.add(anno_metadata) - self.db.commit_session() + + with self.db.session_scope() as session: + session.add(anno_metadata) + logging.info( f"Table: {table_name} created using {AnnotationModel} model at {creation_time}" ) return table_name - def update_table_metadata( - self, - table_name: str, - description: str = None, - user_id: str = None, - flat_segmentation_source: str = None, - read_permission: str = None, - write_permission: str = None, - notice_text: str = None, - ): + def update_table_metadata(self, table_name: str, **kwargs): r"""Update metadata for an annotation table. Parameters @@ -184,34 +168,30 @@ def update_table_metadata( TableNameNotFound If no table with 'table_name' found in the metadata table """ - metadata = ( - self.db.cached_session.query(AnnoMetadata) - .filter(AnnoMetadata.table_name == table_name) - .first() - ) - if metadata is None: - raise TableNameNotFound( - f"no table named {table_name} in database {self.sql_url} " - ) + with self.db.session_scope() as session: + metadata = session.execute( + select(AnnoMetadata).where(AnnoMetadata.table_name == table_name) + ).scalar_one_or_none() + + if metadata is None: + raise TableNameNotFound(f"no table named {table_name} in database") + + for key, value in kwargs.items(): + if hasattr(metadata, key): + setattr(metadata, key, value) + + if "notice_text" in kwargs and kwargs["notice_text"] == "": + metadata.notice_text = None + + # Explicitly flush the session to ensure the changes are visible + session.flush() + + # Refresh the metadata object to get the updated values + session.refresh(metadata) + + self.db.get_table_metadata.cache_clear() + logging.info(f"Table: {table_name} metadata updated") - update_dict = { - "description": description, - "user_id": user_id, - "flat_segmentation_source": flat_segmentation_source, - "read_permission": read_permission, - "write_permission": write_permission, - } - update_dict = {k: v for k, v in update_dict.items() if v is not None} - if notice_text is not None: - if len(notice_text) == 0: - update_dict["notice_text"] = None - else: - update_dict["notice_text"] = notice_text - for column, value in update_dict.items(): - if hasattr(metadata, str(column)): - setattr(metadata, column, value) - self.db.commit_session() - logging.info(f"Table: {table_name} metadata updated ") return self.db.get_table_metadata(table_name) def delete_table(self, table_name: str) -> bool: @@ -230,20 +210,21 @@ def delete_table(self, table_name: str) -> bool: bool whether table was successfully deleted """ - metadata = ( - self.db.cached_session.query(AnnoMetadata) - .filter(AnnoMetadata.table_name == table_name) - .first() - ) - if metadata is None: - raise TableNameNotFound( - f"no table named {table_name} in database {self.sql_url} " - ) - metadata.deleted = datetime.datetime.utcnow() - self.db.commit_session() + with self.db.session_scope() as session: + metadata = session.execute( + select(AnnoMetadata).where(AnnoMetadata.table_name == table_name) + ).scalar_one_or_none() + + if metadata is None: + raise TableNameNotFound(f"no table named {table_name} in database") + + metadata.deleted = datetime.datetime.utcnow() + return True - def insert_annotations(self, table_name: str, annotations: List[dict]): + def insert_annotations( + self, table_name: str, annotations: List[Dict[str, Any]] + ) -> List[int]: """Insert some annotations. Parameters @@ -273,36 +254,33 @@ def insert_annotations(self, table_name: str, annotations: List[dict]): formatted_anno_data = [] for annotation in annotations: - - annotation_data, __ = self.schema.split_flattened_schema_data( + annotation_data, _ = self.schema.split_flattened_schema_data( schema_type, annotation ) - if annotation.get("id"): + if "id" in annotation: annotation_data["id"] = annotation["id"] if hasattr(AnnotationModel, "created"): annotation_data["created"] = datetime.datetime.utcnow() annotation_data["valid"] = True formatted_anno_data.append(annotation_data) - annos = [ - AnnotationModel(**annotation_data) - for annotation_data in formatted_anno_data - ] + with self.db.session_scope() as session: + annos = [AnnotationModel(**data) for data in formatted_anno_data] + session.add_all(annos) + session.flush() + anno_ids = [anno.id for anno in annos] - self.db.cached_session.add_all(annos) - self.db.cached_session.flush() - anno_ids = [anno.id for anno in annos] - - ( - self.db.cached_session.query(AnnoMetadata) - .filter(AnnoMetadata.table_name == table_name) - .update({AnnoMetadata.last_modified: datetime.datetime.utcnow()}) - ) + session.execute( + update(AnnoMetadata) + .where(AnnoMetadata.table_name == table_name) + .values(last_modified=datetime.datetime.utcnow()) + ) - self.db.commit_session() return anno_ids - def get_annotations(self, table_name: str, annotation_ids: List[int]) -> List[dict]: + def get_annotations( + self, table_name: str, annotation_ids: List[int] + ) -> List[Dict[str, Any]]: """Get a set of annotations by ID Parameters @@ -319,33 +297,38 @@ def get_annotations(self, table_name: str, annotation_ids: List[int]) -> List[di """ schema_type, AnnotationModel = self._load_model(table_name) - annotations = ( - self.db.cached_session.query(AnnotationModel) - .filter(AnnotationModel.id.in_(list(annotation_ids))) - .all() - ) + with self.db.session_scope() as session: + annotations = ( + session.execute( + select(AnnotationModel).where( + AnnotationModel.id.in_(annotation_ids) + ) + ) + .scalars() + .all() + ) - anno_schema, __ = self.schema.split_flattened_schema(schema_type) + anno_schema, _ = self.schema.split_flattened_schema(schema_type) schema = anno_schema(unknown=INCLUDE) + try: data = [] - for anno in annotations: - anno_data = anno.__dict__ - anno_data["created"] = str(anno_data.get("created")) - anno_data["deleted"] = str(anno_data.get("deleted")) anno_data = { - k: v for (k, v) in anno_data.items() if k != "_sa_instance_state" + k: str(v) if isinstance(v, datetime.datetime) else v + for k, v in anno.__dict__.items() + if not k.startswith("_") } data.append(anno_data) return schema.load(data, many=True) - except Exception as e: logging.exception(e) raise NoAnnotationsFoundWithID(annotation_ids) from e - def update_annotation(self, table_name: str, annotation: dict) -> str: + def update_annotation( + self, table_name: str, annotation: Dict[str, Any] + ) -> Dict[int, int]: """Update an annotation Parameters @@ -353,7 +336,7 @@ def update_annotation(self, table_name: str, annotation: dict) -> str: table_name : str name of targeted table to update annotations annotation : dict - new data for that annotation, allows for partial updates but + new data for that annotation, allows for partial updates but requires an 'id' field to target the row Returns @@ -368,56 +351,52 @@ def update_annotation(self, table_name: str, annotation: dict) -> str: """ anno_id = annotation.get("id") if not anno_id: - return "Annotation requires an 'id' to update targeted row" - schema_type, AnnotationModel = self._load_model(table_name) + raise ValueError("Annotation requires an 'id' to update targeted row") - try: - old_anno = ( - self.db.cached_session.query(AnnotationModel) - .filter(AnnotationModel.id == anno_id) - .one() - ) - except NoAnnotationsFoundWithID as e: - raise f"No result found for {anno_id}. Error: {e}" from e + schema_type, AnnotationModel = self._load_model(table_name) - if old_anno.superceded_id: - raise UpdateAnnotationError(anno_id, old_anno.superceded_id) + with self.db.session_scope() as session: + old_anno = session.execute( + select(AnnotationModel).where(AnnotationModel.id == anno_id) + ).scalar_one_or_none() - # Merge old data with new changes - old_data = { - column.name: getattr(old_anno, column.name) - for column in old_anno.__table__.columns - } - updated_data = {**old_data, **annotation} + if old_anno is None: + raise NoAnnotationsFoundWithID(f"No result found for {anno_id}") - new_annotation, __ = self.schema.split_flattened_schema_data( - schema_type, updated_data - ) + if old_anno.superceded_id: + raise UpdateAnnotationError(anno_id, old_anno.superceded_id) - if hasattr(AnnotationModel, "created"): - new_annotation["created"] = datetime.datetime.utcnow() - if hasattr(AnnotationModel, "valid"): - new_annotation["valid"] = True + old_data = { + column.name: getattr(old_anno, column.name) + for column in old_anno.__table__.columns + } + updated_data = {**old_data, **annotation} - new_data = AnnotationModel(**new_annotation) + new_annotation, _ = self.schema.split_flattened_schema_data( + schema_type, updated_data + ) - self.db.cached_session.add(new_data) - self.db.cached_session.flush() + if hasattr(AnnotationModel, "created"): + new_annotation["created"] = datetime.datetime.utcnow() + if hasattr(AnnotationModel, "valid"): + new_annotation["valid"] = True - deleted_time = datetime.datetime.utcnow() - old_anno.deleted = deleted_time - old_anno.superceded_id = new_data.id - old_anno.valid = False - update_map = {anno_id: new_data.id} + new_data = AnnotationModel(**new_annotation) + session.add(new_data) + session.flush() - ( - self.db.cached_session.query(AnnoMetadata) - .filter(AnnoMetadata.table_name == table_name) - .update({AnnoMetadata.last_modified: datetime.datetime.utcnow()}) - ) - self.db.commit_session() + deleted_time = datetime.datetime.utcnow() + old_anno.deleted = deleted_time + old_anno.superceded_id = new_data.id + old_anno.valid = False + + session.execute( + update(AnnoMetadata) + .where(AnnoMetadata.table_name == table_name) + .values(last_modified=datetime.datetime.utcnow()) + ) - return update_map + return {anno_id: new_data.id} def delete_annotation( self, table_name: str, annotation_ids: List[int] @@ -438,45 +417,45 @@ def delete_annotation( """ schema_type, AnnotationModel = self._load_model(table_name) - annotations = ( - self.db.cached_session.query(AnnotationModel) - .filter(AnnotationModel.id.in_(annotation_ids)) - .all() - ) - deleted_ids = [] - if annotations: + with self.db.session_scope() as session: + annotations = ( + session.execute( + select(AnnotationModel).where( + AnnotationModel.id.in_(annotation_ids) + ) + ) + .scalars() + .all() + ) + + if not annotations: + return [] + deleted_time = datetime.datetime.utcnow() + deleted_ids = [] for annotation in annotations: - # TODO: This should be deprecated, as all tables should have - # CRUD columns now, but leaving this for backward safety. if not hasattr(AnnotationModel, "deleted"): - self.db.cached_session.delete(annotation) + session.delete(annotation) else: annotation.deleted = deleted_time annotation.valid = False deleted_ids.append(annotation.id) - ( - self.db.cached_session.query(AnnoMetadata) - .filter(AnnoMetadata.table_name == table_name) - .update({AnnoMetadata.last_modified: datetime.datetime.utcnow()}) + session.execute( + update(AnnoMetadata) + .where(AnnoMetadata.table_name == table_name) + .values(last_modified=datetime.datetime.utcnow()) ) - self.db.commit_session() - - else: - return None return deleted_ids def _load_model(self, table_name): metadata = self.db.get_table_metadata(table_name) - schema_type = metadata["schema_type"] + schema_type = metadata.anno_metadata.schema_type - # load reference table into metadata if not already present - ref_table = metadata.get("reference_table") - if ref_table: - reference_table_name = self.db.cached_table(ref_table) + if metadata.anno_metadata.reference_table: + self.db.cached_table(metadata.anno_metadata.reference_table) AnnotationModel = self.db.cached_table(table_name) return schema_type, AnnotationModel diff --git a/dynamicannotationdb/database.py b/dynamicannotationdb/database.py index 5202765..39d3ff2 100644 --- a/dynamicannotationdb/database.py +++ b/dynamicannotationdb/database.py @@ -1,44 +1,53 @@ import logging from contextlib import contextmanager -from typing import List - -from sqlalchemy import create_engine, func, inspect, or_ +from typing import Any, Dict, List, Tuple +from functools import lru_cache + + +from sqlalchemy import ( + engine, + MetaData, + Table, + create_engine, + func, + inspect, + or_, + select, +) from sqlalchemy.ext.automap import automap_base -from sqlalchemy.ext.declarative.api import DeclarativeMeta -from sqlalchemy.orm import Session, scoped_session, sessionmaker -from sqlalchemy.orm.exc import NoResultFound +from sqlalchemy.orm import sessionmaker, scoped_session from sqlalchemy.schema import MetaData from sqlalchemy.sql.schema import Table -from .errors import TableAlreadyExists, TableNameNotFound, TableNotInMetadata -from .models import AnnoMetadata, Base, SegmentationMetadata, AnalysisView -from .schema import DynamicSchemaClient +from dynamicannotationdb.errors import ( + TableAlreadyExists, + TableNameNotFound, + TableNotInMetadata, +) +from dynamicannotationdb.models import ( + AnalysisView, + AnnoMetadata, + Base, + SegmentationMetadata, + CombinedMetadata, +) +from dynamicannotationdb.schema import DynamicSchemaClient -class DynamicAnnotationDB: - def __init__(self, sql_url: str, pool_size=5, max_overflow=5) -> None: - self._cached_session = None - self._cached_tables = {} - self._engine = create_engine( +class DynamicAnnotationDB: + def __init__(self, sql_url: str, pool_size: int = 5, max_overflow: int = 5) -> None: + self._engine: engine = create_engine( sql_url, pool_recycle=3600, pool_size=pool_size, max_overflow=max_overflow ) self.base = Base - self.base.metadata.bind = self._engine - self.base.metadata.create_all( - tables=[AnnoMetadata.__table__, SegmentationMetadata.__table__], - checkfirst=True, - ) - self.session = scoped_session( - sessionmaker(bind=self.engine, autocommit=False, autoflush=False) - ) - self.schema_client = DynamicSchemaClient() + session_factory = sessionmaker(bind=self._engine, expire_on_commit=False) + self.Session = scoped_session(session_factory) - self._inspector = inspect(self.engine) - - self._cached_session = None - self._cached_tables = {} + self.schema_client = DynamicSchemaClient() + self._inspector = inspect(self._engine) + self._cached_tables: Dict[str, Any] = {} @property def inspector(self): @@ -48,37 +57,21 @@ def inspector(self): def engine(self): return self._engine - @property - def cached_session(self) -> Session: - if self._cached_session is None: - self._cached_session = self.session() - return self._cached_session - @contextmanager def session_scope(self): + session = self.Session() try: - yield self.cached_session - except Exception as e: - self.cached_session.rollback() - logging.exception(f"SQL Error: {e}") - raise e - finally: - self.cached_session.close() - self._cached_session = None - - def commit_session(self): - try: - self.cached_session.commit() + yield session + session.commit() except Exception as e: - self.cached_session.rollback() + session.rollback() logging.exception(f"SQL Error: {e}") - raise e + raise finally: - self.cached_session.close() - self._cached_session = None + self.Session.remove() - def get_table_sql_metadata(self, table_name: str): - self.base.metadata.reflect(bind=self.engine) + def get_table_sql_metadata(self, table_name: str) -> Table: + self.base.metadata.reflect(bind=self._engine) return self.base.metadata.tables[table_name] def get_unique_string_values(self, table_name: str): @@ -95,113 +88,94 @@ def get_unique_string_values(self, table_name: str): dictionary of column names and unique values """ model = self.cached_table(table_name) - unique_values = {} + with self.session_scope() as session: - for column_name in model.__table__.columns.keys(): - - # if the column is a string - try: - python_type = model.__table__.columns[column_name].type.python_type - except NotImplementedError: - python_type = None - if python_type == str: - query = session.query(getattr(model, column_name)).distinct() - unique_values[column_name] = [row[0] for row in query.all()] + for column_name, column in model.__table__.columns.items(): + if isinstance(column.type.python_type, str): + stmt = select(getattr(model, column_name)).distinct() + result = session.execute(stmt) + unique_values[column_name] = [row[0] for row in result.fetchall()] + return unique_values - def get_views(self, datastack_name: str): + def get_views(self, datastack_name: str) -> List[AnalysisView]: with self.session_scope() as session: - query = session.query(AnalysisView).filter( + stmt = select(AnalysisView).where( AnalysisView.datastack_name == datastack_name ) - return query.all() + result = session.execute(stmt) + return result.scalars().all() - def get_view_metadata(self, datastack_name: str, view_name: str): + def get_view_metadata(self, datastack_name: str, view_name: str) -> Dict[str, Any]: with self.session_scope() as session: - query = ( - session.query(AnalysisView) - .filter(AnalysisView.table_name == view_name) - .filter(AnalysisView.datastack_name == datastack_name) + stmt = select(AnalysisView).where( + AnalysisView.table_name == view_name, + AnalysisView.datastack_name == datastack_name, ) - result = query.one() - if hasattr(result, "__dict__"): - return self.get_automap_items(result) - else: - return result[0] + result = session.execute(stmt) + view = result.scalar_one() + return self.get_automap_items(view) - def get_table_metadata(self, table_name: str, filter_col: str = None): - data = getattr(AnnoMetadata, filter_col) if filter_col else AnnoMetadata + @lru_cache(maxsize=128) + def get_table_metadata(self, table_name: str) -> CombinedMetadata: with self.session_scope() as session: - if filter_col and data: - - query = session.query(data).filter( - AnnoMetadata.table_name == table_name + stmt = ( + select(AnnoMetadata, SegmentationMetadata) + .outerjoin( + SegmentationMetadata, + AnnoMetadata.table_name == SegmentationMetadata.annotation_table, ) - result = query.one() - - if hasattr(result, "__dict__"): - return self.get_automap_items(result) - else: - return result[0] - else: - metadata = ( - session.query(data, SegmentationMetadata) - .outerjoin( - SegmentationMetadata, - AnnoMetadata.table_name - == SegmentationMetadata.annotation_table, + .where( + or_( + AnnoMetadata.table_name == table_name, + SegmentationMetadata.table_name == table_name, ) - .filter( - or_( - AnnoMetadata.table_name == table_name, - SegmentationMetadata.table_name == table_name, - ) - ) - .all() ) - try: - if metadata: - flatted_metadata = self.flatten_join(metadata) - return flatted_metadata[0] - except NoResultFound: - return None + ) + result = session.execute(stmt).first() + + if result is None: + raise ValueError(f"No metadata found for table '{table_name}'") + + anno_metadata, seg_metadata = result + + return CombinedMetadata( + table_name=table_name, + anno_metadata=anno_metadata, + seg_metadata=seg_metadata, + ) def get_table_schema(self, table_name: str) -> str: table_metadata = self.get_table_metadata(table_name) - return table_metadata.get("schema_type") + return table_metadata.anno_metadata.schema_type def get_valid_table_names(self) -> List[str]: with self.session_scope() as session: - metadata = session.query(AnnoMetadata).all() - return [m.table_name for m in metadata if m.valid == True] + stmt = select(AnnoMetadata.table_name).where(AnnoMetadata.valid == True) + result = session.execute(stmt) + return result.scalars().all() def get_annotation_table_size(self, table_name: str) -> int: - """Get the number of annotations in a table - - Parameters - ---------- - table_name : str - name of table contained within the aligned_volume database - - Returns - ------- - int - number of annotations - """ Model = self.cached_table(table_name) with self.session_scope() as session: - return session.query(Model).count() + stmt = select(func.count()).select_from(Model) + result = session.execute(stmt) + return result.scalar_one() def get_max_id_value(self, table_name: str) -> int: model = self.cached_table(table_name) with self.session_scope() as session: - return session.query(func.max(model.id)).scalar() + stmt = select(func.max(model.id)) + result = session.execute(stmt) + return result.scalar_one() def get_min_id_value(self, table_name: str) -> int: model = self.cached_table(table_name) with self.session_scope() as session: - return session.query(func.min(model.id)).scalar() + stmt = select(func.min(model.id)) + result = session.execute(stmt) + return result.scalar_one() def get_table_row_count( self, table_name: str, filter_valid: bool = False, filter_timestamp: str = None @@ -219,27 +193,29 @@ def get_table_row_count( """ model = self.cached_table(table_name) with self.session_scope() as session: - sql_query = session.query(func.count(model.id)) + stmt = select(func.count(model.id)) if filter_valid: - sql_query = sql_query.filter(model.valid == True) + stmt = stmt.where(model.valid == True) if filter_timestamp and hasattr(model, "created"): - sql_query = sql_query.filter(model.created <= filter_timestamp) - return sql_query.scalar() + stmt = stmt.where(model.created <= filter_timestamp) + result = session.execute(stmt) + return result.scalar_one() @staticmethod - def get_automap_items(result): + def get_automap_items(result: Any) -> Dict[str, Any]: return {k: v for (k, v) in result.__dict__.items() if k != "_sa_instance_state"} - def obj_to_dict(self, obj): - if obj: - return { + def obj_to_dict(self, obj: Any) -> Dict[str, Any]: + return ( + { column.key: getattr(obj, column.key) for column in inspect(obj).mapper.column_attrs } - else: - return {} + if obj + else {} + ) - def flatten_join(self, _list: List): + def flatten_join(self, _list: List[Tuple[Any, Any]]) -> List[Dict[str, Any]]: return [{**self.obj_to_dict(a), **self.obj_to_dict(b)} for a, b in _list] def drop_table(self, table_name: str) -> bool: @@ -256,16 +232,15 @@ def drop_table(self, table_name: str) -> bool: bool whether drop was successful """ - table = self.base.metadata.tables.get(table_name) - if table: + if table := self.base.metadata.tables.get(table_name): logging.info(f"Deleting {table_name} table") self.base.metadata.drop_all(self._engine, [table], checkfirst=True) - if self._is_cached(table): - del self._cached_tables[table] + if table_name in self._cached_tables: + del self._cached_tables[table_name] return True return False - def _check_table_is_unique(self, table_name): + def _check_table_is_unique(self, table_name: str) -> List[str]: existing_tables = self._get_existing_table_names() if table_name in existing_tables: raise TableAlreadyExists( @@ -274,46 +249,39 @@ def _check_table_is_unique(self, table_name): return existing_tables def _get_existing_table_names(self, filter_valid: bool = False) -> List[str]: - """Collects table_names keys of existing tables - - Returns - ------- - list - List of table_names - """ with self.session_scope() as session: - stmt = session.query(AnnoMetadata) + stmt = select(AnnoMetadata.table_name) if filter_valid: - stmt = stmt.filter(AnnoMetadata.valid == True) - metadata = stmt.all() - return [m.table_name for m in metadata] + stmt = stmt.where(AnnoMetadata.valid == True) + result = session.execute(stmt) + return result.scalars().all() - def _get_model_from_table_name(self, table_name: str) -> DeclarativeMeta: - metadata = self.get_table_metadata(table_name) + def _get_model_from_table_name(self, table_name: str) -> Any: + combined_metadata = self.get_table_metadata(table_name) - if metadata: - if metadata["reference_table"]: + if combined_metadata is None: + raise TableNotInMetadata(f"No metadata found for table '{table_name}'") + + anno_metadata = combined_metadata.anno_metadata + seg_metadata = combined_metadata.seg_metadata + + if anno_metadata: + if anno_metadata.reference_table: return self.schema_client.create_reference_annotation_model( - table_name, - metadata["schema_type"], - metadata["reference_table"], + table_name, anno_metadata.schema_type, anno_metadata.reference_table ) - elif metadata.get("annotation_table") and table_name != metadata.get( - "annotation_table" - ): + elif seg_metadata and table_name == seg_metadata.table_name: return self.schema_client.create_segmentation_model( - metadata["annotation_table"], - metadata["schema_type"], - metadata["pcg_table_name"], + anno_metadata.table_name, + anno_metadata.schema_type, + seg_metadata.pcg_table_name, ) - else: return self.schema_client.create_annotation_model( - table_name, metadata["schema_type"] + table_name, anno_metadata.schema_type ) - else: - raise TableNotInMetadata + raise ValueError(f"Invalid metadata structure for table '{table_name}'") def _get_model_columns(self, table_name: str) -> List[tuple]: """Return list of column names and types of a given table @@ -328,23 +296,22 @@ def _get_model_columns(self, table_name: str) -> List[tuple]: list column names and types """ - db_columns = self.inspector.get_columns(table_name) + db_columns = self._inspector.get_columns(table_name) if not db_columns: raise TableNameNotFound(table_name) return [(column["name"], column["type"]) for column in db_columns] def get_view_table(self, view_name: str) -> Table: """Return the sqlalchemy table object for a view""" - if self._is_cached(view_name): + if view_name in self._cached_tables: return self._cached_tables[view_name] - else: - meta = MetaData(self._engine) - meta.reflect(views=True, only=[view_name]) - table = meta.tables[view_name] - self._cached_tables[view_name] = table - return table + meta = MetaData() + meta.reflect(bind=self._engine, views=True, only=[view_name]) + table = meta.tables[view_name] + self._cached_tables[view_name] = table + return table - def cached_table(self, table_name: str) -> DeclarativeMeta: + def cached_table(self, table_name: str): """Returns cached table 'DeclarativeMeta' callable for querying. Parameters @@ -375,41 +342,21 @@ def _load_table(self, table_name: str): bool Returns True if table exists and is loaded into cached table dict. """ - if self._is_cached(table_name): - return True - try: self._cached_tables[table_name] = self._get_model_from_table_name( table_name ) return True except TableNotInMetadata: - # cant find the table so lets try the slow reflection before giving up - self.mapped_base = automap_base() - self.mapped_base.prepare(self._engine, reflect=True) try: - model = self.mapped_base.classes[table_name] + base = automap_base() + base.prepare(self._engine, reflect=True) + model = base.classes[table_name] self._cached_tables[table_name] = model + return True except KeyError as table_error: logging.error(f"Could not load table: {table_error}") return False - except Exception as table_error: logging.error(f"Could not load table: {table_error}") return False - - def _is_cached(self, table_name: str) -> bool: - """Check if table is loaded into cached instance dict of tables - - Parameters - ---------- - table_name : str - Name of table to check if loaded - - Returns - ------- - bool - True if table is loaded else False. - """ - - return table_name in self._cached_tables diff --git a/dynamicannotationdb/errors.py b/dynamicannotationdb/errors.py index e9e4135..a415da1 100644 --- a/dynamicannotationdb/errors.py +++ b/dynamicannotationdb/errors.py @@ -13,9 +13,11 @@ def __str__(self): class TableAlreadyExists(KeyError): """Table name already exists in the Metadata table""" + class TableNotInMetadata(KeyError): """Table does not exist in the Metadata table""" + class IdsAlreadyExists(KeyError): """Annotation IDs already exists in the segmentation table""" @@ -30,9 +32,9 @@ class BadRequest(Exception): class UpdateAnnotationError(ValueError): def __init__( - self, - target_id: int, - superseded_id: int, + self, + target_id: int, + superseded_id: int, ): self.target_id = target_id self.message = f"Annotation with ID {target_id} has already been superseded by annotation ID {superseded_id}, update annotation ID {superseded_id} instead" @@ -46,7 +48,7 @@ class AnnotationInsertLimitExceeded(ValueError): """Exception raised when amount of annotations exceeds defined limit.""" def __init__( - self, limit: int, length: int, message: str = "Annotation limit exceeded" + self, limit: int, length: int, message: str = "Annotation limit exceeded" ): self.limit = limit self.message = ( diff --git a/dynamicannotationdb/interface.py b/dynamicannotationdb/interface.py index e28c6bf..9ebef57 100644 --- a/dynamicannotationdb/interface.py +++ b/dynamicannotationdb/interface.py @@ -1,7 +1,6 @@ import logging - -from sqlalchemy import create_engine -from sqlalchemy.engine.url import make_url +from sqlalchemy import create_engine, text +from sqlalchemy.engine import URL, make_url from sqlalchemy.pool import NullPool from .annotation import DynamicAnnotationClient @@ -32,7 +31,6 @@ class DynamicAnnotationInterface: linked to annotation tables. schema : Wrapper for EMAnnotationSchemas to generate dynamic sqlalchemy models. - """ def __init__( @@ -57,15 +55,22 @@ def create_or_select_database(self, url: str, aligned_volume: str): url : str base path to the sql server aligned_volume : str - name of aligned volume which the database name will inherent + name of aligned volume which the database name will inherit Returns ------- sql_url instance """ sql_base_uri = url.rpartition("/")[0] - - sql_uri = make_url(f"{sql_base_uri}/{aligned_volume}") + parsed_url = make_url(url) + sql_uri = URL.create( + drivername=parsed_url.drivername, + username=parsed_url.username, + password=parsed_url.password, + host=parsed_url.host, + port=parsed_url.port, + database=aligned_volume, + ) temp_engine = create_engine( sql_base_uri, @@ -75,11 +80,12 @@ def create_or_select_database(self, url: str, aligned_volume: str): ) with temp_engine.connect() as connection: - connection.execute("commit") database_exists = connection.execute( - f"SELECT 1 FROM pg_catalog.pg_database WHERE datname = '{sql_uri.database}'" - ) - if not database_exists.fetchone(): + text(f"SELECT 1 FROM pg_catalog.pg_database WHERE datname = :dbname"), + {"dbname": sql_uri.database}, + ).scalar() + + if not database_exists: logging.info(f"Database {aligned_volume} does not exist.") self._create_aligned_volume_database(sql_uri, connection) @@ -95,21 +101,25 @@ def _create_aligned_volume_database(self, sql_uri, connection): logging.info(f"Creating new database: {sql_uri.database}") connection.execute( - f"SELECT pg_terminate_backend(pid) FROM pg_stat_activity \ - WHERE pid <> pg_backend_pid() AND datname = '{sql_uri.database}';" + text( + f"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid() AND datname = :dbname" + ), + {"dbname": sql_uri.database}, ) # check if template exists, create if missing template_exist = connection.execute( - "SELECT 1 FROM pg_catalog.pg_database WHERE datname = 'template_postgis'" - ) + text( + "SELECT 1 FROM pg_catalog.pg_database WHERE datname = 'template_postgis'" + ) + ).scalar() - if not template_exist.fetchone(): + if not template_exist: # create postgis template db - connection.execute("CREATE DATABASE template_postgis") + connection.execute(text("CREATE DATABASE template_postgis")) # create postgis extension - template_uri = make_url( + template_uri = URL.create( f"{str(sql_uri).rpartition('/')[0]}/template_postgis" ) template_engine = create_engine( @@ -119,12 +129,14 @@ def _create_aligned_volume_database(self, sql_uri, connection): pool_pre_ping=True, ) with template_engine.connect() as template_connection: - template_connection.execute("CREATE EXTENSION IF NOT EXISTS postgis") + template_connection.execute( + text("CREATE EXTENSION IF NOT EXISTS postgis") + ) template_engine.dispose() # finally create new annotation database connection.execute( - f"CREATE DATABASE {sql_uri.database} TEMPLATE template_postgis" + text(f"CREATE DATABASE {sql_uri.database} TEMPLATE template_postgis") ) aligned_volume_engine = create_engine( sql_uri, diff --git a/dynamicannotationdb/models.py b/dynamicannotationdb/models.py index 23b3a4a..7070109 100644 --- a/dynamicannotationdb/models.py +++ b/dynamicannotationdb/models.py @@ -1,5 +1,5 @@ import enum - +from typing import Optional from emannotationschemas.models import Base from sqlalchemy import ( Boolean, @@ -11,13 +11,15 @@ Integer, String, Text, - Enum, JSON, ) from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import relationship from sqlalchemy.dialects import postgresql +from marshmallow import Schema, fields, post_load +from marshmallow_sqlalchemy import SQLAlchemyAutoSchema + # Models that will be created in the 'materialized' database. MatBase = declarative_base() # Models that will be created in the 'annotation' database. @@ -125,6 +127,13 @@ class AnnoMetadata(Base): last_modified = Column(DateTime, nullable=False) +class AnnotationMetadataSchema(SQLAlchemyAutoSchema): + class Meta: + model = AnnoMetadata + include_relationships = True + load_instance = True + + class SegmentationMetadata(Base): __tablename__ = "segmentation_table_metadata" id = Column(Integer, primary_key=True) @@ -141,6 +150,35 @@ class SegmentationMetadata(Base): ) +class SegmentationMetadataSchema(SQLAlchemyAutoSchema): + class Meta: + model = SegmentationMetadata + include_relationships = True + load_instance = True + + +class CombinedMetadata: + def __init__( + self, + table_name: str, + anno_metadata: Optional[AnnoMetadata] = None, + seg_metadata: Optional[SegmentationMetadata] = None, + ): + self.table_name = table_name + self.anno_metadata = anno_metadata + self.seg_metadata = seg_metadata + + +class CombinedMetadataSchema(Schema): + table_name = fields.Str(required=True) + anno_metadata = fields.Nested(AnnotationMetadataSchema, allow_none=True) + seg_metadata = fields.Nested(SegmentationMetadataSchema, allow_none=True) + + @post_load + def make_combined_metadata(self, data, **kwargs): + return CombinedMetadata(**data) + + class CombinedTableMetadata(Base): __tablename__ = "combined_table_metadata" __table_args__ = ( diff --git a/dynamicannotationdb/schema.py b/dynamicannotationdb/schema.py index 6eb3c52..2cb1e75 100644 --- a/dynamicannotationdb/schema.py +++ b/dynamicannotationdb/schema.py @@ -6,7 +6,7 @@ from emannotationschemas.schemas.base import ReferenceAnnotation, SegmentationField from marshmallow import EXCLUDE, Schema -from .errors import SelfReferenceTableError, TableNameNotFound +from dynamicannotationdb.errors import SelfReferenceTableError, TableNameNotFound class DynamicSchemaClient: diff --git a/dynamicannotationdb/segmentation.py b/dynamicannotationdb/segmentation.py index 3dca556..ce0d884 100644 --- a/dynamicannotationdb/segmentation.py +++ b/dynamicannotationdb/segmentation.py @@ -1,19 +1,22 @@ import datetime import logging -from typing import List +from typing import Any, Dict, List, Optional from marshmallow import INCLUDE +from sqlalchemy import and_, select +from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.exc import SQLAlchemyError -from .database import DynamicAnnotationDB -from .errors import ( +from dynamicannotationdb.database import DynamicAnnotationDB +from dynamicannotationdb.errors import ( AnnotationInsertLimitExceeded, IdsAlreadyExists, UpdateAnnotationError, ) -from .key_utils import build_segmentation_table_name -from .models import SegmentationMetadata -from .schema import DynamicSchemaClient -from .errors import TableNameNotFound +from dynamicannotationdb.key_utils import build_segmentation_table_name +from dynamicannotationdb.models import SegmentationMetadata +from dynamicannotationdb.schema import DynamicSchemaClient + class DynamicSegmentationClient: def __init__(self, sql_url: str) -> None: @@ -21,38 +24,16 @@ def __init__(self, sql_url: str) -> None: self.schema = DynamicSchemaClient() def create_segmentation_table( - self, - table_name: str, - schema_type: str, - segmentation_source: str, - table_metadata: dict = None, - with_crud_columns: bool = False, - ): - """Create a segmentation table with the primary key as foreign key - to the annotation table. - - Parameters - ---------- - table_name : str - Name of annotation table to link to. - schema_type : str - schema type - segmentation_source : str - name of segmentation data source, used to create table name. - table_metadata : dict, optional - metadata to extend table behavior, by default None - with_crud_columns : bool, optional - add additional columns to track CRUD operations on rows, by default False - - Returns - ------- - str - name of segmentation table. - """ + self, + table_name: str, + schema_type: str, + segmentation_source: str, + table_metadata: Optional[Dict[str, Any]] = None, + with_crud_columns: bool = False, + ) -> str: segmentation_table_name = build_segmentation_table_name( table_name, segmentation_source ) - self.db._check_table_is_unique(segmentation_table_name) SegmentationModel = self.schema.create_segmentation_model( @@ -63,330 +44,285 @@ def create_segmentation_table( with_crud_columns, ) - if ( - not self.db.cached_session.query(SegmentationMetadata) - .filter(SegmentationMetadata.table_name == segmentation_table_name) - .scalar() - ): - SegmentationModel.__table__.create(bind=self.db._engine, checkfirst=True) - creation_time = datetime.datetime.utcnow() - metadata_dict = { - "annotation_table": table_name, - "schema_type": schema_type, - "table_name": segmentation_table_name, - "valid": True, - "created": creation_time, - "pcg_table_name": segmentation_source, - } + with self.db.session_scope() as session: + if not session.execute( + select(SegmentationMetadata).filter_by( + table_name=segmentation_table_name + ) + ).scalar(): + SegmentationModel.__table__.create( + bind=self.db._engine, checkfirst=True + ) + creation_time = datetime.datetime.utcnow() + metadata_dict = { + "annotation_table": table_name, + "schema_type": schema_type, + "table_name": segmentation_table_name, + "valid": True, + "created": creation_time, + "pcg_table_name": segmentation_source, + } + + seg_metadata = SegmentationMetadata(**metadata_dict) + session.add(seg_metadata) + + return segmentation_table_name - seg_metadata = SegmentationMetadata(**metadata_dict) + def get_linked_tables( + self, table_name: str, pcg_table_name: str + ) -> List[SegmentationMetadata]: + with self.db.session_scope() as session: try: - self.db.cached_session.add(seg_metadata) - self.db.commit_session() + stmt = select(SegmentationMetadata).filter_by( + annotation_table=table_name, pcg_table_name=pcg_table_name + ) + result = session.execute(stmt) + return result.scalars().all() except Exception as e: - logging.error(f"SQL ERROR: {e}") + raise AttributeError( + f"No table found with name '{table_name}'. Error: {e}" + ) from e - return segmentation_table_name + def get_segmentation_table_metadata( + self, table_name: str, pcg_table_name: str + ) -> Optional[Dict[str, Any]]: + seg_table_name = build_segmentation_table_name(table_name, pcg_table_name) + with self.db.session_scope() as session: + try: + stmt = select(SegmentationMetadata).filter_by(table_name=seg_table_name) + result = session.execute(stmt).scalar_one_or_none() + return self.db.get_automap_items(result) if result else None + except Exception as e: + logging.error(f"Error fetching segmentation table metadata: {e}") + return None - def get_linked_tables(self, table_name: str, pcg_table_name: str) -> List: + def get_linked_annotations( + self, table_name: str, pcg_table_name: str, annotation_ids: List[int] + ) -> List[Dict[str, Any]]: try: - return ( - self.db.cached_session.query(SegmentationMetadata) - .filter(SegmentationMetadata.annotation_table == table_name) - .filter(SegmentationMetadata.pcg_table_name == pcg_table_name) - .all() + metadata = self.db.get_table_metadata(table_name) + schema_type = metadata.anno_metadata.schema_type + seg_table_name = build_segmentation_table_name(table_name, pcg_table_name) + AnnotationModel, SegmentationModel = self._get_models( + table_name, seg_table_name ) - except Exception as e: - raise AttributeError( - f"No table found with name '{table_name}'. Error: {e}" - ) from e + with self.db.session_scope() as session: + # Perform a join query + stmt = ( + select(AnnotationModel, SegmentationModel) + .join(SegmentationModel, AnnotationModel.id == SegmentationModel.id) + .filter(AnnotationModel.id.in_(annotation_ids)) + ) + result = session.execute(stmt).all() - def get_segmentation_table_metadata(self, table_name: str, pcg_table_name: str): - seg_table_name = build_segmentation_table_name(table_name, pcg_table_name) - try: - result = ( - self.db.cached_session.query(SegmentationMetadata) - .filter(SegmentationMetadata.table_name == seg_table_name) - .one() - ) - return self.db.get_automap_items(result) - except Exception as e: - self.db.cached_session.rollback() - return None - + data = [] + for anno, seg in result: + merged_data = self._format_model_to_dict(anno) + merged_data.update(self._format_model_to_dict(seg)) + data.append(merged_data) - def get_linked_annotations( - self, table_name: str, pcg_table_name: str, annotation_ids: List[int] - ) -> dict: - """Get list of annotations from database by id. - - Parameters - ---------- - table_name : str - name of annotation table - pcg_table_name: str - name of chunked graph reference table - annotation_ids : int - annotation id - - Returns - ------- - list - list of annotation data dicts - """ + FlatSchema = self.schema.get_flattened_schema(schema_type) + schema = FlatSchema(unknown=INCLUDE) - metadata = self.db.get_table_metadata(table_name) - schema_type = metadata["schema_type"] - seg_table_name = build_segmentation_table_name(table_name, pcg_table_name) - AnnotationModel = self.db.cached_table(table_name) - SegmentationModel = self.db.cached_table(seg_table_name) + return schema.load(data, many=True) - annotations = ( - self.db.cached_session.query(AnnotationModel, SegmentationModel) - .join(SegmentationModel, SegmentationModel.id == AnnotationModel.id) - .filter(AnnotationModel.id.in_(list(annotation_ids))) - .all() - ) - - FlatSchema = self.schema.get_flattened_schema(schema_type) - schema = FlatSchema(unknown=INCLUDE) + except Exception as e: + logging.error(f"Error retrieving linked annotations: {str(e)}") + raise - data = [] - for anno, seg in annotations: - anno_data = anno.__dict__ - seg_data = seg.__dict__ - anno_data = { - k: v for (k, v) in anno_data.items() if k != "_sa_instance_state" - } - seg_data = { - k: v for (k, v) in seg_data.items() if k != "_sa_instance_state" - } - anno_data["created"] = str(anno_data.get("created")) - anno_data["deleted"] = str(anno_data.get("deleted")) + def _get_models(self, table_name: str, seg_table_name: str): + AnnotationModel = self.db.cached_table(table_name) + SegmentationModel = self.db.cached_table(seg_table_name) + return AnnotationModel, SegmentationModel - merged_data = {**anno_data, **seg_data} - data.append(merged_data) + def _format_model_to_dict(self, model): + return { + k: self._format_value(v) + for k, v in model.__dict__.items() + if not k.startswith("_") + } - return schema.load(data, many=True) + def _format_value(self, value): + return str(value) if isinstance(value, datetime.datetime) else value def insert_linked_segmentation( - self, table_name: str, pcg_table_name: str, segmentation_data: List[dict] - ): - """Insert segmentation data by linking to annotation ids. - Limited to 10,000 inserts. If more consider using a bulk insert script. - - Parameters - ---------- - table_name : str - name of annotation table - pcg_table_name: str - name of chunked graph reference table - segmentation_data : List[dict] - List of dictionaries of single segmentation data. - """ + self, + table_name: str, + pcg_table_name: str, + segmentation_data: List[Dict[str, Any]], + ) -> List[int]: insertion_limit = 10_000 - if len(segmentation_data) > insertion_limit: raise AnnotationInsertLimitExceeded(len(segmentation_data), insertion_limit) metadata = self.db.get_table_metadata(table_name) - schema_type = metadata["schema_type"] - + schema_type = metadata.anno_metadata.schema_type seg_table_name = build_segmentation_table_name(table_name, pcg_table_name) - SegmentationModel = self.db.cached_table(seg_table_name) - formatted_seg_data = [] _, segmentation_schema = self.schema.split_flattened_schema(schema_type) - for segmentation in segmentation_data: - segmentation_data = self.schema.flattened_schema_data(segmentation) - flat_data = self.schema._map_values_to_schema( - segmentation_data, segmentation_schema - ) - flat_data["id"] = segmentation["id"] + formatted_seg_data = [ + { + **self.schema._map_values_to_schema( + self.schema.flattened_schema_data(segmentation), segmentation_schema + ), + "id": segmentation["id"], + } + for segmentation in segmentation_data + ] - formatted_seg_data.append(flat_data) + with self.db.session_scope() as session: + try: + ids = [data["id"] for data in formatted_seg_data] + + # Check for existing IDs efficiently + existing_ids = set( + session.execute( + select(SegmentationModel.id).filter( + SegmentationModel.id.in_(ids) + ) + ) + .scalars() + .all() + ) - segs = [ - SegmentationModel(**segmentation_data) - for segmentation_data in formatted_seg_data - ] + if existing_ids: + raise IdsAlreadyExists( + f"Annotation IDs {existing_ids} already linked in database" + ) - ids = [data["id"] for data in formatted_seg_data] - q = self.db.cached_session.query(SegmentationModel).filter( - SegmentationModel.id.in_(list(ids)) - ) + # Bulk insert using PostgreSQL's INSERT ... ON CONFLICT + insert_stmt = insert(SegmentationModel).values(formatted_seg_data) + insert_stmt = insert_stmt.on_conflict_do_nothing(index_elements=["id"]) + result = session.execute(insert_stmt.returning(SegmentationModel.id)) - ids_exist = self.db.cached_session.query(q.exists()).scalar() + seg_ids = result.scalars().all() - if ids_exist: - raise IdsAlreadyExists(f"Annotation IDs {ids} already linked in database ") - self.db.cached_session.add_all(segs) - seg_ids = [seg.id for seg in segs] - self.db.commit_session() - return seg_ids + return seg_ids + + except SQLAlchemyError as e: + logging.error(f"Error inserting linked segmentations: {str(e)}") + raise def insert_linked_annotations( - self, table_name: str, pcg_table_name: str, annotations: List[dict] - ): - """Insert annotations by type and schema. Limited to 10,000 - annotations. If more consider using a bulk insert script. - - Parameters - ---------- - table_name : str - name of annotation table - pcg_table_name: str - name of chunked graph reference table - annotations : dict - Dictionary of single annotation data. - """ + self, table_name: str, pcg_table_name: str, annotations: List[Dict[str, Any]] + ) -> List[int]: insertion_limit = 10_000 - if len(annotations) > insertion_limit: raise AnnotationInsertLimitExceeded(len(annotations), insertion_limit) metadata = self.db.get_table_metadata(table_name) - schema_type = metadata["schema_type"] + schema_type = metadata.anno_metadata.schema_type seg_table_name = build_segmentation_table_name(table_name, pcg_table_name) + AnnotationModel = self.db.cached_table(table_name) + SegmentationModel = self.db.cached_table(seg_table_name) formatted_anno_data = [] formatted_seg_data = [] - AnnotationModel = self.db.cached_table(table_name) - SegmentationModel = self.db.cached_table(seg_table_name) - logging.info(f"{AnnotationModel.__table__.columns}") - logging.info(f"{SegmentationModel.__table__.columns}") - for annotation in annotations: - anno_data, seg_data = self.schema.split_flattened_schema_data( schema_type, annotation ) - if annotation.get("id"): + if "id" in annotation: anno_data["id"] = annotation["id"] if hasattr(AnnotationModel, "created"): anno_data["created"] = datetime.datetime.utcnow() anno_data["valid"] = True formatted_anno_data.append(anno_data) formatted_seg_data.append(seg_data) + logging.info(f"DATA TO BE INSERTED: {formatted_anno_data} {formatted_seg_data}") - try: - annos = [ - AnnotationModel(**annotation_data) - for annotation_data in formatted_anno_data - ] - except Exception as e: - raise e - self.db.cached_session.add_all(annos) - self.db.cached_session.flush() - segs = [ - SegmentationModel(**segmentation_data, id=anno.id) - for segmentation_data, anno in zip(formatted_seg_data, annos) - ] - ids = [anno.id for anno in annos] - self.db.cached_session.add_all(segs) - self.db.commit_session() + + with self.db.session_scope() as session: + try: + annos = [AnnotationModel(**data) for data in formatted_anno_data] + session.add_all(annos) + session.flush() + segs = [ + SegmentationModel(**seg_data, id=anno.id) + for seg_data, anno in zip(formatted_seg_data, annos) + ] + session.add_all(segs) + session.flush() + ids = [anno.id for anno in annos] + except Exception as e: + logging.error(f"Error inserting linked annotations: {e}") + raise + return ids def update_linked_annotations( - self, table_name: str, pcg_table_name: str, annotation: dict - ): - """Updates an annotation by inserting a new row. The original annotation - will refer to the new row with a superseded_id. Does not update inplace. - - Parameters - ---------- - table_name : str - name of annotation table - pcg_table_name: str - name of chunked graph reference table - annotation : dict, annotation to update by ID - """ + self, table_name: str, pcg_table_name: str, annotation: Dict[str, Any] + ) -> Dict[int, int]: anno_id = annotation.get("id") if not anno_id: - return "Annotation requires an 'id' to update targeted row" + raise ValueError("Annotation requires an 'id' to update targeted row") metadata = self.db.get_table_metadata(table_name) - schema_type = metadata["schema_type"] - + schema_type = metadata.anno_metadata.schema_type seg_table_name = build_segmentation_table_name(table_name, pcg_table_name) - AnnotationModel = self.db.cached_table(table_name) SegmentationModel = self.db.cached_table(seg_table_name) - new_annotation, __ = self.schema.split_flattened_schema_data( + new_annotation, _ = self.schema.split_flattened_schema_data( schema_type, annotation ) - new_annotation["created"] = datetime.datetime.utcnow() new_annotation["valid"] = True - new_data = AnnotationModel(**new_annotation) + with self.db.session_scope() as session: + stmt = select(AnnotationModel, SegmentationModel).filter( + and_(AnnotationModel.id == anno_id, SegmentationModel.id == anno_id) + ) + result = session.execute(stmt).first() + + if not result: + raise ValueError(f"No annotation found with id {anno_id}") + + old_anno, old_seg = result - data = ( - self.db.cached_session.query(AnnotationModel, SegmentationModel) - .filter(AnnotationModel.id == anno_id) - .filter(SegmentationModel.id == anno_id) - .all() - ) - update_map = {} - for old_anno, old_seg in data: if old_anno.superceded_id: raise UpdateAnnotationError(anno_id, old_anno.superceded_id) - self.db.cached_session.add(new_data) - self.db.cached_session.flush() + new_data = AnnotationModel(**new_annotation) + session.add(new_data) + session.flush() deleted_time = datetime.datetime.utcnow() old_anno.deleted = deleted_time old_anno.superceded_id = new_data.id old_anno.valid = False - update_map[anno_id] = new_data.id - self.db.commit_session() - return update_map + return {anno_id: new_data.id} def delete_linked_annotation( - self, table_name: str, pcg_table_name: str, annotation_ids: List[int] - ): - """Mark annotations by for deletion by list of ids. - - Parameters - ---------- - table_name : str - name of annotation table - pcg_table_name: str - name of chunked graph reference table - annotation_ids : List[int] - list of ids to delete - - Returns - ------- - - Raises - ------ - """ + self, table_name: str, pcg_table_name: str, annotation_ids: List[int] + ) -> Optional[List[int]]: seg_table_name = build_segmentation_table_name(table_name, pcg_table_name) AnnotationModel = self.db.cached_table(table_name) SegmentationModel = self.db.cached_table(seg_table_name) - annotations = ( - self.db.cached_session.query(AnnotationModel) + with self.db.session_scope() as session: + stmt = ( + select(AnnotationModel) .join(SegmentationModel, SegmentationModel.id == AnnotationModel.id) - .filter(AnnotationModel.id.in_(list(annotation_ids))) - .all() - ) + .filter(AnnotationModel.id.in_(annotation_ids)) + ) + result = session.execute(stmt) + annotations = result.scalars().all() + + if not annotations: + return None + + deleted_time = datetime.datetime.utcnow() + for annotation in annotations: + annotation.deleted = deleted_time + annotation.valid = False + + deleted_ids = [annotation.id for annotation in annotations] - if not annotations: - return None - deleted_ids = [annotation.id for annotation in annotations] - deleted_time = datetime.datetime.utcnow() - for annotation in annotations: - annotation.deleted = deleted_time - annotation.valid = False - self.db.commit_session() return deleted_ids diff --git a/requirements.txt b/requirements.txt index 7fb916a..a5153b6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,11 @@ -marshmallow==3.5.1 +marshmallow emannotationschemas>=5.4.0 -sqlalchemy<1.4 +sqlalchemy +marshmallow-sqlalchemy +marshmallow-enum psycopg2-binary geoalchemy2 pytz alembic shapely -jsonschema<4.0 \ No newline at end of file +jsonschema \ No newline at end of file diff --git a/tests/test_annotation.py b/tests/test_annotation.py index 348aada..f0cb73a 100644 --- a/tests/test_annotation.py +++ b/tests/test_annotation.py @@ -79,8 +79,9 @@ def test_create_reference_table(dadb_interface, annotation_metadata): assert table_name == table table_info = dadb_interface.database.get_table_metadata(table) - assert table_info["reference_table"] == "anno_test" - + + assert table_info.anno_metadata.reference_table == "anno_test" + assert table_info.table_name == table_name def test_create_nested_reference_table(dadb_interface, annotation_metadata): table_name = "reference_tag" @@ -108,7 +109,8 @@ def test_create_nested_reference_table(dadb_interface, annotation_metadata): assert table_name == table table_info = dadb_interface.database.get_table_metadata(table) - assert table_info["reference_table"] == "presynaptic_bouton_types" + + assert table_info.anno_metadata.reference_table == "presynaptic_bouton_types" def test_bad_schema_reference_table(dadb_interface, annotation_metadata): @@ -306,4 +308,4 @@ def test_update_table_metadata(dadb_interface, annotation_metadata): table_name, description="New description" ) - assert updated_metadata["description"] == "New description" + assert updated_metadata.anno_metadata.description == "New description" diff --git a/tests/test_database.py b/tests/test_database.py index f405a51..9278204 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -3,8 +3,7 @@ import pytest from sqlalchemy import Table -from sqlalchemy.ext.declarative.api import DeclarativeMeta - +from sqlalchemy.orm import DeclarativeMeta from emannotationschemas import type_mapping @@ -13,28 +12,11 @@ def test_get_table_metadata(dadb_interface, annotation_metadata): schema_type = annotation_metadata["schema_type"] metadata = dadb_interface.database.get_table_metadata(table_name) logging.info(metadata) - assert metadata["schema_type"] == schema_type - assert metadata["table_name"] == "anno_test" - assert metadata["user_id"] == "foo@bar.com" - assert metadata["description"] == "New description" - assert metadata["voxel_resolution_x"] == 4.0 - - # test with filter to get a col value - metadata_value = dadb_interface.database.get_table_metadata( - table_name, filter_col="valid" - ) - logging.info(metadata) - assert metadata_value == True - - # test for missing column - with pytest.raises(AttributeError) as e: - bad_return = dadb_interface.database.get_table_metadata( - table_name, "missing_column" - ) - assert ( - str(e.value) == "type object 'AnnoMetadata' has no attribute 'missing_column'" - ) - + assert metadata.anno_metadata.schema_type == schema_type + assert metadata.table_name == "anno_test" + assert metadata.anno_metadata.user_id == "foo@bar.com" + assert metadata.anno_metadata.description == "New description" + assert metadata.anno_metadata.voxel_resolution_x == 4.0 def test_get_table_sql_metadata(dadb_interface, annotation_metadata): table_name = annotation_metadata["table_name"] diff --git a/tests/test_schema.py b/tests/test_schema.py index 719ee86..a209656 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1,8 +1,7 @@ import marshmallow from emannotationschemas.errors import UnknownAnnotationTypeException import pytest -from sqlalchemy.ext.declarative.api import DeclarativeMeta - +from sqlalchemy.orm import DeclarativeMeta def test_get_schema(dadb_interface): valid_schema = dadb_interface.schema.get_schema("synapse")