Skip to content

Commit e229944

Browse files
committed
Authentication Flow Improvements: fix token refresh handling
- Implement automatic token refresh when tokens expire - Tokens now refresh transparently without user re-authentication - Add comprehensive refresh token tests
1 parent 561d0f4 commit e229944

File tree

5 files changed

+317
-22
lines changed

5 files changed

+317
-22
lines changed

assisted_service_mcp/src/mcp.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def _create_oauth_token(self) -> Callable[[Any], Optional[str]]:
6565
Function that can extract OAuth tokens from MCP context
6666
"""
6767

68-
def get_oauth_token(mcp: Any) -> Optional[str]:
68+
def get_oauth_token(
69+
mcp: Any,
70+
) -> Optional[str]: # pylint: disable=too-many-locals
6971
if not settings.OAUTH_ENABLED:
7072
return None
7173

@@ -86,11 +88,24 @@ def get_oauth_token(mcp: Any) -> Optional[str]:
8688
# Get client identifier for OAuth flow
8789
client_id = self._get_mcp_client_identifier(mcp)
8890

89-
# Check if we have a completed OAuth token for this client
90-
token = oauth_manager.token_store.get_access_token_by_client(client_id)
91-
if token:
92-
log.info("Using cached OAuth token for MCP client %s", client_id)
93-
return token
91+
# Try to get token (with automatic refresh if expired)
92+
# Use async method which handles refresh, wrapped in asyncio.run for sync context
93+
try:
94+
token = asyncio.run(
95+
oauth_manager.get_access_token_by_client(client_id)
96+
)
97+
if token:
98+
log.info(
99+
"Using cached OAuth token for MCP client %s", client_id
100+
)
101+
return token
102+
except Exception as e:
103+
log.debug(
104+
"Could not get/refresh token for client %s: %s", client_id, e
105+
)
106+
# Fall through to start new OAuth flow
107+
108+
# No valid token found - need to start OAuth flow
94109

95110
# Check if OAuth flow is already in progress for this client
96111
for (

assisted_service_mcp/src/oauth/manager.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
)
2424
from assisted_service_mcp.src.settings import settings
2525

26+
# Token expiration constants
27+
DEFAULT_TOKEN_EXPIRES_IN = 900 # 15 minutes in seconds
28+
TOKEN_EXPIRY_BUFFER = 300 # 5 minutes safety margin before expiration
29+
2630

2731
class OAuthManager:
2832
"""Manages OAuth authentication flow for the MCP server.
@@ -214,14 +218,14 @@ async def exchange_code_for_token(
214218

215219
# Create token object
216220
token_id = secrets.token_hex(16)
217-
expires_in = token_data.get("expires_in", 3600)
221+
expires_in = token_data.get("expires_in", DEFAULT_TOKEN_EXPIRES_IN)
218222

219223
token = OAuthToken(
220224
token_id=token_id,
221225
client_id=state.client_id,
222226
access_token=token_data["access_token"],
223227
refresh_token=token_data.get("refresh_token"),
224-
expires_at=time.time() + expires_in - 300, # 5 min safety margin
228+
expires_at=time.time() + expires_in - TOKEN_EXPIRY_BUFFER,
225229
)
226230

227231
# Store token
@@ -248,7 +252,8 @@ async def get_access_token_by_id(self, token_id: str) -> Optional[str]:
248252
Returns:
249253
Access token if found and valid, None otherwise
250254
"""
251-
token = self.token_store.get_token_by_id(token_id)
255+
# Get token including expired ones (for refresh purposes)
256+
token = self.token_store.get_token_by_id(token_id, include_expired=True)
252257
if not token:
253258
return None
254259

@@ -260,6 +265,8 @@ async def get_access_token_by_id(self, token_id: str) -> Optional[str]:
260265
token = self.token_store.get_token_by_id(token_id)
261266
return token.access_token if token else None
262267
log.warning("Failed to refresh token %s", token_id)
268+
# Clean up expired token if refresh failed
269+
self.token_store.remove_token(token_id)
263270
return None
264271

265272
return token.access_token
@@ -287,7 +294,8 @@ async def get_access_token_by_client(self, client_id: str) -> Optional[str]:
287294
Returns:
288295
Access token if found and valid, None otherwise
289296
"""
290-
token = self.token_store.get_token_by_client(client_id)
297+
# Get token including expired ones (for refresh purposes)
298+
token = self.token_store.get_token_by_client(client_id, include_expired=True)
291299
if not token:
292300
return None
293301

@@ -299,6 +307,8 @@ async def get_access_token_by_client(self, client_id: str) -> Optional[str]:
299307
token = self.token_store.get_token_by_client(client_id)
300308
return token.access_token if token else None
301309
log.warning("Failed to refresh token for client %s", client_id)
310+
# Clean up expired token if refresh failed
311+
self.token_store.remove_client_token(client_id)
302312
return None
303313

304314
return token.access_token
@@ -333,8 +343,8 @@ async def _refresh_token(self, token: OAuthToken) -> bool:
333343
# Update token in store
334344
new_access_token = token_data["access_token"]
335345
new_refresh_token = token_data.get("refresh_token", token.refresh_token)
336-
expires_in = token_data.get("expires_in", 3600)
337-
new_expires_at = time.time() + expires_in - 300
346+
expires_in = token_data.get("expires_in", DEFAULT_TOKEN_EXPIRES_IN)
347+
new_expires_at = time.time() + expires_in - TOKEN_EXPIRY_BUFFER
338348

339349
self.token_store.update_token(
340350
token.token_id, new_access_token, new_refresh_token, new_expires_at
@@ -536,7 +546,7 @@ async def oauth_token_handler(request: Request) -> Dict[str, Any]:
536546
return {
537547
"access_token": token.get("access_token"),
538548
"token_type": token.get("token_type", "Bearer"),
539-
"expires_in": token.get("expires_in", 3600),
549+
"expires_in": token.get("expires_in", DEFAULT_TOKEN_EXPIRES_IN),
540550
"refresh_token": token.get("refresh_token"),
541551
"scope": token.get("scope", "openid profile email"),
542552
}

assisted_service_mcp/src/oauth/middleware.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ async def handle_mcp_request(self, request: Request, call_next: Any) -> Response
4242

4343
client_id = self._get_client_identifier(request)
4444

45-
# Try to use existing token
46-
token = oauth_manager.token_store.get_access_token_by_client(client_id)
45+
# Try to use existing token (with automatic refresh if expired)
46+
token = await oauth_manager.get_access_token_by_client(client_id)
4747
if token:
4848
log.info("Using cached token for client %s", client_id)
4949
return await self._create_authenticated_request(request, call_next, token)
@@ -176,7 +176,7 @@ async def _wait_for_oauth_completion(
176176
waited_time += poll_interval
177177

178178
# Check if OAuth completed (token available for client)
179-
token = oauth_manager.token_store.get_access_token_by_client(client_id)
179+
token = await oauth_manager.get_access_token_by_client(client_id)
180180
if token:
181181
log.info(
182182
"OAuth completed for client %s, proceeding with request", client_id

assisted_service_mcp/src/oauth/store.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,43 +47,49 @@ def store_token(self, token: OAuthToken) -> None:
4747
token.expires_at,
4848
)
4949

50-
def get_token_by_id(self, token_id: str) -> Optional[OAuthToken]:
50+
def get_token_by_id(
51+
self, token_id: str, include_expired: bool = False
52+
) -> Optional[OAuthToken]:
5153
"""Get token by token ID.
5254
5355
Args:
5456
token_id: Token identifier
57+
include_expired: If True, return expired tokens (for refresh purposes)
5558
5659
Returns:
57-
OAuthToken if found and valid, None otherwise
60+
OAuthToken if found and valid (or expired if include_expired=True), None otherwise
5861
"""
5962
with self._lock:
6063
token = self._tokens.get(token_id)
6164
if not token:
6265
return None
6366

64-
if token.is_expired():
67+
if token.is_expired() and not include_expired:
6568
log.debug("Token %s is expired, removing from store", token_id)
6669
self._remove_token_unsafe(token_id)
6770
return None
6871

6972
return token
7073

71-
def get_token_by_client(self, client_id: str) -> Optional[OAuthToken]:
74+
def get_token_by_client(
75+
self, client_id: str, include_expired: bool = False
76+
) -> Optional[OAuthToken]:
7277
"""Get token for a client.
7378
7479
Args:
7580
client_id: Client identifier
81+
include_expired: If True, return expired tokens (for refresh purposes)
7682
7783
Returns:
78-
OAuthToken if found and valid, None otherwise
84+
OAuthToken if found and valid (or expired if include_expired=True), None otherwise
7985
"""
8086
with self._lock:
8187
token_id = self._client_tokens.get(client_id)
8288
if not token_id:
8389
return None
8490

8591
# Re-use get_token_by_id which also acquires lock (RLock allows re-entrance)
86-
return self.get_token_by_id(token_id)
92+
return self.get_token_by_id(token_id, include_expired=include_expired)
8793

8894
def get_access_token_by_id(self, token_id: str) -> Optional[str]:
8995
"""Get access token string by token ID.

0 commit comments

Comments
 (0)