-
Notifications
You must be signed in to change notification settings - Fork 577
Adding centralized cache for Entra tokens #21434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
5137f74
bdca23e
c0e8b2f
469cded
42967d4
7ff77a6
527833c
4a2a27a
c83f2e9
4d8987a
4a6eb48
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -43,6 +43,7 @@ import { | |
| TelemetryActions, | ||
| TelemetryViews, | ||
| } from "../sharedInterfaces/telemetry"; | ||
| import { ApiStatus, isStatus, Status } from "../sharedInterfaces/webview"; | ||
| import { ObjectExplorerUtils } from "../objectExplorer/objectExplorerUtils"; | ||
| import { changeLanguageServiceForFile } from "../languageservice/utils"; | ||
| import { AddFirewallRuleWebviewController } from "./addFirewallRuleWebviewController"; | ||
|
|
@@ -120,6 +121,11 @@ export default class ConnectionManager { | |
| Deferred<ConnectionContracts.ConnectionCompleteParams> | ||
| >; | ||
| private _keyVaultTokenCache: Map<string, IToken> = new Map<string, IToken>(); | ||
| private _entraSqlTokenCache: Map<string, IToken> = new Map<string, IToken>(); | ||
| private _entraSqlTokenRefreshInFlight: Map<string, Promise<IToken>> = new Map< | ||
| string, | ||
| Promise<IToken> | ||
| >(); | ||
| private _accountService: AccountService; | ||
| private _firewallService: FirewallService; | ||
| public azureController: AzureController; | ||
|
|
@@ -1028,19 +1034,23 @@ export default class ConnectionManager { | |
| * Does nothing if connection is not using Entra auth. | ||
| * throws if token refresh fails or if account/profile cannot be found. | ||
| */ | ||
| public async confirmEntraTokenValidity(connectionInfo: IConnectionInfo) { | ||
| public async refreshEntraTokenIfNeeded(connectionInfo: IConnectionInfo) { | ||
| // 1. Validate that the connection is using Entra auth | ||
| if (connectionInfo.authenticationType !== Constants.azureMfa) { | ||
| // Connection not using Entra auth, nothing to validate | ||
| return; | ||
| } | ||
|
|
||
| // 2. Validate that the token needs refreshing (isn't expired) | ||
| if ( | ||
| AzureController.isTokenValid(connectionInfo.azureAccountToken, connectionInfo.expiresOn) | ||
| ) { | ||
| // Token not expired, nothing to refresh | ||
| this._logger?.verbose( | ||
| `Entra token for account ${connectionInfo.user} (${connectionInfo.email}) is still valid until ${connectionInfo.expiresOn}. No refresh needed.`, | ||
| ); | ||
| return; | ||
| } | ||
|
|
||
| // 3. Collect Entra account information | ||
| let account: IAccount; | ||
| let profile: ConnectionProfile; | ||
|
|
||
|
|
@@ -1052,22 +1062,53 @@ export default class ConnectionManager { | |
| sendErrorEvent( | ||
| TelemetryViews.ConnectionManager, | ||
| TelemetryActions.Connect, | ||
| new Error("Azure MFA connection missing accountId in confirmEntraTokenValidity"), | ||
| new Error("Azure MFA connection missing accountId in refreshEntraTokenIfNeeded"), | ||
| true, // includeErrorMessage | ||
| ); | ||
| throw new Error(LocalizedConstants.cannotConnect); | ||
| } | ||
|
|
||
| if (!account) { | ||
| this._logger?.verbose( | ||
| `No account found in account store for accountId ${connectionInfo.accountId}. Cannot refresh Entra token.`, | ||
| ); | ||
| throw new Error(LocalizedConstants.msgAccountNotFound); | ||
| } | ||
|
|
||
| // Always set username | ||
| connectionInfo.user = account.displayInfo.displayName; | ||
| connectionInfo.email = account.displayInfo.email; | ||
| profile.user = account.displayInfo.displayName; | ||
| profile.email = account.displayInfo.email; | ||
|
|
||
| // 4. Use cached token if present and valid/unexpired | ||
| const cacheKey = this.getEntraSqlTokenCacheKey( | ||
| connectionInfo, | ||
| account.properties?.owningTenant?.id, | ||
| ); | ||
| const cachedToken = this._entraSqlTokenCache.get(cacheKey); | ||
|
|
||
| this._logger?.verbose( | ||
| `Cached token ${cachedToken ? "found" : "not found"} for cache key ${cacheKey}.`, | ||
| ); | ||
|
|
||
| if (cachedToken) { | ||
| // If there's a cached token, use it if still valid, or remove it from cache if expired | ||
| if (AzureController.isTokenValid(cachedToken.token, cachedToken.expiresOn)) { | ||
| this.applyEntraToken(connectionInfo, cachedToken); | ||
| this._logger?.verbose( | ||
| `Using cached Entra token for account ${account.displayInfo.displayName} (${account.displayInfo.email}) and tenant ${profile.tenantId}. Cached token expires on ${cachedToken.expiresOn}. (currently ${Date.now() / 1000})`, | ||
| ); | ||
|
|
||
| return; | ||
Benjin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } else { | ||
| this._logger?.verbose( | ||
| `Cached token for cache key ${cacheKey} is expired. Removing from cache. (currently ${Date.now() / 1000})`, | ||
| ); | ||
| this._entraSqlTokenCache.delete(cacheKey); | ||
| } | ||
| } | ||
|
|
||
| // 5. Lastly, refresh the token, cache the new token, and update the connection info with it | ||
| const refreshTask = async () => { | ||
| return await this.azureController.refreshAccessToken( | ||
| account, | ||
|
|
@@ -1077,63 +1118,109 @@ export default class ConnectionManager { | |
| ); | ||
| }; | ||
|
|
||
| /** | ||
| * Token refresh code cannot figure out if the user closed the browser window, | ||
| * so we wrap it in a cancellable progress dialog to allow the user to cancel | ||
| * the operation. If the user cancels, we resolve with undefined and handle | ||
| * that case below. | ||
| */ | ||
| const azureAccountToken = await new Promise<IToken | undefined>((resolve) => { | ||
| vscode.window.withProgress( | ||
| { | ||
| location: vscode.ProgressLocation.Notification, | ||
| title: LocalizedConstants.ObjectExplorer.AzureSignInMessage, | ||
| cancellable: true, | ||
| }, | ||
| async (progress, token) => { | ||
| token.onCancellationRequested(() => { | ||
| this._logger.verbose("Azure sign in cancelled by user."); | ||
| resolve(undefined); | ||
| }); | ||
| try { | ||
| resolve(await refreshTask()); | ||
| } catch (error) { | ||
| this._logger.error("Error refreshing account: " + error); | ||
| this._vscodeWrapper.showErrorMessage(error.message); | ||
| resolve(undefined); | ||
| } | ||
| }, | ||
| ); | ||
| }); | ||
| // Dedupe concurrent token refresh requests for the same account into a single request, and share the result | ||
| let refreshPromise = this._entraSqlTokenRefreshInFlight.get(cacheKey); | ||
| if (!refreshPromise) { | ||
| // Token refresh code cannot figure out if the user closed the browser window, | ||
| // so we wrap it in a cancellable progress dialog to allow the user to cancel | ||
| // the operation. | ||
| refreshPromise = new Promise<IToken>((resolve, reject) => { | ||
| void vscode.window.withProgress( | ||
| { | ||
| location: vscode.ProgressLocation.Notification, | ||
| title: LocalizedConstants.ObjectExplorer.AzureSignInMessage( | ||
| account.displayInfo.displayName || account.displayInfo.email, | ||
| ), | ||
| cancellable: true, | ||
| }, | ||
| async (_progress, token) => { | ||
| token.onCancellationRequested(() => { | ||
| reject({ | ||
| status: ApiStatus.Cancelled, | ||
| message: "Azure sign in cancelled by user.", | ||
| } as Status); | ||
| }); | ||
| try { | ||
| const refreshedToken = await refreshTask(); | ||
| if (!refreshedToken) { | ||
| reject({ | ||
| status: ApiStatus.Error, | ||
| message: LocalizedConstants.msgAccountRefreshFailed(), | ||
| } as Status); | ||
| return; | ||
| } | ||
| this._logger?.verbose( | ||
| `Successfully refreshed Entra token for account ${account.displayInfo.displayName} (${account.displayInfo.email}) and tenant ${profile.tenantId}; now expires on ${refreshedToken.expiresOn} (currently ${Date.now() / 1000}).`, | ||
| ); | ||
| resolve(refreshedToken); | ||
Benjin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } catch (error) { | ||
| const refreshErrorStatus: Status = { | ||
| status: ApiStatus.Error, | ||
| message: getErrorMessage(error), | ||
| }; | ||
| this._logger?.error( | ||
| `Error refreshing Entra token for account ${account.displayInfo.displayName} (${account.displayInfo.email}) and tenant ${profile.tenantId}: ${refreshErrorStatus.message}`, | ||
| ); | ||
| reject(refreshErrorStatus); | ||
| } | ||
|
Comment on lines
+1137
to
+1165
|
||
| }, | ||
| ); | ||
| }).finally(() => { | ||
| this._entraSqlTokenRefreshInFlight.delete(cacheKey); | ||
| }); | ||
| this._entraSqlTokenRefreshInFlight.set(cacheKey, refreshPromise); | ||
| } | ||
|
|
||
| if (!azureAccountToken) { | ||
| let errorMessage = LocalizedConstants.msgAccountRefreshFailed; | ||
| let refreshResult = await this.vscodeWrapper.showErrorMessage( | ||
| errorMessage, | ||
| LocalizedConstants.refreshTokenLabel, | ||
| try { | ||
| const azureAccountToken = await refreshPromise; | ||
| this.applyEntraToken(connectionInfo, azureAccountToken); | ||
| // Save refreshed token so other connections for the same account+tenant can reuse it. | ||
| this._entraSqlTokenCache.set(cacheKey, azureAccountToken); | ||
| this._logger?.verbose( | ||
| `Successfully refreshed Entra token for account ${account.displayInfo.displayName} (${account.displayInfo.email}) and tenant ${profile.tenantId}. Cached token for future use with cache key ${cacheKey}.`, | ||
| ); | ||
| if (refreshResult === LocalizedConstants.refreshTokenLabel) { | ||
| await this.azureController.populateAccountProperties( | ||
| profile, | ||
| this.accountStore, | ||
| getCloudProviderSettings(account.key.providerId).settings.sqlResource!, | ||
| ); | ||
| } catch (error) { | ||
| this._logger?.verbose( | ||
| `Failed to refresh Entra token for account ${account.displayInfo.displayName} (${account.displayInfo.email}) and tenant ${profile.tenantId}. Error: ${getErrorMessage(error)}`, | ||
| ); | ||
| if (isStatus(error)) { | ||
| if (error.status === ApiStatus.Cancelled) { | ||
| this._logger.verbose("Refresh cancelled: " + error.message); | ||
| throw new Error(LocalizedConstants.cannotConnect); | ||
| } | ||
|
|
||
| connectionInfo.azureAccountToken = profile.azureAccountToken; | ||
| connectionInfo.expiresOn = profile.expiresOn; | ||
| connectionInfo.accountId = profile.accountId; | ||
| connectionInfo.tenantId = profile.tenantId; | ||
| connectionInfo.user = profile.user; | ||
| connectionInfo.email = profile.email; | ||
| } else { | ||
| throw new Error(LocalizedConstants.cannotConnect); | ||
| if (error.status === ApiStatus.Error) { | ||
| const message = LocalizedConstants.msgAccountRefreshFailed(error.message); | ||
| this._logger.error("Error refreshing account: " + message); | ||
| await this.vscodeWrapper.showErrorMessage(message); | ||
| throw new Error(message); | ||
Benjin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
| } | ||
| } else { | ||
| connectionInfo.azureAccountToken = azureAccountToken.token; | ||
| connectionInfo.expiresOn = azureAccountToken.expiresOn; | ||
|
|
||
| throw error; | ||
| } | ||
| } | ||
|
|
||
| private getEntraSqlTokenCacheKey( | ||
| connectionInfo: IConnectionInfo, | ||
| defaultTenantId?: string, | ||
| ): string { | ||
| return `${connectionInfo.accountId ?? ""}|${connectionInfo.tenantId ?? defaultTenantId ?? ""}`; | ||
| } | ||
|
|
||
| private applyEntraToken(connectionInfo: IConnectionInfo, token: IToken): void { | ||
| connectionInfo.azureAccountToken = token.token; | ||
| connectionInfo.expiresOn = token.expiresOn; | ||
| } | ||
|
|
||
| /** | ||
| * Clears both token entries and any in-flight refresh promises. | ||
| */ | ||
| private clearEntraSqlTokenCache(): void { | ||
| this._entraSqlTokenCache.clear(); | ||
| this._entraSqlTokenRefreshInFlight.clear(); | ||
| } | ||
|
|
||
| /** | ||
| * Handles password-based credential authentication by prompting for password if needed. | ||
| * This method checks if a password is required and prompts the user if it's not saved or available. | ||
|
|
@@ -1422,7 +1509,7 @@ export default class ConnectionManager { | |
| // Handle Entra token validity | ||
| if (credentials.authenticationType === Constants.azureMfa) { | ||
| try { | ||
| await this.confirmEntraTokenValidity(credentials); | ||
| await this.refreshEntraTokenIfNeeded(credentials); | ||
| } catch (error) { | ||
| telemetryActivity?.endFailed( | ||
| error, | ||
|
|
@@ -1653,14 +1740,17 @@ export default class ConnectionManager { | |
| }; | ||
| } else if (errorType === SqlConnectionErrorType.EntraTokenExpired) { | ||
| try { | ||
| await this.confirmEntraTokenValidity(credentials); | ||
| await this.refreshEntraTokenIfNeeded(credentials); | ||
| return { | ||
| isHandled: true, | ||
| updatedCredentials: credentials, | ||
| errorHandled: SqlConnectionErrorType.EntraTokenExpired, | ||
| }; | ||
| } catch (error) { | ||
| Utils.showErrorMsg(getErrorMessage(error)); | ||
| const errorMessage = getErrorMessage(error); | ||
| if (errorMessage !== LocalizedConstants.cannotConnect) { | ||
| Utils.showErrorMsg(errorMessage); | ||
| } | ||
Benjin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return { | ||
| isHandled: false, | ||
| updatedCredentials: credentials, | ||
|
|
@@ -1853,7 +1943,7 @@ export default class ConnectionManager { | |
| // No connection for this URI, nothing to do | ||
| return; | ||
| } | ||
| await this.confirmEntraTokenValidity(connectionInfo.credentials); | ||
| await this.refreshEntraTokenIfNeeded(connectionInfo.credentials); | ||
| } | ||
|
|
||
| public async addAccount(): Promise<IAccount> { | ||
|
|
@@ -1907,6 +1997,7 @@ export default class ConnectionManager { | |
|
|
||
| public onClearAzureTokenCache(): void { | ||
| this.azureController.clearTokenCache(); | ||
| this.clearEntraSqlTokenCache(); | ||
| this.vscodeWrapper.showInformationMessage( | ||
| LocalizedConstants.Accounts.clearedEntraTokenCache, | ||
| ); | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.