diff --git a/extensions/mssql/src/constants/constants.ts b/extensions/mssql/src/constants/constants.ts index 8c3a83e481..01fa4df82c 100644 --- a/extensions/mssql/src/constants/constants.ts +++ b/extensions/mssql/src/constants/constants.ts @@ -15,6 +15,8 @@ export const mssqlChatParticipantName = "mssql"; // must be the same as the one export const noneProviderName = "None"; export const objectExplorerId = "objectExplorer"; export const queryHistory = "queryHistory"; +export const queryHistoryEncryptionKeySecretStorageKey = "mssql.queryHistoryTreeEncryptionKey"; +export const queryHistoryGlobalStorageFileName = "queryHistoryTree.enc"; export const connectionApplicationName = "vscode-mssql"; export const outputChannelName = "MSSQL"; export const connectionConfigFilename = "settings.json"; diff --git a/extensions/mssql/src/controllers/mainController.ts b/extensions/mssql/src/controllers/mainController.ts index 011f7358f9..6848354bb5 100644 --- a/extensions/mssql/src/controllers/mainController.ts +++ b/extensions/mssql/src/controllers/mainController.ts @@ -2243,6 +2243,7 @@ export default class MainController implements vscode.Disposable { this._sqlDocumentService, this._statusview, this._prompter, + this._context, ); this._context.subscriptions.push( diff --git a/extensions/mssql/src/queryHistory/queryHistoryNode.ts b/extensions/mssql/src/queryHistory/queryHistoryNode.ts index d35831513d..6a32804f6e 100644 --- a/extensions/mssql/src/queryHistory/queryHistoryNode.ts +++ b/extensions/mssql/src/queryHistory/queryHistoryNode.ts @@ -90,4 +90,8 @@ export class QueryHistoryNode extends vscode.TreeItem { public get connectionLabel(): string { return this._connectionLabel; } + + public get isSuccess(): boolean { + return this._isSuccess; + } } diff --git a/extensions/mssql/src/queryHistory/queryHistoryProvider.ts b/extensions/mssql/src/queryHistory/queryHistoryProvider.ts index 04c94a082b..ab004e5e01 100644 --- a/extensions/mssql/src/queryHistory/queryHistoryProvider.ts +++ b/extensions/mssql/src/queryHistory/queryHistoryProvider.ts @@ -19,16 +19,31 @@ import { IPrompter } from "../prompts/question"; import { QueryHistoryUI, QueryHistoryAction } from "../views/queryHistoryUI"; import { getUriKey } from "../utils/utils"; import { Deferred } from "../protocol"; +import * as vscodeMssql from "vscode-mssql"; +import { + decryptData, + type EncryptedData, + encryptData, + generateEncryptionKey, +} from "../utils/encryptionUtils"; -export class QueryHistoryProvider implements vscode.TreeDataProvider { - private _onDidChangeTreeData: vscode.EventEmitter = new vscode.EventEmitter< - any | undefined - >(); - readonly onDidChangeTreeData: vscode.Event = this._onDidChangeTreeData.event; +type QueryHistoryTreeNode = QueryHistoryNode | EmptyHistoryNode; - private _queryHistoryNodes: vscode.TreeItem[] = [new EmptyHistoryNode()]; +export class QueryHistoryProvider implements vscode.TreeDataProvider { + private _onDidChangeTreeData: vscode.EventEmitter = + new vscode.EventEmitter(); + readonly onDidChangeTreeData: vscode.Event = + this._onDidChangeTreeData.event; + + private _queryHistoryNodes: QueryHistoryTreeNode[] = [new EmptyHistoryNode()]; private _queryHistoryLimit: number; private _queryHistoryUI: QueryHistoryUI; + private _queryHistoryMutationId = 0; + + /** + * Version number for the persisted query history. Increment this if there are breaking changes to the persisted format to ensure old formats are not loaded. + */ + private static readonly _queryHistoryStorageVersion = 1; constructor( private _connectionManager: ConnectionManager, @@ -37,18 +52,22 @@ export class QueryHistoryProvider implements vscode.TreeDataProvider { private _sqlDocumentService: SqlDocumentService, private _statusView: StatusView, private _prompter: IPrompter, + private _context: vscode.ExtensionContext, ) { const config = this._vscodeWrapper.getConfiguration(Constants.extensionConfigSectionName); this._queryHistoryLimit = config.get(Constants.configQueryHistoryLimit); this._queryHistoryUI = new QueryHistoryUI(this._prompter); + void this.restoreQueryHistory(); } clearAll(): void { + this._queryHistoryMutationId++; this._queryHistoryNodes = [new EmptyHistoryNode()]; this._onDidChangeTreeData.fire(undefined); + void this.persistQueryHistory(); } - refresh(ownerUri: string, timeStamp: Date, hasError): void { + refresh(ownerUri: string, timeStamp: Date, hasError: boolean): void { const timeStampString = timeStamp.toLocaleString(); const historyNodeLabel = this.createHistoryNodeLabel(ownerUri); const tooltip = this.createHistoryNodeTooltip(ownerUri, timeStampString); @@ -64,31 +83,32 @@ export class QueryHistoryProvider implements vscode.TreeDataProvider { connectionLabel, !hasError, ); + + this._queryHistoryMutationId++; if (this._queryHistoryNodes.length === 1) { if (this._queryHistoryNodes[0] instanceof EmptyHistoryNode) { this._queryHistoryNodes = []; } } this._queryHistoryNodes.push(node); - // sort the query history sorted by timestamp this._queryHistoryNodes.sort((a, b) => { return ( (b as QueryHistoryNode).timeStamp.getTime() - (a as QueryHistoryNode).timeStamp.getTime() ); }); - // Remove old entries if we are over the limit. if (this._queryHistoryNodes.length > this._queryHistoryLimit) { this._queryHistoryNodes.pop(); } this._onDidChangeTreeData.fire(undefined); + void this.persistQueryHistory(); } getTreeItem(node: QueryHistoryNode): QueryHistoryNode { return node; } - getChildren(element?: any): vscode.TreeItem[] { + getChildren(_element?: QueryHistoryTreeNode): QueryHistoryTreeNode[] { if (this._queryHistoryNodes.length === 0) { this._queryHistoryNodes.push(new EmptyHistoryNode()); } @@ -191,8 +211,10 @@ export class QueryHistoryProvider implements vscode.TreeDataProvider { let historyNode = n as QueryHistoryNode; return historyNode === node; }); + this._queryHistoryMutationId++; this._queryHistoryNodes.splice(index, 1); this._onDidChangeTreeData.fire(undefined); + void this.persistQueryHistory(); } /** @@ -246,4 +268,222 @@ export class QueryHistoryProvider implements vscode.TreeDataProvider { const connectionLabel = this.getConnectionLabel(ownerUri); return `${connectionLabel}${os.EOL}${os.EOL}${timeStamp}${os.EOL}${os.EOL}${queryString}`; } + + private createPersistedHistoryNodeLabel(queryString: string, connectionLabel: string): string { + const limitedQueryString = Utils.limitStringSize(queryString).trim(); + const limitedConnectionLabel = Utils.limitStringSize(connectionLabel).trim(); + return `${limitedQueryString} : ${limitedConnectionLabel}`; + } + + private createPersistedHistoryNodeTooltip( + queryString: string, + connectionLabel: string, + timeStamp: string, + ): string { + return `${connectionLabel}${os.EOL}${os.EOL}${timeStamp}${os.EOL}${os.EOL}${queryString}`; + } + + private async restoreQueryHistory(): Promise { + const restoreMutationId = this._queryHistoryMutationId; + + try { + const serializedHistory = await this.readEncryptedPersistedQueryHistory(); + if (!serializedHistory) { + return; + } + + const persistedHistory = JSON.parse(serializedHistory) as PersistedQueryHistory; + if ( + !persistedHistory || + persistedHistory.version !== QueryHistoryProvider._queryHistoryStorageVersion || + !Array.isArray(persistedHistory.nodes) + ) { + return; + } + + const restoredNodes = persistedHistory.nodes + .map((node) => this.createNodeFromPersisted(node)) + .filter((node): node is QueryHistoryNode => node !== undefined); + + if (restoreMutationId !== this._queryHistoryMutationId) { + return; + } + + if (restoredNodes.length === 0) { + this._queryHistoryNodes = [new EmptyHistoryNode()]; + } else { + restoredNodes.sort((a, b) => b.timeStamp.getTime() - a.timeStamp.getTime()); + this._queryHistoryNodes = restoredNodes.slice(0, this._queryHistoryLimit); + } + + this._onDidChangeTreeData.fire(undefined); + } catch { + if (restoreMutationId === this._queryHistoryMutationId) { + this._queryHistoryNodes = [new EmptyHistoryNode()]; + } + } + } + + private createNodeFromPersisted(node: PersistedQueryHistoryNode): QueryHistoryNode | undefined { + if ( + !node || + typeof node.queryString !== "string" || + typeof node.connectionLabel !== "string" || + typeof node.timeStamp !== "number" || + typeof node.isSuccess !== "boolean" + ) { + return undefined; + } + + const restoredTimestamp = new Date(node.timeStamp); + if (Number.isNaN(restoredTimestamp.getTime())) { + return undefined; + } + + const label = this.createPersistedHistoryNodeLabel(node.queryString, node.connectionLabel); + const tooltip = this.createPersistedHistoryNodeTooltip( + node.queryString, + node.connectionLabel, + restoredTimestamp.toLocaleString(), + ); + + return new QueryHistoryNode( + label, + tooltip, + node.queryString, + node.ownerUri ?? "", + node.credentials, + restoredTimestamp, + node.connectionLabel, + node.isSuccess, + ); + } + + private async persistQueryHistory(): Promise { + const historyNodes = this._queryHistoryNodes.filter( + (node): node is QueryHistoryNode => node instanceof QueryHistoryNode, + ); + + if (historyNodes.length === 0) { + await this.clearPersistedQueryHistoryContent(); + return; + } + + const payload: PersistedQueryHistory = { + version: QueryHistoryProvider._queryHistoryStorageVersion, + nodes: historyNodes.slice(0, this._queryHistoryLimit).map((node) => ({ + queryString: node.queryString, + ownerUri: node.ownerUri, + credentials: this.sanitizeCredentialsForPersistence(node.credentials), + timeStamp: node.timeStamp.getTime(), + connectionLabel: node.connectionLabel, + isSuccess: node.isSuccess, + })), + }; + + await this.writePersistedQueryHistoryContent(JSON.stringify(payload)); + } + + private sanitizeCredentialsForPersistence( + credentials?: vscodeMssql.IConnectionInfo, + ): vscodeMssql.IConnectionInfo | undefined { + if (!credentials) { + return undefined; + } + + const persistedCredentials = { ...credentials }; + if ((credentials as IConnectionProfile).savePassword === false) { + persistedCredentials.password = ""; + } + + return persistedCredentials; + } + + private async readEncryptedPersistedQueryHistory(): Promise { + const storageFileUri = this.getPersistedQueryHistoryFileUri(); + if (!(await this.persistedQueryHistoryFileExists(storageFileUri))) { + return undefined; + } + + const encryptionKey = await this._context.secrets.get( + Constants.queryHistoryEncryptionKeySecretStorageKey, + ); + if (!encryptionKey) { + return undefined; + } + + const encryptedFileContents = await vscode.workspace.fs.readFile(storageFileUri); + const encryptedData = JSON.parse( + new TextDecoder().decode(encryptedFileContents), + ) as EncryptedData; + + return decryptData(encryptedData, encryptionKey); + } + + private async writePersistedQueryHistoryContent(serializedHistory: string): Promise { + const storageFileUri = this.getPersistedQueryHistoryFileUri(); + const encryptionKey = await this.getOrCreateQueryHistoryEncryptionKey(); + const encryptedData = encryptData(serializedHistory, encryptionKey); + + await vscode.workspace.fs.createDirectory(this._context.globalStorageUri); + await vscode.workspace.fs.writeFile( + storageFileUri, + new TextEncoder().encode(JSON.stringify(encryptedData)), + ); + } + + private async clearPersistedQueryHistoryContent(): Promise { + try { + await vscode.workspace.fs.delete(this.getPersistedQueryHistoryFileUri(), { + useTrash: false, + }); + } catch { + // Ignore missing file errors when clearing persisted history. + } + } + + private async getOrCreateQueryHistoryEncryptionKey(): Promise { + let encryptionKey = await this._context.secrets.get( + Constants.queryHistoryEncryptionKeySecretStorageKey, + ); + if (!encryptionKey) { + encryptionKey = generateEncryptionKey(); + await this._context.secrets.store( + Constants.queryHistoryEncryptionKeySecretStorageKey, + encryptionKey, + ); + } + + return encryptionKey; + } + + private async persistedQueryHistoryFileExists(storageFileUri: vscode.Uri): Promise { + try { + await vscode.workspace.fs.stat(storageFileUri); + return true; + } catch { + return false; + } + } + + private getPersistedQueryHistoryFileUri(): vscode.Uri { + return vscode.Uri.joinPath( + this._context.globalStorageUri, + Constants.queryHistoryGlobalStorageFileName, + ); + } +} + +interface PersistedQueryHistoryNode { + queryString: string; + ownerUri?: string; + credentials?: vscodeMssql.IConnectionInfo; + timeStamp: number; + connectionLabel: string; + isSuccess: boolean; +} + +interface PersistedQueryHistory { + version: number; + nodes: PersistedQueryHistoryNode[]; } diff --git a/extensions/mssql/src/utils/encryptionUtils.ts b/extensions/mssql/src/utils/encryptionUtils.ts new file mode 100644 index 0000000000..94df7f4976 --- /dev/null +++ b/extensions/mssql/src/utils/encryptionUtils.ts @@ -0,0 +1,66 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import * as crypto from "crypto"; + +const encryptionAlgorithm = "aes-256-gcm"; +const encryptionKeyLength = 32; +const initializationVectorLength = 12; + +export interface EncryptedData { + version: 1; + algorithm: typeof encryptionAlgorithm; + iv: string; + authTag: string; + ciphertext: string; +} + +export function generateEncryptionKey(): string { + return crypto.randomBytes(encryptionKeyLength).toString("base64"); +} + +export function encryptData(data: string, key: string): EncryptedData { + const encryptionKey = getEncryptionKeyBuffer(key); + const initializationVector = crypto.randomBytes(initializationVectorLength); + const cipher = crypto.createCipheriv(encryptionAlgorithm, encryptionKey, initializationVector); + const ciphertext = Buffer.concat([cipher.update(data, "utf8"), cipher.final()]); + + return { + version: 1, + algorithm: encryptionAlgorithm, + iv: initializationVector.toString("base64"), + authTag: cipher.getAuthTag().toString("base64"), + ciphertext: ciphertext.toString("base64"), + }; +} + +export function decryptData(encryptedData: EncryptedData, key: string): string { + if (encryptedData.version !== 1 || encryptedData.algorithm !== encryptionAlgorithm) { + throw new Error("Unsupported encrypted payload."); + } + + const decipher = crypto.createDecipheriv( + encryptionAlgorithm, + getEncryptionKeyBuffer(key), + Buffer.from(encryptedData.iv, "base64"), + ); + decipher.setAuthTag(Buffer.from(encryptedData.authTag, "base64")); + + const plaintext = Buffer.concat([ + decipher.update(Buffer.from(encryptedData.ciphertext, "base64")), + decipher.final(), + ]); + + return plaintext.toString("utf8"); +} + +function getEncryptionKeyBuffer(key: string): Buffer { + const keyBuffer = Buffer.from(key, "base64"); + if (keyBuffer.length !== encryptionKeyLength) { + throw new Error("Invalid encryption key length."); + } + + return keyBuffer; +} diff --git a/extensions/mssql/test/unit/encryptionUtils.test.ts b/extensions/mssql/test/unit/encryptionUtils.test.ts new file mode 100644 index 0000000000..85f6300e5d --- /dev/null +++ b/extensions/mssql/test/unit/encryptionUtils.test.ts @@ -0,0 +1,36 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import { expect } from "chai"; +import { decryptData, encryptData, generateEncryptionKey } from "../../src/utils/encryptionUtils"; + +suite("encryptionUtils", () => { + test("generateEncryptionKey should return a 32-byte base64 key", () => { + const encryptionKey = generateEncryptionKey(); + + expect(Buffer.from(encryptionKey, "base64")).to.have.lengthOf(32); + }); + + test("encryptData and decryptData should round-trip plaintext", () => { + const encryptionKey = generateEncryptionKey(); + const plaintext = JSON.stringify({ + version: 1, + nodes: [{ queryString: "SELECT 1", isSuccess: true }], + }); + + const encryptedData = encryptData(plaintext, encryptionKey); + + expect(decryptData(encryptedData, encryptionKey)).to.equal(plaintext); + }); + + test("decryptData should reject tampered ciphertext", () => { + const encryptionKey = generateEncryptionKey(); + const encryptedData = encryptData("sensitive query history", encryptionKey); + + encryptedData.ciphertext = `A${encryptedData.ciphertext.slice(1)}`; + + expect(() => decryptData(encryptedData, encryptionKey)).to.throw(); + }); +}); diff --git a/extensions/mssql/test/unit/queryHistoryProvider.test.ts b/extensions/mssql/test/unit/queryHistoryProvider.test.ts new file mode 100644 index 0000000000..961383c6af --- /dev/null +++ b/extensions/mssql/test/unit/queryHistoryProvider.test.ts @@ -0,0 +1,507 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import * as sinon from "sinon"; +import sinonChai from "sinon-chai"; +import { expect } from "chai"; +import * as chai from "chai"; +import * as vscode from "vscode"; +import type { IConnectionInfo } from "vscode-mssql"; +import { QueryHistoryProvider } from "../../src/queryHistory/queryHistoryProvider"; +import { QueryHistoryNode, EmptyHistoryNode } from "../../src/queryHistory/queryHistoryNode"; +import ConnectionManager, { ConnectionInfo } from "../../src/controllers/connectionManager"; +import { SqlOutputContentProvider } from "../../src/models/sqlOutputContentProvider"; +import VscodeWrapper from "../../src/controllers/vscodeWrapper"; +import SqlDocumentService from "../../src/controllers/sqlDocumentService"; +import StatusView from "../../src/views/statusView"; +import * as Constants from "../../src/constants/constants"; +import type { IConnectionProfile } from "../../src/models/interfaces"; +import { stubVscodeWrapper, initializeIconUtils } from "./utils"; +import { createWorkspaceConfiguration } from "./stubs"; +import { IPrompter } from "../../src/prompts/question"; +import CodeAdapter from "../../src/prompts/adapter"; +import { + decryptData, + type EncryptedData, + encryptData, + generateEncryptionKey, +} from "../../src/utils/encryptionUtils"; + +chai.use(sinonChai); + +suite("QueryHistoryProvider persistence", () => { + type QueryRunnerStub = Pick< + ReturnType, + "getQueryString" + >; + type TestConnectionCredentials = Pick< + IConnectionProfile, + "server" | "database" | "authenticationType" | "user" + > & + Partial>; + type QueryHistoryProviderPrivate = QueryHistoryProvider & { + readEncryptedPersistedQueryHistory(): Promise; + writePersistedQueryHistoryContent(serializedHistory: string): Promise; + clearPersistedQueryHistoryContent(): Promise; + }; + + interface PersistedQueryHistoryPayload { + version: number; + nodes: Array<{ + queryString: string; + ownerUri?: string; + credentials?: Record; + timeStamp: number; + connectionLabel: string; + isSuccess: boolean; + }>; + } + + let sandbox: sinon.SinonSandbox; + let provider: QueryHistoryProvider; + let connectionManagerStub: sinon.SinonStubbedInstance; + let outputContentProviderStub: sinon.SinonStubbedInstance; + let vscodeWrapperStub: sinon.SinonStubbedInstance; + let sqlDocumentServiceStub: sinon.SinonStubbedInstance; + let statusViewStub: sinon.SinonStubbedInstance; + let prompterStub: sinon.SinonStubbedInstance; + let secretStorage: { + get: sinon.SinonStub<[string], Promise>; + store: sinon.SinonStub<[string, string], Promise>; + delete: sinon.SinonStub<[string], Promise>; + }; + let context: vscode.ExtensionContext; + let secretValues: Map; + let persistedFileContents: Uint8Array | undefined; + let readEncryptedPersistedQueryHistoryStub: sinon.SinonStub<[], Promise>; + + function createProvider(): QueryHistoryProvider { + return new QueryHistoryProvider( + connectionManagerStub as unknown as ConnectionManager, + outputContentProviderStub as unknown as SqlOutputContentProvider, + vscodeWrapperStub as unknown as VscodeWrapper, + sqlDocumentServiceStub as unknown as SqlDocumentService, + statusViewStub as unknown as StatusView, + prompterStub as unknown as IPrompter, + context, + ); + } + + function createTestNode( + queryString: string = "SELECT 1", + connectionLabel: string = "(localhost|master)", + timeStamp: Date = new Date(2025, 0, 15, 10, 30, 0), + isSuccess: boolean = true, + ownerUri: string = "file:///test.sql", + ): QueryHistoryNode { + const label = `${queryString} : ${connectionLabel}`; + const tooltip = `${connectionLabel}\n\n${timeStamp.toLocaleString()}\n\n${queryString}`; + return new QueryHistoryNode( + label, + tooltip, + queryString, + ownerUri, + undefined, + timeStamp, + connectionLabel, + isSuccess, + ); + } + + function waitForPersistedStorageWork(): Promise { + return new Promise((resolve) => setTimeout(resolve, 50)); + } + + function createQueryRunnerStub( + queryString: string, + ): ReturnType { + const queryRunner: QueryRunnerStub = { + getQueryString: sandbox.stub().returns(queryString), + }; + + return queryRunner as ReturnType; + } + + function createConnectionResult(credentials: TestConnectionCredentials): ConnectionInfo { + const connectionInfo = new ConnectionInfo(); + connectionInfo.credentials = credentials as unknown as IConnectionInfo; + return connectionInfo; + } + + async function readPersistedFileContents(): Promise { + return persistedFileContents; + } + + async function readEncryptedPersistedHistoryContent(): Promise { + const encryptedFileContents = await readPersistedFileContents(); + if (!encryptedFileContents) { + return undefined; + } + + const encryptionKey = secretValues.get(Constants.queryHistoryEncryptionKeySecretStorageKey); + if (!encryptionKey) { + return undefined; + } + + const encryptedData = JSON.parse( + new TextDecoder().decode(encryptedFileContents), + ) as EncryptedData; + + return decryptData(encryptedData, encryptionKey); + } + + async function writePersistedHistoryContent(serializedHistory: string): Promise { + let encryptionKey = secretValues.get(Constants.queryHistoryEncryptionKeySecretStorageKey); + if (!encryptionKey) { + encryptionKey = generateEncryptionKey(); + await secretStorage.store( + Constants.queryHistoryEncryptionKeySecretStorageKey, + encryptionKey, + ); + } + + persistedFileContents = new TextEncoder().encode( + JSON.stringify(encryptData(serializedHistory, encryptionKey)), + ); + } + + async function clearPersistedHistoryContent(): Promise { + persistedFileContents = undefined; + } + + async function setEncryptedPersistedHistoryContent( + serializedHistory: string, + encryptionKey: string = generateEncryptionKey(), + ): Promise { + secretValues.set(Constants.queryHistoryEncryptionKeySecretStorageKey, encryptionKey); + persistedFileContents = new TextEncoder().encode( + JSON.stringify(encryptData(serializedHistory, encryptionKey)), + ); + + return encryptionKey; + } + + async function setEncryptedPersistedHistory( + persistedData: unknown, + encryptionKey: string = generateEncryptionKey(), + ): Promise { + return setEncryptedPersistedHistoryContent(JSON.stringify(persistedData), encryptionKey); + } + + async function getPersistedHistoryPayload(): Promise { + const persistedFileContents = await readPersistedFileContents(); + expect(persistedFileContents).to.not.be.undefined; + + const encryptionKey = secretValues.get(Constants.queryHistoryEncryptionKeySecretStorageKey); + expect(encryptionKey).to.not.be.undefined; + + const encryptedData = JSON.parse( + new TextDecoder().decode(persistedFileContents), + ) as EncryptedData; + + return JSON.parse( + decryptData(encryptedData, encryptionKey!), + ) as PersistedQueryHistoryPayload; + } + + setup(() => { + sandbox = sinon.createSandbox(); + initializeIconUtils(); + secretValues = new Map(); + persistedFileContents = undefined; + + connectionManagerStub = sandbox.createStubInstance(ConnectionManager); + outputContentProviderStub = sandbox.createStubInstance(SqlOutputContentProvider); + vscodeWrapperStub = stubVscodeWrapper(sandbox); + sqlDocumentServiceStub = sandbox.createStubInstance(SqlDocumentService); + statusViewStub = sandbox.createStubInstance(StatusView); + prompterStub = sandbox.createStubInstance(CodeAdapter); + + const config = createWorkspaceConfiguration({ + [Constants.configQueryHistoryLimit]: 10, + }); + vscodeWrapperStub.getConfiguration.returns(config); + + secretStorage = { + get: sandbox + .stub<[string], Promise>() + .callsFake(async (key) => secretValues.get(key)), + store: sandbox.stub<[string, string], Promise>().callsFake(async (key, value) => { + secretValues.set(key, value); + }), + delete: sandbox.stub<[string], Promise>().callsFake(async (key) => { + secretValues.delete(key); + }), + }; + + context = { + secrets: secretStorage as unknown as vscode.SecretStorage, + subscriptions: [], + globalStorageUri: vscode.Uri.file("/query-history-tests"), + } as unknown as vscode.ExtensionContext; + + const queryHistoryProviderPrototype = + QueryHistoryProvider.prototype as unknown as QueryHistoryProviderPrivate; + readEncryptedPersistedQueryHistoryStub = sandbox.stub( + queryHistoryProviderPrototype, + "readEncryptedPersistedQueryHistory", + ) as sinon.SinonStub<[], Promise>; + readEncryptedPersistedQueryHistoryStub.callsFake(readEncryptedPersistedHistoryContent); + sandbox + .stub(queryHistoryProviderPrototype, "writePersistedQueryHistoryContent") + .callsFake(async (serializedHistory: string) => + writePersistedHistoryContent(serializedHistory), + ); + sandbox + .stub(queryHistoryProviderPrototype, "clearPersistedQueryHistoryContent") + .callsFake(async () => clearPersistedHistoryContent()); + }); + + teardown(() => { + sandbox.restore(); + persistedFileContents = undefined; + }); + + test("restores nodes from encrypted global storage", async () => { + await setEncryptedPersistedHistory({ + version: 1, + nodes: [ + { + queryString: "SELECT * FROM users", + ownerUri: "file:///test.sql", + timeStamp: new Date(2025, 0, 15, 10, 30, 0).getTime(), + connectionLabel: "(localhost|testdb)", + isSuccess: true, + }, + { + queryString: "INSERT INTO logs VALUES(1)", + ownerUri: "file:///test2.sql", + timeStamp: new Date(2025, 0, 14, 9, 0, 0).getTime(), + connectionLabel: "(localhost|master)", + isSuccess: false, + }, + ], + }); + + provider = createProvider(); + await waitForPersistedStorageWork(); + + const children = provider.getChildren(); + expect(children).to.have.lengthOf(2); + expect((children[0] as QueryHistoryNode).queryString).to.equal("SELECT * FROM users"); + expect((children[1] as QueryHistoryNode).queryString).to.equal( + "INSERT INTO logs VALUES(1)", + ); + expect(secretStorage.store).to.not.have.been.called; + }); + + test("restores persisted credentials", async () => { + await setEncryptedPersistedHistory({ + version: 1, + nodes: [ + { + queryString: "SELECT 1", + ownerUri: "file:///test.sql", + credentials: { + server: "localhost", + database: "master", + authenticationType: Constants.sqlAuthentication, + user: "sa", + password: "example-value", + savePassword: true, + }, + timeStamp: new Date(2025, 0, 15, 10, 30, 0).getTime(), + connectionLabel: "(localhost|master) : sa", + isSuccess: true, + }, + ], + }); + + provider = createProvider(); + await waitForPersistedStorageWork(); + + const node = provider.getChildren()[0] as QueryHistoryNode; + expect(node.credentials).to.deep.include({ + server: "localhost", + database: "master", + user: "sa", + password: "example-value", + }); + }); + + test("does not overwrite newer history when restore finishes later", async () => { + const persistedHistory = { + version: 1, + nodes: [ + { + queryString: "restored query", + ownerUri: "file:///restored.sql", + timeStamp: new Date(2025, 0, 10).getTime(), + connectionLabel: "(localhost|restoreddb)", + isSuccess: true, + }, + ], + }; + + let resolveStoredHistory: ((value: string | undefined) => void) | undefined; + readEncryptedPersistedQueryHistoryStub.resetBehavior(); + readEncryptedPersistedQueryHistoryStub.callsFake( + () => + new Promise((resolve) => { + resolveStoredHistory = resolve; + }), + ); + + provider = createProvider(); + + outputContentProviderStub.getQueryRunner.returns(createQueryRunnerStub("fresh query")); + connectionManagerStub.getConnectionInfo.returns( + createConnectionResult({ + server: "localhost", + database: "master", + authenticationType: Constants.sqlAuthentication, + user: "sa", + }), + ); + + provider.refresh("file:///fresh.sql", new Date(2025, 0, 20), false); + + expect(resolveStoredHistory).to.not.be.undefined; + resolveStoredHistory?.(JSON.stringify(persistedHistory)); + await waitForPersistedStorageWork(); + + const node = provider.getChildren()[0] as QueryHistoryNode; + expect(node.queryString).to.equal("fresh query"); + expect(node.ownerUri).to.equal("file:///fresh.sql"); + }); + + test("shows EmptyHistoryNode when encrypted storage has invalid JSON", async () => { + await setEncryptedPersistedHistoryContent("not valid json{{{"); + + provider = createProvider(); + await waitForPersistedStorageWork(); + + const children = provider.getChildren(); + expect(children).to.have.lengthOf(1); + expect(children[0]).to.be.instanceOf(EmptyHistoryNode); + }); + + test("stores history nodes in encrypted global storage", async () => { + provider = createProvider(); + await waitForPersistedStorageWork(); + + connectionManagerStub.getConnectionInfo.returns( + createConnectionResult({ + server: "localhost", + database: "master", + authenticationType: Constants.sqlAuthentication, + user: "sa", + password: "example-value", + savePassword: true, + }), + ); + outputContentProviderStub.getQueryRunner.returns(createQueryRunnerStub("SELECT 1")); + + provider.refresh("file:///test.sql", new Date(2025, 0, 15), false); + await waitForPersistedStorageWork(); + + expect(secretStorage.store).to.have.been.calledOnceWithExactly( + Constants.queryHistoryEncryptionKeySecretStorageKey, + sinon.match.string, + ); + + const payload = await getPersistedHistoryPayload(); + expect(payload.version).to.equal(1); + expect(payload.nodes).to.have.lengthOf(1); + expect(payload.nodes[0].queryString).to.equal("SELECT 1"); + expect(payload.nodes[0].connectionLabel).to.contain("localhost"); + expect(payload.nodes[0].credentials).to.deep.include({ + server: "localhost", + database: "master", + user: "sa", + password: "example-value", + }); + }); + + test("does not persist password when savePassword is false", async () => { + provider = createProvider(); + await waitForPersistedStorageWork(); + + connectionManagerStub.getConnectionInfo.returns( + createConnectionResult({ + server: "localhost", + database: "master", + authenticationType: Constants.sqlAuthentication, + user: "sa", + password: "example-value", + savePassword: false, + }), + ); + outputContentProviderStub.getQueryRunner.returns(createQueryRunnerStub("SELECT 1")); + + provider.refresh("file:///test.sql", new Date(2025, 0, 15), false); + await waitForPersistedStorageWork(); + + const payload = await getPersistedHistoryPayload(); + expect(payload.nodes[0].credentials).to.deep.include({ + server: "localhost", + database: "master", + user: "sa", + password: "", + }); + }); + + test("uses configured query history limit without truncating query length", async () => { + provider = createProvider(); + await waitForPersistedStorageWork(); + + const longQuery = "A".repeat(25000); + const queryHistoryProvider = provider as unknown as { + _queryHistoryNodes: Array; + persistQueryHistory: () => Promise; + }; + + queryHistoryProvider._queryHistoryNodes = [ + createTestNode(longQuery, "(localhost|db0)", new Date(2025, 0, 1)), + ...Array.from({ length: 259 }, (_, index) => + createTestNode( + `SELECT ${index + 1}`, + `(localhost|db${index + 1})`, + new Date(2025, 0, 1, 0, index + 1), + ), + ), + ]; + + await queryHistoryProvider.persistQueryHistory(); + + const payload = await getPersistedHistoryPayload(); + expect(payload.nodes).to.have.lengthOf(10); + expect(payload.nodes[0].queryString).to.equal(longQuery); + }); + + test("clears persisted file when no history nodes remain", async () => { + provider = createProvider(); + await waitForPersistedStorageWork(); + + connectionManagerStub.getConnectionInfo.returns( + createConnectionResult({ + server: "localhost", + database: "master", + authenticationType: Constants.sqlAuthentication, + user: "sa", + }), + ); + outputContentProviderStub.getQueryRunner.returns(createQueryRunnerStub("SELECT 1")); + + provider.refresh("file:///test.sql", new Date(2025, 0, 15), false); + await waitForPersistedStorageWork(); + expect(await readPersistedFileContents()).to.not.be.undefined; + + provider.clearAll(); + await waitForPersistedStorageWork(); + + expect(await readPersistedFileContents()).to.be.undefined; + }); +});