diff --git a/.github/workflows/dynamic_annotation_db _ci.yml b/.github/workflows/dynamic_annotation_db _ci.yml
index b68def9..7943da8 100644
--- a/.github/workflows/dynamic_annotation_db _ci.yml	
+++ b/.github/workflows/dynamic_annotation_db _ci.yml	
@@ -19,7 +19,7 @@ jobs:
     
     services:
       postgres:
-        image: postgis/postgis:9.6-2.5
+        image: postgis/postgis:13-3.3
         env:
           POSTGRES_USER: postgres
           POSTGRES_PASSWORD: postgres
@@ -29,20 +29,20 @@ jobs:
         options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5    
         
     steps:
-    - uses: actions/checkout@v2
-    - name: Set up Python 3.7
-      uses: actions/setup-python@v2
+    - uses: actions/checkout@v4
+    - name: Set up Python 3.11
+      uses: actions/setup-python@v5
       with:
-        python-version: 3.7
+        python-version: 3.11
 
-    - uses: actions/cache@v2
+    - uses: actions/cache@v4
       with:
         path: ~/.cache/pip
         key: ${{ runner.os }}-pip-${{ hashFiles('**/test_requirements.txt') }}
         restore-keys: |
           ${{ runner.os }}-pip-
           
-    - uses: actions/cache@v2
+    - uses: actions/cache@v4
       with:
         path: ~/.cache/pip
         key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
diff --git a/dynamicannotationdb/annotation.py b/dynamicannotationdb/annotation.py
index 8f451e1..35a4d07 100644
--- a/dynamicannotationdb/annotation.py
+++ b/dynamicannotationdb/annotation.py
@@ -1,43 +1,35 @@
 import datetime
 import logging
-from typing import List
+from typing import List, Dict, Any
 
 from marshmallow import INCLUDE
-from sqlalchemy import DDL, event
+from sqlalchemy import select, update
 
-from .database import DynamicAnnotationDB
-from .errors import (
+from dynamicannotationdb.database import DynamicAnnotationDB
+from dynamicannotationdb.errors import (
     AnnotationInsertLimitExceeded,
     NoAnnotationsFoundWithID,
     UpdateAnnotationError,
     TableNameNotFound,
 )
-from .models import AnnoMetadata
-from .schema import DynamicSchemaClient
+from dynamicannotationdb.models import AnnoMetadata
+from dynamicannotationdb.schema import DynamicSchemaClient
 
 
 class DynamicAnnotationClient:
     def __init__(self, sql_url: str) -> None:
         self.db = DynamicAnnotationDB(sql_url)
         self.schema = DynamicSchemaClient()
+        self._table = None
 
     @property
     def table(self):
+        if self._table is None:
+            raise ValueError("No table loaded. Use load_table() first.")
         return self._table
 
     def load_table(self, table_name: str):
-        """Load a table
-
-        Parameters
-        ----------
-        table_name : str
-            name of table
-
-        Returns
-        -------
-        DeclarativeMeta
-            the sqlalchemy table of that name
-        """
+        """Load a table"""
         self._table = self.db.cached_table(table_name)
         return self._table
 
@@ -95,18 +87,17 @@ def create_table(
         """
         existing_tables = self.db._check_table_is_unique(table_name)
 
+        reference_table = None
         if table_metadata:
             reference_table, _ = self.schema._parse_schema_metadata_params(
                 schema_type, table_name, table_metadata, existing_tables
             )
-        else:
-            reference_table = None
 
         AnnotationModel = self.schema.create_annotation_model(
             table_name,
             schema_type,
             table_metadata=table_metadata,
-            with_crud_columns=with_crud_columns,
+            with_crud_columns=True,
         )
 
         self.db.base.metadata.tables[AnnotationModel.__name__].create(
@@ -134,23 +125,16 @@ def create_table(
 
         logging.info(f"Metadata for table: {table_name} is {metadata_dict}")
         anno_metadata = AnnoMetadata(**metadata_dict)
-        self.db.cached_session.add(anno_metadata)
-        self.db.commit_session()
+
+        with self.db.session_scope() as session:
+            session.add(anno_metadata)
+
         logging.info(
             f"Table: {table_name} created using {AnnotationModel} model at {creation_time}"
         )
         return table_name
 
-    def update_table_metadata(
-        self,
-        table_name: str,
-        description: str = None,
-        user_id: str = None,
-        flat_segmentation_source: str = None,
-        read_permission: str = None,
-        write_permission: str = None,
-        notice_text: str = None,
-    ):
+    def update_table_metadata(self, table_name: str, **kwargs):
         r"""Update metadata for an annotation table.
 
         Parameters
@@ -184,34 +168,30 @@ def update_table_metadata(
         TableNameNotFound
             If no table with 'table_name' found in the metadata table
         """
-        metadata = (
-            self.db.cached_session.query(AnnoMetadata)
-            .filter(AnnoMetadata.table_name == table_name)
-            .first()
-        )
-        if metadata is None:
-            raise TableNameNotFound(
-                f"no table named {table_name} in database {self.sql_url} "
-            )
+        with self.db.session_scope() as session:
+            metadata = session.execute(
+                select(AnnoMetadata).where(AnnoMetadata.table_name == table_name)
+            ).scalar_one_or_none()
+
+            if metadata is None:
+                raise TableNameNotFound(f"no table named {table_name} in database")
+
+            for key, value in kwargs.items():
+                if hasattr(metadata, key):
+                    setattr(metadata, key, value)
+
+            if "notice_text" in kwargs and kwargs["notice_text"] == "":
+                metadata.notice_text = None
+
+            # Explicitly flush the session to ensure the changes are visible
+            session.flush()
+
+            # Refresh the metadata object to get the updated values
+            session.refresh(metadata)
+
+        self.db.get_table_metadata.cache_clear()
+        logging.info(f"Table: {table_name} metadata updated")
 
-        update_dict = {
-            "description": description,
-            "user_id": user_id,
-            "flat_segmentation_source": flat_segmentation_source,
-            "read_permission": read_permission,
-            "write_permission": write_permission,
-        }
-        update_dict = {k: v for k, v in update_dict.items() if v is not None}
-        if notice_text is not None:
-            if len(notice_text) == 0:
-                update_dict["notice_text"] = None
-            else:
-                update_dict["notice_text"] = notice_text
-        for column, value in update_dict.items():
-            if hasattr(metadata, str(column)):
-                setattr(metadata, column, value)
-        self.db.commit_session()
-        logging.info(f"Table: {table_name} metadata updated ")
         return self.db.get_table_metadata(table_name)
 
     def delete_table(self, table_name: str) -> bool:
@@ -230,20 +210,21 @@ def delete_table(self, table_name: str) -> bool:
         bool
             whether table was successfully deleted
         """
-        metadata = (
-            self.db.cached_session.query(AnnoMetadata)
-            .filter(AnnoMetadata.table_name == table_name)
-            .first()
-        )
-        if metadata is None:
-            raise TableNameNotFound(
-                f"no table named {table_name} in database {self.sql_url} "
-            )
-        metadata.deleted = datetime.datetime.utcnow()
-        self.db.commit_session()
+        with self.db.session_scope() as session:
+            metadata = session.execute(
+                select(AnnoMetadata).where(AnnoMetadata.table_name == table_name)
+            ).scalar_one_or_none()
+
+            if metadata is None:
+                raise TableNameNotFound(f"no table named {table_name} in database")
+
+            metadata.deleted = datetime.datetime.utcnow()
+
         return True
 
-    def insert_annotations(self, table_name: str, annotations: List[dict]):
+    def insert_annotations(
+        self, table_name: str, annotations: List[Dict[str, Any]]
+    ) -> List[int]:
         """Insert some annotations.
 
         Parameters
@@ -273,36 +254,33 @@ def insert_annotations(self, table_name: str, annotations: List[dict]):
 
         formatted_anno_data = []
         for annotation in annotations:
-
-            annotation_data, __ = self.schema.split_flattened_schema_data(
+            annotation_data, _ = self.schema.split_flattened_schema_data(
                 schema_type, annotation
             )
-            if annotation.get("id"):
+            if "id" in annotation:
                 annotation_data["id"] = annotation["id"]
             if hasattr(AnnotationModel, "created"):
                 annotation_data["created"] = datetime.datetime.utcnow()
             annotation_data["valid"] = True
             formatted_anno_data.append(annotation_data)
 
-        annos = [
-            AnnotationModel(**annotation_data)
-            for annotation_data in formatted_anno_data
-        ]
+        with self.db.session_scope() as session:
+            annos = [AnnotationModel(**data) for data in formatted_anno_data]
+            session.add_all(annos)
+            session.flush()
+            anno_ids = [anno.id for anno in annos]
 
-        self.db.cached_session.add_all(annos)
-        self.db.cached_session.flush()
-        anno_ids = [anno.id for anno in annos]
-
-        (
-            self.db.cached_session.query(AnnoMetadata)
-            .filter(AnnoMetadata.table_name == table_name)
-            .update({AnnoMetadata.last_modified: datetime.datetime.utcnow()})
-        )
+            session.execute(
+                update(AnnoMetadata)
+                .where(AnnoMetadata.table_name == table_name)
+                .values(last_modified=datetime.datetime.utcnow())
+            )
 
-        self.db.commit_session()
         return anno_ids
 
-    def get_annotations(self, table_name: str, annotation_ids: List[int]) -> List[dict]:
+    def get_annotations(
+        self, table_name: str, annotation_ids: List[int]
+    ) -> List[Dict[str, Any]]:
         """Get a set of annotations by ID
 
         Parameters
@@ -319,33 +297,38 @@ def get_annotations(self, table_name: str, annotation_ids: List[int]) -> List[di
         """
         schema_type, AnnotationModel = self._load_model(table_name)
 
-        annotations = (
-            self.db.cached_session.query(AnnotationModel)
-            .filter(AnnotationModel.id.in_(list(annotation_ids)))
-            .all()
-        )
+        with self.db.session_scope() as session:
+            annotations = (
+                session.execute(
+                    select(AnnotationModel).where(
+                        AnnotationModel.id.in_(annotation_ids)
+                    )
+                )
+                .scalars()
+                .all()
+            )
 
-        anno_schema, __ = self.schema.split_flattened_schema(schema_type)
+        anno_schema, _ = self.schema.split_flattened_schema(schema_type)
         schema = anno_schema(unknown=INCLUDE)
+
         try:
             data = []
-
             for anno in annotations:
-                anno_data = anno.__dict__
-                anno_data["created"] = str(anno_data.get("created"))
-                anno_data["deleted"] = str(anno_data.get("deleted"))
                 anno_data = {
-                    k: v for (k, v) in anno_data.items() if k != "_sa_instance_state"
+                    k: str(v) if isinstance(v, datetime.datetime) else v
+                    for k, v in anno.__dict__.items()
+                    if not k.startswith("_")
                 }
                 data.append(anno_data)
 
             return schema.load(data, many=True)
-
         except Exception as e:
             logging.exception(e)
             raise NoAnnotationsFoundWithID(annotation_ids) from e
 
-    def update_annotation(self, table_name: str, annotation: dict) -> str:
+    def update_annotation(
+        self, table_name: str, annotation: Dict[str, Any]
+    ) -> Dict[int, int]:
         """Update an annotation
 
         Parameters
@@ -353,7 +336,7 @@ def update_annotation(self, table_name: str, annotation: dict) -> str:
         table_name : str
             name of targeted table to update annotations
         annotation : dict
-            new data for that annotation, allows for partial updates but 
+            new data for that annotation, allows for partial updates but
             requires an 'id' field to target the row
 
         Returns
@@ -368,56 +351,52 @@ def update_annotation(self, table_name: str, annotation: dict) -> str:
         """
         anno_id = annotation.get("id")
         if not anno_id:
-            return "Annotation requires an 'id' to update targeted row"
-        schema_type, AnnotationModel = self._load_model(table_name)
+            raise ValueError("Annotation requires an 'id' to update targeted row")
 
-        try:
-            old_anno = (
-                self.db.cached_session.query(AnnotationModel)
-                .filter(AnnotationModel.id == anno_id)
-                .one()
-            )
-        except NoAnnotationsFoundWithID as e:
-            raise f"No result found for {anno_id}. Error: {e}" from e
+        schema_type, AnnotationModel = self._load_model(table_name)
 
-        if old_anno.superceded_id:
-            raise UpdateAnnotationError(anno_id, old_anno.superceded_id)
+        with self.db.session_scope() as session:
+            old_anno = session.execute(
+                select(AnnotationModel).where(AnnotationModel.id == anno_id)
+            ).scalar_one_or_none()
 
-        # Merge old data with new changes
-        old_data = {
-            column.name: getattr(old_anno, column.name)
-            for column in old_anno.__table__.columns
-        }
-        updated_data = {**old_data, **annotation}
+            if old_anno is None:
+                raise NoAnnotationsFoundWithID(f"No result found for {anno_id}")
 
-        new_annotation, __ = self.schema.split_flattened_schema_data(
-            schema_type, updated_data
-        )
+            if old_anno.superceded_id:
+                raise UpdateAnnotationError(anno_id, old_anno.superceded_id)
 
-        if hasattr(AnnotationModel, "created"):
-            new_annotation["created"] = datetime.datetime.utcnow()
-        if hasattr(AnnotationModel, "valid"):
-            new_annotation["valid"] = True
+            old_data = {
+                column.name: getattr(old_anno, column.name)
+                for column in old_anno.__table__.columns
+            }
+            updated_data = {**old_data, **annotation}
 
-        new_data = AnnotationModel(**new_annotation)
+            new_annotation, _ = self.schema.split_flattened_schema_data(
+                schema_type, updated_data
+            )
 
-        self.db.cached_session.add(new_data)
-        self.db.cached_session.flush()
+            if hasattr(AnnotationModel, "created"):
+                new_annotation["created"] = datetime.datetime.utcnow()
+            if hasattr(AnnotationModel, "valid"):
+                new_annotation["valid"] = True
 
-        deleted_time = datetime.datetime.utcnow()
-        old_anno.deleted = deleted_time
-        old_anno.superceded_id = new_data.id
-        old_anno.valid = False
-        update_map = {anno_id: new_data.id}
+            new_data = AnnotationModel(**new_annotation)
+            session.add(new_data)
+            session.flush()
 
-        (
-            self.db.cached_session.query(AnnoMetadata)
-            .filter(AnnoMetadata.table_name == table_name)
-            .update({AnnoMetadata.last_modified: datetime.datetime.utcnow()})
-        )
-        self.db.commit_session()
+            deleted_time = datetime.datetime.utcnow()
+            old_anno.deleted = deleted_time
+            old_anno.superceded_id = new_data.id
+            old_anno.valid = False
+
+            session.execute(
+                update(AnnoMetadata)
+                .where(AnnoMetadata.table_name == table_name)
+                .values(last_modified=datetime.datetime.utcnow())
+            )
 
-        return update_map
+        return {anno_id: new_data.id}
 
     def delete_annotation(
         self, table_name: str, annotation_ids: List[int]
@@ -438,45 +417,45 @@ def delete_annotation(
         """
         schema_type, AnnotationModel = self._load_model(table_name)
 
-        annotations = (
-            self.db.cached_session.query(AnnotationModel)
-            .filter(AnnotationModel.id.in_(annotation_ids))
-            .all()
-        )
-        deleted_ids = []
-        if annotations:
+        with self.db.session_scope() as session:
+            annotations = (
+                session.execute(
+                    select(AnnotationModel).where(
+                        AnnotationModel.id.in_(annotation_ids)
+                    )
+                )
+                .scalars()
+                .all()
+            )
+
+            if not annotations:
+                return []
+
             deleted_time = datetime.datetime.utcnow()
+            deleted_ids = []
 
             for annotation in annotations:
-                # TODO: This should be deprecated, as all tables should have
-                # CRUD columns now, but leaving this for backward safety.
                 if not hasattr(AnnotationModel, "deleted"):
-                    self.db.cached_session.delete(annotation)
+                    session.delete(annotation)
                 else:
                     annotation.deleted = deleted_time
                     annotation.valid = False
                 deleted_ids.append(annotation.id)
 
-            (
-                self.db.cached_session.query(AnnoMetadata)
-                .filter(AnnoMetadata.table_name == table_name)
-                .update({AnnoMetadata.last_modified: datetime.datetime.utcnow()})
+            session.execute(
+                update(AnnoMetadata)
+                .where(AnnoMetadata.table_name == table_name)
+                .values(last_modified=datetime.datetime.utcnow())
             )
 
-            self.db.commit_session()
-
-        else:
-            return None
         return deleted_ids
 
     def _load_model(self, table_name):
         metadata = self.db.get_table_metadata(table_name)
-        schema_type = metadata["schema_type"]
+        schema_type = metadata.anno_metadata.schema_type
 
-        # load reference table into metadata if not already present
-        ref_table = metadata.get("reference_table")
-        if ref_table:
-            reference_table_name = self.db.cached_table(ref_table)
+        if metadata.anno_metadata.reference_table:
+            self.db.cached_table(metadata.anno_metadata.reference_table)
 
         AnnotationModel = self.db.cached_table(table_name)
         return schema_type, AnnotationModel
diff --git a/dynamicannotationdb/database.py b/dynamicannotationdb/database.py
index 5202765..39d3ff2 100644
--- a/dynamicannotationdb/database.py
+++ b/dynamicannotationdb/database.py
@@ -1,44 +1,53 @@
 import logging
 from contextlib import contextmanager
-from typing import List
-
-from sqlalchemy import create_engine, func, inspect, or_
+from typing import Any, Dict, List, Tuple
+from functools import lru_cache
+
+
+from sqlalchemy import (
+    engine,
+    MetaData,
+    Table,
+    create_engine,
+    func,
+    inspect,
+    or_,
+    select,
+)
 from sqlalchemy.ext.automap import automap_base
-from sqlalchemy.ext.declarative.api import DeclarativeMeta
-from sqlalchemy.orm import Session, scoped_session, sessionmaker
-from sqlalchemy.orm.exc import NoResultFound
+from sqlalchemy.orm import sessionmaker, scoped_session
 from sqlalchemy.schema import MetaData
 from sqlalchemy.sql.schema import Table
 
-from .errors import TableAlreadyExists, TableNameNotFound, TableNotInMetadata
-from .models import AnnoMetadata, Base, SegmentationMetadata, AnalysisView
-from .schema import DynamicSchemaClient
 
+from dynamicannotationdb.errors import (
+    TableAlreadyExists,
+    TableNameNotFound,
+    TableNotInMetadata,
+)
+from dynamicannotationdb.models import (
+    AnalysisView,
+    AnnoMetadata,
+    Base,
+    SegmentationMetadata,
+    CombinedMetadata,
+)
+from dynamicannotationdb.schema import DynamicSchemaClient
 
-class DynamicAnnotationDB:
-    def __init__(self, sql_url: str, pool_size=5, max_overflow=5) -> None:
 
-        self._cached_session = None
-        self._cached_tables = {}
-        self._engine = create_engine(
+class DynamicAnnotationDB:
+    def __init__(self, sql_url: str, pool_size: int = 5, max_overflow: int = 5) -> None:
+        self._engine: engine = create_engine(
             sql_url, pool_recycle=3600, pool_size=pool_size, max_overflow=max_overflow
         )
         self.base = Base
-        self.base.metadata.bind = self._engine
-        self.base.metadata.create_all(
-            tables=[AnnoMetadata.__table__, SegmentationMetadata.__table__],
-            checkfirst=True,
-        )
 
-        self.session = scoped_session(
-            sessionmaker(bind=self.engine, autocommit=False, autoflush=False)
-        )
-        self.schema_client = DynamicSchemaClient()
+        session_factory = sessionmaker(bind=self._engine, expire_on_commit=False)
+        self.Session = scoped_session(session_factory)
 
-        self._inspector = inspect(self.engine)
-
-        self._cached_session = None
-        self._cached_tables = {}
+        self.schema_client = DynamicSchemaClient()
+        self._inspector = inspect(self._engine)
+        self._cached_tables: Dict[str, Any] = {}
 
     @property
     def inspector(self):
@@ -48,37 +57,21 @@ def inspector(self):
     def engine(self):
         return self._engine
 
-    @property
-    def cached_session(self) -> Session:
-        if self._cached_session is None:
-            self._cached_session = self.session()
-        return self._cached_session
-
     @contextmanager
     def session_scope(self):
+        session = self.Session()
         try:
-            yield self.cached_session
-        except Exception as e:
-            self.cached_session.rollback()
-            logging.exception(f"SQL Error: {e}")
-            raise e
-        finally:
-            self.cached_session.close()
-        self._cached_session = None
-
-    def commit_session(self):
-        try:
-            self.cached_session.commit()
+            yield session
+            session.commit()
         except Exception as e:
-            self.cached_session.rollback()
+            session.rollback()
             logging.exception(f"SQL Error: {e}")
-            raise e
+            raise
         finally:
-            self.cached_session.close()
-        self._cached_session = None
+            self.Session.remove()
 
-    def get_table_sql_metadata(self, table_name: str):
-        self.base.metadata.reflect(bind=self.engine)
+    def get_table_sql_metadata(self, table_name: str) -> Table:
+        self.base.metadata.reflect(bind=self._engine)
         return self.base.metadata.tables[table_name]
 
     def get_unique_string_values(self, table_name: str):
@@ -95,113 +88,94 @@ def get_unique_string_values(self, table_name: str):
             dictionary of column names and unique values
         """
         model = self.cached_table(table_name)
-
         unique_values = {}
+
         with self.session_scope() as session:
-            for column_name in model.__table__.columns.keys():
-
-                # if the column is a string
-                try:
-                    python_type = model.__table__.columns[column_name].type.python_type
-                except NotImplementedError:
-                    python_type = None
-                if python_type == str:
-                    query = session.query(getattr(model, column_name)).distinct()
-                    unique_values[column_name] = [row[0] for row in query.all()]
+            for column_name, column in model.__table__.columns.items():
+                if isinstance(column.type.python_type, str):
+                    stmt = select(getattr(model, column_name)).distinct()
+                    result = session.execute(stmt)
+                    unique_values[column_name] = [row[0] for row in result.fetchall()]
+
         return unique_values
 
-    def get_views(self, datastack_name: str):
+    def get_views(self, datastack_name: str) -> List[AnalysisView]:
         with self.session_scope() as session:
-            query = session.query(AnalysisView).filter(
+            stmt = select(AnalysisView).where(
                 AnalysisView.datastack_name == datastack_name
             )
-            return query.all()
+            result = session.execute(stmt)
+            return result.scalars().all()
 
-    def get_view_metadata(self, datastack_name: str, view_name: str):
+    def get_view_metadata(self, datastack_name: str, view_name: str) -> Dict[str, Any]:
         with self.session_scope() as session:
-            query = (
-                session.query(AnalysisView)
-                .filter(AnalysisView.table_name == view_name)
-                .filter(AnalysisView.datastack_name == datastack_name)
+            stmt = select(AnalysisView).where(
+                AnalysisView.table_name == view_name,
+                AnalysisView.datastack_name == datastack_name,
             )
-            result = query.one()
-            if hasattr(result, "__dict__"):
-                return self.get_automap_items(result)
-            else:
-                return result[0]
+            result = session.execute(stmt)
+            view = result.scalar_one()
+            return self.get_automap_items(view)
 
-    def get_table_metadata(self, table_name: str, filter_col: str = None):
-        data = getattr(AnnoMetadata, filter_col) if filter_col else AnnoMetadata
+    @lru_cache(maxsize=128)
+    def get_table_metadata(self, table_name: str) -> CombinedMetadata:
         with self.session_scope() as session:
-            if filter_col and data:
-
-                query = session.query(data).filter(
-                    AnnoMetadata.table_name == table_name
+            stmt = (
+                select(AnnoMetadata, SegmentationMetadata)
+                .outerjoin(
+                    SegmentationMetadata,
+                    AnnoMetadata.table_name == SegmentationMetadata.annotation_table,
                 )
-                result = query.one()
-
-                if hasattr(result, "__dict__"):
-                    return self.get_automap_items(result)
-                else:
-                    return result[0]
-            else:
-                metadata = (
-                    session.query(data, SegmentationMetadata)
-                    .outerjoin(
-                        SegmentationMetadata,
-                        AnnoMetadata.table_name
-                        == SegmentationMetadata.annotation_table,
+                .where(
+                    or_(
+                        AnnoMetadata.table_name == table_name,
+                        SegmentationMetadata.table_name == table_name,
                     )
-                    .filter(
-                        or_(
-                            AnnoMetadata.table_name == table_name,
-                            SegmentationMetadata.table_name == table_name,
-                        )
-                    )
-                    .all()
                 )
-                try:
-                    if metadata:
-                        flatted_metadata = self.flatten_join(metadata)
-                        return flatted_metadata[0]
-                except NoResultFound:
-                    return None
+            )
+            result = session.execute(stmt).first()
+
+            if result is None:
+                raise ValueError(f"No metadata found for table '{table_name}'")
+
+            anno_metadata, seg_metadata = result
+
+            return CombinedMetadata(
+                table_name=table_name,
+                anno_metadata=anno_metadata,
+                seg_metadata=seg_metadata,
+            )
 
     def get_table_schema(self, table_name: str) -> str:
         table_metadata = self.get_table_metadata(table_name)
-        return table_metadata.get("schema_type")
+        return table_metadata.anno_metadata.schema_type
 
     def get_valid_table_names(self) -> List[str]:
         with self.session_scope() as session:
-            metadata = session.query(AnnoMetadata).all()
-            return [m.table_name for m in metadata if m.valid == True]
+            stmt = select(AnnoMetadata.table_name).where(AnnoMetadata.valid == True)
+            result = session.execute(stmt)
+            return result.scalars().all()
 
     def get_annotation_table_size(self, table_name: str) -> int:
-        """Get the number of annotations in a table
-
-        Parameters
-        ----------
-        table_name : str
-            name of table contained within the aligned_volume database
-
-        Returns
-        -------
-        int
-            number of annotations
-        """
         Model = self.cached_table(table_name)
         with self.session_scope() as session:
-            return session.query(Model).count()
+            stmt = select(func.count()).select_from(Model)
+            result = session.execute(stmt)
+            return result.scalar_one()
 
     def get_max_id_value(self, table_name: str) -> int:
         model = self.cached_table(table_name)
         with self.session_scope() as session:
-            return session.query(func.max(model.id)).scalar()
+            stmt = select(func.max(model.id))
+            result = session.execute(stmt)
+            return result.scalar_one()
 
     def get_min_id_value(self, table_name: str) -> int:
         model = self.cached_table(table_name)
         with self.session_scope() as session:
-            return session.query(func.min(model.id)).scalar()
+            stmt = select(func.min(model.id))
+            result = session.execute(stmt)
+            return result.scalar_one()
 
     def get_table_row_count(
         self, table_name: str, filter_valid: bool = False, filter_timestamp: str = None
@@ -219,27 +193,29 @@ def get_table_row_count(
         """
         model = self.cached_table(table_name)
         with self.session_scope() as session:
-            sql_query = session.query(func.count(model.id))
+            stmt = select(func.count(model.id))
             if filter_valid:
-                sql_query = sql_query.filter(model.valid == True)
+                stmt = stmt.where(model.valid == True)
             if filter_timestamp and hasattr(model, "created"):
-                sql_query = sql_query.filter(model.created <= filter_timestamp)
-            return sql_query.scalar()
+                stmt = stmt.where(model.created <= filter_timestamp)
+            result = session.execute(stmt)
+            return result.scalar_one()
 
     @staticmethod
-    def get_automap_items(result):
+    def get_automap_items(result: Any) -> Dict[str, Any]:
         return {k: v for (k, v) in result.__dict__.items() if k != "_sa_instance_state"}
 
-    def obj_to_dict(self, obj):
-        if obj:
-            return {
+    def obj_to_dict(self, obj: Any) -> Dict[str, Any]:
+        return (
+            {
                 column.key: getattr(obj, column.key)
                 for column in inspect(obj).mapper.column_attrs
             }
-        else:
-            return {}
+            if obj
+            else {}
+        )
 
-    def flatten_join(self, _list: List):
+    def flatten_join(self, _list: List[Tuple[Any, Any]]) -> List[Dict[str, Any]]:
         return [{**self.obj_to_dict(a), **self.obj_to_dict(b)} for a, b in _list]
 
     def drop_table(self, table_name: str) -> bool:
@@ -256,16 +232,15 @@ def drop_table(self, table_name: str) -> bool:
         bool
             whether drop was successful
         """
-        table = self.base.metadata.tables.get(table_name)
-        if table:
+        if table := self.base.metadata.tables.get(table_name):
             logging.info(f"Deleting {table_name} table")
             self.base.metadata.drop_all(self._engine, [table], checkfirst=True)
-            if self._is_cached(table):
-                del self._cached_tables[table]
+            if table_name in self._cached_tables:
+                del self._cached_tables[table_name]
             return True
         return False
 
-    def _check_table_is_unique(self, table_name):
+    def _check_table_is_unique(self, table_name: str) -> List[str]:
         existing_tables = self._get_existing_table_names()
         if table_name in existing_tables:
             raise TableAlreadyExists(
@@ -274,46 +249,39 @@ def _check_table_is_unique(self, table_name):
         return existing_tables
 
     def _get_existing_table_names(self, filter_valid: bool = False) -> List[str]:
-        """Collects table_names keys of existing tables
-
-        Returns
-        -------
-        list
-            List of table_names
-        """
         with self.session_scope() as session:
-            stmt = session.query(AnnoMetadata)
+            stmt = select(AnnoMetadata.table_name)
             if filter_valid:
-                stmt = stmt.filter(AnnoMetadata.valid == True)
-            metadata = stmt.all()
-            return [m.table_name for m in metadata]
+                stmt = stmt.where(AnnoMetadata.valid == True)
+            result = session.execute(stmt)
+            return result.scalars().all()
 
-    def _get_model_from_table_name(self, table_name: str) -> DeclarativeMeta:
-        metadata = self.get_table_metadata(table_name)
+    def _get_model_from_table_name(self, table_name: str) -> Any:
+        combined_metadata = self.get_table_metadata(table_name)
 
-        if metadata:
-            if metadata["reference_table"]:
+        if combined_metadata is None:
+            raise TableNotInMetadata(f"No metadata found for table '{table_name}'")
+
+        anno_metadata = combined_metadata.anno_metadata
+        seg_metadata = combined_metadata.seg_metadata
+
+        if anno_metadata:
+            if anno_metadata.reference_table:
                 return self.schema_client.create_reference_annotation_model(
-                    table_name,
-                    metadata["schema_type"],
-                    metadata["reference_table"],
+                    table_name, anno_metadata.schema_type, anno_metadata.reference_table
                 )
-            elif metadata.get("annotation_table") and table_name != metadata.get(
-                "annotation_table"
-            ):
+            elif seg_metadata and table_name == seg_metadata.table_name:
                 return self.schema_client.create_segmentation_model(
-                    metadata["annotation_table"],
-                    metadata["schema_type"],
-                    metadata["pcg_table_name"],
+                    anno_metadata.table_name,
+                    anno_metadata.schema_type,
+                    seg_metadata.pcg_table_name,
                 )
-
             else:
                 return self.schema_client.create_annotation_model(
-                    table_name, metadata["schema_type"]
+                    table_name, anno_metadata.schema_type
                 )
-
         else:
-            raise TableNotInMetadata
+            raise ValueError(f"Invalid metadata structure for table '{table_name}'")
 
     def _get_model_columns(self, table_name: str) -> List[tuple]:
         """Return list of column names and types of a given table
@@ -328,23 +296,22 @@ def _get_model_columns(self, table_name: str) -> List[tuple]:
         list
             column names and types
         """
-        db_columns = self.inspector.get_columns(table_name)
+        db_columns = self._inspector.get_columns(table_name)
         if not db_columns:
             raise TableNameNotFound(table_name)
         return [(column["name"], column["type"]) for column in db_columns]
 
     def get_view_table(self, view_name: str) -> Table:
         """Return the sqlalchemy table object for a view"""
-        if self._is_cached(view_name):
+        if view_name in self._cached_tables:
             return self._cached_tables[view_name]
-        else:
-            meta = MetaData(self._engine)
-            meta.reflect(views=True, only=[view_name])
-            table = meta.tables[view_name]
-            self._cached_tables[view_name] = table
-            return table
+        meta = MetaData()
+        meta.reflect(bind=self._engine, views=True, only=[view_name])
+        table = meta.tables[view_name]
+        self._cached_tables[view_name] = table
+        return table
 
-    def cached_table(self, table_name: str) -> DeclarativeMeta:
+    def cached_table(self, table_name: str):
         """Returns cached table 'DeclarativeMeta' callable for querying.
 
         Parameters
@@ -375,41 +342,21 @@ def _load_table(self, table_name: str):
         bool
             Returns True if table exists and is loaded into cached table dict.
         """
-        if self._is_cached(table_name):
-            return True
-
         try:
             self._cached_tables[table_name] = self._get_model_from_table_name(
                 table_name
             )
             return True
         except TableNotInMetadata:
-            # cant find the table so lets try the slow reflection before giving up
-            self.mapped_base = automap_base()
-            self.mapped_base.prepare(self._engine, reflect=True)
             try:
-                model = self.mapped_base.classes[table_name]
+                base = automap_base()
+                base.prepare(self._engine, reflect=True)
+                model = base.classes[table_name]
                 self._cached_tables[table_name] = model
+                return True
             except KeyError as table_error:
                 logging.error(f"Could not load table: {table_error}")
                 return False
-
         except Exception as table_error:
             logging.error(f"Could not load table: {table_error}")
             return False
-
-    def _is_cached(self, table_name: str) -> bool:
-        """Check if table is loaded into cached instance dict of tables
-
-        Parameters
-        ----------
-        table_name : str
-            Name of table to check if loaded
-
-        Returns
-        -------
-        bool
-            True if table is loaded else False.
-        """
-
-        return table_name in self._cached_tables
diff --git a/dynamicannotationdb/errors.py b/dynamicannotationdb/errors.py
index e9e4135..a415da1 100644
--- a/dynamicannotationdb/errors.py
+++ b/dynamicannotationdb/errors.py
@@ -13,9 +13,11 @@ def __str__(self):
 class TableAlreadyExists(KeyError):
     """Table name already exists in the Metadata table"""
 
+
 class TableNotInMetadata(KeyError):
     """Table does not exist in the Metadata table"""
 
+
 class IdsAlreadyExists(KeyError):
     """Annotation IDs already exists in the segmentation table"""
 
@@ -30,9 +32,9 @@ class BadRequest(Exception):
 
 class UpdateAnnotationError(ValueError):
     def __init__(
-            self,
-            target_id: int,
-            superseded_id: int,
+        self,
+        target_id: int,
+        superseded_id: int,
     ):
         self.target_id = target_id
         self.message = f"Annotation with ID {target_id} has already been superseded by annotation ID {superseded_id}, update annotation ID {superseded_id} instead"
@@ -46,7 +48,7 @@ class AnnotationInsertLimitExceeded(ValueError):
     """Exception raised when amount of annotations exceeds defined limit."""
 
     def __init__(
-            self, limit: int, length: int, message: str = "Annotation limit exceeded"
+        self, limit: int, length: int, message: str = "Annotation limit exceeded"
     ):
         self.limit = limit
         self.message = (
diff --git a/dynamicannotationdb/interface.py b/dynamicannotationdb/interface.py
index e28c6bf..9ebef57 100644
--- a/dynamicannotationdb/interface.py
+++ b/dynamicannotationdb/interface.py
@@ -1,7 +1,6 @@
 import logging
-
-from sqlalchemy import create_engine
-from sqlalchemy.engine.url import make_url
+from sqlalchemy import create_engine, text
+from sqlalchemy.engine import URL, make_url
 from sqlalchemy.pool import NullPool
 
 from .annotation import DynamicAnnotationClient
@@ -32,7 +31,6 @@ class DynamicAnnotationInterface:
         linked to annotation tables.
     schema :
         Wrapper for EMAnnotationSchemas to generate dynamic sqlalchemy models.
-
     """
 
     def __init__(
@@ -57,15 +55,22 @@ def create_or_select_database(self, url: str, aligned_volume: str):
         url : str
             base path to the sql server
         aligned_volume : str
-            name of aligned volume which the database name will inherent
+            name of aligned volume which the database name will inherit
 
         Returns
         -------
         sql_url instance
         """
         sql_base_uri = url.rpartition("/")[0]
-
-        sql_uri = make_url(f"{sql_base_uri}/{aligned_volume}")
+        parsed_url = make_url(url)
+        sql_uri = URL.create(
+            drivername=parsed_url.drivername,
+            username=parsed_url.username,
+            password=parsed_url.password,
+            host=parsed_url.host,
+            port=parsed_url.port,
+            database=aligned_volume,
+        )
 
         temp_engine = create_engine(
             sql_base_uri,
@@ -75,11 +80,12 @@ def create_or_select_database(self, url: str, aligned_volume: str):
         )
 
         with temp_engine.connect() as connection:
-            connection.execute("commit")
             database_exists = connection.execute(
-                f"SELECT 1 FROM pg_catalog.pg_database WHERE datname = '{sql_uri.database}'"
-            )
-            if not database_exists.fetchone():
+                text(f"SELECT 1 FROM pg_catalog.pg_database WHERE datname = :dbname"),
+                {"dbname": sql_uri.database},
+            ).scalar()
+
+            if not database_exists:
                 logging.info(f"Database {aligned_volume} does not exist.")
                 self._create_aligned_volume_database(sql_uri, connection)
 
@@ -95,21 +101,25 @@ def _create_aligned_volume_database(self, sql_uri, connection):
         logging.info(f"Creating new database: {sql_uri.database}")
 
         connection.execute(
-            f"SELECT pg_terminate_backend(pid) FROM pg_stat_activity \
-                           WHERE pid <> pg_backend_pid() AND datname = '{sql_uri.database}';"
+            text(
+                f"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid() AND datname = :dbname"
+            ),
+            {"dbname": sql_uri.database},
         )
 
         # check if template exists, create if missing
         template_exist = connection.execute(
-            "SELECT 1 FROM pg_catalog.pg_database WHERE datname = 'template_postgis'"
-        )
+            text(
+                "SELECT 1 FROM pg_catalog.pg_database WHERE datname = 'template_postgis'"
+            )
+        ).scalar()
 
-        if not template_exist.fetchone():
+        if not template_exist:
             # create postgis template db
-            connection.execute("CREATE DATABASE template_postgis")
+            connection.execute(text("CREATE DATABASE template_postgis"))
 
             # create postgis extension
-            template_uri = make_url(
+            template_uri = URL.create(
                 f"{str(sql_uri).rpartition('/')[0]}/template_postgis"
             )
             template_engine = create_engine(
@@ -119,12 +129,14 @@ def _create_aligned_volume_database(self, sql_uri, connection):
                 pool_pre_ping=True,
             )
             with template_engine.connect() as template_connection:
-                template_connection.execute("CREATE EXTENSION IF NOT EXISTS postgis")
+                template_connection.execute(
+                    text("CREATE EXTENSION IF NOT EXISTS postgis")
+                )
             template_engine.dispose()
 
         # finally create new annotation database
         connection.execute(
-            f"CREATE DATABASE {sql_uri.database} TEMPLATE template_postgis"
+            text(f"CREATE DATABASE {sql_uri.database} TEMPLATE template_postgis")
         )
         aligned_volume_engine = create_engine(
             sql_uri,
diff --git a/dynamicannotationdb/models.py b/dynamicannotationdb/models.py
index 23b3a4a..7070109 100644
--- a/dynamicannotationdb/models.py
+++ b/dynamicannotationdb/models.py
@@ -1,5 +1,5 @@
 import enum
-
+from typing import Optional
 from emannotationschemas.models import Base
 from sqlalchemy import (
     Boolean,
@@ -11,13 +11,15 @@
     Integer,
     String,
     Text,
-    Enum,
     JSON,
 )
 from sqlalchemy.ext.declarative import declarative_base
 from sqlalchemy.orm import relationship
 from sqlalchemy.dialects import postgresql
 
+from marshmallow import Schema, fields, post_load
+from marshmallow_sqlalchemy import SQLAlchemyAutoSchema
+
 # Models that will be created in the 'materialized' database.
 MatBase = declarative_base()
 # Models that will be created in the 'annotation' database.
@@ -125,6 +127,13 @@ class AnnoMetadata(Base):
     last_modified = Column(DateTime, nullable=False)
 
 
+class AnnotationMetadataSchema(SQLAlchemyAutoSchema):
+    class Meta:
+        model = AnnoMetadata
+        include_relationships = True
+        load_instance = True
+
+
 class SegmentationMetadata(Base):
     __tablename__ = "segmentation_table_metadata"
     id = Column(Integer, primary_key=True)
@@ -141,6 +150,35 @@ class SegmentationMetadata(Base):
     )
 
 
+class SegmentationMetadataSchema(SQLAlchemyAutoSchema):
+    class Meta:
+        model = SegmentationMetadata
+        include_relationships = True
+        load_instance = True
+
+
+class CombinedMetadata:
+    def __init__(
+        self,
+        table_name: str,
+        anno_metadata: Optional[AnnoMetadata] = None,
+        seg_metadata: Optional[SegmentationMetadata] = None,
+    ):
+        self.table_name = table_name
+        self.anno_metadata = anno_metadata
+        self.seg_metadata = seg_metadata
+
+
+class CombinedMetadataSchema(Schema):
+    table_name = fields.Str(required=True)
+    anno_metadata = fields.Nested(AnnotationMetadataSchema, allow_none=True)
+    seg_metadata = fields.Nested(SegmentationMetadataSchema, allow_none=True)
+
+    @post_load
+    def make_combined_metadata(self, data, **kwargs):
+        return CombinedMetadata(**data)
+
+
 class CombinedTableMetadata(Base):
     __tablename__ = "combined_table_metadata"
     __table_args__ = (
diff --git a/dynamicannotationdb/schema.py b/dynamicannotationdb/schema.py
index 6eb3c52..2cb1e75 100644
--- a/dynamicannotationdb/schema.py
+++ b/dynamicannotationdb/schema.py
@@ -6,7 +6,7 @@
 from emannotationschemas.schemas.base import ReferenceAnnotation, SegmentationField
 from marshmallow import EXCLUDE, Schema
 
-from .errors import SelfReferenceTableError, TableNameNotFound
+from dynamicannotationdb.errors import SelfReferenceTableError, TableNameNotFound
 
 
 class DynamicSchemaClient:
diff --git a/dynamicannotationdb/segmentation.py b/dynamicannotationdb/segmentation.py
index 3dca556..ce0d884 100644
--- a/dynamicannotationdb/segmentation.py
+++ b/dynamicannotationdb/segmentation.py
@@ -1,19 +1,22 @@
 import datetime
 import logging
-from typing import List
+from typing import Any, Dict, List, Optional
 
 from marshmallow import INCLUDE
+from sqlalchemy import and_, select
+from sqlalchemy.dialects.postgresql import insert
+from sqlalchemy.exc import SQLAlchemyError
 
-from .database import DynamicAnnotationDB
-from .errors import (
+from dynamicannotationdb.database import DynamicAnnotationDB
+from dynamicannotationdb.errors import (
     AnnotationInsertLimitExceeded,
     IdsAlreadyExists,
     UpdateAnnotationError,
 )
-from .key_utils import build_segmentation_table_name
-from .models import SegmentationMetadata
-from .schema import DynamicSchemaClient
-from .errors import TableNameNotFound
+from dynamicannotationdb.key_utils import build_segmentation_table_name
+from dynamicannotationdb.models import SegmentationMetadata
+from dynamicannotationdb.schema import DynamicSchemaClient
+
 
 class DynamicSegmentationClient:
     def __init__(self, sql_url: str) -> None:
@@ -21,38 +24,16 @@ def __init__(self, sql_url: str) -> None:
         self.schema = DynamicSchemaClient()
 
     def create_segmentation_table(
-            self,
-            table_name: str,
-            schema_type: str,
-            segmentation_source: str,
-            table_metadata: dict = None,
-            with_crud_columns: bool = False,
-    ):
-        """Create a segmentation table with the primary key as foreign key
-        to the annotation table.
-
-        Parameters
-        ----------
-        table_name : str
-            Name of annotation table to link to.
-        schema_type : str
-            schema type
-        segmentation_source : str
-            name of segmentation data source, used to create table name.
-        table_metadata : dict, optional
-            metadata to extend table behavior, by default None
-        with_crud_columns : bool, optional
-            add additional columns to track CRUD operations on rows, by default False
-
-        Returns
-        -------
-        str
-            name of segmentation table.
-        """
+        self,
+        table_name: str,
+        schema_type: str,
+        segmentation_source: str,
+        table_metadata: Optional[Dict[str, Any]] = None,
+        with_crud_columns: bool = False,
+    ) -> str:
         segmentation_table_name = build_segmentation_table_name(
             table_name, segmentation_source
         )
-
         self.db._check_table_is_unique(segmentation_table_name)
 
         SegmentationModel = self.schema.create_segmentation_model(
@@ -63,330 +44,285 @@ def create_segmentation_table(
             with_crud_columns,
         )
 
-        if (
-                not self.db.cached_session.query(SegmentationMetadata)
-                        .filter(SegmentationMetadata.table_name == segmentation_table_name)
-                        .scalar()
-        ):
-            SegmentationModel.__table__.create(bind=self.db._engine, checkfirst=True)
-            creation_time = datetime.datetime.utcnow()
-            metadata_dict = {
-                "annotation_table": table_name,
-                "schema_type": schema_type,
-                "table_name": segmentation_table_name,
-                "valid": True,
-                "created": creation_time,
-                "pcg_table_name": segmentation_source,
-            }
+        with self.db.session_scope() as session:
+            if not session.execute(
+                select(SegmentationMetadata).filter_by(
+                    table_name=segmentation_table_name
+                )
+            ).scalar():
+                SegmentationModel.__table__.create(
+                    bind=self.db._engine, checkfirst=True
+                )
+                creation_time = datetime.datetime.utcnow()
+                metadata_dict = {
+                    "annotation_table": table_name,
+                    "schema_type": schema_type,
+                    "table_name": segmentation_table_name,
+                    "valid": True,
+                    "created": creation_time,
+                    "pcg_table_name": segmentation_source,
+                }
+
+                seg_metadata = SegmentationMetadata(**metadata_dict)
+                session.add(seg_metadata)
+
+        return segmentation_table_name
 
-            seg_metadata = SegmentationMetadata(**metadata_dict)
+    def get_linked_tables(
+        self, table_name: str, pcg_table_name: str
+    ) -> List[SegmentationMetadata]:
+        with self.db.session_scope() as session:
             try:
-                self.db.cached_session.add(seg_metadata)
-                self.db.commit_session()
+                stmt = select(SegmentationMetadata).filter_by(
+                    annotation_table=table_name, pcg_table_name=pcg_table_name
+                )
+                result = session.execute(stmt)
+                return result.scalars().all()
             except Exception as e:
-                logging.error(f"SQL ERROR: {e}")
+                raise AttributeError(
+                    f"No table found with name '{table_name}'. Error: {e}"
+                ) from e
 
-        return segmentation_table_name
+    def get_segmentation_table_metadata(
+        self, table_name: str, pcg_table_name: str
+    ) -> Optional[Dict[str, Any]]:
+        seg_table_name = build_segmentation_table_name(table_name, pcg_table_name)
+        with self.db.session_scope() as session:
+            try:
+                stmt = select(SegmentationMetadata).filter_by(table_name=seg_table_name)
+                result = session.execute(stmt).scalar_one_or_none()
+                return self.db.get_automap_items(result) if result else None
+            except Exception as e:
+                logging.error(f"Error fetching segmentation table metadata: {e}")
+                return None
 
-    def get_linked_tables(self, table_name: str, pcg_table_name: str) -> List:
+    def get_linked_annotations(
+        self, table_name: str, pcg_table_name: str, annotation_ids: List[int]
+    ) -> List[Dict[str, Any]]:
         try:
-            return (
-                self.db.cached_session.query(SegmentationMetadata)
-                    .filter(SegmentationMetadata.annotation_table == table_name)
-                    .filter(SegmentationMetadata.pcg_table_name == pcg_table_name)
-                    .all()
+            metadata = self.db.get_table_metadata(table_name)
+            schema_type = metadata.anno_metadata.schema_type
+            seg_table_name = build_segmentation_table_name(table_name, pcg_table_name)
+            AnnotationModel, SegmentationModel = self._get_models(
+                table_name, seg_table_name
             )
 
-        except Exception as e:
-            raise AttributeError(
-                f"No table found with name '{table_name}'. Error: {e}"
-            ) from e
+            with self.db.session_scope() as session:
+                # Perform a join query
+                stmt = (
+                    select(AnnotationModel, SegmentationModel)
+                    .join(SegmentationModel, AnnotationModel.id == SegmentationModel.id)
+                    .filter(AnnotationModel.id.in_(annotation_ids))
+                )
+                result = session.execute(stmt).all()
 
-    def get_segmentation_table_metadata(self, table_name: str, pcg_table_name: str):
-        seg_table_name = build_segmentation_table_name(table_name, pcg_table_name)
-        try:
-            result = (
-                self.db.cached_session.query(SegmentationMetadata)
-                .filter(SegmentationMetadata.table_name == seg_table_name)
-                .one()
-            )
-            return self.db.get_automap_items(result)
-        except Exception as e:
-            self.db.cached_session.rollback()
-            return None
-        
+            data = []
+            for anno, seg in result:
+                merged_data = self._format_model_to_dict(anno)
+                merged_data.update(self._format_model_to_dict(seg))
+                data.append(merged_data)
 
-    def get_linked_annotations(
-            self, table_name: str, pcg_table_name: str, annotation_ids: List[int]
-    ) -> dict:
-        """Get list of annotations from database by id.
-
-        Parameters
-        ----------
-        table_name : str
-            name of annotation table
-        pcg_table_name: str
-            name of chunked graph reference table
-        annotation_ids : int
-            annotation id
-
-        Returns
-        -------
-        list
-            list of annotation data dicts
-        """
+            FlatSchema = self.schema.get_flattened_schema(schema_type)
+            schema = FlatSchema(unknown=INCLUDE)
 
-        metadata = self.db.get_table_metadata(table_name)
-        schema_type = metadata["schema_type"]
-        seg_table_name = build_segmentation_table_name(table_name, pcg_table_name)
-        AnnotationModel = self.db.cached_table(table_name)
-        SegmentationModel = self.db.cached_table(seg_table_name)
+            return schema.load(data, many=True)
 
-        annotations = (
-            self.db.cached_session.query(AnnotationModel, SegmentationModel)
-                .join(SegmentationModel, SegmentationModel.id == AnnotationModel.id)
-                .filter(AnnotationModel.id.in_(list(annotation_ids)))
-                .all()
-        )
-
-        FlatSchema = self.schema.get_flattened_schema(schema_type)
-        schema = FlatSchema(unknown=INCLUDE)
+        except Exception as e:
+            logging.error(f"Error retrieving linked annotations: {str(e)}")
+            raise
 
-        data = []
-        for anno, seg in annotations:
-            anno_data = anno.__dict__
-            seg_data = seg.__dict__
-            anno_data = {
-                k: v for (k, v) in anno_data.items() if k != "_sa_instance_state"
-            }
-            seg_data = {
-                k: v for (k, v) in seg_data.items() if k != "_sa_instance_state"
-            }
-            anno_data["created"] = str(anno_data.get("created"))
-            anno_data["deleted"] = str(anno_data.get("deleted"))
+    def _get_models(self, table_name: str, seg_table_name: str):
+        AnnotationModel = self.db.cached_table(table_name)
+        SegmentationModel = self.db.cached_table(seg_table_name)
+        return AnnotationModel, SegmentationModel
 
-            merged_data = {**anno_data, **seg_data}
-            data.append(merged_data)
+    def _format_model_to_dict(self, model):
+        return {
+            k: self._format_value(v)
+            for k, v in model.__dict__.items()
+            if not k.startswith("_")
+        }
 
-        return schema.load(data, many=True)
+    def _format_value(self, value):
+        return str(value) if isinstance(value, datetime.datetime) else value
 
     def insert_linked_segmentation(
-            self, table_name: str, pcg_table_name: str, segmentation_data: List[dict]
-    ):
-        """Insert segmentation data by linking to annotation ids.
-        Limited to 10,000 inserts. If more consider using a bulk insert script.
-
-        Parameters
-        ----------
-        table_name : str
-            name of annotation table
-        pcg_table_name: str
-            name of chunked graph reference table
-        segmentation_data : List[dict]
-            List of dictionaries of single segmentation data.
-        """
+        self,
+        table_name: str,
+        pcg_table_name: str,
+        segmentation_data: List[Dict[str, Any]],
+    ) -> List[int]:
         insertion_limit = 10_000
-
         if len(segmentation_data) > insertion_limit:
             raise AnnotationInsertLimitExceeded(len(segmentation_data), insertion_limit)
 
         metadata = self.db.get_table_metadata(table_name)
-        schema_type = metadata["schema_type"]
-
+        schema_type = metadata.anno_metadata.schema_type
         seg_table_name = build_segmentation_table_name(table_name, pcg_table_name)
-
         SegmentationModel = self.db.cached_table(seg_table_name)
-        formatted_seg_data = []
 
         _, segmentation_schema = self.schema.split_flattened_schema(schema_type)
 
-        for segmentation in segmentation_data:
-            segmentation_data = self.schema.flattened_schema_data(segmentation)
-            flat_data = self.schema._map_values_to_schema(
-                segmentation_data, segmentation_schema
-            )
-            flat_data["id"] = segmentation["id"]
+        formatted_seg_data = [
+            {
+                **self.schema._map_values_to_schema(
+                    self.schema.flattened_schema_data(segmentation), segmentation_schema
+                ),
+                "id": segmentation["id"],
+            }
+            for segmentation in segmentation_data
+        ]
 
-            formatted_seg_data.append(flat_data)
+        with self.db.session_scope() as session:
+            try:
+                ids = [data["id"] for data in formatted_seg_data]
+
+                # Check for existing IDs efficiently
+                existing_ids = set(
+                    session.execute(
+                        select(SegmentationModel.id).filter(
+                            SegmentationModel.id.in_(ids)
+                        )
+                    )
+                    .scalars()
+                    .all()
+                )
 
-        segs = [
-            SegmentationModel(**segmentation_data)
-            for segmentation_data in formatted_seg_data
-        ]
+                if existing_ids:
+                    raise IdsAlreadyExists(
+                        f"Annotation IDs {existing_ids} already linked in database"
+                    )
 
-        ids = [data["id"] for data in formatted_seg_data]
-        q = self.db.cached_session.query(SegmentationModel).filter(
-            SegmentationModel.id.in_(list(ids))
-        )
+                # Bulk insert using PostgreSQL's INSERT ... ON CONFLICT
+                insert_stmt = insert(SegmentationModel).values(formatted_seg_data)
+                insert_stmt = insert_stmt.on_conflict_do_nothing(index_elements=["id"])
+                result = session.execute(insert_stmt.returning(SegmentationModel.id))
 
-        ids_exist = self.db.cached_session.query(q.exists()).scalar()
+                seg_ids = result.scalars().all()
 
-        if ids_exist:
-            raise IdsAlreadyExists(f"Annotation IDs {ids} already linked in database ")
-        self.db.cached_session.add_all(segs)
-        seg_ids = [seg.id for seg in segs]
-        self.db.commit_session()
-        return seg_ids
+                return seg_ids
+
+            except SQLAlchemyError as e:
+                logging.error(f"Error inserting linked segmentations: {str(e)}")
+                raise
 
     def insert_linked_annotations(
-            self, table_name: str, pcg_table_name: str, annotations: List[dict]
-    ):
-        """Insert annotations by type and schema. Limited to 10,000
-        annotations. If more consider using a bulk insert script.
-
-        Parameters
-        ----------
-        table_name : str
-            name of annotation table
-        pcg_table_name: str
-            name of chunked graph reference table
-        annotations : dict
-            Dictionary of single annotation data.
-        """
+        self, table_name: str, pcg_table_name: str, annotations: List[Dict[str, Any]]
+    ) -> List[int]:
         insertion_limit = 10_000
-
         if len(annotations) > insertion_limit:
             raise AnnotationInsertLimitExceeded(len(annotations), insertion_limit)
 
         metadata = self.db.get_table_metadata(table_name)
-        schema_type = metadata["schema_type"]
+        schema_type = metadata.anno_metadata.schema_type
 
         seg_table_name = build_segmentation_table_name(table_name, pcg_table_name)
+        AnnotationModel = self.db.cached_table(table_name)
+        SegmentationModel = self.db.cached_table(seg_table_name)
 
         formatted_anno_data = []
         formatted_seg_data = []
 
-        AnnotationModel = self.db.cached_table(table_name)
-        SegmentationModel = self.db.cached_table(seg_table_name)
-        logging.info(f"{AnnotationModel.__table__.columns}")
-        logging.info(f"{SegmentationModel.__table__.columns}")
-
         for annotation in annotations:
-
             anno_data, seg_data = self.schema.split_flattened_schema_data(
                 schema_type, annotation
             )
-            if annotation.get("id"):
+            if "id" in annotation:
                 anno_data["id"] = annotation["id"]
             if hasattr(AnnotationModel, "created"):
                 anno_data["created"] = datetime.datetime.utcnow()
             anno_data["valid"] = True
             formatted_anno_data.append(anno_data)
             formatted_seg_data.append(seg_data)
+
         logging.info(f"DATA TO BE INSERTED: {formatted_anno_data} {formatted_seg_data}")
-        try:
-            annos = [
-                AnnotationModel(**annotation_data)
-                for annotation_data in formatted_anno_data
-            ]
-        except Exception as e:
-            raise e
-        self.db.cached_session.add_all(annos)
-        self.db.cached_session.flush()
-        segs = [
-            SegmentationModel(**segmentation_data, id=anno.id)
-            for segmentation_data, anno in zip(formatted_seg_data, annos)
-        ]
-        ids = [anno.id for anno in annos]
-        self.db.cached_session.add_all(segs)
-        self.db.commit_session()
+
+        with self.db.session_scope() as session:
+            try:
+                annos = [AnnotationModel(**data) for data in formatted_anno_data]
+                session.add_all(annos)
+                session.flush()
+                segs = [
+                    SegmentationModel(**seg_data, id=anno.id)
+                    for seg_data, anno in zip(formatted_seg_data, annos)
+                ]
+                session.add_all(segs)
+                session.flush()
+                ids = [anno.id for anno in annos]
+            except Exception as e:
+                logging.error(f"Error inserting linked annotations: {e}")
+                raise
+
         return ids
 
     def update_linked_annotations(
-            self, table_name: str, pcg_table_name: str, annotation: dict
-    ):
-        """Updates an annotation by inserting a new row. The original annotation
-        will refer to the new row with a superseded_id. Does not update inplace.
-
-        Parameters
-        ----------
-        table_name : str
-            name of annotation table
-        pcg_table_name: str
-            name of chunked graph reference table
-        annotation : dict, annotation to update by ID
-        """
+        self, table_name: str, pcg_table_name: str, annotation: Dict[str, Any]
+    ) -> Dict[int, int]:
         anno_id = annotation.get("id")
         if not anno_id:
-            return "Annotation requires an 'id' to update targeted row"
+            raise ValueError("Annotation requires an 'id' to update targeted row")
 
         metadata = self.db.get_table_metadata(table_name)
-        schema_type = metadata["schema_type"]
-
+        schema_type = metadata.anno_metadata.schema_type
         seg_table_name = build_segmentation_table_name(table_name, pcg_table_name)
-
         AnnotationModel = self.db.cached_table(table_name)
         SegmentationModel = self.db.cached_table(seg_table_name)
 
-        new_annotation, __ = self.schema.split_flattened_schema_data(
+        new_annotation, _ = self.schema.split_flattened_schema_data(
             schema_type, annotation
         )
-
         new_annotation["created"] = datetime.datetime.utcnow()
         new_annotation["valid"] = True
 
-        new_data = AnnotationModel(**new_annotation)
+        with self.db.session_scope() as session:
+            stmt = select(AnnotationModel, SegmentationModel).filter(
+                and_(AnnotationModel.id == anno_id, SegmentationModel.id == anno_id)
+            )
+            result = session.execute(stmt).first()
+
+            if not result:
+                raise ValueError(f"No annotation found with id {anno_id}")
+
+            old_anno, old_seg = result
 
-        data = (
-            self.db.cached_session.query(AnnotationModel, SegmentationModel)
-                .filter(AnnotationModel.id == anno_id)
-                .filter(SegmentationModel.id == anno_id)
-                .all()
-        )
-        update_map = {}
-        for old_anno, old_seg in data:
             if old_anno.superceded_id:
                 raise UpdateAnnotationError(anno_id, old_anno.superceded_id)
 
-            self.db.cached_session.add(new_data)
-            self.db.cached_session.flush()
+            new_data = AnnotationModel(**new_annotation)
+            session.add(new_data)
+            session.flush()
 
             deleted_time = datetime.datetime.utcnow()
             old_anno.deleted = deleted_time
             old_anno.superceded_id = new_data.id
             old_anno.valid = False
-            update_map[anno_id] = new_data.id
-        self.db.commit_session()
 
-        return update_map
+        return {anno_id: new_data.id}
 
     def delete_linked_annotation(
-            self, table_name: str, pcg_table_name: str, annotation_ids: List[int]
-    ):
-        """Mark annotations by for deletion by list of ids.
-
-        Parameters
-        ----------
-        table_name : str
-            name of annotation table
-        pcg_table_name: str
-            name of chunked graph reference table
-        annotation_ids : List[int]
-            list of ids to delete
-
-        Returns
-        -------
-
-        Raises
-        ------
-        """
+        self, table_name: str, pcg_table_name: str, annotation_ids: List[int]
+    ) -> Optional[List[int]]:
         seg_table_name = build_segmentation_table_name(table_name, pcg_table_name)
         AnnotationModel = self.db.cached_table(table_name)
         SegmentationModel = self.db.cached_table(seg_table_name)
 
-        annotations = (
-            self.db.cached_session.query(AnnotationModel)
+        with self.db.session_scope() as session:
+            stmt = (
+                select(AnnotationModel)
                 .join(SegmentationModel, SegmentationModel.id == AnnotationModel.id)
-                .filter(AnnotationModel.id.in_(list(annotation_ids)))
-                .all()
-        )
+                .filter(AnnotationModel.id.in_(annotation_ids))
+            )
+            result = session.execute(stmt)
+            annotations = result.scalars().all()
+
+            if not annotations:
+                return None
+
+            deleted_time = datetime.datetime.utcnow()
+            for annotation in annotations:
+                annotation.deleted = deleted_time
+                annotation.valid = False
+
+            deleted_ids = [annotation.id for annotation in annotations]
 
-        if not annotations:
-            return None
-        deleted_ids = [annotation.id for annotation in annotations]
-        deleted_time = datetime.datetime.utcnow()
-        for annotation in annotations:
-            annotation.deleted = deleted_time
-            annotation.valid = False
-        self.db.commit_session()
         return deleted_ids
diff --git a/requirements.txt b/requirements.txt
index 7fb916a..a5153b6 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,9 +1,11 @@
-marshmallow==3.5.1
+marshmallow
 emannotationschemas>=5.4.0
-sqlalchemy<1.4
+sqlalchemy
+marshmallow-sqlalchemy
+marshmallow-enum
 psycopg2-binary
 geoalchemy2
 pytz
 alembic
 shapely
-jsonschema<4.0
\ No newline at end of file
+jsonschema
\ No newline at end of file
diff --git a/tests/test_annotation.py b/tests/test_annotation.py
index 348aada..f0cb73a 100644
--- a/tests/test_annotation.py
+++ b/tests/test_annotation.py
@@ -79,8 +79,9 @@ def test_create_reference_table(dadb_interface, annotation_metadata):
     assert table_name == table
 
     table_info = dadb_interface.database.get_table_metadata(table)
-    assert table_info["reference_table"] == "anno_test"
-
+    
+    assert table_info.anno_metadata.reference_table == "anno_test"
+    assert table_info.table_name == table_name
 
 def test_create_nested_reference_table(dadb_interface, annotation_metadata):
     table_name = "reference_tag"
@@ -108,7 +109,8 @@ def test_create_nested_reference_table(dadb_interface, annotation_metadata):
     assert table_name == table
 
     table_info = dadb_interface.database.get_table_metadata(table)
-    assert table_info["reference_table"] == "presynaptic_bouton_types"
+    
+    assert table_info.anno_metadata.reference_table == "presynaptic_bouton_types"
 
 
 def test_bad_schema_reference_table(dadb_interface, annotation_metadata):
@@ -306,4 +308,4 @@ def test_update_table_metadata(dadb_interface, annotation_metadata):
         table_name, description="New description"
     )
 
-    assert updated_metadata["description"] == "New description"
+    assert updated_metadata.anno_metadata.description == "New description"
diff --git a/tests/test_database.py b/tests/test_database.py
index f405a51..9278204 100644
--- a/tests/test_database.py
+++ b/tests/test_database.py
@@ -3,8 +3,7 @@
 import pytest
 
 from sqlalchemy import Table
-from sqlalchemy.ext.declarative.api import DeclarativeMeta
-
+from sqlalchemy.orm import DeclarativeMeta
 from emannotationschemas import type_mapping
 
 
@@ -13,28 +12,11 @@ def test_get_table_metadata(dadb_interface, annotation_metadata):
     schema_type = annotation_metadata["schema_type"]
     metadata = dadb_interface.database.get_table_metadata(table_name)
     logging.info(metadata)
-    assert metadata["schema_type"] == schema_type
-    assert metadata["table_name"] == "anno_test"
-    assert metadata["user_id"] == "foo@bar.com"
-    assert metadata["description"] == "New description"
-    assert metadata["voxel_resolution_x"] == 4.0
-
-    # test with filter to get a col value
-    metadata_value = dadb_interface.database.get_table_metadata(
-        table_name, filter_col="valid"
-    )
-    logging.info(metadata)
-    assert metadata_value == True
-
-    # test for missing column
-    with pytest.raises(AttributeError) as e:
-        bad_return = dadb_interface.database.get_table_metadata(
-            table_name, "missing_column"
-        )
-    assert (
-        str(e.value) == "type object 'AnnoMetadata' has no attribute 'missing_column'"
-    )
-
+    assert metadata.anno_metadata.schema_type == schema_type
+    assert metadata.table_name == "anno_test"
+    assert metadata.anno_metadata.user_id == "foo@bar.com"
+    assert metadata.anno_metadata.description == "New description"
+    assert metadata.anno_metadata.voxel_resolution_x == 4.0
 
 def test_get_table_sql_metadata(dadb_interface, annotation_metadata):
     table_name = annotation_metadata["table_name"]
diff --git a/tests/test_schema.py b/tests/test_schema.py
index 719ee86..a209656 100644
--- a/tests/test_schema.py
+++ b/tests/test_schema.py
@@ -1,8 +1,7 @@
 import marshmallow
 from emannotationschemas.errors import UnknownAnnotationTypeException
 import pytest
-from sqlalchemy.ext.declarative.api import DeclarativeMeta
-
+from sqlalchemy.orm import DeclarativeMeta
 
 def test_get_schema(dadb_interface):
     valid_schema = dadb_interface.schema.get_schema("synapse")