-
-
Notifications
You must be signed in to change notification settings - Fork 491
Description
Summary
I am building a Litestar app that will use Microsoft Entra ID for authentication. Entra ID signs JWTs with private-public key pairs which rotate on a regular basis, so applications are expected to query the public keys dynamically. However, the Litestar documentation and code samples I've found appear to use symmetric algorithms for JWTs and expect a pre-shared, confidential secret.
As a developer building with Litestar, I would like the framework to provide support for the RS256 algorithm of JWTs, along with improved ergonomics for the usage pattern of querying the public keys dynamically.
In addition to Microsoft Entra ID, this would also improve the integration for services such as Keycloak.
There are two separate pain points I've found:
- While the underlying
jwtpackage supports the use of either a string or ajwt.PyJWKclass in its decode() method ( pyjwt docs, relevant Litestar call), Litestar's type hinting restricts the type of the JWTsecretortoken_secretto be a string.- I traced the logic all the way back from the JWTAuthenticationMiddleware class to the call to
jwt.PyJWK.decodeand stuffing aPyJWKinstance in this field does work properly. - Fixing this might be as simple as adjusting the type annotation on
tokento be astr | jwt.PyJWK, unless there's a more idiomatic way to do it in Litestar. - I also confirmed that by default, the
jwt.PyJWKclass caches public keys. This means the first call could either be at app init or the first time it needs to validate a JWT.
- I traced the logic all the way back from the JWTAuthenticationMiddleware class to the call to
- The
jwt.PyJWKClientclass needs to load a URL at runtime that includes configuration data (in my case, the Microsoft Entra tenant ID). I've had trouble finding a good place to inject that configuration into the authentication middleware.
This is also loosely related to a discussion from a couple of years ago.
Basic Example
Here's the code snippet I arrived upon to make this work:
class EntraIDAuthenticationMiddleware(JWTAuthenticationMiddleware):
"""Middleware to authenticate Entra ID JWT tokens.
Litestar makes an assumption that the token's secret is a static
string (probably for synchronous symmetric algorithms like
HS256). Entra ID uses the asynchronous RS256 algorithm instead.
The underlying PyJWT library supports this use case, but
Litestar's default implementation types the `token_secret`
property as a string. By typing it as `str | jwt.PyJWK` here
instead, we can assign the asymmetric public key retrieved
from Microsoft's JWKS endpoint at runtime.
"""
def __init__(self, *args: Any, **kwargs: dict[str, Any]) -> None:
"""Instantiate Entra ID authentication middleware.
This constructor does not take any special parameters; it only
needs to override the type of the `token_secret` property.
"""
self.token_secret: str | jwt.PyJWK = ""
self._jwks_client: jwt.PyJWKClient | None = None
super().__init__(*args, **kwargs)
@override
async def authenticate_token(
self, encoded_token: str, connection: ASGIConnection[Any, Any, Any, Any]
) -> AuthenticationResult:
if self._jwks_client is None:
raise RuntimeError("JWKS client is not initialized")
signing_key = self._jwks_client.get_signing_key_from_jwt(encoded_token)
self.token_secret = signing_key
return await super().authenticate_token(encoded_token, connection)
class Settings(pydantic.BaseSettings):
entra_client_id: str
entra_tenant_id: str
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
)
def build_jwt_auth(settings: Settings) -> JWTAuth[DummyUser, Token]:
class ConfiguredEntraIDMiddleware(EntraIDAuthenticationMiddleware):
"""Anonymous subclass to configure application settings."""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._jwks_client = jwt.PyJWKClient(
f"https://login.microsoftonline.com/{settings.entra_tenant_id}/discovery/v2.0/keys"
)
expected_audience = settings.entra_client_id
expected_issuer = (
f"https://login.microsoftonline.com/{settings.entra_tenant_id}/v2.0"
)
jwt_auth = JWTAuth[DummyUser, Token](
token_secret="not-used-for-entra",
retrieve_user_handler=_retrieve_user_from_entra_token,
algorithm="RS256",
exclude=["/schema"],
authentication_middleware_class=ConfiguredEntraIDMiddleware,
accepted_audiences=[expected_audience],
accepted_issuers=[expected_issuer],
)
return jwt_auth
@get("/")
async def index() -> str:
return "Hello, world!"
settings = Settings()
jwt_auth = build_jwt_auth(settings)
app = Litestar(
[index],
on_app_init=[jwt_auth.on_app_init],
)Drawbacks and Impact
No response
Unresolved questions
Is there a simpler, more idiomatic Litestar way to approach this issue, especially the configuration management? I've always found the idea of creating a subclass inside a function to be pretty awkward.