Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 75 additions & 23 deletions custom_components/reflex_clerk_api/clerk_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class ClerkState(rx.State):
"nbf": {"essential": True},
# "azp": {"essential": False, "values": ["http://localhost:3000", "https://example.com"]},
}
_jwt_validate_leeway_seconds: ClassVar[int] = 60
"""Clock-skew leeway (seconds) for validating JWT claims like exp/nbf."""

@classmethod
def register_dependent_handler(cls, handler: EventCallback) -> None:
Expand All @@ -85,6 +87,29 @@ def set_claims_options(cls, claims_options: dict[str, Any]) -> None:
"""Set the claims options for the JWT claims validation."""
cls._claims_options = claims_options

@classmethod
def set_jwt_validate_leeway_seconds(cls, seconds: int) -> None:
"""Set clock-skew leeway (seconds) for JWT exp/nbf validation.

Default is 60 seconds. Increase if you see intermittent ExpiredTokenError
due to clock drift between Clerk servers and your backend.

Args:
seconds: Non-negative integer, max 3600 (1 hour).

Raises:
ValueError: If seconds is negative or exceeds 3600.
"""
if not isinstance(seconds, int) or isinstance(seconds, bool) or seconds < 0:
raise ValueError(
f"jwt_validate_leeway_seconds must be a non-negative integer, got {seconds!r}"
)
if seconds > 3600:
raise ValueError(
f"jwt_validate_leeway_seconds exceeds maximum of 3600 (1 hour), got {seconds}"
)
cls._jwt_validate_leeway_seconds = seconds

@property
def client(self) -> clerk_backend_api.Clerk:
if self._client is None:
Expand Down Expand Up @@ -116,9 +141,13 @@ async def set_clerk_session(self, token: str) -> EventType:
return ClerkState.clear_clerk_session
try:
# Validate the token according to the claim options (e.g. iss, exp, nbf, azp.)
decoded.validate()
except (jose_errors.InvalidClaimError, jose_errors.MissingClaimError) as e:
logging.warning(f"JWT token is invalid: {e}")
decoded.validate(leeway=self._jwt_validate_leeway_seconds)
except (
jose_errors.ExpiredTokenError,
jose_errors.InvalidClaimError,
jose_errors.MissingClaimError,
) as e:
logging.warning(f"JWT token validation failed: {type(e).__name__}: {e}")
return ClerkState.clear_clerk_session

async with self:
Expand Down Expand Up @@ -364,7 +393,7 @@ def add_imports(
) -> rx.ImportDict:
addl_imports: rx.ImportDict = {
"@clerk/clerk-react": ["useAuth"],
"react": ["useContext", "useEffect"],
"react": ["useContext", "useEffect", "useRef"],
"$/utils/context": ["EventLoopContext"],
"$/utils/state": ["ReflexEvent"],
}
Expand All @@ -375,28 +404,51 @@ def add_custom_code(self) -> list[str]:

return [
"""
function ClerkSessionSynchronizer({ children }) {
const { getToken, isLoaded, isSignedIn } = useAuth()
const [ addEvents, connectErrors ] = useContext(EventLoopContext)

useEffect(() => {
if (isLoaded && !!addEvents) {
if (isSignedIn) {
getToken().then(token => {
addEvents([ReflexEvent("%s.set_clerk_session", {token})])
})
} else {
addEvents([ReflexEvent("%s.clear_clerk_session")])
}
}
}, [isSignedIn])
function ClerkSessionSynchronizer({{ children }}) {{
const {{ getToken, isLoaded, isSignedIn }} = useAuth()
const [ addEvents ] = useContext(EventLoopContext)
const lastSentRef = useRef({{ stateKey: null, addEvents: null }})

useEffect(() => {{
// Wait for all dependencies to be ready.
if (!isLoaded || !addEvents) return

// Deduplicate rapid calls, but remain reconnect-safe:
// addEvents identity changes across websocket reconnects, so include it in the key.
const stateKey = isSignedIn ? "signed_in" : "signed_out"
if (
lastSentRef.current?.stateKey === stateKey &&
lastSentRef.current?.addEvents === addEvents
) return
lastSentRef.current = {{ stateKey, addEvents }}
Comment on lines +418 to +423
Copy link

Copilot AI Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lastSentRef.current is updated before the async getToken() chain runs. If getToken fails/returns a falsy token and you send clear_clerk_session, the effect will not run again while isSignedIn and addEvents stay the same (deps unchanged), so the backend session can remain cleared even though Clerk is still signed in. Consider only updating the dedupe ref after successfully dispatching set_clerk_session, and/or introduce an explicit retry trigger (e.g., useState/useReducer + backoff) when token retrieval fails.

Copilot uses AI. Check for mistakes.

if (isSignedIn) {{
// Prefer a fresh token; cached tokens can be close to expiry.
// If this Clerk version doesn't support skipCache, fall back to the default call.
Promise.resolve()
.then(() => getToken({{ skipCache: true }}))
.catch(() => getToken())
.then(token => {{
if (token) {{
addEvents([ReflexEvent("{state}.set_clerk_session", {{token}})])
}} else {{
// Token unavailable despite isSignedIn - clear to avoid stuck auth state.
addEvents([ReflexEvent("{state}.clear_clerk_session")])
}}
}}).catch(() => {{
// Token retrieval failed - clear to avoid stuck auth state.
addEvents([ReflexEvent("{state}.clear_clerk_session")])
}})
Comment on lines +425 to +441
Copy link

Copilot AI Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the isSignedIn branch, token fetch errors (including the skipCache attempt and the fallback) immediately result in dispatching clear_clerk_session. For transient token retrieval failures this can force a backend logout even though Clerk remains signed in, and combined with the current dedupe logic it may prevent the backend from ever re-syncing. Prefer retrying token retrieval (possibly with a short delay/backoff) before clearing, or ensure the failure path triggers a subsequent re-sync attempt.

Copilot uses AI. Check for mistakes.
}} else {{
addEvents([ReflexEvent("{state}.clear_clerk_session")])
}}
}}, [isLoaded, isSignedIn, addEvents, getToken])

return (
<>{children}</>
<>{{children}}</>
)
}
"""
% (clerk_state_name, clerk_state_name)
}}
""".format(state=clerk_state_name)
]


Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ homepage = "https://reflex-clerk-api-demo.adventuresoftim.com"

[tool.pytest.ini_options]
# addopts = "--headed"
pythonpath = ["custom_components"]

[tool.pyright]
venvPath = "."
Expand Down
50 changes: 50 additions & 0 deletions tests/test_clerk_provider_unit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import asyncio

import authlib.jose.errors as jose_errors


def test_set_clerk_session_expired_token_clears(monkeypatch):
"""Expired tokens should not crash the handler; they should clear session."""
# Import inside the test so the module is importable in different test layouts.
# We need the actual module object (not just ClerkState) to monkeypatch jwt.decode
# where it's used. importlib is required because reflex_clerk_api.clerk_provider
# resolves to the function via __init__.py re-exports.
import importlib

clerk_provider_module = importlib.import_module("reflex_clerk_api.clerk_provider")
from reflex_clerk_api.clerk_provider import ClerkState

# Instantiate state in a framework-safe way for tests.
state = ClerkState(_reflex_internal_init=True)

async def fake_get_jwk_keys(self):
return {}

monkeypatch.setattr(ClerkState, "_get_jwk_keys", fake_get_jwk_keys, raising=True)

validate_calls: dict[str, object] = {}

class FakeClaims:
def validate(self, leeway=None):
validate_calls["leeway"] = leeway
raise jose_errors.ExpiredTokenError()

monkeypatch.setattr(
clerk_provider_module.jwt,
"decode",
lambda *args, **kwargs: FakeClaims(),
raising=True,
)

result = asyncio.run(ClerkState.set_clerk_session.fn(state, token="fake"))
assert validate_calls["leeway"] == 60
assert result == ClerkState.clear_clerk_session


def test_clerk_session_synchronizer_js_contains_reconnect_safe_deps_and_skipcache():
"""String-based regression test for the generated JS."""
from reflex_clerk_api.clerk_provider import ClerkSessionSynchronizer

js = ClerkSessionSynchronizer.create().add_custom_code()[0]
assert "[isLoaded, isSignedIn, addEvents, getToken]" in js
assert "skipCache: true" in js
Loading