diff --git a/.dockerignore b/.dockerignore index a546413..4ed29c5 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,4 +1,4 @@ * -!requirements.txt -!requirements.prod.txt -!app/ +!marble_api/ +**/__pycache__ +!pyproject.toml diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 058bf3b..e2e22eb 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -2,6 +2,8 @@ name: Unit tests on: pull_request: types: [opened, synchronize, reopened, ready_for_review] +env: + MONGODB_URI: mongodb://localhost:27017 jobs: test: if: github.event.pull_request.draft == false @@ -18,7 +20,11 @@ jobs: cache: 'pip' - name: Install python test dependencies run: | - pip install requirements.test.txt + pip install .[test] + - name: Start MongoDB + uses: supercharge/mongodb-github-action@1.12.0 + with: + mongodb-version: latest - name: Test with pytest run: | pytest ./test/ diff --git a/.gitignore b/.gitignore index 7f79574..ceafac3 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,10 @@ __pycache__/ # virtual environments venv/ +# build files +build/ +*.egg-info/ + # sqlite files *.sqlite *.db diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fafb789..8a9e74d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.7.1 + rev: v0.13.2 hooks: # Run the linter. - id: ruff diff --git a/Dockerfile b/Dockerfile index 33141eb..339ad06 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,11 +1,10 @@ FROM python:3.13-alpine -COPY requirements.txt requirements.prod.txt /app/ - -RUN python -m pip install -r /app/requirements.prod.txt - -COPY app/ /app/app/ +COPY marble_api/ /app/marble_api/ +COPY pyproject.toml /app/pyproject.toml WORKDIR /app -CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000", "--root-path", ""] +RUN pip install .[prod] && rm pyproject.toml + +CMD ["uvicorn", "marble_api:app", "--host", "0.0.0.0", "--port", "8000", "--root-path", ""] diff --git a/README.md b/README.md index d281813..969e914 100644 --- a/README.md +++ b/README.md @@ -2,15 +2,31 @@ An API for the Marble platform. +## Requirement + +- MongoDB server + ## Developing To start a development server: ```sh -python3 -m pip install -r requirements.dev.txt -fastapi dev app +python3 -m pip install .[dev] +MONGODB_URI="mongodb://localhost:27017" fastapi dev marble_api +``` + +This assumes that you have a mongodb service running at `mongodb://localhost:27017`. + +Or to start developing using docker: + +```sh +docker compose -f docker-compose.dev.yml up ``` +This will start a dedicated mongodb container for use with your app. Note that this will +track changes you make dynamically so you don't have to restart the container if you make +changes to the source code while the container is running. + ### Contributing We welcome any contributions to this code. To submit suggested changes, please do the following: @@ -80,6 +96,16 @@ Then the `/test` route will not be available from `/v3` onwards. To run tests: ```sh -python3 -m pip install -r requirements.test.txt -pytest test/ +python3 -m pip install .[dev] +MONGODB_URI="mongodb://localhost:27017" pytest ./test +``` + +This assumes that you have a mongodb service running at `mongodb://localhost:27017`. + +Alternatively you can run start up the development stack with docker compose and then +run tests in the docker container: + +```sh +docker compose -f docker-compose.dev.yml up -d +docker compose exec marble_api pytest ./test ``` diff --git a/app/__init__.py b/app/__init__.py deleted file mode 100644 index 4535d67..0000000 --- a/app/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from app.main import app - -__all__ = ["app"] diff --git a/app/database/__init__.py b/app/database/__init__.py deleted file mode 100644 index dec992c..0000000 --- a/app/database/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -import os - -DB_BACKEND = os.getenv("MARBLE_API_DB_BACKEND", "sqlite") - -if DB_BACKEND == "sqlite": - from app.database.sqlite import SessionDep, on_start -else: - raise RuntimeError(f"Database backend '{DB_BACKEND}' is not supported.") - -__all__ = ["SessionDep", "on_start"] diff --git a/app/database/sqlite.py b/app/database/sqlite.py deleted file mode 100644 index 9d78d7d..0000000 --- a/app/database/sqlite.py +++ /dev/null @@ -1,29 +0,0 @@ -import os -from typing import TYPE_CHECKING, Annotated - -from fastapi import Depends -from sqlalchemy import create_engine -from sqlmodel import Session, SQLModel - -if TYPE_CHECKING: - from typing import Iterable - -sqlite_file_name = os.getenv("MARBLE_API_SQLITE_DB_FILE", "database.db") -sqlite_url = f"sqlite:///{sqlite_file_name}" - -connect_args = {"check_same_thread": False} -engine = create_engine(sqlite_url, connect_args=connect_args) - - -def on_start() -> None: - """Run when the app starts to initialize the database.""" - SQLModel.metadata.create_all(engine) - - -def _get_session() -> "Iterable[Session]": - """Yield the database session.""" - with Session(engine) as session: - yield session - - -SessionDep = Annotated[Session, Depends(_get_session)] diff --git a/app/versions/v1/app.py b/app/versions/v1/app.py deleted file mode 100644 index 8b6329a..0000000 --- a/app/versions/v1/app.py +++ /dev/null @@ -1,91 +0,0 @@ -from typing import Annotated - -from fastapi import FastAPI, HTTPException, Query, Request -from pydantic import AfterValidator, BaseModel -from sqlmodel import select - -from app.database import SessionDep -from app.versions.v1.models import DataRequest, DataRequestPublic, DataRequestUpdate - -app = FastAPI(version="1") - - -@app.post("/data-publish-request") -async def post_data_request( - data_request: Annotated[DataRequest, AfterValidator(DataRequest.model_validate)], session: SessionDep -) -> DataRequestPublic: - """Create a new data request and return the newly created data request.""" - session.add(data_request) - session.commit() - session.refresh(data_request) - return data_request - - -@app.patch("/data-publish-request/{request_id}") -async def patch_data_request( - request_id: str, data_request: DataRequestUpdate, session: SessionDep -) -> DataRequestPublic: - """Update fields of data request and return the updated data request.""" - data_request = session.get(DataRequest, request_id) - if not data_request: - raise HTTPException(status_code=404, detail="data publish request not found") - data_request.sqlmodel_update(data_request.model_dump(exclude_unset=True)) - session.add(data_request) - session.commit() - session.refresh(data_request) - return data_request - - -@app.get("/data-publish-request/{request_id}") -async def get_data_request(request_id: str, session: SessionDep) -> DataRequestPublic: - """Get a data request with the given request_id.""" - request = session.get(DataRequest, request_id) - if not request: - raise HTTPException(status_code=404, detail="data publish request not found") - return request - - -class _LinksResponse(BaseModel): - rel: str - href: str - type: str - - -class _DataRequestMultiResponse(BaseModel): - data_publish_requests: list[DataRequestPublic] - links: list[_LinksResponse] - - -@app.get("/data-publish-request") -async def get_data_requests( - session: SessionDep, - request: Request, - offset: Annotated[int, Query(ge=0)] = 0, - limit: Annotated[int, Query(le=100, gt=0)] = 10, -) -> _DataRequestMultiResponse: - """ - Return all data requests. - - This response is paginated and will only return at most limit objects at a time (maximum 100). - Use the offset and limit parameters to select specific ranges of data requests. - """ - requests = session.exec(select(DataRequest).offset(offset).limit(limit + 1)).all() - links = [] - if len(requests) > limit: - links.append( - { - "rel": "next", - "type": "application/json", - "href": str(request.url.replace_query_params(offset=offset + len(requests))), - } - ) - requests.pop() - if offset: - links.append( - { - "ref": "prev", - "type": "application/json", - "href": str(request.url.replace_query_params(offset=max(0, offset - limit))), - } - ) - return {"data_publish_requests": requests, "links": links} diff --git a/app/versions/v1/models.py b/app/versions/v1/models.py deleted file mode 100644 index 389d1cd..0000000 --- a/app/versions/v1/models.py +++ /dev/null @@ -1,110 +0,0 @@ -from datetime import datetime - -from pydantic import field_validator -from sqlmodel import Field, SQLModel - - -class DataRequestBase(SQLModel): - """ - SQL model base object for Data Requests. - - This object contains field validators. - """ - - @field_validator("start_date","end_date", check_fields=False) - def validate_timeperiod(cls, start_date: datetime, end_date: datetime) -> datetime: - """Check that the time periods are correct""" - if end_date < start_date: - raise ValueError("End date can not be earlier than start date") - return start_date, end_date - - @field_validator("longitude", "latitude", "myFile", check_fields=False) - def validate_title(cls, latitude: str, longitude: str, myFile:str) -> str: - """Ensure list lengths match""" - if len(latitude) != len(longitude): - raise ValueError("Latitude and longitude lists are different lengths") - """If there is no latitude and longitude, make sure there is a GeoJSON""" - if latitude == None: - if myFile == None: - raise ValueError("Must include either GeoJSON file or manually inputted latitude and longitudes") - """Check latitude and longitude ranges""" - for i in latitude: - if i > 90: - raise ValueError("Latitudes must be between -90 and 90 degrees") - if i < -90: - raise ValueError("Latitudes must be between -90 and 90 degrees") - for i in longitude: - if i > 180: - raise ValueError("Longitudes must be between -180 and 180 degrees") - if i < -180: - raise ValueError("Longitudes must be between -180 and 180 degrees") - - return latitude, longitude, myFile - - -class DataRequest(DataRequestBase, table=True): - """ - Database model for Data Requests. - - This object contains the representation of the data in the database. - """ - id: int | None = Field(default=None, primary_key=True) - username: str - title: str - desc: str | None - fname: str - lname: str - email: str - geometry: str - latitude: str | None - longitude: str | None - myFile: str | None - start_date: datetime - end_date: datetime - variables: str | None - models: str | None - path: str - input: str | None - link: str | None - - -class DataRequestPublic(DataRequestBase): - """ - Public model for Data Requests. - - This object contains all fields that are visible to users. - If a field defined in DataRequests should not be visible to users, it will not - be included in this object. - """ - - id: int | None = Field(default=None, primary_key=True) - title: str - desc: str | None - fname: str - lname: str - email: str - geometry: str - latitude: str | None - longitude: str | None - myFile: str | None - start_date: datetime - end_date: datetime - variables: str | None - models: str | None - path: str - input: str | None - link: str | None - - -class DataRequestUpdate(DataRequestBase): - """ - Update model for Data Requests. - - This object contains all fields that are updatable on the DataRequest model. - Fields should be optional unless they *must* be updated every time a change is made. - """ - - title: str | None = None - desc: str | None = None - date: datetime | None = None - # TODO: make sure parameters added in DataRequest are made optional here diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml new file mode 100644 index 0000000..e1d5f3d --- /dev/null +++ b/docker-compose.dev.yml @@ -0,0 +1,13 @@ +services: + marble_api: + image: python:3.13-alpine + volumes: + - .:/app + working_dir: /app + command: ["sh", "-c", "pip install -e .[dev,test] && fastapi dev marble_api --host 0.0.0.0"] + environment: + - MONGODB_URI=mongodb://mongo:27017 + ports: + - 8000:8000 + mongo: + image: mongo:5.0.4 diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..bc35c87 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,9 @@ +services: + marble_api: + image: marbleclimate/marble-api:latest + environment: + - MONGODB_URI=mongodb://mongo:27017 + ports: + - 8000:8000 + mongo: + image: mongo:5.0.4 diff --git a/marble_api/__init__.py b/marble_api/__init__.py new file mode 100644 index 0000000..3d4f5eb --- /dev/null +++ b/marble_api/__init__.py @@ -0,0 +1,3 @@ +from marble_api.app import app + +__all__ = ["app"] diff --git a/app/main.py b/marble_api/app.py similarity index 61% rename from app/main.py rename to marble_api/app.py index e05c09d..36d8e4a 100644 --- a/app/main.py +++ b/marble_api/app.py @@ -1,27 +1,13 @@ -from contextlib import asynccontextmanager -from typing import TYPE_CHECKING - from fastapi import FastAPI, Request -from app.database import on_start -from app.utils import get_routes -from app.versions.v1.app import app as v1_app -from app.versions.versioning import add_fallback_routes +from marble_api.utils.routing import get_routes +from marble_api.versions.v1.app import app as v1_app +from marble_api.versions.versioning import add_fallback_routes VERSIONS = [("/v1", v1_app)] -if TYPE_CHECKING: - from typing import AsyncIterable - - -@asynccontextmanager -async def lifespan(app: FastAPI) -> "AsyncIterable": - """Execute setup/teardown code around when the app starts/stops.""" - on_start() - yield - -app = FastAPI(lifespan=lifespan) +app = FastAPI() @app.get("/") diff --git a/marble_api/database/__init__.py b/marble_api/database/__init__.py new file mode 100644 index 0000000..e494b44 --- /dev/null +++ b/marble_api/database/__init__.py @@ -0,0 +1,20 @@ +import os + +from pymongo import AsyncMongoClient +from pymongo.asynchronous.database import AsyncDatabase + + +class Client(AsyncMongoClient): + """AsyncMongoClient with different defaults.""" + + def get_default_database(self, default: str | None = "marble-api", **kwargs) -> AsyncDatabase: + """Override AsyncMongoClient.default_get_database but with a specific default.""" + return super().get_default_database(default, **kwargs) + + @property + def db(self) -> AsyncDatabase: + """Shortcut to get_default_database.""" + return self.get_default_database() + + +client = Client(os.environ["MONGODB_URI"], tz_aware=True) diff --git a/marble_api/utils/geojson.py b/marble_api/utils/geojson.py new file mode 100644 index 0000000..1ed5802 --- /dev/null +++ b/marble_api/utils/geojson.py @@ -0,0 +1,35 @@ +from collections.abc import Iterable +from itertools import zip_longest + +from geojson_pydantic import LineString, MultiLineString, MultiPoint, MultiPolygon, Point, Polygon +from geojson_pydantic.types import ( + BBox, + LineStringCoords, + MultiLineStringCoords, + MultiPointCoords, + MultiPolygonCoords, + PolygonCoords, + Position, +) + +type Geometry = LineString | MultiLineString | MultiPoint | MultiPolygon | Point | Polygon +type Coordinates = ( + LineStringCoords | MultiLineStringCoords | MultiPointCoords | MultiPolygonCoords | PolygonCoords | Position +) + + +def _coordinates_to_points(coordinates: Coordinates) -> Iterable[Position]: + if isinstance(coordinates[0], Iterable): + for coord in coordinates: + yield from _coordinates_to_points(coord) + else: + yield coordinates + + +def bbox_from_coordinates(coordinates: Coordinates) -> BBox: + """Return a bounding box from a set of coordinates.""" + min_max = [] + for values in zip_longest(*_coordinates_to_points(coordinates)): + real_values = [v or 0 for v in values] # coordinates without elevation are considered to be at elevation 0 + min_max.append((min(real_values), max(real_values))) + return [v for val in min_max for v in val] diff --git a/marble_api/utils/models.py b/marble_api/utils/models.py new file mode 100644 index 0000000..60594d3 --- /dev/null +++ b/marble_api/utils/models.py @@ -0,0 +1,58 @@ +from copy import deepcopy +from typing import Any + +import bson +from pydantic import BaseModel, create_model +from pydantic.fields import FieldInfo + + +def partial_model(model: type[BaseModel]) -> type[BaseModel]: + """ + Make all fields in a BaseModel class optional. + + This makes each field's default None but does not update the annotation or + validations so explicitly setting the value to None may still raise a + validation error. Also, if a field has validate_default=True this will + make validate_default=False for the partial model to ensure that the new + (None) default value is not validated. + + >>> class C(BaseModel): + a: int + >>> C(a=2).a + 2 + >>> C() # validation error since a must be an integer + >>> @partial_model + ... class B(C): ... + >>> B().a # is None + >>> B(a=5).a + 5 + >>> B(a=None) # validation error since a must be an integer + + Adapted from https://stackoverflow.com/a/76560886/5992438 + """ + + def make_field_optional(field: FieldInfo) -> tuple[Any, FieldInfo]: + new_field = deepcopy(field) + new_field.validate_default = False + new_field.default = None + return new_field.annotation, new_field + + return create_model( + model.__name__, + __base__=model, + __module__=model.__module__, + **{name: make_field_optional(info) for name, info in model.model_fields.items()}, + ) + + +def object_id(id_: str, error: Exception | None) -> bson.ObjectId: + """ + Convert id_ to a bson.ObjectId. + + Raises error from bson.errors.InvalidId if error is provided + """ + try: + return bson.ObjectId(id_) + except bson.errors.InvalidId as err: + if error is not None: + raise error from err diff --git a/app/utils.py b/marble_api/utils/routing.py similarity index 64% rename from app/utils.py rename to marble_api/utils/routing.py index 6c5eff1..9c46887 100644 --- a/app/utils.py +++ b/marble_api/utils/routing.py @@ -1,23 +1,18 @@ -from typing import TYPE_CHECKING +from typing import Iterable -from starlette.routing import Mount - -if TYPE_CHECKING: - from typing import Iterable - - from fastapi import FastAPI - from starlette.routing import Route +from fastapi import FastAPI +from starlette.routing import Mount, Route def get_routes( - app_: "FastAPI | Mount", included_in_schema_only: bool = True -) -> "Iterable[dict[str, Route | Mount | FastAPI]]": + app_: FastAPI | Mount, included_in_schema_only: bool = True +) -> Iterable[dict[str, Route | Mount | FastAPI]]: """ Yield a dictionary containing information about routes contained in app_. This includes FastAPI applications recursively mounted as well. If included_in_schema_only is True, do not include routes who are not included in the schema - (ie. their included_in_schema attribute is False). + (ie. their included_in_schema attribute is False)._in_ """ for route in app_.routes: if isinstance(route, Mount): diff --git a/app/versions/__init__.py b/marble_api/versions/__init__.py similarity index 100% rename from app/versions/__init__.py rename to marble_api/versions/__init__.py diff --git a/marble_api/versions/v1/__init__.py b/marble_api/versions/v1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/marble_api/versions/v1/app.py b/marble_api/versions/v1/app.py new file mode 100644 index 0000000..2d44d95 --- /dev/null +++ b/marble_api/versions/v1/app.py @@ -0,0 +1,7 @@ +from fastapi import FastAPI + +from marble_api.versions.v1.data_request.routes import router as data_request_router + +app = FastAPI(version="1") + +app.include_router(data_request_router) diff --git a/marble_api/versions/v1/data_request/__init__.py b/marble_api/versions/v1/data_request/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/marble_api/versions/v1/data_request/models.py b/marble_api/versions/v1/data_request/models.py new file mode 100644 index 0000000..0def7b2 --- /dev/null +++ b/marble_api/versions/v1/data_request/models.py @@ -0,0 +1,121 @@ +from collections.abc import Sized +from typing import Required, TypedDict + +from bson import ObjectId +from pydantic import ( + AfterValidator, + AwareDatetime, + BaseModel, + ConfigDict, + EmailStr, + Field, + ValidationInfo, + field_validator, +) +from pydantic.functional_validators import BeforeValidator +from pydantic.json_schema import SkipJsonSchema +from stac_pydantic.item import Item +from stac_pydantic.links import Links +from typing_extensions import Annotated + +from marble_api.utils.geojson import Geometry, bbox_from_coordinates +from marble_api.utils.models import partial_model + +PyObjectId = Annotated[str, BeforeValidator(str)] +Temporal = Annotated[list[AwareDatetime], Field(..., min_length=1, max_length=2), AfterValidator(sorted)] + + +class Author(TypedDict, total=False): + """Author definition.""" + + first_name: str | None = None + last_name: Required[str] + email: EmailStr | None = None + + +class DataRequest(BaseModel): + """ + Database model for Data Requests. + + This object contains the representation of the data in the database. + """ + + id: SkipJsonSchema[PyObjectId | None] = Field(default=None, validation_alias="_id", exclude=True) + user: str + title: str + description: str | None = None + authors: list[Author] + geometry: Geometry | None + temporal: Temporal + links: Links + path: str + contact: EmailStr + additional_paths: list[str] = [] + variables: list[str] = [] + extra_properties: dict[str, str] = {} + model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True) + + @field_validator("user", "title", "description", "authors", "path", "contact") + @classmethod + def min_length_if_set(cls, value: Sized | None, info: ValidationInfo) -> Sized | None: + """Raise an error if the value is not None and is empty.""" + assert value is None or len(value), f"{info.field_name} must be None or non-empty" + return value + + +@partial_model +class DataRequestUpdate(DataRequest): + """ + Update model for Data Requests. + + This object contains all fields that are updatable on the DataRequest model. + Fields should be optional unless they *must* be updated every time a change is made. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True, json_encoders={ObjectId: str}) + + +class DataRequestPublic(DataRequest): + """ + Public model for Data Requests. + + This allows for the id field to be included in the response extra fields (like stac_item) so that they can be visible in API responses. + """ + + id: Annotated[str, BeforeValidator(str)] = Field(..., validation_alias="_id") + model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True, extra="allow") + + @property + def stac_item(self) -> Item: + """Dynamically create a STAC item representation of this data.""" + item = { + "type": "Feature", + "stac_version": "1.1.0", + "geometry": self.geometry and self.geometry.model_dump(), + "stac_extensions": [], # TODO + "id": self.id, # TODO + "bbox": None, + "properties": dict(self.extra_properties), # TODO: add more + "links": self.links.model_dump(), + "assets": {}, # TODO: determine assets from other fields + } + + # STAC spec recommends including datetime even if using start_datetime and end_datetime + # See: https://github.com/radiantearth/stac-spec/blob/master/best-practices.md#datetime-selection + item["properties"]["datetime"] = self.temporal[0].isoformat() + + if len(set(self.temporal)) > 1: + item["properties"]["start_datetime"], item["properties"]["end_datetime"] = [ + t.isoformat() for t in self.temporal + ] + + if self.geometry: + item["bbox"] = item["geometry"].get("bbox") or bbox_from_coordinates(self.geometry.coordinates) + return item + + +class DataRequestsResponse(BaseModel): + """Response model for returning multiple data requests.""" + + data_requests: list[DataRequestPublic] + links: Links diff --git a/marble_api/versions/v1/data_request/routes.py b/marble_api/versions/v1/data_request/routes.py new file mode 100644 index 0000000..efccf6c --- /dev/null +++ b/marble_api/versions/v1/data_request/routes.py @@ -0,0 +1,145 @@ +from typing import Annotated + +import pymongo +from bson import ObjectId +from fastapi import APIRouter, HTTPException, Query, Request, Response, status +from pymongo import ReturnDocument + +from marble_api.database import client +from marble_api.utils.models import object_id +from marble_api.versions.v1.data_request.models import ( + DataRequest, + DataRequestPublic, + DataRequestsResponse, + DataRequestUpdate, +) + +router = APIRouter(prefix="/data-requests") + + +def _data_request_id(id_: str) -> ObjectId: + return object_id(id_, HTTPException(status_code=404, detail=f"data publish request with id={id_} not found")) + + +@router.post("/") +async def post_data_request(data_request: DataRequest) -> DataRequestPublic: + """Create a new data request and return the newly created data request.""" + new_data_request = data_request.model_dump(by_alias=True) + result = await client.db["data-request"].insert_one(new_data_request) + new_data_request["id"] = str(result.inserted_id) + return new_data_request + + +@router.patch("/{request_id}") +async def patch_data_request(request_id: str, data_request: DataRequestUpdate) -> DataRequestPublic: + """Update fields of data request and return the updated data request.""" + updated_fields = data_request.model_dump(exclude_unset=True, by_alias=True) + selector = {"_id": _data_request_id(request_id)} + if updated_fields: + updated_fields.update(data_request.model_dump(include="stac_item")) + result = await client.db["data-request"].find_one_and_update( + selector, {"$set": updated_fields}, return_document=ReturnDocument.AFTER + ) + if result is not None: + return result + else: + if (result := await client.db["data-request"].find_one(selector)) is not None: + return result + + raise HTTPException(status_code=404, detail="data publish request not found") + + +@router.get("/{request_id}", response_model_by_alias=False) +async def get_data_request(request_id: str, stac: bool = False) -> DataRequestPublic: + """Get a data request with the given request_id.""" + if (result := await client.db["data-request"].find_one({"_id": _data_request_id(request_id)})) is not None: + if stac: + try: + result["stac_item"] = DataRequestPublic(**result).stac_item + except Exception as e: + raise Exception(result) from e + return result + + raise HTTPException(status_code=404, detail="data publish request not found") + + +@router.delete("/{request_id}") +async def delete_data_request(request_id: str) -> Response: + """Delete a data request with the given request_id.""" + result = await client.db["data-request"].delete_one({"_id": _data_request_id(request_id)}) + + if result.deleted_count == 1: + return Response(status_code=status.HTTP_204_NO_CONTENT) + + raise HTTPException(status_code=404, detail="data publish request not found") + + +@router.get("/") +async def get_data_requests( + request: Request, + after: str | None = None, + before: str | None = None, + limit: Annotated[int, Query(le=100, gt=0)] = 10, + stac: bool = False, +) -> DataRequestsResponse: + """ + Return all data requests. + + This response is paginated and will only return at most limit objects at a time (maximum 100). + Use the offset and limit parameters to select specific ranges of data requests. + """ + reverse_it = False + if after: + db_request = ( + client.db["data-request"].find({"_id": {"$gte": _data_request_id(after)}}).sort("_id", pymongo.ASCENDING) + ) + elif before: + db_request = ( + client.db["data-request"].find({"_id": {"$lte": _data_request_id(before)}}).sort("_id", pymongo.DESCENDING) + ) + reverse_it = True # put the eventual result back in ascending order for consistency + else: + db_request = client.db["data-request"].find({}).sort("_id", pymongo.ASCENDING) + + data_requests = await db_request.limit(limit + 1).to_list() + if reverse_it: + data_requests = reversed(data_requests) + + query_params = {} + + over_limit = len(data_requests) > limit + + if data_requests: + if after: + if over_limit: + query_params["after"] = data_requests[-1]["_id"] + query_params["before"] = data_requests.pop(0)["_id"] + elif before: + if over_limit: + query_params["before"] = data_requests[0]["_id"] + query_params["after"] = data_requests.pop()["_id"] + elif over_limit: + query_params["after"] = data_requests.pop()["_id"] + + links = [] + + base_url = request.url.remove_query_params(["after", "before"]) + if query_params.get("after"): + links.append( + { + "rel": "next", + "type": "application/json", + "href": str(base_url.include_query_params(after=query_params["after"])), + } + ) + if query_params.get("before"): + links.append( + { + "rel": "prev", + "type": "application/json", + "href": str(base_url.include_query_params(before=query_params["before"])), + } + ) + if stac: + data_requests = [{**r, "stac_item": DataRequestPublic(**r).stac_item} for r in data_requests] + return {"data_requests": data_requests, "links": links} diff --git a/app/versions/versioning.py b/marble_api/versions/versioning.py similarity index 95% rename from app/versions/versioning.py rename to marble_api/versions/versioning.py index f3c8fe2..4f94960 100644 --- a/app/versions/versioning.py +++ b/marble_api/versions/versioning.py @@ -2,7 +2,7 @@ from fastapi import FastAPI -from app.utils import get_routes +from marble_api.utils.routing import get_routes if TYPE_CHECKING: from typing import Callable, Iterable diff --git a/pyproject.toml b/pyproject.toml index 9fda7a2..57943b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,18 @@ classifiers = [ "Programming Language :: Python :: 3.13", "Programming Language :: Python :: 3 :: Only", ] +dependencies = [ + "fastapi~=0.115", + "pymongo~=4.14", + "geojson-pydantic~=2.0", + "stac-pydantic~=3.4", + "pydantic[email]~=2.11" +] + +[project.optional-dependencies] +dev = ["ruff~=0.13", "pre-commit~=4.3", "fastapi[standard]"] +prod = ["uvicorn~=0.34"] +test = ["pytest~=8.4", "faker~=37.8", "pystac[validation]~=1.14", "httpx~=0.28"] [tool.ruff] line-length = 120 @@ -35,3 +47,8 @@ convention = "numpy" [tool.ruff.lint.per-file-ignores] "test/**.py" = ["D", "ANN"] + +[tool.pytest.ini_options] +markers = [ + "no_db_cleanup: tests with this mark will not clean up the database immediately after the test." +] \ No newline at end of file diff --git a/requirements.dev.txt b/requirements.dev.txt deleted file mode 100644 index 9f0b397..0000000 --- a/requirements.dev.txt +++ /dev/null @@ -1,4 +0,0 @@ --r requirements.txt -ruff -pre-commit -fastapi[standard] diff --git a/requirements.prod.txt b/requirements.prod.txt deleted file mode 100644 index 0b2163c..0000000 --- a/requirements.prod.txt +++ /dev/null @@ -1,2 +0,0 @@ --r requirements.txt -uvicorn~=0.34.0 diff --git a/requirements.test.txt b/requirements.test.txt deleted file mode 100644 index 26b77f6..0000000 --- a/requirements.test.txt +++ /dev/null @@ -1,2 +0,0 @@ --r requirements.txt -pytest diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 18f5cab..0000000 --- a/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -fastapi~=0.115.8 -sqlmodel~=0.0.22 diff --git a/test/conftest.py b/test/conftest.py index e69de29..9ce3eca 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -0,0 +1,12 @@ +import pytest +from faker_providers import DataRequestProvider, GeoJsonProvider + + +@pytest.fixture(scope="session") +def anyio_backend(): + return "asyncio" + + +@pytest.fixture(scope="session") +def faker_providers(): + return {"DataRequestProvider": DataRequestProvider, "GeoJsonProvider": GeoJsonProvider} diff --git a/test/faker_providers.py b/test/faker_providers.py new file mode 100644 index 0000000..cbe20f1 --- /dev/null +++ b/test/faker_providers.py @@ -0,0 +1,148 @@ +import datetime + +import bson +import pytest +from faker import Faker +from faker.providers import BaseProvider + +from marble_api.versions.v1.data_request.models import DataRequest, DataRequestPublic, DataRequestUpdate + + +class GeoJsonProvider(BaseProvider): + def point(self, dimensions=None): + point = [self.generator.random.uniform(-180, 180), self.generator.random.uniform(-90, 90)] + if dimensions == 3 or (dimensions is None and self.generator.pybool()): + point.append(self.generator.random.uniform(-100, 100)) + return point + + def bbox(self, dimensions=None): + if dimensions is None: + dimensions = self.generator.pybool() + return [a for b in zip(*(sorted(x) for x in zip(self.point(dimensions), self.point(dimensions)))) for a in b] + + def line(self, dimensions=None): + return [self.point(dimensions) for _ in range(self.generator.pyint(min_value=2, max_value=12))] + + def linear_ring(self, dimensions=None): + ring = [self.point(dimensions) for _ in range(self.generator.random.randint(3, 100))] + ring.append(list(ring[0])) + return ring + + def _geo_base(self): + base = {} + if self.generator.random.random() < 0.5: + base["bbox"] = self.bbox() + return base + + def geo_point(self, dimensions=None): + return {**self._geo_base(), "type": "Point", "coordinates": self.point(dimensions)} + + def geo_multipoint(self, dimensions=None): + return { + **self._geo_base(), + "type": "MultiPoint", + "coordinates": [self.point(dimensions) for _ in range(self.generator.pyint(min_value=1, max_value=12))], + } + + def geo_linestring(self, dimensions=None): + return {**self._geo_base(), "type": "LineString", "coordinates": self.line(dimensions)} + + def geo_multilinestring(self, dimensions=None): + return { + **self._geo_base(), + "type": "MultiLineString", + "coordinates": [self.line(dimensions) for _ in range(self.generator.pyint(min_value=1, max_value=12))], + } + + def geo_polygon(self, dimensions=None): + return {**self._geo_base(), "type": "Polygon", "coordinates": [self.linear_ring(dimensions)]} + + def geo_multipolygon(self, dimensions=None): + return { + **self._geo_base(), + "type": "MultiPolygon", + "coordinates": [ + [self.linear_ring(dimensions) for _ in range(self.generator.pyint(min_value=1, max_value=12))] + ], + } + + def geometry(self, dimensions=None): + if dimensions is None: + dimensions = self.generator.random.choice([3, 2, None]) + return self.generator.random.choice( + [ + self.geo_point, + self.geo_multipoint, + self.geo_linestring, + self.geo_multilinestring, + self.geo_polygon, + self.geo_multipolygon, + ] + )(dimensions) + + +class DataRequestProvider(GeoJsonProvider): + def author(self): + author_ = {"last_name": self.generator.last_name()} + if self.generator.pybool(): + author_["first_name"] = self.generator.first_name() + if self.generator.pybool(): + author_["email"] = self.generator.email() + return author_ + + def utc_date_time_seconds_precision(self): + return self.generator.date_time(tzinfo=datetime.timezone.utc).replace(microsecond=0) + + def temporal(self): + opt = self.generator.random.random() + if opt < 1 / 3: + return sorted( + [ + self.utc_date_time_seconds_precision(), + self.utc_date_time_seconds_precision(), + ] + ) + elif opt < 2 / 3: + return [self.utc_date_time_seconds_precision()] * 2 + else: + return [self.utc_date_time_seconds_precision()] + + def link(self): + return {"href": self.generator.uri(), "rel": self.generator.word(), "type": self.generator.mime_type()} + + def _data_request_inputs(self, unset=None): + inputs = dict( + id=bson.ObjectId(), + user=self.generator.profile("username")["username"], + title=self.generator.sentence(), + description=(None if self.generator.pybool(30) else self.generator.paragraph()), + authors=[self.author() for _ in range(self.generator.random.randint(1, 10))], + geometry=self.geometry(), + temporal=self.temporal(), + links=[self.link() for _ in range(self.generator.random.randint(0, 10))], + path=self.generator.file_path(), + contact=self.generator.email(), + additional_paths=[self.generator.file_path() for _ in range(self.generator.random.randint(0, 10))], + variables=([] if self.generator.pybool(10) else self.generator.pylist(allowed_types=[str])), + extra_properties=({} if self.generator.pybool(10) else self.generator.pydict(allowed_types=[str])), + ) + if unset: + for field in unset: + inputs.pop(field) + return inputs + + def data_request(self, unset=None, **kwargs): + return DataRequest(**{**self._data_request_inputs(unset=unset), **kwargs}) + + def data_request_public(self, unset=None, **kwargs): + return DataRequestPublic(**{**self._data_request_inputs(unset=unset), **kwargs}) + + def data_request_update(self, unset=None, **kwargs): + return DataRequestUpdate(**{**self._data_request_inputs(unset=unset), **kwargs}) + + +@pytest.fixture(scope="session") +def fake(): + fake_ = Faker() + fake_.add_provider(DataRequestProvider) + return fake_ diff --git a/test/integration/conftest.py b/test/integration/conftest.py new file mode 100644 index 0000000..875f1f6 --- /dev/null +++ b/test/integration/conftest.py @@ -0,0 +1,38 @@ +import functools + +import pytest +from httpx import ASGITransport, AsyncClient + +from marble_api import app +from marble_api.database import client + + +@pytest.fixture(scope="session", autouse=True) +def init_test_db(): + # use different default database name from prod/dev in order to minimize the chance + # of accidentally using a prod/dev database. + client.get_default_database = functools.partial(client.get_default_database, default="marble-api-test") + + +@pytest.fixture(scope="session", autouse=True) +async def check_empty_test_db(init_test_db): + database = client.db + if await database.list_collection_names(): + raise RuntimeError( + f"Database {database.name} contains some collections. Tests must be run on an empty database." + ) + + +@pytest.fixture(autouse=True) +async def refresh_database(request): + try: + yield + finally: + if "no_db_cleanup" not in request.keywords: + await client.drop_database(client.db.name) + + +@pytest.fixture +async def async_client(): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + yield client diff --git a/test/integration/test_app.py b/test/integration/test_app.py new file mode 100644 index 0000000..ba405c8 --- /dev/null +++ b/test/integration/test_app.py @@ -0,0 +1,16 @@ +import pytest + +from marble_api.app import VERSIONS + +pytestmark = [pytest.mark.anyio, pytest.mark.no_db_cleanup] + + +async def test_root_in_root(async_client): + resp = await async_client.get("/") + assert {"methods": ["GET"], "path": "/"} in resp.json()["routes"] + + +async def test_version_roots_in_root(async_client): + resp = await async_client.get("/") + for version, _ in VERSIONS: + assert {"methods": ["GET"], "path": f"{version}/"} in resp.json()["routes"] diff --git a/test/integration/versions/v1/data_request/conftest.py b/test/integration/versions/v1/data_request/conftest.py new file mode 100644 index 0000000..793de6a --- /dev/null +++ b/test/integration/versions/v1/data_request/conftest.py @@ -0,0 +1,9 @@ +import pytest +from faker import Faker + + +@pytest.fixture(scope="session") +def fake(faker_providers): + fake_ = Faker() + fake_.add_provider(faker_providers["DataRequestProvider"]) + return fake_ diff --git a/test/integration/versions/v1/data_request/test_routes.py b/test/integration/versions/v1/data_request/test_routes.py new file mode 100644 index 0000000..4bdde76 --- /dev/null +++ b/test/integration/versions/v1/data_request/test_routes.py @@ -0,0 +1,222 @@ +import inspect +import json +from urllib.parse import parse_qs, urlparse + +import bson +import pytest +from stac_pydantic import Item + +from marble_api.database import client +from marble_api.versions.v1.data_request.models import DataRequest, DataRequestPublic +from marble_api.versions.v1.data_request.routes import get_data_requests + +pytestmark = pytest.mark.anyio + + +class _TestGet: + n_data_requests = 2 + + @pytest.fixture(scope="class", autouse=True) + @classmethod + async def load_data(cls, fake): + data = [fake.data_request().model_dump() for _ in range(cls.n_data_requests)] + await client.db.get_collection("data-request").insert_many(data) + + @pytest.fixture(scope="class", autouse=True) + @classmethod + async def cleanup(cls): + try: + yield + finally: + await client.drop_database(client.db.name) + + @pytest.fixture(scope="class") + @classmethod + async def data_requests(cls): + yield await client.db.get_collection("data-request").find({}).to_list() + + +@pytest.mark.no_db_cleanup +class TestGetOne(_TestGet): + async def test_get(self, async_client, data_requests): + resp = await async_client.get(f"/v1/data-requests/{data_requests[0]['_id']}") + assert resp.status_code == 200 + assert DataRequestPublic(**data_requests[0]) == DataRequestPublic(**resp.json()) + + async def test_get_stac(self, async_client, data_requests): + resp = await async_client.get(f"/v1/data-requests/{data_requests[0]['_id']}?stac=true") + assert resp.status_code == 200 + assert (item := resp.json().get("stac_item")) + Item(**item) + + async def test_bad_id(self, async_client): + resp = await async_client.get("/v1/data-requests/id-does-not-exist") + assert resp.status_code == 404 + + +@pytest.mark.no_db_cleanup +class TestGetMany(_TestGet): + default_link_limit = inspect.signature(get_data_requests).parameters["limit"].default + n_data_requests = default_link_limit + 2 + + async def test_get(self, async_client, data_requests): + response = await async_client.get("/v1/data-requests/") + models = {str(req["_id"]): DataRequestPublic(**req) for req in data_requests} + for req in response.json()["data_requests"]: + assert DataRequestPublic(**req) == models[req["id"]] + + async def test_get_stac(self, async_client): + resp = await async_client.get("/v1/data-requests/?stac=true") + for req in resp.json()["data_requests"]: + assert (item := req.get("stac_item")) + Item(**item) + + async def test_get_limit_default(self, async_client): + response = await async_client.get("/v1/data-requests/") + assert len(response.json()["data_requests"]) == self.default_link_limit + + async def test_get_limit_non_default(self, async_client): + response = await async_client.get("/v1/data-requests/?limit=5") + assert len(response.json()["data_requests"]) == 5 + + async def test_get_limit_more(self, async_client): + response = await async_client.get(f"/v1/data-requests/?limit={self.n_data_requests + 1}") + assert len(response.json()["data_requests"]) == self.n_data_requests + + async def test_get_limit_none(self, async_client): + response = await async_client.get("/v1/data-requests/?limit=0") + assert response.status_code == 422 + + async def test_get_limit_over_max(self, async_client): + response = await async_client.get("/v1/data-requests/?limit=200") + assert response.status_code == 422 + + async def test_get_first_page_links(self, async_client): + response = await async_client.get("/v1/data-requests/") + links = response.json()["links"] + assert len(links) == 1 + link = links[0] + assert link["rel"] == "next" + assert link["type"] == "application/json" + assert link["href"].startswith(str(response.url)) + assert (after_id := parse_qs(urlparse(link["href"]).query).get("after")) + assert after_id not in [r["id"] for r in response.json()["data_requests"]] + + async def test_get_last_page_links(self, async_client): + response = await async_client.get("/v1/data-requests/") + next_link = next(link for link in response.json()["links"] if link["rel"] == "next") + response2 = await async_client.get(next_link["href"]) + links = response2.json()["links"] + assert len(links) == 1 + link = links[0] + assert link["rel"] == "prev" + assert link["type"] == "application/json" + assert link["href"].startswith(str(response.url)) + assert (before_id := parse_qs(urlparse(link["href"]).query).get("before")) + assert before_id not in [r["id"] for r in response.json()["data_requests"]] + + async def test_get_mid_page_links(self, async_client): + response = await async_client.get("/v1/data-requests/?limit=4") + next_link = next(link for link in response.json()["links"] if link["rel"] == "next") + response2 = await async_client.get(next_link["href"]) + links = response2.json()["links"] + assert len(links) == 2 + assert {link["rel"] for link in links} == {"prev", "next"} + for link in links: + assert link["type"] == "application/json" + assert link["href"].startswith(str(response.url)) + assert parse_qs(urlparse(link["href"]).query).get("limit") == ["4"] + if link["rel"] == "prev": + assert (id_ := parse_qs(urlparse(link["href"]).query).get("before")) + elif link["rel"] == "next": + assert (id_ := parse_qs(urlparse(link["href"]).query).get("after")) + assert id_ not in [r["id"] for r in response.json()["data_requests"]] + + +class TestPost: + async def test_valid(self, fake, async_client): + data = fake.data_request().model_dump_json() + response = await async_client.post("/v1/data-requests/", json=json.loads(data)) + assert response.status_code == 200 + assert (id_ := response.json().get("id")) + bson.ObjectId(id_) # check that the id is a valid object id + assert json.loads(data) == json.loads(DataRequest(**response.json()).model_dump_json()) + + async def test_invalid(self, fake, async_client): + data = json.loads(fake.data_request().model_dump_json()) + data["authors"] = [] + response = await async_client.post("/v1/data-requests/", json=data) + assert response.status_code == 422 + + +class _TestUpdate: + @pytest.fixture + async def loaded_data(self, fake): + model = json.loads(fake.data_request().model_dump_json()) + resp = await client.db.get_collection("data-request").insert_one(model) + model.pop("_id") + model["id"] = str(resp.inserted_id) + return model + + +class TestPatch(_TestUpdate): + async def test_valid(self, loaded_data, async_client, fake): + title = fake.sentence() + update = {"title": title} + response = await async_client.patch(f"/v1/data-requests/{loaded_data['id']}", json=update) + assert response.status_code == 200 + loaded_data.update(update) + assert loaded_data == response.json() + + async def test_valid_multiple(self, loaded_data, async_client, fake): + title = fake.sentence() + authors = [fake.author(), fake.author()] + update = {"title": title, "authors": authors} + response = await async_client.patch(f"/v1/data-requests/{loaded_data['id']}", json=update) + assert response.status_code == 200 + loaded_data.update(update) + assert loaded_data == response.json() + + async def test_update_nothing(self, loaded_data, async_client): + response = await async_client.patch(f"/v1/data-requests/{loaded_data['id']}", json={}) + assert response.status_code == 200 + assert loaded_data == response.json() + + async def test_update_everything(self, loaded_data, async_client, fake): + update = json.loads(fake.data_request().model_dump_json()) + response = await async_client.patch(f"/v1/data-requests/{loaded_data['id']}", json=update) + assert response.status_code == 200 + update["id"] = loaded_data["id"] + assert update == response.json() + + async def test_no_id_update(self, loaded_data, async_client): + update = {"id": str(bson.ObjectId())} + response = await async_client.patch(f"/v1/data-requests/{loaded_data['id']}", json=update) + assert response.status_code == 200 + assert response.json()["id"] == loaded_data["id"] + assert response.json()["id"] != update["id"] + assert loaded_data == response.json() + + async def test_invalid_unset_value(self, loaded_data, async_client): + response = await async_client.patch(f"/v1/data-requests/{loaded_data['id']}", json={"title": None}) + assert response.status_code == 422 + + async def test_invalid_bad_type(self, loaded_data, async_client): + response = await async_client.patch(f"/v1/data-requests/{loaded_data['id']}", json={"title": 10}) + assert response.status_code == 422 + + async def test_bad_id(self, async_client): + resp = await async_client.patch("/v1/data-requests/id-does-not-exist", json={}) + assert resp.status_code == 404, resp.json() + + +class TestDelete(_TestUpdate): + async def test_exists(self, loaded_data, async_client): + response = await async_client.delete(f"/v1/data-requests/{loaded_data['id']}") + assert response.status_code == 204 + resp = await client.db.get_collection("data-request").find_one({"_id": bson.ObjectId(loaded_data["id"])}) + assert resp is None + + async def test_bad_id(self, async_client): + resp = await async_client.delete("/v1/data-requests/id-does-not-exist") + assert resp.status_code == 404, resp.json() diff --git a/test/unit/conftest.py b/test/unit/conftest.py new file mode 100644 index 0000000..d4533a0 --- /dev/null +++ b/test/unit/conftest.py @@ -0,0 +1,12 @@ +from unittest.mock import MagicMock + +import pytest + +from marble_api.database import client + + +@pytest.fixture(autouse=True) +def mock_db(): + mock = MagicMock() + client.get_default_database = mock + return mock diff --git a/test/unit/database/test_database.py b/test/unit/database/test_database.py new file mode 100644 index 0000000..9128cf3 --- /dev/null +++ b/test/unit/database/test_database.py @@ -0,0 +1,19 @@ +from marble_api.database import Client, client + + +class TestClient: + def test_default_database(self): + assert Client("mongodb://example.com").get_default_database().name == "marble-api" + + def test_default_database_from_uri(self): + assert Client("mongodb://example.com/other-db").get_default_database().name == "other-db" + + def test_db(self): + assert Client("mongodb://example.com").db.name == "marble-api" + + def test_db_from_uri(self): + assert Client("mongodb://example.com/other-db").db.name == "other-db" + + +def test_client_singleton(): + assert isinstance(client, Client) diff --git a/test/unit/utils/test_utils_geojson.py b/test/unit/utils/test_utils_geojson.py new file mode 100644 index 0000000..e2d5b71 --- /dev/null +++ b/test/unit/utils/test_utils_geojson.py @@ -0,0 +1,24 @@ +from marble_api.utils.geojson import bbox_from_coordinates + + +class TestBboxFromCoordinates: + def test_2d_point(self): + assert bbox_from_coordinates([1, 2]) == [1, 1, 2, 2] + + def test_3d_point(self): + assert bbox_from_coordinates([1, 2, 3]) == [1, 1, 2, 2, 3, 3] + + def test_2d_line(self): + assert bbox_from_coordinates([[1, 2], [-1, -3]]) == [-1, 1, -3, 2] + + def test_3d_line(self): + assert bbox_from_coordinates([[1, 2, 4], [-1, -3, 33]]) == [-1, 1, -3, 2, 4, 33] + + def test_mixed_d_line(self): + assert bbox_from_coordinates([[1, 2], [-1, -3, 33]]) == [-1, 1, -3, 2, 0, 33] + + def test_deeply_nested(self): + assert bbox_from_coordinates([[[[1, 2], [-1, -3, 33]]]]) == [-1, 1, -3, 2, 0, 33] + + def test_different_nested(self): + assert bbox_from_coordinates([[1, 2], [[[-1, -3, 33]]]]) == [-1, 1, -3, 2, 0, 33] diff --git a/test/unit/utils/test_utils_models.py b/test/unit/utils/test_utils_models.py new file mode 100644 index 0000000..c81d017 --- /dev/null +++ b/test/unit/utils/test_utils_models.py @@ -0,0 +1,61 @@ +import bson +import pytest +from pydantic import BaseModel, Field, ValidationError, field_validator + +from marble_api.utils.models import object_id, partial_model + + +class TestPartial: + @pytest.fixture + def partial_class(self): + @partial_model + class PModel(BaseModel): + a: int + b: str = Field(None, validate_default=True) + + @field_validator("b") + @classmethod + def not_none(cls, value): + assert value is not None + return value + + return PModel + + def test_nothing_required(self, partial_class): + assert partial_class().model_dump() == {"a": None, "b": None} + + def test_field_without_default_settable(self, partial_class): + assert partial_class(a=10).model_dump() == {"a": 10, "b": None} + + def test_field_with_default_settable(self, partial_class): + assert partial_class(b="other").model_dump() == {"a": None, "b": "other"} + + def test_validations_still_work(self, partial_class): + with pytest.raises(ValidationError): + partial_class(a="some string") + + def test_no_default_validation(self, partial_class): + with pytest.raises(ValidationError): + partial_class(b=None) + + +class TestObjectId: + def test_valid_id(self): + id_ = bson.ObjectId() + assert id_ == object_id(str(id_), Exception) + + def test_invalid_id(self): + with pytest.raises(Exception): + object_id("invalid string", Exception) + + def test_invalid_id_custom_exception(self): + class MyCustomException(Exception): ... + + with pytest.raises(MyCustomException): + object_id("invalid string", MyCustomException) + + def test_invalid_id_custom_message(self): + with pytest.raises(Exception) as e: + msg = "here is a message" + object_id("invalid string", Exception(msg)) + assert msg == str(e) diff --git a/test/unit/versions/test_versioning.py b/test/unit/versions/test_versioning.py new file mode 100644 index 0000000..be8872a --- /dev/null +++ b/test/unit/versions/test_versioning.py @@ -0,0 +1 @@ +# TODO: test this when we have a version 2 to test diff --git a/test/unit/versions/v1/data_request/conftest.py b/test/unit/versions/v1/data_request/conftest.py new file mode 100644 index 0000000..307fb21 --- /dev/null +++ b/test/unit/versions/v1/data_request/conftest.py @@ -0,0 +1,9 @@ +import pytest +from faker import Faker + + +@pytest.fixture(scope="session") +def fake(faker_providers) -> Faker: + fake_ = Faker() + fake_.add_provider(faker_providers["DataRequestProvider"]) + return fake_ diff --git a/test/unit/versions/v1/data_request/test_models.py b/test/unit/versions/v1/data_request/test_models.py new file mode 100644 index 0000000..bcbe1d6 --- /dev/null +++ b/test/unit/versions/v1/data_request/test_models.py @@ -0,0 +1,146 @@ +import datetime + +import pytest +from pydantic import TypeAdapter, ValidationError +from pystac import Item + +from marble_api.versions.v1.data_request.models import Author, DataRequestUpdate + + +class TestAuthor: + validator = TypeAdapter(Author) + + def test_all(self): + author = Author(first_name="first", last_name="last", email="email@example.com") + self.validator.validate_python(author) + + def test_minimal(self): + author = Author(last_name="last") + self.validator.validate_python(author) + + def test_invalid_email(self): + author = Author(last_name="last", email="not an email") + with pytest.raises(ValidationError): + self.validator.validate_python(author) + + +class TestDataRequest: + @pytest.fixture + def fake_class(self, fake): + return fake.data_request + + def test_id_dumped(self, fake_class): + assert "id" not in fake_class().model_dump() + + @pytest.mark.parametrize("field", ["user", "title", "description", "authors", "path", "contact"]) + def test_text_fields_not_empty(self, fake_class, field): + with pytest.raises(ValidationError): + fake_class(**{field: ""}) + + @pytest.mark.parametrize( + "field", + [ + "user", + "title", + "authors", + "temporal", + "links", + "path", + "contact", + "additional_paths", + "variables", + "extra_properties", + ], + ) + def test_fields_not_nullable(self, fake_class, field): + with pytest.raises(ValidationError): + fake_class(**{field: None}) + + @pytest.mark.parametrize( + "field", + [ + "description", + "additional_paths", + "variables", + "extra_properties", + ], + ) + def test_fields_default_if_unset(self, fake_class, field): + request = fake_class(unset=[field]) + assert request.model_dump()[field] == type(request).model_fields[field].default + + def test_id_is_str(self, fake_class): + assert isinstance(fake_class().id, str) + + def test_temporal_sorted(self, fake_class): + now = datetime.datetime.now(tz=datetime.timezone.utc) + temporal = [now, now - datetime.timedelta(hours=1)] + request = fake_class(temporal=temporal) + assert request.temporal == temporal[::-1] + + def test_temporal_tzaware(self, fake_class): + with pytest.raises(ValidationError): + fake_class(temporal=[datetime.datetime.now()]) + + +class TestDataRequestPublic(TestDataRequest): + @pytest.fixture + def fake_class(self, fake): + return fake.data_request_public + + def test_id_dumped(self, fake_class): + assert "id" in fake_class().model_dump() + + class TestStacItem: + def test_valid(self, fake_class): + assert Item.from_dict(fake_class().stac_item) + + def test_geometry(self, fake_class): + request = fake_class() + assert request.stac_item["geometry"] == request.geometry.model_dump() + assert request.stac_item["bbox"] + + def test_null_geometry(self, fake_class): + request = fake_class(geometry=None) + assert request.stac_item["geometry"] is None + assert request.stac_item["bbox"] is None + + def test_single_temporal(self, fake_class): + now = datetime.datetime.now(tz=datetime.timezone.utc) + request = fake_class(temporal=[now]) + assert request.stac_item["properties"]["datetime"] == now.isoformat() + assert "start_datetime" not in request.stac_item["properties"] + assert "end_datetime" not in request.stac_item["properties"] + + def test_range_temporal(self, fake_class): + now = datetime.datetime.now(tz=datetime.timezone.utc) + temporal = [now, now + datetime.timedelta(hours=1)] + request = fake_class(temporal=temporal) + assert request.stac_item["properties"]["datetime"] == temporal[0].isoformat() + assert request.stac_item["properties"]["start_datetime"] == temporal[0].isoformat() + assert request.stac_item["properties"]["end_datetime"] == temporal[1].isoformat() + + def test_extra_properties(self, fake_class): + request = fake_class() + item = request.stac_item + assert set(request.extra_properties) & set(item["properties"]) == set(request.extra_properties) + + def test_links(self, fake_class): + request = fake_class() + assert request.links.model_dump() == request.stac_item["links"] + + def test_id(self, fake_class): + request = fake_class() + assert request.id == request.stac_item["id"] + + +class TestDataRequestUpdate(TestDataRequest): + @pytest.fixture + def fake_class(self, fake): + return fake.data_request_update + + def test_all_fields_optional(self): + DataRequestUpdate() + + def test_all_defaults_none(self): + assert all(field.default is None for field in DataRequestUpdate.model_fields.values())