Skip to content

Commit 561d0f4

Browse files
committed
Thread-safe TokenStore
- Add threading.RLock() to TokenStore for concurrent access protection - Protect all dict operations (store, get, update, remove, cleanup) with lock - Add _remove_token_unsafe() internal method for re-entrant calls - Prevents race conditions between FastAPI handlers and MCP tool threads - RLock allows nested locking (e.g., get_token_by_client -> get_token_by_id)
1 parent e7560de commit 561d0f4

File tree

1 file changed

+71
-45
lines changed

1 file changed

+71
-45
lines changed

assisted_service_mcp/src/oauth/store.py

Lines changed: 71 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
oauth.py and mcp_oauth_middleware.py.
55
"""
66

7+
import threading
78
import time
89
from typing import Dict, Optional
910

@@ -19,10 +20,14 @@ class TokenStore:
1920
- mcp_oauth_middleware.completed_tokens (client_id -> token_id)
2021
2122
Into a single, well-defined storage mechanism.
23+
24+
Thread-safe: All operations are protected by a re-entrant lock to handle
25+
concurrent access from FastAPI async handlers and MCP tool worker threads.
2226
"""
2327

2428
def __init__(self) -> None:
2529
"""Initialize token store."""
30+
self._lock = threading.RLock() # Re-entrant lock for thread safety
2631
self._tokens: Dict[str, OAuthToken] = {}
2732
self._client_tokens: Dict[str, str] = {} # client_id -> token_id
2833

@@ -32,14 +37,15 @@ def store_token(self, token: OAuthToken) -> None:
3237
Args:
3338
token: OAuthToken to store
3439
"""
35-
self._tokens[token.token_id] = token
36-
self._client_tokens[token.client_id] = token.token_id
37-
log.debug(
38-
"Stored token %s for client %s (expires at %s)",
39-
token.token_id,
40-
token.client_id,
41-
token.expires_at,
42-
)
40+
with self._lock:
41+
self._tokens[token.token_id] = token
42+
self._client_tokens[token.client_id] = token.token_id
43+
log.debug(
44+
"Stored token %s for client %s (expires at %s)",
45+
token.token_id,
46+
token.client_id,
47+
token.expires_at,
48+
)
4349

4450
def get_token_by_id(self, token_id: str) -> Optional[OAuthToken]:
4551
"""Get token by token ID.
@@ -50,16 +56,17 @@ def get_token_by_id(self, token_id: str) -> Optional[OAuthToken]:
5056
Returns:
5157
OAuthToken if found and valid, None otherwise
5258
"""
53-
token = self._tokens.get(token_id)
54-
if not token:
55-
return None
59+
with self._lock:
60+
token = self._tokens.get(token_id)
61+
if not token:
62+
return None
5663

57-
if token.is_expired():
58-
log.debug("Token %s is expired, removing from store", token_id)
59-
self.remove_token(token_id)
60-
return None
64+
if token.is_expired():
65+
log.debug("Token %s is expired, removing from store", token_id)
66+
self._remove_token_unsafe(token_id)
67+
return None
6168

62-
return token
69+
return token
6370

6471
def get_token_by_client(self, client_id: str) -> Optional[OAuthToken]:
6572
"""Get token for a client.
@@ -70,11 +77,13 @@ def get_token_by_client(self, client_id: str) -> Optional[OAuthToken]:
7077
Returns:
7178
OAuthToken if found and valid, None otherwise
7279
"""
73-
token_id = self._client_tokens.get(client_id)
74-
if not token_id:
75-
return None
80+
with self._lock:
81+
token_id = self._client_tokens.get(client_id)
82+
if not token_id:
83+
return None
7684

77-
return self.get_token_by_id(token_id)
85+
# Re-use get_token_by_id which also acquires lock (RLock allows re-entrance)
86+
return self.get_token_by_id(token_id)
7887

7988
def get_access_token_by_id(self, token_id: str) -> Optional[str]:
8089
"""Get access token string by token ID.
@@ -118,24 +127,37 @@ def update_token(
118127
Returns:
119128
True if token was updated, False if token not found
120129
"""
121-
token = self._tokens.get(token_id)
122-
if not token:
123-
return False
130+
with self._lock:
131+
token = self._tokens.get(token_id)
132+
if not token:
133+
return False
124134

125-
token.access_token = access_token
126-
if refresh_token:
127-
token.refresh_token = refresh_token
128-
token.expires_at = expires_at
135+
token.access_token = access_token
136+
if refresh_token:
137+
token.refresh_token = refresh_token
138+
token.expires_at = expires_at
129139

130-
log.debug("Updated token %s (new expiry: %s)", token_id, expires_at)
131-
return True
140+
log.debug("Updated token %s (new expiry: %s)", token_id, expires_at)
141+
return True
132142

133143
def remove_token(self, token_id: str) -> None:
134144
"""Remove a token and its associations.
135145
136146
Args:
137147
token_id: Token identifier to remove
138148
"""
149+
with self._lock:
150+
self._remove_token_unsafe(token_id)
151+
152+
def _remove_token_unsafe(self, token_id: str) -> None:
153+
"""Remove a token without acquiring lock (internal use only).
154+
155+
Args:
156+
token_id: Token identifier to remove
157+
158+
Note:
159+
Caller must hold self._lock before calling this method.
160+
"""
139161
token = self._tokens.pop(token_id, None)
140162
if token:
141163
self._client_tokens.pop(token.client_id, None)
@@ -147,43 +169,47 @@ def remove_client_token(self, client_id: str) -> None:
147169
Args:
148170
client_id: Client identifier
149171
"""
150-
token_id = self._client_tokens.get(client_id)
151-
if token_id:
152-
self.remove_token(token_id)
172+
with self._lock:
173+
token_id = self._client_tokens.get(client_id)
174+
if token_id:
175+
self._remove_token_unsafe(token_id)
153176

154177
def cleanup_expired_tokens(self) -> int:
155178
"""Clean up expired tokens.
156179
157180
Returns:
158181
Number of tokens removed
159182
"""
160-
current_time = time.time()
161-
expired_token_ids = [
162-
token_id
163-
for token_id, token in self._tokens.items()
164-
if current_time >= token.expires_at
165-
]
183+
with self._lock:
184+
current_time = time.time()
185+
expired_token_ids = [
186+
token_id
187+
for token_id, token in self._tokens.items()
188+
if current_time >= token.expires_at
189+
]
166190

167-
for token_id in expired_token_ids:
168-
self.remove_token(token_id)
191+
for token_id in expired_token_ids:
192+
self._remove_token_unsafe(token_id)
169193

170-
if expired_token_ids:
171-
log.info("Cleaned up %d expired tokens", len(expired_token_ids))
194+
if expired_token_ids:
195+
log.info("Cleaned up %d expired tokens", len(expired_token_ids))
172196

173-
return len(expired_token_ids)
197+
return len(expired_token_ids)
174198

175199
def get_all_tokens(self) -> Dict[str, OAuthToken]:
176200
"""Get all stored tokens (for debugging/monitoring).
177201
178202
Returns:
179203
Dictionary of token_id -> OAuthToken
180204
"""
181-
return self._tokens.copy()
205+
with self._lock:
206+
return self._tokens.copy()
182207

183208
def get_token_count(self) -> int:
184209
"""Get number of stored tokens.
185210
186211
Returns:
187212
Number of tokens in store
188213
"""
189-
return len(self._tokens)
214+
with self._lock:
215+
return len(self._tokens)

0 commit comments

Comments
 (0)