-
Notifications
You must be signed in to change notification settings - Fork 8
fix: improve auth session sync reliability #22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -59,6 +59,8 @@ class ClerkState(rx.State): | |
| "nbf": {"essential": True}, | ||
| # "azp": {"essential": False, "values": ["http://localhost:3000", "https://example.com"]}, | ||
| } | ||
| _jwt_validate_leeway_seconds: ClassVar[int] = 60 | ||
| """Clock-skew leeway (seconds) for validating JWT claims like exp/nbf.""" | ||
|
|
||
| @classmethod | ||
| def register_dependent_handler(cls, handler: EventCallback) -> None: | ||
|
|
@@ -85,6 +87,29 @@ def set_claims_options(cls, claims_options: dict[str, Any]) -> None: | |
| """Set the claims options for the JWT claims validation.""" | ||
| cls._claims_options = claims_options | ||
|
|
||
| @classmethod | ||
| def set_jwt_validate_leeway_seconds(cls, seconds: int) -> None: | ||
| """Set clock-skew leeway (seconds) for JWT exp/nbf validation. | ||
|
|
||
| Default is 60 seconds. Increase if you see intermittent ExpiredTokenError | ||
| due to clock drift between Clerk servers and your backend. | ||
|
|
||
| Args: | ||
| seconds: Non-negative integer, max 3600 (1 hour). | ||
|
|
||
| Raises: | ||
| ValueError: If seconds is negative or exceeds 3600. | ||
| """ | ||
| if not isinstance(seconds, int) or isinstance(seconds, bool) or seconds < 0: | ||
| raise ValueError( | ||
| f"jwt_validate_leeway_seconds must be a non-negative integer, got {seconds!r}" | ||
| ) | ||
| if seconds > 3600: | ||
| raise ValueError( | ||
| f"jwt_validate_leeway_seconds exceeds maximum of 3600 (1 hour), got {seconds}" | ||
| ) | ||
| cls._jwt_validate_leeway_seconds = seconds | ||
|
|
||
| @property | ||
| def client(self) -> clerk_backend_api.Clerk: | ||
| if self._client is None: | ||
|
|
@@ -116,9 +141,13 @@ async def set_clerk_session(self, token: str) -> EventType: | |
| return ClerkState.clear_clerk_session | ||
| try: | ||
| # Validate the token according to the claim options (e.g. iss, exp, nbf, azp.) | ||
| decoded.validate() | ||
| except (jose_errors.InvalidClaimError, jose_errors.MissingClaimError) as e: | ||
| logging.warning(f"JWT token is invalid: {e}") | ||
| decoded.validate(leeway=self._jwt_validate_leeway_seconds) | ||
| except ( | ||
| jose_errors.ExpiredTokenError, | ||
| jose_errors.InvalidClaimError, | ||
| jose_errors.MissingClaimError, | ||
| ) as e: | ||
| logging.warning(f"JWT token validation failed: {type(e).__name__}: {e}") | ||
| return ClerkState.clear_clerk_session | ||
|
|
||
| async with self: | ||
|
|
@@ -364,7 +393,7 @@ def add_imports( | |
| ) -> rx.ImportDict: | ||
| addl_imports: rx.ImportDict = { | ||
| "@clerk/clerk-react": ["useAuth"], | ||
| "react": ["useContext", "useEffect"], | ||
| "react": ["useContext", "useEffect", "useRef"], | ||
| "$/utils/context": ["EventLoopContext"], | ||
| "$/utils/state": ["ReflexEvent"], | ||
| } | ||
|
|
@@ -375,28 +404,51 @@ def add_custom_code(self) -> list[str]: | |
|
|
||
| return [ | ||
| """ | ||
| function ClerkSessionSynchronizer({ children }) { | ||
| const { getToken, isLoaded, isSignedIn } = useAuth() | ||
| const [ addEvents, connectErrors ] = useContext(EventLoopContext) | ||
|
|
||
| useEffect(() => { | ||
| if (isLoaded && !!addEvents) { | ||
| if (isSignedIn) { | ||
| getToken().then(token => { | ||
| addEvents([ReflexEvent("%s.set_clerk_session", {token})]) | ||
| }) | ||
| } else { | ||
| addEvents([ReflexEvent("%s.clear_clerk_session")]) | ||
| } | ||
| } | ||
| }, [isSignedIn]) | ||
| function ClerkSessionSynchronizer({{ children }}) {{ | ||
| const {{ getToken, isLoaded, isSignedIn }} = useAuth() | ||
| const [ addEvents ] = useContext(EventLoopContext) | ||
| const lastSentRef = useRef({{ stateKey: null, addEvents: null }}) | ||
|
|
||
| useEffect(() => {{ | ||
| // Wait for all dependencies to be ready. | ||
| if (!isLoaded || !addEvents) return | ||
|
|
||
| // Deduplicate rapid calls, but remain reconnect-safe: | ||
| // addEvents identity changes across websocket reconnects, so include it in the key. | ||
| const stateKey = isSignedIn ? "signed_in" : "signed_out" | ||
| if ( | ||
| lastSentRef.current?.stateKey === stateKey && | ||
| lastSentRef.current?.addEvents === addEvents | ||
| ) return | ||
| lastSentRef.current = {{ stateKey, addEvents }} | ||
|
|
||
| if (isSignedIn) {{ | ||
| // Prefer a fresh token; cached tokens can be close to expiry. | ||
| // If this Clerk version doesn't support skipCache, fall back to the default call. | ||
| Promise.resolve() | ||
| .then(() => getToken({{ skipCache: true }})) | ||
| .catch(() => getToken()) | ||
| .then(token => {{ | ||
| if (token) {{ | ||
| addEvents([ReflexEvent("{state}.set_clerk_session", {{token}})]) | ||
| }} else {{ | ||
| // Token unavailable despite isSignedIn - clear to avoid stuck auth state. | ||
| addEvents([ReflexEvent("{state}.clear_clerk_session")]) | ||
| }} | ||
| }}).catch(() => {{ | ||
| // Token retrieval failed - clear to avoid stuck auth state. | ||
| addEvents([ReflexEvent("{state}.clear_clerk_session")]) | ||
| }}) | ||
|
Comment on lines
+425
to
+441
|
||
| }} else {{ | ||
| addEvents([ReflexEvent("{state}.clear_clerk_session")]) | ||
| }} | ||
| }}, [isLoaded, isSignedIn, addEvents, getToken]) | ||
|
|
||
| return ( | ||
| <>{children}</> | ||
| <>{{children}}</> | ||
| ) | ||
| } | ||
| """ | ||
| % (clerk_state_name, clerk_state_name) | ||
| }} | ||
| """.format(state=clerk_state_name) | ||
| ] | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| import asyncio | ||
|
|
||
| import authlib.jose.errors as jose_errors | ||
|
|
||
|
|
||
| def test_set_clerk_session_expired_token_clears(monkeypatch): | ||
| """Expired tokens should not crash the handler; they should clear session.""" | ||
| # Import inside the test so the module is importable in different test layouts. | ||
| # We need the actual module object (not just ClerkState) to monkeypatch jwt.decode | ||
| # where it's used. importlib is required because reflex_clerk_api.clerk_provider | ||
| # resolves to the function via __init__.py re-exports. | ||
| import importlib | ||
|
|
||
| clerk_provider_module = importlib.import_module("reflex_clerk_api.clerk_provider") | ||
| from reflex_clerk_api.clerk_provider import ClerkState | ||
|
|
||
| # Instantiate state in a framework-safe way for tests. | ||
| state = ClerkState(_reflex_internal_init=True) | ||
|
|
||
| async def fake_get_jwk_keys(self): | ||
| return {} | ||
|
|
||
| monkeypatch.setattr(ClerkState, "_get_jwk_keys", fake_get_jwk_keys, raising=True) | ||
|
|
||
| validate_calls: dict[str, object] = {} | ||
|
|
||
| class FakeClaims: | ||
| def validate(self, leeway=None): | ||
| validate_calls["leeway"] = leeway | ||
| raise jose_errors.ExpiredTokenError() | ||
|
|
||
| monkeypatch.setattr( | ||
| clerk_provider_module.jwt, | ||
| "decode", | ||
| lambda *args, **kwargs: FakeClaims(), | ||
| raising=True, | ||
| ) | ||
|
|
||
| result = asyncio.run(ClerkState.set_clerk_session.fn(state, token="fake")) | ||
| assert validate_calls["leeway"] == 60 | ||
| assert result == ClerkState.clear_clerk_session | ||
|
|
||
|
|
||
| def test_clerk_session_synchronizer_js_contains_reconnect_safe_deps_and_skipcache(): | ||
| """String-based regression test for the generated JS.""" | ||
| from reflex_clerk_api.clerk_provider import ClerkSessionSynchronizer | ||
|
|
||
| js = ClerkSessionSynchronizer.create().add_custom_code()[0] | ||
| assert "[isLoaded, isSignedIn, addEvents, getToken]" in js | ||
| assert "skipCache: true" in js |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lastSentRef.currentis updated before the asyncgetToken()chain runs. IfgetTokenfails/returns a falsy token and you sendclear_clerk_session, the effect will not run again whileisSignedInandaddEventsstay the same (deps unchanged), so the backend session can remain cleared even though Clerk is still signed in. Consider only updating the dedupe ref after successfully dispatchingset_clerk_session, and/or introduce an explicit retry trigger (e.g., useState/useReducer + backoff) when token retrieval fails.