Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/north_mcp_python_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware

from .auth import AuthContextMiddleware, NorthAuthBackend, on_auth_error
from .auth import (
AuthContextMiddleware,
HeadersContextMiddleware,
NorthAuthBackend,
on_auth_error,
)


def is_debug_mode() -> bool:
Expand Down Expand Up @@ -66,5 +71,6 @@ def _add_middleware(self, app: Starlette) -> None:
on_error=on_auth_error,
),
Middleware(AuthContextMiddleware, debug=self._debug),
Middleware(HeadersContextMiddleware, debug=self._debug),
]
app.user_middleware.extend(middleware)
39 changes: 39 additions & 0 deletions src/north_mcp_python_sdk/auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
import contextvars
import logging
from typing import Any

import jwt
from pydantic import BaseModel, Field, ValidationError
Expand Down Expand Up @@ -35,6 +36,10 @@ def __init__(
"north_auth_context", default=None
)

headers_context_var = contextvars.ContextVar[dict[str, Any] | None](
"north_headers_context", default=None
)


def on_auth_error(request: HTTPConnection, exc: AuthenticationError) -> JSONResponse:
return JSONResponse({"error": str(exc)}, status_code=401)
Expand All @@ -48,6 +53,14 @@ def get_authenticated_user() -> AuthenticatedNorthUser:
return user


def get_raw_headers() -> dict[str, Any]:
headers = headers_context_var.get()
if not headers:
raise Exception("headers not found in context")

return headers


class AuthContextMiddleware:
"""
Middleware that extracts the authenticated user from the request
Expand Down Expand Up @@ -83,6 +96,32 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send):
auth_context_var.reset(token)


class HeadersContextMiddleware:
"""
Middleware that sets the request headers in a contextvar for easy access
throughout the request lifecycle.
"""

def __init__(self, app: ASGIApp, debug: bool = False):
self.app = app
self.debug = debug
self.logger = logging.getLogger("NorthMCP.HeadersContext")
if debug:
self.logger.setLevel(logging.DEBUG)

async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] == "lifespan":
return await self.app(scope, receive, send)

headers = dict(scope.get("headers", {}))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the point of dict()?

self.logger.debug("Setting request headers in context: %s", headers)
token = headers_context_var.set(headers)
try:
await self.app(scope, receive, send)
finally:
headers_context_var.reset(token)


class NorthAuthBackend(AuthenticationBackend):
"""
Authentication backend that validates Bearer tokens.
Expand Down