diff --git a/README.md b/README.md index ba1be95..332ac17 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ For local development and testing, copy `env.example` to `.env` and populate it. Running the API locally: -`uvicorn app.main:app --reload --host 0.0.0.0 --port 8000` +`uv run fastapi --env-file=.env fastapi dev` To run in a docker container, make sure your .env file is setup then run: @@ -15,4 +15,4 @@ To run in a docker container, make sure your .env file is setup then run: Testing: -`pytest` \ No newline at end of file +`uv run --env-file=.env pytest --maxfail=4 --tb=short -v` \ No newline at end of file diff --git a/app/carto.py b/app/carto.py index d70af88..90caa14 100644 --- a/app/carto.py +++ b/app/carto.py @@ -27,13 +27,17 @@ def __init__(self): "CARTO_TOKEN" ) # token passed in at runtime as env variable. Available at Keeper record "CARTO - New Platform" assert self.public_token, "Carto token not provided" - self.auth_header = {"Authorization": f"Bearer {self.public_token}"} + self.headers = { + "Authorization": f"Bearer {self.public_token}", + "Cache-Control": "max-age=1800", # Default Carto Cache to 30 minutes to prevent stale responses + } async def get_count( self, table: str | None, where: str | None, timeout: float, + no_cache: bool, session: aiohttp.ClientSession, request: Request, **kwargs, @@ -47,8 +51,11 @@ async def get_count( q_where = psql.SQL(f"WHERE {where} ") query = query + q_where params = {"q": query.as_string()} + headers = self.headers + if no_cache: + headers['cache-control'] = "max-age=0" async with session.get( - self.base_url, params=params, headers=self.auth_header, timeout=timeout + self.base_url, params=params, headers=headers, timeout=timeout ) as response: return await self.normalize_rv_count(request, response) @@ -82,6 +89,7 @@ async def get( out_sr: int | None, sql: str | None, timeout: float, + no_cache: bool, session: aiohttp.ClientSession, request: Request, **kwargs, @@ -134,8 +142,11 @@ async def get( query = psql.SQL(sql) table_schema = None params = {"q": query.as_string()} + headers = self.headers + if no_cache: + headers['cache-control'] = "max-age=0" async with session.get( - self.base_url, params=params, headers=self.auth_header, timeout=timeout + self.base_url, params=params, headers=headers, timeout=timeout ) as response: return await self.normalize_rv(request, response, table_schema, limit, sql) diff --git a/app/main.py b/app/main.py index 1184b66..4a49bdc 100644 --- a/app/main.py +++ b/app/main.py @@ -34,6 +34,7 @@ async def lifespan(app: FastAPI): schema_cache.check_latest_commit() commit_check_task = create_task(schema_cache.loop_commit_check()) await session_manager.start() + assert schema_cache.cache yield commit_check_task.cancel() await session_manager.stop() @@ -87,25 +88,25 @@ async def get_data( table: Annotated[ str | None, Query( - description=f"""Name of table to retrieve. Either `table` or `sql` - parameter is required. Ignored if `sql` parameter is provided. Need an - example tale? Try `table=dor_parcel`. + description=f"""Name of table to retrieve. Either `table` or `sql` + parameter is required. Ignored if `sql` parameter is provided. Need an + example tale? Try `table=dor_parcel`. {make_param_api_descriptions(api_manager, "table")}""" ), ] = None, fields: Annotated[ str | None, Query( - description=f"""List of fields to retrieve, taking the form - _field_1_,_field_2_,... To receive all fields, do not include this parameter. - Writing fields=* will return an error. Ignored if `sql` or + description=f"""List of fields to retrieve, taking the form + _field_1_,_field_2_,... To receive all fields, do not include this parameter. + Writing fields=* will return an error. Ignored if `sql` or `count_only` parameters are provided.{make_param_api_descriptions(api_manager, "fields")}""" ), ] = None, where: Annotated[ str | None, Query( - description=f"""An SQL _WHERE_ clause to filter data. Ignored if + description=f"""An SQL _WHERE_ clause to filter data. Ignored if `sql` parameter is provided.{make_param_api_descriptions(api_manager, "where")}""" ), ] = None, @@ -113,53 +114,60 @@ async def get_data( int | None, Query( description=f"""Limit to the number of records to return. AGO enforces - a limit specific to each table (frequently 2,000 records); for Carto, this API - enforces a limit of 1,000 records as Carto otherwise does not have - limits. Any user-provided limit smaller than those takes precedence. + a limit specific to each table (frequently 2,000 records); for Carto, this API + enforces a limit of 1,000 records as Carto otherwise does not have + limits. Any user-provided limit smaller than those takes precedence. Ignored if `sql` or `count_only` paramaters are provided.{make_param_api_descriptions(api_manager, "limit")}""" ), ] = None, out_sr: Annotated[ int, Query( - description=f"""Spatial Reference to return geometric records in. - Default SRID is WGS84 (4326). Ignored if dataset is not geometric, or `sql` or `count_only` + description=f"""Spatial Reference to return geometric records in. + Default SRID is WGS84 (4326). Ignored if dataset is not geometric, or `sql` or `count_only` parameters are provided.{make_param_api_descriptions(api_manager, "count_only")}""" ), ] = AbstractWorker.DEFAULT_SRID, count_only: Annotated[ bool, Query( - description=f"""Return record count of provided query. Ignored if + description=f"""Return record count of provided query. Ignored if `sql` parameter is provided.{make_param_api_descriptions(api_manager, "count_only")}""" ), ] = False, sql: Annotated[ str | None, Query( - description=f"""Raw SQL string to use when retrieving data. Users - should request no more than ~2,000 rows to avoid an `HTTP 413` error. - Either `table` or `sql` parameter is required. Need an example? Try + description=f"""Raw SQL string to use when retrieving data. Users + should request no more than ~2,000 rows to avoid an `HTTP 413` error. + Either `table` or `sql` parameter is required. Need an example? Try `sql=SELECT * FROM DOR_PARCEL LIMIT 10`.{make_param_api_descriptions(api_manager, "sql")}""" ), ] = None, service: Annotated[ Service | None, Query( - description="""Name of API service to use. If not provided, the first - API service to locate the table will be used. Ignored if `sql` parameter + description="""Name of API service to use. If not provided, the first + API service to locate the table will be used. Ignored if `sql` parameter is provided.""" ), ] = None, timeout: Annotated[ float, Query( - description="""Amount of time in seconds to wait for response from downstream APIs + description="""Amount of time in seconds to wait for response from downstream APIs before raising a timeout error""", gt=0, lt=300, ), ] = 30, + no_cache: Annotated[ + bool, + Query( + description=f"""Request fresh results from downstream APIs, ignoring any HTTP caching. + {make_param_api_descriptions(api_manager, "no_cache")}""" + ), + ] = False, session: aiohttp.ClientSession = Depends(session_manager), ) -> ReturnJson | JSONResponse: """Use this endpoint to retrieve data from the available @@ -178,11 +186,12 @@ async def get_data( "out_sr": out_sr, "count_only": count_only, "sql": sql, - "token": token, "session": session, "timeout": timeout, + "no_cache": no_cache, "request": request, "schema_cache": schema_cache, + "token": token, } if sql: if service and service.lower() != "carto": diff --git a/app/test_main.py b/app/test_main.py index 22091fa..188119f 100644 --- a/app/test_main.py +++ b/app/test_main.py @@ -45,22 +45,25 @@ def token() -> str: @pytest.mark.parametrize("service", api_manager.map_str_to_api.keys()) def test_valid(client: TestClient, service: str, table: str): """Test that each service works""" - params = {"table": table, "service": service} + params = {"table": table, "service": service, "no_cache": True} response = client.get("/get", params=params) rv = response.json() assert response.status_code == 200 assert rv["links"]["self"] == response.url + assert rv["data"]["features"], f'Service {service} found zero features in table {table}' @pytest.mark.parametrize("count_only", [True, False]) @pytest.mark.parametrize("service", ["ago"]) # Keeping this as a parameter for easier ID'ing which tests use which APIs def test_valid_private(client: TestClient, token: str, service: str, count_only: bool): """Test that a token passed in can access AGO private data""" + table = PRIVATE_TABLE params = { - "table": PRIVATE_TABLE, + "table": table, "limit": 5, "count_only": count_only, "service": service, + "no_cache": True, } response = client.get("/get", params=params) assert response.status_code >= 400 and response.status_code <= 500 @@ -71,26 +74,35 @@ def test_valid_private(client: TestClient, token: str, service: str, count_only: assert response.status_code == 200 assert rv["links"]["self"] == response.url assert "********" in rv["meta"]["service_url"] + if 'records_total' in rv['meta']: + assert rv["meta"]["records_total"] > 0, f"Service {service} found zero features in table {table}" + elif 'record_count' in rv['meta']: + assert rv["meta"]["record_count"] > 0, f"Service {service} found zero features in table {table}" @pytest.mark.parametrize("service", ["carto"]) def test_valid_private_no_interfere(client: TestClient, service: str, token: str): """Test that a private token doesn't interfere with other APIs""" - params = {"table": GOOD_TABLES[0], "limit": 5, "service": service} + table = GOOD_TABLES[0] + params = {"table": table, "limit": 5, "service": service, "no_cache": True} response = client.get( "/get", params=params, headers={"Authorization": f"Bearer {token}"} ) assert response.status_code == 200 + rv = response.json() + assert rv["data"]["features"], f'Service {service} found zero features in table {table}' @pytest.mark.parametrize("service", api_manager.map_str_to_api.keys()) def test_valid_fields(client: TestClient, service: str): """Test that the `fields` parameter returns only those fields""" + table = GOOD_TABLES[0] params = { - "table": GOOD_TABLES[0], + "table": table, "limit": 2, "fields": "objectid,document_id,document_type,display_date", "service": service, + "no_cache": True, } response = client.get("/get", params=params) assert response.status_code == 200 @@ -98,7 +110,8 @@ def test_valid_fields(client: TestClient, service: str): data = rv["data"] for feature in data["features"]: assert set(feature["properties"].keys()) == set(params["fields"].split(",")) - + assert rv["data"]["features"], f'Service {service} found zero features in table {table}' + @pytest.mark.skip("""Skipping this test because if user does not request the "objectid" field and this API doesn't include it, then AGO will not provide feature IDs. @@ -112,6 +125,7 @@ def test_valid_fields2(client: TestClient, service: str): "limit": 2, "fields": "addr_std", "service": service, + "no_cache": True, } response = client.get("/get", params=params) assert response.status_code == 200 @@ -124,16 +138,19 @@ def test_valid_fields2(client: TestClient, service: str): @pytest.mark.parametrize("service", api_manager.map_str_to_api.keys()) def test_valid_where(client: TestClient, service: str): """Test that the `where` parameter works""" + table = GOOD_TABLES[1] params = { - "table": GOOD_TABLES[1], + "table": table, "where": "objectid <= 2", "service": service, + "no_cache": True, } response = client.get("/get", params=params) rv = response.json() assert response.status_code == 200 assert rv["meta"]["record_count"] == 2 assert len(rv["data"]["features"]) == 2 + assert rv["data"]["features"], f'Service {service} found zero features in table {table}' @pytest.mark.parametrize("service", api_manager.map_str_to_api.keys()) @@ -142,21 +159,25 @@ def test_valid_where_parethesization(client: TestClient, service: str): AND doesn't decouple any existing WHERE clause, i.e. because SQL `AND` binds more tightly than `OR`""" LIMIT = 2 + table = GOOD_TABLES[1] params = { - "table": GOOD_TABLES[1], + "table": table, "where": "objectid >= 1 OR objectid >= 3", "limit": LIMIT, "service": service, + "no_cache": True, } response = client.get("/get", params=params) rv = response.json() assert response.status_code == 200 next_url = rv["links"]["next"] + assert rv["data"]["features"], f'Service {service} found zero features in table {table}' response2 = client.get(next_url) rv2 = response2.json() rv2_first_objectid = int(rv2["data"]["features"][0]["id"]) assert rv2_first_objectid >= LIMIT + assert rv2["data"]["features"], f'Service {service} found zero features in table {table}' @pytest.mark.parametrize("table", [GOOD_TABLES[0]]) @@ -164,10 +185,11 @@ def test_valid_where_parethesization(client: TestClient, service: str): def test_valid_limit_next(client: TestClient, service: str, table: str): """Test that the `limit` parameter works and that the `next` url works""" LIMIT = 2 - params = {"table": table, "limit": LIMIT, "service": service} + params = {"table": table, "limit": LIMIT, "service": service, "no_cache": True} response = client.get("/get", params=params) rv = response.json() assert response.status_code == 200 + assert rv["data"]["features"], f'Service {service} found zero features in table {table}' data = rv["data"] ids = [feature["id"] for feature in data["features"]] max_id = max(ids) @@ -178,6 +200,7 @@ def test_valid_limit_next(client: TestClient, service: str, table: str): response2 = client.get(next_url) assert response2.status_code == 200 rv2 = response2.json() + assert rv2["data"]["features"], f'Service {service} found zero features in table {table}' data2 = rv2["data"] ids2 = [feature["id"] for feature in data2["features"]] assert rv2["meta"]["record_count"] == LIMIT @@ -189,13 +212,15 @@ def test_valid_limit_next(client: TestClient, service: str, table: str): @pytest.mark.parametrize("service", api_manager.map_str_to_api.keys()) def test_valid_count_only(client: TestClient, service: str): """Test that the `count_only` parameter works""" + table = GOOD_TABLES[1] params = { - "table": GOOD_TABLES[1], + "table": table, "fields": "whatever,whatever", # Should have no effect "limit": 3, # Should have no effect "count_only": "true", "where": "objectid <= 5", "service": service, + "no_cache": True, } response = client.get("/get", params=params) assert response.status_code == 200 @@ -206,16 +231,19 @@ def test_valid_count_only(client: TestClient, service: str): @pytest.mark.parametrize("service", api_manager.map_str_to_api.keys()) def test_valid_srid(client: TestClient, service: str): """Test that the `srid` parameter works""" + table = GOOD_TABLES[0] params = { - "table": GOOD_TABLES[0], + "table": table, "limit": 2, "out_sr": 4326, "service": service, + "no_cache": True, } response = client.get("/get", params=params) rv = response.json() assert response.status_code == 200 data = rv["data"] + assert rv["data"]["features"], f'Service {service} found zero features in table {table}' params["out_sr"] = 2272 response2 = client.get("/get", params=params) @@ -223,22 +251,26 @@ def test_valid_srid(client: TestClient, service: str): assert response2.status_code == 200 data2 = rv2["data"] assert data != data2 + assert rv2["data"]["features"], f'Service {service} found zero features in table {table}' @pytest.mark.parametrize("service", ["carto"]) def test_valid_sql(client: TestClient, service: str): """Test that the `sql` parameter works, only on Carto""" + table = GOOD_TABLES[0] params = { "table": "ANSTHES", # Should have no effect "fields": "whatever,whatever", # Should have no effect "limit": 3, # Should have no effect - "sql": f"SELECT * FROM {GOOD_TABLES[0]} LIMIT 5", + "sql": f"SELECT * FROM {table} LIMIT 5", "service": service, + "no_cache": True, } response = client.get("/get", params=params) assert response.status_code == 200 - data = response.json() - assert data["meta"]["record_count"] == 5 + rv = response.json() + assert rv["meta"]["record_count"] == 5 + assert rv["data"]["features"], f'Service {service} found zero features in table {table}' @pytest.mark.parametrize("service", ["carto"]) @@ -247,6 +279,7 @@ def test_valid_sql_too_large(client: TestClient, service: str): is too large to handle but smaller than a timeout""" params = { "sql": f"SELECT * FROM {GOOD_TABLES[0]} LIMIT 50000", + "no_cache": True, } response = client.get("/get", params=params) assert response.status_code == 413 @@ -256,25 +289,33 @@ def test_valid_sql_too_large(client: TestClient, service: str): @pytest.mark.parametrize("service", [None]) def test_valid_no_service(client: TestClient, table: str, service: None): """Test that the API works if no `service` is provided""" - params = {"table": table, "limit": 1} + params = {"table": table, "limit": 1, "no_cache": True} response = client.get("/get", params=params) assert response.status_code == 200 + rv = response.json() + assert rv["data"]["features"], f'Service {service} found zero features in table {table}' @pytest.mark.parametrize("service", api_manager.map_str_to_api.keys()) def test_valid_timeout(client: TestClient, service: str): """Test that the API timeout parameter returns the correct error code""" - params = {"table": GOOD_TABLES[0], "timeout": 0.001, "service": service} + params = { + "table": GOOD_TABLES[0], + "timeout": 0.001, + "service": service, + "no_cache": True, + } response = client.get("/get", params=params) assert response.status_code == 408 + @pytest.mark.skip("""Skipping this test because these tables have differences both in timestamp fields and in geometry fields that are unrelated to this API. """) @pytest.mark.parametrize("table", GOOD_TABLES) def test_same_response(client: TestClient, table: str): """Test that the API timeout parameter returns the correct error code""" - params = {"table": table, "limit": 1} + params = {"table": table, "limit": 1, "no_cache": True} ago_params = params | {"service": "ago"} carto_params = params | {"service": "carto"} ago_response = client.get("/get", params=ago_params) @@ -416,6 +457,7 @@ def test_invalid_sql_large_payload(client: TestClient, service: None): """Test that the API fails if too large of a dataset is requsted""" params = { "sql": f"SELECT * FROM {GOOD_TABLES[0]}", + "no_cache": True, } response = client.get("/get", params=params) assert response.status_code >= 400 and response.status_code < 500 diff --git a/app/utils.py b/app/utils.py index a0b5344..073a078 100644 --- a/app/utils.py +++ b/app/utils.py @@ -2,7 +2,6 @@ import json import os -import re from asyncio import sleep from enum import Enum @@ -26,7 +25,7 @@ class SchemaCache: def __init__(self): self.folder = "/var/git/databridge-schemas" self.commit_check_delay = 300 - self.latest_commit: str = None + self.latest_repo_target: str = None self.cache: dict[str, TableSchema] = {} self.invalid_fields: list[str] = [ "shape", # Carto @@ -44,26 +43,27 @@ async def loop_commit_check(self): await sleep(self.commit_check_delay) def check_latest_commit(self): - """Check if the API has the latest commit of the schemas repository""" - print("Checking latest commit") - path = os.path.join(self.folder, ".git") - commit = None - if os.path.isdir(path): # Local development - with open(os.path.join(path, "refs", "heads", "main")) as f: - commit = f.readline().strip() - elif os.path.isfile(path): # Prod environment - print(f"DEBUG: {path}") - with open(path) as f: - content = f.read() - match = re.search(r"worktrees/([a-f0-9]{40})", content) - if match: - commit = match.group(1).strip() - assert commit - else: - print(f"Warning: {path} does not exist??") - if commit != self.latest_commit and commit: - self.update() - self.latest_commit = commit + # Resolve the symlink to its actual current directory + # Or if we're locally developing, to the full path of the repo. + # Either will work with realpath(). + # (e.g., /var/git/.worktrees//) + current_target = os.path.realpath(self.folder) + + if getattr(self, 'latest_repo_target', None) != current_target: + print(f"New symlink target detected: {current_target} Updating SchemaCache.") + try: + # Offload the blocking I/O to a separate thread + self.latest_repo_target = current_target + # Overwrite the folder path with the new target + self.folder = current_target + self.update() + + except FileNotFoundError: + # Expected race condition: git-sync swapped directories while we were reading. + # Abort this attempt; the loop will try again on the next tick. + print("Update aborted: git-sync modified files during read.") + except Exception as e: + print(f"Unexpected error updating cache: {e}") def update(self): """Call the functions necessary to update the geometry cache. Note these