diff --git a/app/application.py b/app/application.py index 12bd0a90b..cd9c1383f 100644 --- a/app/application.py +++ b/app/application.py @@ -17,6 +17,7 @@ from app.db.session import configure_database_session_manager from app.dependencies.common import forbid_extra_query_params from app.errors import ApiError, ApiErrorCode +from app.graphql.router import graphql_router from app.logger import L from app.routers import router from app.schemas.api import ErrorResponse @@ -106,3 +107,4 @@ async def validation_exception_handler( }, dependencies=[Depends(forbid_extra_query_params)], ) +app.include_router(graphql_router) diff --git a/app/graphql/__init__.py b/app/graphql/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/app/graphql/filters/__init__.py b/app/graphql/filters/__init__.py new file mode 100644 index 000000000..7602fe15c --- /dev/null +++ b/app/graphql/filters/__init__.py @@ -0,0 +1,7 @@ +from app.graphql.filters import common, species, morphology # noqa: I001 + +__all__ = [ + "common", + "morphology", + "species", +] diff --git a/app/graphql/filters/common.py b/app/graphql/filters/common.py new file mode 100644 index 000000000..bd6cc13f5 --- /dev/null +++ b/app/graphql/filters/common.py @@ -0,0 +1,18 @@ +import strawberry + +from app.filters.common import AgentFilter, MTypeClassFilter, StrainFilter + + +@strawberry.experimental.pydantic.input(model=MTypeClassFilter, all_fields=True) +class MTypeClassFilterInput: + pass + + +@strawberry.experimental.pydantic.input(model=StrainFilter, all_fields=True) +class StrainFilterInput: + pass + + +@strawberry.experimental.pydantic.input(model=AgentFilter, all_fields=True) +class AgentFilterInput: + pass diff --git a/app/graphql/filters/morphology.py b/app/graphql/filters/morphology.py new file mode 100644 index 000000000..f04df22f7 --- /dev/null +++ b/app/graphql/filters/morphology.py @@ -0,0 +1,8 @@ +import strawberry + +from app.filters.morphology import MorphologyFilter + + +@strawberry.experimental.pydantic.input(model=MorphologyFilter, all_fields=True) +class MorphologyFilterInput: + pass diff --git a/app/graphql/filters/species.py b/app/graphql/filters/species.py new file mode 100644 index 000000000..0b97c7ade --- /dev/null +++ b/app/graphql/filters/species.py @@ -0,0 +1,8 @@ +import strawberry + +from app.filters.common import SpeciesFilter + + +@strawberry.experimental.pydantic.input(model=SpeciesFilter, all_fields=True) +class SpeciesFilterInput: + pass diff --git a/app/graphql/resolvers/__init__.py b/app/graphql/resolvers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/app/graphql/resolvers/morphology.py b/app/graphql/resolvers/morphology.py new file mode 100644 index 000000000..ee5cbe9b6 --- /dev/null +++ b/app/graphql/resolvers/morphology.py @@ -0,0 +1,74 @@ +import uuid +from typing import TYPE_CHECKING + +import strawberry + +import app.service.morphology +from app.dependencies.auth import user_with_project_id +from app.filters.morphology import MorphologyFilter +from app.graphql.filters.morphology import MorphologyFilterInput +from app.graphql.types.morphology import MorphologyInput, MorphologyType +from app.graphql.types.pagination import ListResponseType, PaginationRequestInput +from app.schemas.morphology import ReconstructionMorphologyCreate, ReconstructionMorphologyRead +from app.schemas.types import PaginationRequest + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + + from app.schemas.auth import UserContext, UserContextWithProjectId + + +@strawberry.type +class MorphologyQuery: + @strawberry.field + def read_many_morphologies( + self, + *, + info: strawberry.Info, + pagination_request: PaginationRequestInput, + morphology_filter: MorphologyFilterInput, + search: str | None = None, + with_facets: bool = False, + ) -> ListResponseType[MorphologyType, ReconstructionMorphologyRead]: + # for proper validation, validate the input and return a pydantic model + db: Session = info.context.db + user_context: UserContext = info.context.user_context + validated_pagination_request = PaginationRequest.model_validate( + pagination_request, from_attributes=True + ) + validated_morphology_filter = MorphologyFilter.model_validate( + morphology_filter, from_attributes=True + ) + result = app.service.morphology.read_many( + user_context=user_context, + db=db, + pagination_request=validated_pagination_request, + morphology_filter=validated_morphology_filter, + search=search, + with_facets=with_facets, + ) + return ListResponseType[MorphologyType, ReconstructionMorphologyRead].from_pydantic(result) + + @strawberry.field + def read_morphology(self, id_: uuid.UUID, info: strawberry.Info) -> MorphologyType | None: + # for proper validation, validate the input and return a pydantic model + db: Session = info.context.db + user_context: UserContext = info.context.user_context + result = app.service.morphology.read_one(user_context=user_context, db=db, id_=id_) + return MorphologyType.from_pydantic(result) if result else None + + +@strawberry.type +class MorphologyMutation: + @strawberry.mutation + def create_morphology( + self, morphology: MorphologyInput, info: strawberry.Info + ) -> MorphologyType: + # for proper validation, validate the input and return a pydantic model + db: Session = info.context.db + user_context: UserContextWithProjectId = user_with_project_id(info.context.user_context) + validated = ReconstructionMorphologyCreate.model_validate(morphology) + result = app.service.morphology.create_one( + user_context=user_context, db=db, reconstruction=validated + ) + return MorphologyType.from_pydantic(result) diff --git a/app/graphql/resolvers/species.py b/app/graphql/resolvers/species.py new file mode 100644 index 000000000..56b7c2500 --- /dev/null +++ b/app/graphql/resolvers/species.py @@ -0,0 +1,59 @@ +import uuid +from typing import TYPE_CHECKING + +import strawberry + +import app.service.species +from app.filters.common import SpeciesFilter +from app.graphql.filters.species import SpeciesFilterInput +from app.graphql.types.common import SpeciesInput, SpeciesType +from app.graphql.types.pagination import ListResponseType, PaginationRequestInput +from app.schemas.base import SpeciesCreate, SpeciesRead +from app.schemas.types import PaginationRequest + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + + +@strawberry.type +class SpeciesQuery: + @strawberry.field + def read_many_species( + self, + *, + pagination_request: PaginationRequestInput, + species_filter: SpeciesFilterInput, + info: strawberry.Info, + ) -> ListResponseType[SpeciesType, SpeciesRead]: + # for proper validation, validate the input and return a pydantic model + db: Session = info.context.db + validated_pagination_request = PaginationRequest.model_validate( + pagination_request, from_attributes=True + ) + validated_species_filter = SpeciesFilter.model_validate( + species_filter, from_attributes=True + ) + result = app.service.species.read_many( + db=db, + pagination_request=validated_pagination_request, + species_filter=validated_species_filter, + ) + return ListResponseType[SpeciesType, SpeciesRead].from_pydantic(result) + + @strawberry.field + def read_species(self, *, id_: uuid.UUID, info: strawberry.Info) -> SpeciesType | None: + # for proper validation, validate the input and return a pydantic model + db: Session = info.context.db + result = app.service.species.read_one(db=db, id_=id_) + return SpeciesType.from_pydantic(result) if result else None + + +@strawberry.type +class SpeciesMutation: + @strawberry.mutation + def create_species(self, *, species: SpeciesInput, info: strawberry.Info) -> SpeciesType: + # for proper validation, validate the input and return a pydantic model + db: Session = info.context.db + validated = SpeciesCreate.model_validate(species) + result = app.service.species.create_one(db=db, species=validated) + return SpeciesType.from_pydantic(result) diff --git a/app/graphql/router.py b/app/graphql/router.py new file mode 100644 index 000000000..dd88005cb --- /dev/null +++ b/app/graphql/router.py @@ -0,0 +1,30 @@ +from fastapi import Depends +from sqlalchemy.orm import Session +from strawberry.fastapi import BaseContext, GraphQLRouter + +from app.dependencies.auth import UserContextDep, user_with_service_admin_role +from app.dependencies.db import SessionDep +from app.graphql.schema import schema +from app.schemas.auth import UserContext + + +class Context(BaseContext): + def __init__(self, *, db: Session, user_context: UserContext) -> None: + """Initialize a new Context.""" + super().__init__() + self.db = db + self.user_context = user_context + + +def get_context(db: SessionDep, user_context: UserContextDep) -> Context: + return Context(db=db, user_context=user_context) + + +graphql_router = GraphQLRouter( + schema, + prefix="/graphql", + graphql_ide="apollo-sandbox", + context_getter=get_context, + dependencies=[Depends(user_with_service_admin_role)], + include_in_schema=False, +) diff --git a/app/graphql/schema.py b/app/graphql/schema.py new file mode 100644 index 000000000..e88978ab9 --- /dev/null +++ b/app/graphql/schema.py @@ -0,0 +1,44 @@ +import strawberry +from graphql import GraphQLError +from strawberry.schema.config import StrawberryConfig +from strawberry.types import ExecutionContext + +from app.graphql.resolvers.morphology import MorphologyMutation, MorphologyQuery +from app.graphql.resolvers.species import SpeciesMutation, SpeciesQuery + + +@strawberry.type +class Query( + MorphologyQuery, + SpeciesQuery, +): + pass + + +@strawberry.type +class Mutation( + MorphologyMutation, + SpeciesMutation, +): + pass + + +class CustomSchema(strawberry.Schema): + def process_errors( + self, + errors: list[GraphQLError], + execution_context: ExecutionContext | None = None, + ) -> None: + for error in errors: + # temporary workaround to propagate the exception to the FastAPI handlers, + # although this does not follow the GraphQL specifications + if err := getattr(error, "original_error", None): + raise err + super().process_errors(errors, execution_context) + + +schema = CustomSchema( + query=Query, + mutation=Mutation, + config=StrawberryConfig(auto_camel_case=False), +) diff --git a/app/graphql/types/__init__.py b/app/graphql/types/__init__.py new file mode 100644 index 000000000..b3fed1935 --- /dev/null +++ b/app/graphql/types/__init__.py @@ -0,0 +1,10 @@ +from app.graphql.types import common, agent, annotation, asset, role, contribution # noqa: I001 + +__all__ = [ + "agent", + "annotation", + "asset", + "common", + "contribution", + "role", +] diff --git a/app/graphql/types/agent.py b/app/graphql/types/agent.py new file mode 100644 index 000000000..65311711b --- /dev/null +++ b/app/graphql/types/agent.py @@ -0,0 +1,16 @@ +import strawberry + +from app.schemas.agent import OrganizationRead, PersonRead + + +@strawberry.experimental.pydantic.type(model=PersonRead, all_fields=True) +class PersonReadType: + pass + + +@strawberry.experimental.pydantic.type(model=OrganizationRead, all_fields=True) +class OrganizationReadType: + pass + + +AgentReadType = PersonReadType | OrganizationReadType diff --git a/app/graphql/types/annotation.py b/app/graphql/types/annotation.py new file mode 100644 index 000000000..14e43a822 --- /dev/null +++ b/app/graphql/types/annotation.py @@ -0,0 +1,8 @@ +import strawberry + +from app.schemas.annotation import Annotation + + +@strawberry.experimental.pydantic.type(model=Annotation, all_fields=True) +class AnnotationType: + pass diff --git a/app/graphql/types/asset.py b/app/graphql/types/asset.py new file mode 100644 index 000000000..dfb99b4ac --- /dev/null +++ b/app/graphql/types/asset.py @@ -0,0 +1,16 @@ +import strawberry + +from app.schemas.asset import AssetRead + + +@strawberry.experimental.pydantic.type(model=AssetRead) +class AssetReadType: + id: strawberry.auto + status: strawberry.auto + path: strawberry.auto + full_path: strawberry.auto + is_directory: strawberry.auto + content_type: strawberry.auto + size: strawberry.auto + sha256_digest: strawberry.auto + meta: strawberry.scalars.JSON diff --git a/app/graphql/types/common.py b/app/graphql/types/common.py new file mode 100644 index 000000000..6d123a03b --- /dev/null +++ b/app/graphql/types/common.py @@ -0,0 +1,45 @@ +import strawberry + +from app.schemas.base import ( + BrainRegionRead, + LicenseRead, + PointLocationBase, + SpeciesCreate, + SpeciesRead, + StrainRead, +) + + +@strawberry.experimental.pydantic.type(model=LicenseRead, all_fields=True) +class LicenseReadType: + pass + + +@strawberry.experimental.pydantic.type(model=PointLocationBase, all_fields=True) +class PointLocationBaseType: + pass + + +@strawberry.experimental.pydantic.input(model=PointLocationBase, all_fields=True) +class PointLocationBaseInput: + pass + + +@strawberry.experimental.pydantic.type(model=BrainRegionRead, all_fields=True) +class BrainRegionReadType: + pass + + +@strawberry.experimental.pydantic.type(model=StrainRead, all_fields=True) +class StrainReadType: + pass + + +@strawberry.experimental.pydantic.type(model=SpeciesRead, all_fields=True) +class SpeciesType: + pass + + +@strawberry.experimental.pydantic.input(model=SpeciesCreate, all_fields=True) +class SpeciesInput: + pass diff --git a/app/graphql/types/contribution.py b/app/graphql/types/contribution.py new file mode 100644 index 000000000..c2d2a1613 --- /dev/null +++ b/app/graphql/types/contribution.py @@ -0,0 +1,13 @@ +import strawberry + +from app.graphql.types.agent import AgentReadType +from app.schemas.contribution import ContributionReadWithoutEntity + + +@strawberry.experimental.pydantic.type(model=ContributionReadWithoutEntity) +class ContributionReadWithoutEntityType: + id: strawberry.auto + agent: AgentReadType + role: strawberry.auto + creation_date: strawberry.auto + update_date: strawberry.auto diff --git a/app/graphql/types/morphology.py b/app/graphql/types/morphology.py new file mode 100644 index 000000000..47c51a1ab --- /dev/null +++ b/app/graphql/types/morphology.py @@ -0,0 +1,13 @@ +import strawberry + +from app.schemas.morphology import ReconstructionMorphologyCreate, ReconstructionMorphologyRead + + +@strawberry.experimental.pydantic.type(model=ReconstructionMorphologyRead, all_fields=True) +class MorphologyType: + pass + + +@strawberry.experimental.pydantic.input(model=ReconstructionMorphologyCreate, all_fields=True) +class MorphologyInput: + pass diff --git a/app/graphql/types/pagination.py b/app/graphql/types/pagination.py new file mode 100644 index 000000000..21c9768cc --- /dev/null +++ b/app/graphql/types/pagination.py @@ -0,0 +1,65 @@ +import uuid + +import strawberry +from pydantic import BaseModel + +from app.schemas.types import Facet, ListResponse, PaginationRequest, PaginationResponse + + +@strawberry.experimental.pydantic.input(model=PaginationRequest, all_fields=True) +class PaginationRequestInput: + pass + + +@strawberry.experimental.pydantic.type(model=PaginationResponse, all_fields=True) +class PaginationResponseType: + pass + + +@strawberry.experimental.pydantic.type(model=Facet) +class FacetType: + id: uuid.UUID # not working with BrainRegion because using int id + label: strawberry.auto + count: strawberry.auto + type: strawberry.auto + + +@strawberry.type +class KeyValue[T]: + key: str + value: list[T] + + +@strawberry.type +class ListResponseType[T, M: BaseModel]: + """ListResponseType, with facets redefined to be compatible with GraphQL.""" + + data: list[T] + pagination: PaginationResponseType + facets: list[KeyValue[FacetType]] | None = None # Replaces dict[str, list[Facet]] + + @classmethod + def from_pydantic(cls, instance: ListResponse[M]) -> "ListResponseType[T, M]": + if instance.facets: + facets = [ + KeyValue[FacetType]( + key=facet_type, + value=[FacetType.from_pydantic(facet) for facet in facet_list], + ) + for facet_type, facet_list in instance.facets.items() + ] + else: + facets = None + return cls( + data=instance.data, # type: ignore[arg-type] + pagination=PaginationResponseType.from_pydantic(instance.pagination), + facets=facets, + ) + + def to_pydantic(self) -> ListResponse[M]: + facets = {kv.key: kv.value for kv in self.facets} if self.facets else None + return ListResponse[M]( + data=self.data, + pagination=self.pagination, + facets=facets, + ) diff --git a/app/graphql/types/role.py b/app/graphql/types/role.py new file mode 100644 index 000000000..ea09efb69 --- /dev/null +++ b/app/graphql/types/role.py @@ -0,0 +1,8 @@ +import strawberry + +from app.schemas.role import RoleRead + + +@strawberry.experimental.pydantic.type(model=RoleRead, all_fields=True) +class RoleReadType: + pass diff --git a/pyproject.toml b/pyproject.toml index 51387e195..cbf35bc21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "pydantic-settings>=2.7.1", "pyjwt>=2.10.1", "python-multipart>=0.0.20", + "strawberry-graphql[fastapi]>=0.262.5", ] requires-python = "==3.12.*" readme = "README.md" @@ -117,6 +118,9 @@ builtins-ignorelist = ["license"] "app/cli/utils.py" = [ "B006", # Do not use mutable data structures for argument defaults ] +"app/graphql/resolvers/**.py" = [ + "PLR6301", # Method could be a function, class method, or static method +] "test_dump/**.py" = ["ALL"] "test_legacy/**.py" = ["ALL"] "tests/**.py" = [ @@ -187,3 +191,7 @@ omit = [ show_error_codes = true ignore_missing_imports = true allow_redefinition = true +plugins = [ + "pydantic.mypy", + "strawberry.ext.mypy_plugin", +] diff --git a/uv.lock b/uv.lock index c60bae150..0a9f39efb 100644 --- a/uv.lock +++ b/uv.lock @@ -285,6 +285,7 @@ dependencies = [ { name = "pyjwt" }, { name = "python-multipart" }, { name = "sqlalchemy" }, + { name = "strawberry-graphql", extra = ["fastapi"] }, { name = "tqdm" }, { name = "uvicorn", extra = ["standard"] }, ] @@ -324,6 +325,7 @@ requires-dist = [ { name = "pyjwt", specifier = ">=2.10.1" }, { name = "python-multipart", specifier = ">=0.0.20" }, { name = "sqlalchemy" }, + { name = "strawberry-graphql", extras = ["fastapi"], specifier = ">=0.262.5" }, { name = "tqdm", specifier = ">=4.67.1" }, { name = "uvicorn", extras = ["standard"] }, ] @@ -392,6 +394,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/51/0b/0d7fee5919bccc1fdc1c2a7528b98f65c6f69b223a3fd8f809918c142c36/freezegun-1.5.1-py3-none-any.whl", hash = "sha256:bf111d7138a8abe55ab48a71755673dbaa4ab87f4cff5634a4442dfec34c15f1", size = 17569 }, ] +[[package]] +name = "graphql-core" +version = "3.2.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c4/16/7574029da84834349b60ed71614d66ca3afe46e9bf9c7b9562102acb7d4f/graphql_core-3.2.6.tar.gz", hash = "sha256:c08eec22f9e40f0bd61d805907e3b3b1b9a320bc606e23dc145eebca07c8fbab", size = 505353 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ae/4f/7297663840621022bc73c22d7d9d80dbc78b4db6297f764b545cd5dd462d/graphql_core-3.2.6-py3-none-any.whl", hash = "sha256:78b016718c161a6fb20a7d97bbf107f331cd1afe53e45566c59f776ed7f0b45f", size = 203416 }, +] + [[package]] name = "greenlet" version = "3.1.1" @@ -940,6 +951,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a0/4b/528ccf7a982216885a1ff4908e886b8fb5f19862d1962f56a3fce2435a70/starlette-0.46.1-py3-none-any.whl", hash = "sha256:77c74ed9d2720138b25875133f3a2dae6d854af2ec37dceb56aef370c1d8a227", size = 71995 }, ] +[[package]] +name = "strawberry-graphql" +version = "0.262.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "graphql-core" }, + { name = "packaging" }, + { name = "python-dateutil" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1d/9f/77a2611aeeef2b01dbfeea3d4a48be2517ba73935c87ee000e9c14844fd6/strawberry_graphql-0.262.5.tar.gz", hash = "sha256:92a5403133fb22ea4f31a09df9aa70567cbd7c860dc34afe92a32103125c6f26", size = 202428 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4d/6b/715835515ff21ab9de351df401a769a51b051dad6b67fa43778682de13a5/strawberry_graphql-0.262.5-py3-none-any.whl", hash = "sha256:7bc62e19326d3f5294f473c2ca3418bd01297e6abfd4a5a133f33fc9a5fcd5e1", size = 296015 }, +] + +[package.optional-dependencies] +fastapi = [ + { name = "fastapi" }, + { name = "python-multipart" }, +] + [[package]] name = "tqdm" version = "4.67.1"