Skip to content

Enhancement: Support for JWKs and asynchronous JWT authentication #4471

@replicaJunction

Description

@replicaJunction

Summary

I am building a Litestar app that will use Microsoft Entra ID for authentication. Entra ID signs JWTs with private-public key pairs which rotate on a regular basis, so applications are expected to query the public keys dynamically. However, the Litestar documentation and code samples I've found appear to use symmetric algorithms for JWTs and expect a pre-shared, confidential secret.

As a developer building with Litestar, I would like the framework to provide support for the RS256 algorithm of JWTs, along with improved ergonomics for the usage pattern of querying the public keys dynamically.

In addition to Microsoft Entra ID, this would also improve the integration for services such as Keycloak.

There are two separate pain points I've found:

  1. While the underlying jwt package supports the use of either a string or a jwt.PyJWK class in its decode() method ( pyjwt docs, relevant Litestar call), Litestar's type hinting restricts the type of the JWT secret or token_secret to be a string.
    • I traced the logic all the way back from the JWTAuthenticationMiddleware class to the call to jwt.PyJWK.decode and stuffing a PyJWK instance in this field does work properly.
    • Fixing this might be as simple as adjusting the type annotation on token to be a str | jwt.PyJWK, unless there's a more idiomatic way to do it in Litestar.
    • I also confirmed that by default, the jwt.PyJWK class caches public keys. This means the first call could either be at app init or the first time it needs to validate a JWT.
  2. The jwt.PyJWKClient class needs to load a URL at runtime that includes configuration data (in my case, the Microsoft Entra tenant ID). I've had trouble finding a good place to inject that configuration into the authentication middleware.

This is also loosely related to a discussion from a couple of years ago.

Basic Example

Here's the code snippet I arrived upon to make this work:

class EntraIDAuthenticationMiddleware(JWTAuthenticationMiddleware):
    """Middleware to authenticate Entra ID JWT tokens.

    Litestar makes an assumption that the token's secret is a static
    string (probably for synchronous symmetric algorithms like
    HS256). Entra ID uses the asynchronous RS256 algorithm instead.

    The underlying PyJWT library supports this use case, but
    Litestar's default implementation types the `token_secret`
    property as a string. By typing it as `str | jwt.PyJWK` here
    instead, we can assign the asymmetric public key retrieved
    from Microsoft's JWKS endpoint at runtime.
    """

    def __init__(self, *args: Any, **kwargs: dict[str, Any]) -> None:
        """Instantiate Entra ID authentication middleware.

        This constructor does not take any special parameters; it only
        needs to override the type of the `token_secret` property.
        """
        self.token_secret: str | jwt.PyJWK = ""
        self._jwks_client: jwt.PyJWKClient | None = None
        super().__init__(*args, **kwargs)

    @override
    async def authenticate_token(
        self, encoded_token: str, connection: ASGIConnection[Any, Any, Any, Any]
    ) -> AuthenticationResult:
        if self._jwks_client is None:
            raise RuntimeError("JWKS client is not initialized")
        signing_key = self._jwks_client.get_signing_key_from_jwt(encoded_token)
        self.token_secret = signing_key
        return await super().authenticate_token(encoded_token, connection)

class Settings(pydantic.BaseSettings):
    entra_client_id: str
    entra_tenant_id: str
    model_config = SettingsConfigDict(
        env_file=".env",
        env_file_encoding="utf-8",
    )

def build_jwt_auth(settings: Settings) -> JWTAuth[DummyUser, Token]:
    class ConfiguredEntraIDMiddleware(EntraIDAuthenticationMiddleware):
        """Anonymous subclass to configure application settings."""

        def __init__(self, *args, **kwargs) -> None:
            super().__init__(*args, **kwargs)
            self._jwks_client = jwt.PyJWKClient(
                f"https://login.microsoftonline.com/{settings.entra_tenant_id}/discovery/v2.0/keys"
            )

    expected_audience = settings.entra_client_id
    expected_issuer = (
        f"https://login.microsoftonline.com/{settings.entra_tenant_id}/v2.0"
    )

    jwt_auth = JWTAuth[DummyUser, Token](
        token_secret="not-used-for-entra",
        retrieve_user_handler=_retrieve_user_from_entra_token,
        algorithm="RS256",
        exclude=["/schema"],
        authentication_middleware_class=ConfiguredEntraIDMiddleware,
        accepted_audiences=[expected_audience],
        accepted_issuers=[expected_issuer],
    )

    return jwt_auth

@get("/")
async def index() -> str:
    return "Hello, world!"

settings = Settings()
jwt_auth = build_jwt_auth(settings)
app = Litestar(
    [index],
    on_app_init=[jwt_auth.on_app_init],
)

Drawbacks and Impact

No response

Unresolved questions

Is there a simpler, more idiomatic Litestar way to approach this issue, especially the configuration management? I've always found the idea of creating a subclass inside a function to be pretty awkward.

Metadata

Metadata

Assignees

No one assigned

    Labels

    EnhancementThis is a new feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions