Skip to content

Commit 92ce2c5

Browse files
feat: implemented backend changes
2 parents 14f77e5 + 241528a commit 92ce2c5

File tree

11 files changed

+179
-128
lines changed

11 files changed

+179
-128
lines changed

backend/pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ dependencies = [
1717
"fastapi-pagination>=0.12.34",
1818
"bcrypt==4.0.1",
1919
"google-genai>=1.5.0",
20-
"itsdangerous (>=2.2.0,<3.0.0)",
21-
"authlib (>=1.5.2,<2.0.0)",
20+
"starlette (>=0.46.2,<0.47.0)",
2221
]
2322

2423
[tool.uv]
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""Add auth0_id field to Users table
2+
3+
Revision ID: 1425c896d3ef
4+
Revises: cb16ae472c1e
5+
Create Date: 2025-05-10 00:12:00.358973
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
import sqlmodel.sql.sqltypes
11+
12+
13+
# revision identifiers, used by Alembic.
14+
revision = '1425c896d3ef'
15+
down_revision = 'cb16ae472c1e'
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
def upgrade():
21+
# ### commands auto generated by Alembic - please adjust! ###
22+
op.alter_column('user', 'hashed_password',
23+
existing_type=sa.VARCHAR(),
24+
nullable=True)
25+
op.drop_index('ix_user_auth0_id', table_name='user')
26+
op.create_index(op.f('ix_user_auth0_id'), 'user', ['auth0_id'], unique=False)
27+
# ### end Alembic commands ###
28+
29+
30+
def downgrade():
31+
# ### commands auto generated by Alembic - please adjust! ###
32+
op.drop_index(op.f('ix_user_auth0_id'), table_name='user')
33+
op.create_index('ix_user_auth0_id', 'user', ['auth0_id'], unique=True)
34+
op.alter_column('user', 'hashed_password',
35+
existing_type=sa.VARCHAR(),
36+
nullable=False)
37+
# ### end Alembic commands ###

backend/src/auth/auth0_api.py

Lines changed: 46 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,58 @@
1-
from typing import Annotated, Any
2-
3-
from fastapi import APIRouter, Depends, HTTPException, Request
4-
from sqlmodel import Session
5-
from starlette.responses import RedirectResponse
6-
7-
from src.core.db import get_db
8-
from src.dependencies.auth0 import (
9-
get_auth0_service,
10-
get_current_user,
11-
get_current_user_claims,
1+
from fastapi import APIRouter, Request
2+
from fastapi.responses import RedirectResponse
3+
from authlib.integrations.starlette_client import OAuth
4+
from src.core.config import settings
5+
6+
router = APIRouter(tags=["auth"])
7+
8+
oauth = OAuth()
9+
oauth.register(
10+
name="auth0",
11+
client_id=settings.AUTH0_CLIENT_ID,
12+
client_secret=settings.AUTH0_CLIENT_SECRET,
13+
server_metadata_url=f"https://{settings.AUTH0_DOMAIN}/.well-known/openid-configuration",
14+
client_kwargs={"scope": "openid profile email"},
1215
)
13-
from src.services.auth0 import Auth0Service
14-
from src.users.auth0 import get_or_create_user_from_auth0
15-
from src.users.models import User
16-
from src.users.schemas import UserPublic
17-
18-
router = APIRouter(prefix="/auth0", tags=["auth0"])
1916

2017

2118
@router.get("/login")
22-
async def login(
23-
request: Request, auth_service: Annotated[Auth0Service, Depends(get_auth0_service)]
24-
) -> RedirectResponse:
25-
return await auth_service.login(request)
19+
async def login(request: Request):
20+
redirect_uri = request.url_for("auth0_callback")
21+
return await oauth.auth0.authorize_redirect(
22+
request,
23+
redirect_uri,
24+
prompt="select_account",
25+
connection="google-oauth2"
26+
)
2627

2728

28-
@router.get("/callback")
29-
async def callback(
30-
request: Request,
31-
session: Annotated[Session, Depends(get_db)],
32-
auth_service: Annotated[Auth0Service, Depends(get_auth0_service)],
33-
) -> RedirectResponse:
34-
try:
35-
# Exchange auth code for tokens
36-
token_response = await auth_service.callback(request)
37-
access_token = token_response.get("access_token")
29+
@router.get("/callback", name="auth0_callback")
30+
async def auth0_callback(request: Request):
31+
token = await oauth.auth0.authorize_access_token(request)
3832

39-
# Store access token in session for later use
40-
request.session["access_token"] = access_token
33+
user = token.get("userinfo") or await oauth.auth0.userinfo(token=token)
4134

42-
# Get user info from Auth0
43-
user_info = await auth_service.get_user_info(access_token)
35+
request.session["user"] = {
36+
"email": user["email"],
37+
"name": user.get("name"),
38+
"picture": user.get("picture"),
39+
"sub": user.get("sub"),
40+
}
4441

45-
# Get or create user in our database
46-
db_user = await get_or_create_user_from_auth0(session, user_info)
47-
48-
# Store user ID in session
49-
request.session["user_id"] = str(db_user.id)
50-
51-
# Redirect to the frontend after successful authentication
52-
return RedirectResponse(url="/")
53-
except Exception as e:
54-
# Log the error and redirect to error page
55-
return RedirectResponse(url=f"/auth0/error?message={str(e)}")
42+
return RedirectResponse(url="http://localhost:5173/collections")
5643

5744

5845
@router.get("/logout")
59-
async def logout(
60-
auth_service: Annotated[Auth0Service, Depends(get_auth0_service)],
61-
) -> RedirectResponse:
62-
return auth_service.logout()
63-
64-
65-
@router.get("/me", response_model=UserPublic)
66-
async def read_users_me(
67-
current_user: Annotated[User, Depends(get_current_user)],
68-
) -> User:
69-
return current_user
70-
71-
72-
@router.get("/validate")
73-
async def validate_token(
74-
claims: Annotated[dict[str, Any], Depends(get_current_user_claims)],
75-
) -> dict[str, Any]:
76-
return claims
77-
78-
79-
@router.get("/error")
80-
async def auth_error(message: str = "Authentication error"):
81-
raise HTTPException(status_code=401, detail=message)
46+
async def logout(request: Request):
47+
request.session.clear()
48+
return RedirectResponse(
49+
url=f"https://{settings.AUTH0_DOMAIN}/v2/logout"
50+
f"?client_id={settings.AUTH0_CLIENT_ID}"
51+
f"&returnTo=http://localhost:5173"
52+
)
53+
54+
55+
@router.get("/me")
56+
async def me(request: Request):
57+
user = request.session.get("user")
58+
return {"authenticated": bool(user), "user": user}

backend/src/auth/services.py

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,18 @@
22
from typing import Annotated, Any
33

44
import jwt
5-
from fastapi import Depends, HTTPException, status
5+
from fastapi import Depends, HTTPException, status, Request
66
from fastapi.security import OAuth2PasswordBearer
77
from jwt.exceptions import InvalidTokenError
88
from passlib.context import CryptContext
99
from pydantic import ValidationError
10-
from sqlmodel import Session
10+
from sqlmodel import Session, select
1111

1212
from src.auth.schemas import TokenPayload
1313
from src.core.config import settings
1414
from src.core.db import get_db
1515
from src.users.models import User
16+
from src.users.schemas import UserPublic
1617

1718
ALGORITHM = "HS256"
1819

@@ -23,21 +24,50 @@
2324
TokenDep = Annotated[str, Depends(reusable_oauth2)]
2425

2526

26-
def get_current_user(session: SessionDep, token: TokenDep) -> User:
27+
def get_user_from_session(request: Request, session: SessionDep) -> User:
28+
session_user = request.session.get("user")
29+
if not session_user:
30+
raise HTTPException(status_code=401, detail="Not authenticated (no session)")
31+
32+
user = session.exec(select(User).where(User.email == session_user["email"])).first()
33+
if not user or not user.is_active:
34+
raise HTTPException(status_code=401, detail="Invalid session user")
35+
return UserPublic.model_validate(user)
36+
37+
38+
def get_user_from_token(
39+
session: SessionDep,
40+
token: Annotated[str, Depends(reusable_oauth2)],
41+
) -> User:
2742
try:
2843
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
2944
token_data = TokenPayload(**payload)
45+
user = session.get(User, token_data.sub)
46+
if not user or not user.is_active:
47+
raise HTTPException(status_code=401, detail="Invalid user")
48+
return user
3049
except (InvalidTokenError, ValidationError):
31-
raise HTTPException(
32-
status_code=status.HTTP_403_FORBIDDEN,
33-
detail="Could not validate credentials",
34-
)
35-
user = session.get(User, token_data.sub)
36-
if not user:
37-
raise HTTPException(status_code=404, detail="User not found")
38-
if not user.is_active:
39-
raise HTTPException(status_code=400, detail="Inactive user")
40-
return user
50+
raise HTTPException(status_code=403, detail="Invalid token")
51+
52+
53+
def get_current_user(
54+
request: Request,
55+
session: SessionDep,
56+
token: Annotated[str | None, Depends(OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/tokens", auto_error=False))] = None,
57+
) -> User:
58+
print("in get current user")
59+
# Prefer session (Auth0 flow)
60+
session_user = request.session.get("user")
61+
if session_user:
62+
print("Session user found:", session_user["email"])
63+
res = get_user_from_session(request, session)
64+
print("User from session:", res)
65+
return res
66+
# Fallback to token (JWT flow)
67+
if token:
68+
return get_user_from_token(session, token)
69+
70+
raise HTTPException(status_code=401, detail="Not authenticated")
4171

4272

4373
CurrentUser = Annotated[User, Depends(get_current_user)]
@@ -53,8 +83,15 @@ def authenticate(*, session: Session, email: str, password: str) -> User | None:
5383
db_user = get_user_by_email(session=session, email=email)
5484
if not db_user:
5585
return None
86+
87+
# Auth0 users may not have a password
88+
if not db_user.hashed_password:
89+
# Return None for users without a password when using password authentication
90+
return None
91+
5692
if not verify_password(password, db_user.hashed_password):
5793
return None
94+
5895
return db_user
5996

6097

@@ -67,3 +104,15 @@ def create_access_token(subject: str | Any, expires_delta: timedelta) -> str:
67104

68105
def get_password_hash(password: str) -> str:
69106
return pwd_context.hash(password)
107+
108+
109+
def get_or_create_user_by_email(session: Session, email: str, defaults: dict = {}) -> User:
110+
user = session.exec(select(User).where(User.email == email)).first()
111+
if user:
112+
return user
113+
user = User(email=email, **defaults)
114+
session.add(user)
115+
session.commit()
116+
session.refresh(user)
117+
return user
118+

backend/src/core/config.py

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,12 @@ class Settings(BaseSettings):
4545
POSTGRES_USER: str
4646
POSTGRES_PASSWORD: str
4747
POSTGRES_DB: str = ""
48+
# Session Configuration
49+
SESSION_MAX_AGE: int = 60 * 60 * 24 * 7 # 7 days
4850

49-
# Auth0 Configuration
5051
AUTH0_CLIENT_ID: str
5152
AUTH0_CLIENT_SECRET: str
5253
AUTH0_DOMAIN: str
53-
AUTH0_ISSUER: str = "" # Default empty string to make it optional
54-
AUTH0_CALLBACK_URL: str
55-
AUTH0_LOGOUT_URL: str = "" # Default empty string to make it optional
56-
AUTH0_AUDIENCE: str
57-
58-
# Session Configuration
59-
SESSION_MAX_AGE: int = 60 * 60 * 24 * 7 # 7 days
6054

6155
@computed_field # type: ignore[misc]
6256
@property
@@ -107,29 +101,5 @@ def _enforce_non_default_secrets(self) -> Self:
107101

108102
return self
109103

110-
@computed_field # type: ignore[prop-decorator]
111-
@property
112-
def auth0_jwks_url(self) -> str:
113-
"""Get the JWKS URL for token validation."""
114-
return f"https://{self.AUTH0_DOMAIN}/.well-known/jwks.json"
115-
116-
@computed_field # type: ignore[prop-decorator]
117-
@property
118-
def auth0_authorization_url(self) -> str:
119-
"""Get the authorization URL for login."""
120-
return f"https://{self.AUTH0_DOMAIN}/authorize"
121-
122-
@computed_field # type: ignore[prop-decorator]
123-
@property
124-
def auth0_token_url(self) -> str:
125-
"""Get the token URL for token exchange."""
126-
return f"https://{self.AUTH0_DOMAIN}/oauth/token"
127-
128-
@computed_field # type: ignore[prop-decorator]
129-
@property
130-
def auth0_userinfo_url(self) -> str:
131-
"""Get the userinfo URL for fetching user data."""
132-
return f"https://{self.AUTH0_DOMAIN}/userinfo"
133-
134104

135105
settings = Settings() # type: ignore

backend/src/routers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@
1313
api_router.include_router(user_router, prefix="/users", tags=["users"])
1414
api_router.include_router(flashcards_router, tags=["flashcards"])
1515
api_router.include_router(stats_router, tags=["stats"])
16+
api_router.include_router(auth0_router, prefix="/auth0", tags=["auth0"])

backend/src/users/api.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any
22

3-
from fastapi import APIRouter, HTTPException
3+
from fastapi import APIRouter, HTTPException, Request
44

55
from src.auth.services import CurrentUser, SessionDep
66
from src.core.config import settings
@@ -12,10 +12,11 @@
1212

1313

1414
@router.get("/me", response_model=UserPublic)
15-
def read_user_me(current_user: CurrentUser) -> Any:
15+
def read_user_me(request: Request, current_user: CurrentUser) -> Any:
1616
"""
1717
Get current user.
1818
"""
19+
print("session content:", request.session)
1920
return current_user
2021

2122

backend/src/users/models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import uuid
2-
from typing import TYPE_CHECKING
2+
from typing import TYPE_CHECKING, Optional
33

44
from sqlmodel import Field, Relationship
55

@@ -11,7 +11,8 @@
1111

1212
class User(UserBase, table=True):
1313
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
14-
hashed_password: str
14+
auth0_id: Optional[str] = Field(default=None, index=True)
15+
hashed_password: Optional[str] = Field(default=None)
1516
collections: list["Collection"] = Relationship(
1617
back_populates="user",
1718
cascade_delete=True,

0 commit comments

Comments
 (0)