Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ ANTHROPIC_API_KEY=
POSTHOG_API_KEY=
POSTHOG_HOST=
FIRECRAWL_API_KEY=
API_KEY_SECRET_PREFIX=
12 changes: 10 additions & 2 deletions app/api/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sqlalchemy.orm import Session

from app.core.database import get_db
from app.modules.auth.api_key_service import APIKeyService
from app.modules.auth.api_key_service import APIKeyService, InvalidAPIKeyFormatError
from app.modules.conversations.conversation.conversation_controller import (
ConversationController,
)
Expand Down Expand Up @@ -63,7 +63,15 @@ async def get_api_key_user(
)
return {"user_id": user.uid, "email": user.email, "auth_type": "api_key"}

user = await APIKeyService.validate_api_key(x_api_key, db)
try:
user = await APIKeyService.validate_api_key(x_api_key, db)
except InvalidAPIKeyFormatError as exc:
raise HTTPException(
status_code=401,
detail="Invalid API key format",
headers={"WWW-Authenticate": "ApiKey"},
) from exc

if not user:
raise HTTPException(
status_code=401,
Expand Down
25 changes: 20 additions & 5 deletions app/modules/auth/api_key_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,23 @@

from app.modules.users.user_model import User
from app.modules.users.user_preferences_model import UserPreferences
from google.api_core.exceptions import NotFound


class APIKeyServiceError(Exception):
"""Base exception class for APIKeyService errors."""


class InvalidAPIKeyFormatError(APIKeyServiceError):
"""Raised when the API key format is invalid."""


class APIKeyNotFoundError(APIKeyServiceError):
"""Raised when the API key is not found in the database."""


class APIKeyService:
SECRET_PREFIX = "sk-"
SECRET_PREFIX = os.getenv("API_KEY_SECRET_PREFIX", "sk-")
KEY_LENGTH = 32

@staticmethod
Expand Down Expand Up @@ -98,10 +111,12 @@ async def create_api_key(user_id: str, db: Session) -> str:
return api_key

@staticmethod
async def validate_api_key(api_key: str, db: Session) -> Optional[dict]:
async def validate_api_key(api_key: str, db: Session) -> dict:
"""Validate an API key and return user info if valid."""
if not api_key.startswith(APIKeyService.SECRET_PREFIX):
return None
raise InvalidAPIKeyFormatError(
"API key format is invalid. Expected prefix missing."
)

hashed_key = APIKeyService.hash_api_key(api_key)

Expand All @@ -115,7 +130,7 @@ async def validate_api_key(api_key: str, db: Session) -> Optional[dict]:
)

if not result:
return None
raise APIKeyNotFoundError("API key not found in the database.")

user_pref, email = result
return {"user_id": user_pref.user_id, "email": email, "auth_type": "api_key"}
Expand Down Expand Up @@ -144,7 +159,7 @@ async def revoke_api_key(user_id: str, db: Session) -> bool:

try:
client.delete_secret(request={"name": name})
except Exception:
except NotFound:
pass # Ignore if secret doesn't exist

return True
Expand Down