44oauth.py and mcp_oauth_middleware.py.
55"""
66
7+ import threading
78import time
89from typing import Dict , Optional
910
@@ -19,10 +20,14 @@ class TokenStore:
1920 - mcp_oauth_middleware.completed_tokens (client_id -> token_id)
2021
2122 Into a single, well-defined storage mechanism.
23+
24+ Thread-safe: All operations are protected by a re-entrant lock to handle
25+ concurrent access from FastAPI async handlers and MCP tool worker threads.
2226 """
2327
2428 def __init__ (self ) -> None :
2529 """Initialize token store."""
30+ self ._lock = threading .RLock () # Re-entrant lock for thread safety
2631 self ._tokens : Dict [str , OAuthToken ] = {}
2732 self ._client_tokens : Dict [str , str ] = {} # client_id -> token_id
2833
@@ -32,14 +37,15 @@ def store_token(self, token: OAuthToken) -> None:
3237 Args:
3338 token: OAuthToken to store
3439 """
35- self ._tokens [token .token_id ] = token
36- self ._client_tokens [token .client_id ] = token .token_id
37- log .debug (
38- "Stored token %s for client %s (expires at %s)" ,
39- token .token_id ,
40- token .client_id ,
41- token .expires_at ,
42- )
40+ with self ._lock :
41+ self ._tokens [token .token_id ] = token
42+ self ._client_tokens [token .client_id ] = token .token_id
43+ log .debug (
44+ "Stored token %s for client %s (expires at %s)" ,
45+ token .token_id ,
46+ token .client_id ,
47+ token .expires_at ,
48+ )
4349
4450 def get_token_by_id (self , token_id : str ) -> Optional [OAuthToken ]:
4551 """Get token by token ID.
@@ -50,16 +56,17 @@ def get_token_by_id(self, token_id: str) -> Optional[OAuthToken]:
5056 Returns:
5157 OAuthToken if found and valid, None otherwise
5258 """
53- token = self ._tokens .get (token_id )
54- if not token :
55- return None
59+ with self ._lock :
60+ token = self ._tokens .get (token_id )
61+ if not token :
62+ return None
5663
57- if token .is_expired ():
58- log .debug ("Token %s is expired, removing from store" , token_id )
59- self .remove_token (token_id )
60- return None
64+ if token .is_expired ():
65+ log .debug ("Token %s is expired, removing from store" , token_id )
66+ self ._remove_token_unsafe (token_id )
67+ return None
6168
62- return token
69+ return token
6370
6471 def get_token_by_client (self , client_id : str ) -> Optional [OAuthToken ]:
6572 """Get token for a client.
@@ -70,11 +77,13 @@ def get_token_by_client(self, client_id: str) -> Optional[OAuthToken]:
7077 Returns:
7178 OAuthToken if found and valid, None otherwise
7279 """
73- token_id = self ._client_tokens .get (client_id )
74- if not token_id :
75- return None
80+ with self ._lock :
81+ token_id = self ._client_tokens .get (client_id )
82+ if not token_id :
83+ return None
7684
77- return self .get_token_by_id (token_id )
85+ # Re-use get_token_by_id which also acquires lock (RLock allows re-entrance)
86+ return self .get_token_by_id (token_id )
7887
7988 def get_access_token_by_id (self , token_id : str ) -> Optional [str ]:
8089 """Get access token string by token ID.
@@ -118,24 +127,37 @@ def update_token(
118127 Returns:
119128 True if token was updated, False if token not found
120129 """
121- token = self ._tokens .get (token_id )
122- if not token :
123- return False
130+ with self ._lock :
131+ token = self ._tokens .get (token_id )
132+ if not token :
133+ return False
124134
125- token .access_token = access_token
126- if refresh_token :
127- token .refresh_token = refresh_token
128- token .expires_at = expires_at
135+ token .access_token = access_token
136+ if refresh_token :
137+ token .refresh_token = refresh_token
138+ token .expires_at = expires_at
129139
130- log .debug ("Updated token %s (new expiry: %s)" , token_id , expires_at )
131- return True
140+ log .debug ("Updated token %s (new expiry: %s)" , token_id , expires_at )
141+ return True
132142
133143 def remove_token (self , token_id : str ) -> None :
134144 """Remove a token and its associations.
135145
136146 Args:
137147 token_id: Token identifier to remove
138148 """
149+ with self ._lock :
150+ self ._remove_token_unsafe (token_id )
151+
152+ def _remove_token_unsafe (self , token_id : str ) -> None :
153+ """Remove a token without acquiring lock (internal use only).
154+
155+ Args:
156+ token_id: Token identifier to remove
157+
158+ Note:
159+ Caller must hold self._lock before calling this method.
160+ """
139161 token = self ._tokens .pop (token_id , None )
140162 if token :
141163 self ._client_tokens .pop (token .client_id , None )
@@ -147,43 +169,47 @@ def remove_client_token(self, client_id: str) -> None:
147169 Args:
148170 client_id: Client identifier
149171 """
150- token_id = self ._client_tokens .get (client_id )
151- if token_id :
152- self .remove_token (token_id )
172+ with self ._lock :
173+ token_id = self ._client_tokens .get (client_id )
174+ if token_id :
175+ self ._remove_token_unsafe (token_id )
153176
154177 def cleanup_expired_tokens (self ) -> int :
155178 """Clean up expired tokens.
156179
157180 Returns:
158181 Number of tokens removed
159182 """
160- current_time = time .time ()
161- expired_token_ids = [
162- token_id
163- for token_id , token in self ._tokens .items ()
164- if current_time >= token .expires_at
165- ]
183+ with self ._lock :
184+ current_time = time .time ()
185+ expired_token_ids = [
186+ token_id
187+ for token_id , token in self ._tokens .items ()
188+ if current_time >= token .expires_at
189+ ]
166190
167- for token_id in expired_token_ids :
168- self .remove_token (token_id )
191+ for token_id in expired_token_ids :
192+ self ._remove_token_unsafe (token_id )
169193
170- if expired_token_ids :
171- log .info ("Cleaned up %d expired tokens" , len (expired_token_ids ))
194+ if expired_token_ids :
195+ log .info ("Cleaned up %d expired tokens" , len (expired_token_ids ))
172196
173- return len (expired_token_ids )
197+ return len (expired_token_ids )
174198
175199 def get_all_tokens (self ) -> Dict [str , OAuthToken ]:
176200 """Get all stored tokens (for debugging/monitoring).
177201
178202 Returns:
179203 Dictionary of token_id -> OAuthToken
180204 """
181- return self ._tokens .copy ()
205+ with self ._lock :
206+ return self ._tokens .copy ()
182207
183208 def get_token_count (self ) -> int :
184209 """Get number of stored tokens.
185210
186211 Returns:
187212 Number of tokens in store
188213 """
189- return len (self ._tokens )
214+ with self ._lock :
215+ return len (self ._tokens )
0 commit comments