diff --git a/macai.xcodeproj/project.pbxproj b/macai.xcodeproj/project.pbxproj index 7c2e054..52b46d4 100644 --- a/macai.xcodeproj/project.pbxproj +++ b/macai.xcodeproj/project.pbxproj @@ -11,6 +11,10 @@ 1001F81047D94F5484238553 /* APIServiceTemplateAddView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5774456B942348B88C536014 /* APIServiceTemplateAddView.swift */; }; 16E55B88D2D449948BD6C60F /* APIServiceTemplate.swift in Sources */ = {isa = PBXBuildFile; fileRef = 62B9230773084F7E92FFB3AD /* APIServiceTemplate.swift */; }; 2DEB5B6E4EFE4D8C90FBA248 /* MacaiTextField.swift in Sources */ = {isa = PBXBuildFile; fileRef = 04F05D023BB34EF796C7F7D3 /* MacaiTextField.swift */; }; + 3B03E8042EF4FA190067D667 /* VertexCommon.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3B03E8052EF4FA190067D667 /* VertexCommon.swift */; }; + 3B03E8062EF4FA1A0067D667 /* VertexGeminiHandler.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3B03E8072EF4FA1A0067D667 /* VertexGeminiHandler.swift */; }; + 3B03E8082EF4FA1B0067D667 /* VertexClaudeHandler.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3B03E8092EF4FA1B0067D667 /* VertexClaudeHandler.swift */; }; + 3B924AFB2EF5B6F500FA7608 /* ADCCredentialsAccess.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3B924AFA2EF5B6F500FA7608 /* ADCCredentialsAccess.swift */; }; 4A0BCBBC2D5199140033AB96 /* ButtonWithStatusIndicator.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4A0BCBBB2D5199140033AB96 /* ButtonWithStatusIndicator.swift */; }; 4A0E02852D0F571000D2FAF3 /* PerplexityHandler.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4A0E02842D0F571000D2FAF3 /* PerplexityHandler.swift */; }; 4A1614112D430AE300A1EF8D /* ThinkingProcessView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4A1614102D430AE300A1EF8D /* ThinkingProcessView.swift */; }; @@ -89,6 +93,7 @@ 4AE3CEB72C93A88D00A9CF4C /* TabAPIServicesView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AE3CEB62C93A86F00A9CF4C /* TabAPIServicesView.swift */; }; 4AE3CEB92C93AEA800A9CF4C /* APIServiceDetailView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AE3CEB82C93AEA200A9CF4C /* APIServiceDetailView.swift */; }; 4AE8691B2D5696A300E3B3AC /* GeminiHandler.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AE8691A2D56969700E3B3AC /* GeminiHandler.swift */; }; + 4AE8691C2D5696A400E3B3AD /* GeminiHandlerBase.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AE8691D2D5696A500E3B3AE /* GeminiHandlerBase.swift */; }; 4AEB7E1B2C8D04A70004818C /* DatabasePatcher.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AEB7E1A2C8D04A30004818C /* DatabasePatcher.swift */; }; 4AFAC2022EF496C3004B125E /* SettingsIndicators.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AFAC2012EF496C3004B125E /* SettingsIndicators.swift */; }; 4AFDEB7C2AFF90C000BA8642 /* TabBackupRestoreView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AFDEB7B2AFF90C000BA8642 /* TabBackupRestoreView.swift */; }; @@ -139,6 +144,10 @@ 14461D2B9E6A48DEB427E5F1 /* NotificationPresenter.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = NotificationPresenter.swift; sourceTree = ""; }; 16D75E6599084CC8A6C8587C /* APIServiceTemplateProviderTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = APIServiceTemplateProviderTests.swift; sourceTree = ""; }; 25B74289407745C983DD1294 /* APIServiceTemplateAddViewModel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = APIServiceTemplateAddViewModel.swift; sourceTree = ""; }; + 3B03E8052EF4FA190067D667 /* VertexCommon.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = VertexCommon.swift; sourceTree = ""; }; + 3B03E8072EF4FA1A0067D667 /* VertexGeminiHandler.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = VertexGeminiHandler.swift; sourceTree = ""; }; + 3B03E8092EF4FA1B0067D667 /* VertexClaudeHandler.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = VertexClaudeHandler.swift; sourceTree = ""; }; + 3B924AFA2EF5B6F500FA7608 /* ADCCredentialsAccess.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ADCCredentialsAccess.swift; sourceTree = ""; }; 3EA7767A135E47FFAB9E3289 /* APIServiceTemplateProvider.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = APIServiceTemplateProvider.swift; sourceTree = ""; }; 4A0BCBBB2D5199140033AB96 /* ButtonWithStatusIndicator.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ButtonWithStatusIndicator.swift; sourceTree = ""; }; 4A0E02842D0F571000D2FAF3 /* PerplexityHandler.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PerplexityHandler.swift; sourceTree = ""; }; @@ -218,6 +227,7 @@ 4AE3CEB62C93A86F00A9CF4C /* TabAPIServicesView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TabAPIServicesView.swift; sourceTree = ""; }; 4AE3CEB82C93AEA200A9CF4C /* APIServiceDetailView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = APIServiceDetailView.swift; sourceTree = ""; }; 4AE8691A2D56969700E3B3AC /* GeminiHandler.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GeminiHandler.swift; sourceTree = ""; }; + 4AE8691D2D5696A500E3B3AE /* GeminiHandlerBase.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GeminiHandlerBase.swift; sourceTree = ""; }; 4AEB7E1A2C8D04A30004818C /* DatabasePatcher.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = DatabasePatcher.swift; sourceTree = ""; }; 4AFAC2012EF496C3004B125E /* SettingsIndicators.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SettingsIndicators.swift; sourceTree = ""; }; 4AFDEB7B2AFF90C000BA8642 /* TabBackupRestoreView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TabBackupRestoreView.swift; sourceTree = ""; }; @@ -524,6 +534,7 @@ 4AB4B0902D778B1D00CA41B6 /* OpenRouterHandler.swift */, 4AB4B08C2D75099E00CA41B6 /* DeepseekHandler.swift */, 4AE8691A2D56969700E3B3AC /* GeminiHandler.swift */, + 4AE8691D2D5696A500E3B3AE /* GeminiHandlerBase.swift */, 4A0E02842D0F571000D2FAF3 /* PerplexityHandler.swift */, 4AADF71E2C9CE780001B7A26 /* ClaudeHandler.swift */, 4AFEE7202C472F53006F99FB /* APIProtocol.swift */, @@ -531,8 +542,12 @@ B2E47C1D2CF90E5C00ABCD12 /* OpenAIResponsesHandler.swift */, 4AFEE7222C472FA1006F99FB /* ChatGPTHandler.swift */, 4AFEE7242C472FB1006F99FB /* OllamaHandler.swift */, + 3B03E8052EF4FA190067D667 /* VertexCommon.swift */, + 3B03E8072EF4FA1A0067D667 /* VertexGeminiHandler.swift */, + 3B03E8092EF4FA1B0067D667 /* VertexClaudeHandler.swift */, 4AA6EF582C565F90003B6D41 /* APIServiceConfig.swift */, 4AA6EF5A2C565FDF003B6D41 /* APIServiceFactory.swift */, + 3B924AFA2EF5B6F500FA7608 /* ADCCredentialsAccess.swift */, ); path = APIHandlers; sourceTree = ""; @@ -707,6 +722,7 @@ 4A399C9B29BC8FEF00E98796 /* ContentView.swift in Sources */, 4A885C522CEA3B130018BBED /* HTMLPreviewView.swift in Sources */, 4A55DCA729C264BC00A3800C /* TableView.swift in Sources */, + 3B924AFB2EF5B6F500FA7608 /* ADCCredentialsAccess.swift in Sources */, 4A28424B2C87018C00E5C920 /* TokenManager.swift in Sources */, 0BA0FF405FC14D3FABF82621 /* APIServiceTemplateProvider.swift in Sources */, 4ACE2FB82D61656000DEEA6D /* SwipeModifier.swift in Sources */, @@ -762,9 +778,13 @@ 4A8A78C82D4C41A00057A6CC /* TabGeneralSettingsView.swift in Sources */, 4A1E16852E16DB4500E5CBA3 /* ChatInputView.swift in Sources */, 4A1E16862E16DB4500E5CBA3 /* ChatMessagesView.swift in Sources */, + 3B03E8042EF4FA190067D667 /* VertexCommon.swift in Sources */, + 3B03E8062EF4FA1A0067D667 /* VertexGeminiHandler.swift in Sources */, + 3B03E8082EF4FA1B0067D667 /* VertexClaudeHandler.swift in Sources */, 4A1E16872E16DB4600E5CBA3 /* ChatLogicHandler.swift in Sources */, 4AFEE7252C472FB1006F99FB /* OllamaHandler.swift in Sources */, 4AE8691B2D5696A300E3B3AC /* GeminiHandler.swift in Sources */, + 4AE8691C2D5696A400E3B3AD /* GeminiHandlerBase.swift in Sources */, ); runOnlyForDeploymentPostprocessing = 0; }; diff --git a/macai/Assets.xcassets/logo_vertex.imageset/Contents.json b/macai/Assets.xcassets/logo_vertex.imageset/Contents.json new file mode 100644 index 0000000..143ffa0 --- /dev/null +++ b/macai/Assets.xcassets/logo_vertex.imageset/Contents.json @@ -0,0 +1,24 @@ +{ + "images" : [ + { + "filename" : "gemini_12.svg", + "idiom" : "universal", + "scale" : "1x" + }, + { + "filename" : "gemini_24.svg", + "idiom" : "universal", + "scale" : "2x" + }, + { + "filename" : "gemini_36.svg", + "idiom" : "universal", + "scale" : "3x" + } + ], + "info" : { + "author" : "xcode", + "version" : 1 + } +} + diff --git a/macai/Assets.xcassets/logo_vertex.imageset/gemini_12.svg b/macai/Assets.xcassets/logo_vertex.imageset/gemini_12.svg new file mode 100644 index 0000000..ba830ca --- /dev/null +++ b/macai/Assets.xcassets/logo_vertex.imageset/gemini_12.svg @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/macai/Assets.xcassets/logo_vertex.imageset/gemini_24.svg b/macai/Assets.xcassets/logo_vertex.imageset/gemini_24.svg new file mode 100644 index 0000000..77258ef --- /dev/null +++ b/macai/Assets.xcassets/logo_vertex.imageset/gemini_24.svg @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/macai/Assets.xcassets/logo_vertex.imageset/gemini_36.svg b/macai/Assets.xcassets/logo_vertex.imageset/gemini_36.svg new file mode 100644 index 0000000..8c9eee9 --- /dev/null +++ b/macai/Assets.xcassets/logo_vertex.imageset/gemini_36.svg @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/macai/Configuration/AppConstants.swift b/macai/Configuration/AppConstants.swift index dd46f93..96f18a0 100644 --- a/macai/Configuration/AppConstants.swift +++ b/macai/Configuration/AppConstants.swift @@ -180,7 +180,7 @@ struct AppConstants { apiModelRef: "https://platform.openai.com/docs/models", defaultModel: "gpt-5", models: [ - "gpt-5", + "gpt-5" ], imageUploadsSupported: true, imageGenerationSupported: true @@ -303,11 +303,36 @@ struct AppConstants { "deepseek/deepseek-r1:free", ] ), + "vertex": defaultApiConfiguration( + name: "Google Vertex AI", + url: "https://us-central1-aiplatform.googleapis.com/v1", + apiKeyRef: "", + apiModelRef: "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models", + defaultModel: "gemini-2.5-pro", + models: [ + "gemini-2.5-pro", + "gemini-2.5-flash", + "gemini-2.0-flash", + "gemini-1.5-pro", + "gemini-1.5-flash", + "claude-sonnet-4-5@20250929", + "claude-3-5-sonnet-v2@20241022", + "claude-3-opus@20240229", + "claude-3-haiku@20240307", + ], + modelsFetching: false, + imageUploadsSupported: true, + imageGenerationSupported: true, + autoEnableImageGenerationModels: ["gemini-2.0-flash-image-generation"] + ), ] static let apiTypes = [ "openai-responses", "chatgpt", "ollama", "claude", "xai", "gemini", "perplexity", "deepseek", "openrouter", + "vertex", ] + + static let defaultGcpRegion = "us-central1" static let newChatNotification = Notification.Name("newChatNotification") static let largeMessageSymbolsThreshold = 25000 static let thumbnailSize: CGFloat = 300 diff --git a/macai/Store/macaiDataModel.xcdatamodeld/macaiDataModel 2.xcdatamodel/contents b/macai/Store/macaiDataModel.xcdatamodeld/macaiDataModel 2.xcdatamodel/contents index e197045..307203f 100644 --- a/macai/Store/macaiDataModel.xcdatamodeld/macaiDataModel 2.xcdatamodel/contents +++ b/macai/Store/macaiDataModel.xcdatamodeld/macaiDataModel 2.xcdatamodel/contents @@ -15,6 +15,8 @@ + + diff --git a/macai/Store/macaiDataModel.xcdatamodeld/macaiDataModel 3.xcdatamodel/contents b/macai/Store/macaiDataModel.xcdatamodeld/macaiDataModel 3.xcdatamodel/contents index 5d27e21..667da3f 100644 --- a/macai/Store/macaiDataModel.xcdatamodeld/macaiDataModel 3.xcdatamodel/contents +++ b/macai/Store/macaiDataModel.xcdatamodeld/macaiDataModel 3.xcdatamodel/contents @@ -15,6 +15,8 @@ + + diff --git a/macai/Store/macaiDataModel.xcdatamodeld/macaiDataModel.xcdatamodel/contents b/macai/Store/macaiDataModel.xcdatamodeld/macaiDataModel.xcdatamodel/contents index 3a32066..b24e36c 100644 --- a/macai/Store/macaiDataModel.xcdatamodeld/macaiDataModel.xcdatamodel/contents +++ b/macai/Store/macaiDataModel.xcdatamodeld/macaiDataModel.xcdatamodel/contents @@ -15,6 +15,8 @@ + + diff --git a/macai/UI/Chat/ChatViewModel.swift b/macai/UI/Chat/ChatViewModel.swift index 162d8f7..cbf393f 100644 --- a/macai/UI/Chat/ChatViewModel.swift +++ b/macai/UI/Chat/ChatViewModel.swift @@ -248,7 +248,9 @@ class ChatViewModel: NSObject, ObservableObject, NSFetchedResultsControllerDeleg name: getApiServiceName(), apiUrl: apiServiceUrl, apiKey: apiKey, - model: chat.gptModel + model: chat.gptModel, + gcpProjectId: apiService.gcpProjectId, + gcpRegion: apiService.gcpRegion ) } diff --git a/macai/UI/Preferences/TabAPIServices/APIServiceDetailView.swift b/macai/UI/Preferences/TabAPIServices/APIServiceDetailView.swift index 806bac4..9ec4230 100644 --- a/macai/UI/Preferences/TabAPIServices/APIServiceDetailView.swift +++ b/macai/UI/Preferences/TabAPIServices/APIServiceDetailView.swift @@ -69,21 +69,82 @@ struct APIServiceDetailView: View { } .padding(.bottom, 8) - HStack { - Text("API URL:") - .frame(width: 100, alignment: .leading) + if viewModel.isVertexAI { + // Vertex AI URL is computed from project ID and region + HStack { + Text("API URL:") + .frame(width: 100, alignment: .leading) - TextField("Paste your URL here", text: $viewModel.url) - .textFieldStyle(RoundedBorderTextFieldStyle()) + Text(viewModel.url) + .foregroundColor(.secondary) + .lineLimit(1) + .truncationMode(.middle) + .frame(maxWidth: .infinity, alignment: .leading) + .padding(.horizontal, 6) + .padding(.vertical, 4) + .background(Color(NSColor.controlBackgroundColor)) + .cornerRadius(5) + .overlay( + RoundedRectangle(cornerRadius: 5) + .stroke(Color(NSColor.separatorColor), lineWidth: 1) + ) + } + } + else { + HStack { + Text("API URL:") + .frame(width: 100, alignment: .leading) - Button(action: { - viewModel.url = viewModel.defaultApiConfiguration!.url - }) { - Text("Default") + TextField("Paste your URL here", text: $viewModel.url) + .textFieldStyle(RoundedBorderTextFieldStyle()) + + Button(action: { + viewModel.url = viewModel.defaultApiConfiguration!.url + }) { + Text("Default") + } } } - if (viewModel.defaultApiConfiguration?.apiKeyRef ?? "") != "" { + if viewModel.isVertexAI { + // Vertex AI specific fields + HStack { + Text("Project ID:") + .frame(width: 100, alignment: .leading) + + TextField("your-gcp-project-id", text: $viewModel.gcpProjectId) + .textFieldStyle(RoundedBorderTextFieldStyle()) + } + + HStack { + Text("Region:") + .frame(width: 100, alignment: .leading) + + TextField("us-central1", text: $viewModel.gcpRegion) + .textFieldStyle(RoundedBorderTextFieldStyle()) + } + + HStack { + Spacer() + Button("Import ADC Credentials…") { + viewModel.importVertexADCCredentialsButtonTapped() + } + } + + if let adcStatus = viewModel.adcImportStatus { + Text(adcStatus) + .font(.footnote) + .foregroundColor(.secondary) + } + + HStack { + Spacer() + Text("Run `gcloud auth application-default login` to authenticate") + .font(.subheadline) + .foregroundColor(.secondary) + } + } + else if (viewModel.defaultApiConfiguration?.apiKeyRef ?? "") != "" { HStack { Text("API Key:") .frame(width: 100, alignment: .leading) @@ -172,7 +233,9 @@ struct APIServiceDetailView: View { gptModel: viewModel.model, apiUrl: viewModel.url, apiType: viewModel.type, - imageGenerationSupported: viewModel.imageGenerationSupported + imageGenerationSupported: viewModel.imageGenerationSupported, + gcpProjectId: viewModel.isVertexAI ? viewModel.gcpProjectId : nil, + gcpRegion: viewModel.isVertexAI ? viewModel.gcpRegion : nil ) } } diff --git a/macai/UI/Preferences/TabAPIServices/APIServiceDetailViewModel.swift b/macai/UI/Preferences/TabAPIServices/APIServiceDetailViewModel.swift index 0ea9a46..e1645ce 100644 --- a/macai/UI/Preferences/TabAPIServices/APIServiceDetailViewModel.swift +++ b/macai/UI/Preferences/TabAPIServices/APIServiceDetailViewModel.swift @@ -34,14 +34,18 @@ class APIServiceDetailViewModel: ObservableObject { @Published var fetchedModels: [AIModel] = [] @Published var isLoadingModels: Bool = false @Published var modelFetchError: String? = nil + @Published var gcpProjectId: String = "" + @Published var gcpRegion: String = AppConstants.defaultGcpRegion + + @Published var adcImportStatus: String? = nil init(viewContext: NSManagedObjectContext, apiService: APIServiceEntity?, preferredType: String? = nil) { self.viewContext = viewContext self.apiService = apiService if apiService == nil, - let preferredType, - let configuration = AppConstants.defaultApiConfigurations[preferredType] + let preferredType, + let configuration = AppConstants.defaultApiConfigurations[preferredType] { type = preferredType defaultApiConfiguration = configuration @@ -79,6 +83,14 @@ class APIServiceDetailViewModel: ObservableObject { print("Failed to get token: \(error.localizedDescription)") } } + + gcpProjectId = service.gcpProjectId ?? "" + gcpRegion = service.gcpRegion ?? AppConstants.defaultGcpRegion + + // For Vertex AI, rebuild URL from stored project/region + if type == "vertex" { + url = buildVertexAIUrl(projectId: gcpProjectId, region: gcpRegion) + } } else { if let config = AppConstants.defaultApiConfigurations[type] { @@ -123,6 +135,27 @@ class APIServiceDetailViewModel: ObservableObject { } } .store(in: &cancellables) + + // Update Vertex AI URL when project ID or region changes + Publishers.CombineLatest($gcpProjectId, $gcpRegion) + .sink { [weak self] projectId, region in + guard let self, self.isVertexAI else { return } + self.updateVertexAIUrl() + } + .store(in: &cancellables) + } + + private func updateVertexAIUrl() { + guard isVertexAI else { return } + url = buildVertexAIUrl(projectId: gcpProjectId, region: gcpRegion) + } + + private func buildVertexAIUrl(projectId: String, region: String) -> String { + let safeRegion = region.isEmpty ? AppConstants.defaultGcpRegion : region + if projectId.isEmpty { + return "https://\(safeRegion)-aiplatform.googleapis.com/v1/projects//locations/\(safeRegion)" + } + return "https://\(safeRegion)-aiplatform.googleapis.com/v1/projects/\(projectId)/locations/\(safeRegion)" } private func fetchModelsForService() { @@ -133,7 +166,9 @@ class APIServiceDetailViewModel: ObservableObject { name: type, apiUrl: URL(string: url)!, apiKey: apiKey, - model: "" + model: "", + gcpProjectId: gcpProjectId.isEmpty ? nil : gcpProjectId, + gcpRegion: gcpRegion.isEmpty ? nil : gcpRegion ) let apiService = APIServiceFactory.createAPIService( @@ -156,14 +191,15 @@ class APIServiceDetailViewModel: ObservableObject { } } } - + private func updateModelSelection() { let modelExists = self.availableModels.contains(self.selectedModel) - + if !modelExists && !self.selectedModel.isEmpty { self.isCustomModel = true self.selectedModel = "custom" - } else if modelExists { + } + else if modelExists { self.isCustomModel = false } } @@ -189,6 +225,8 @@ class APIServiceDetailViewModel: ObservableObject { serviceToSave.imageUploadsAllowed = imageUploadsAllowed serviceToSave.imageGenerationSupported = imageGenerationSupported serviceToSave.defaultPersona = defaultAiPersona + serviceToSave.gcpProjectId = gcpProjectId.isEmpty ? nil : gcpProjectId + serviceToSave.gcpRegion = gcpRegion.isEmpty ? nil : gcpRegion if serviceToSave.tokenIdentifier == nil || serviceToSave.tokenIdentifier?.isEmpty == true { serviceToSave.tokenIdentifier = UUID().uuidString } @@ -262,6 +300,17 @@ class APIServiceDetailViewModel: ObservableObject { using: self.defaultApiConfiguration! ) + // Set defaults when switching to vertex and update URL + if type == "vertex" { + if gcpRegion.isEmpty { + gcpRegion = AppConstants.defaultGcpRegion + } + updateVertexAIUrl() + if model.isEmpty { + model = self.defaultApiConfiguration!.defaultModel + } + } + self.hasProcessedInitialModelSelection = false fetchModelsForService() @@ -285,6 +334,15 @@ class APIServiceDetailViewModel: ObservableObject { return configSupports || imageGenerationSupported } + var isVertexAI: Bool { + return type == "vertex" + } + + var requiresApiKey: Bool { + guard let config = AppConstants.defaultApiConfigurations[type] else { return true } + return !config.apiKeyRef.isEmpty + } + private static func supportedState( for model: String, using config: AppConstants.defaultApiConfiguration @@ -292,4 +350,26 @@ class APIServiceDetailViewModel: ObservableObject { guard config.imageGenerationSupported else { return false } return config.autoEnableImageGenerationModels.contains(model) } + + @MainActor + func importVertexADCCredentials() async { + do { + let data = try await ADCCredentialsAccess.promptAndStoreBookmark() + if !data.isEmpty { + adcImportStatus = "ADC credentials imported successfully." + } + else { + adcImportStatus = "Failed to import ADC credentials: Empty data received." + } + } + catch { + adcImportStatus = "Failed to import ADC credentials: \(error.localizedDescription)" + } + } + + func importVertexADCCredentialsButtonTapped() { + Task { + await importVertexADCCredentials() + } + } } diff --git a/macai/UI/Preferences/TabAPIServices/APIServiceTemplateAddViewModel.swift b/macai/UI/Preferences/TabAPIServices/APIServiceTemplateAddViewModel.swift index 74ea1a8..09cd3f9 100644 --- a/macai/UI/Preferences/TabAPIServices/APIServiceTemplateAddViewModel.swift +++ b/macai/UI/Preferences/TabAPIServices/APIServiceTemplateAddViewModel.swift @@ -388,7 +388,9 @@ final class APIServiceTemplateAddViewModel: ObservableObject { name: type, apiUrl: apiURL, apiKey: apiKey, - model: model + model: model, + gcpProjectId: nil, + gcpRegion: nil ) let apiService = APIServiceFactory.createAPIService( config: config, diff --git a/macai/UI/Preferences/TabAPIServices/ButtonTestApiTokenAndModel.swift b/macai/UI/Preferences/TabAPIServices/ButtonTestApiTokenAndModel.swift index 7dd9b2b..281088e 100644 --- a/macai/UI/Preferences/TabAPIServices/ButtonTestApiTokenAndModel.swift +++ b/macai/UI/Preferences/TabAPIServices/ButtonTestApiTokenAndModel.swift @@ -14,6 +14,8 @@ struct ButtonTestApiTokenAndModel: View { var apiUrl: String = AppConstants.apiUrlOpenAIResponses var apiType: String = AppConstants.defaultApiType var imageGenerationSupported: Bool = false + var gcpProjectId: String? = nil + var gcpRegion: String? = nil @State var testOk: Bool = false @Environment(\.managedObjectContext) private var viewContext @@ -38,7 +40,9 @@ struct ButtonTestApiTokenAndModel: View { name: apiType, apiUrl: URL(string: apiUrl)!, apiKey: gptToken, - model: gptModel + model: gptModel, + gcpProjectId: gcpProjectId, + gcpRegion: gcpRegion ) let apiService = APIServiceFactory.createAPIService( config: config, diff --git a/macai/Utilities/APIHandlers/ADCCredentialsAccess.swift b/macai/Utilities/APIHandlers/ADCCredentialsAccess.swift new file mode 100644 index 0000000..98d681e --- /dev/null +++ b/macai/Utilities/APIHandlers/ADCCredentialsAccess.swift @@ -0,0 +1,121 @@ +import AppKit +import Foundation + +enum ADCCredentialsAccessError: Error { + case noBookmark + case userCancelled + case accessFailed + case invalidBookmark + case readFailed(String) +} + +final class ADCCredentialsAccess { + static func storedBookmarkURL() -> URL { + let fm = FileManager.default + do { + let appSupportURL = try fm.url( + for: .applicationSupportDirectory, + in: .userDomainMask, + appropriateFor: nil, + create: true + ) + let bundleId = Bundle.main.bundleIdentifier ?? "macai" + let folderURL = appSupportURL.appendingPathComponent(bundleId, isDirectory: true) + if !fm.fileExists(atPath: folderURL.path) { + try fm.createDirectory(at: folderURL, withIntermediateDirectories: true) + } + return folderURL.appendingPathComponent("adc.bookmark") + } + catch { + // fallback to ~/Library/Application Support//adc.bookmark without ensuring directory + let fallback = URL(fileURLWithPath: NSHomeDirectory()) + .appendingPathComponent("Library") + .appendingPathComponent("Application Support") + .appendingPathComponent(Bundle.main.bundleIdentifier ?? "macai") + .appendingPathComponent("adc.bookmark") + return fallback + } + } + + static func loadCredentialsData() throws -> Data { + let bookmarkURL = storedBookmarkURL() + let fm = FileManager.default + guard fm.fileExists(atPath: bookmarkURL.path) else { + throw ADCCredentialsAccessError.noBookmark + } + do { + let bookmarkData = try Data(contentsOf: bookmarkURL) + var stale = false + let fileURL = try URL( + resolvingBookmarkData: bookmarkData, + options: [.withSecurityScope, .withoutUI], + relativeTo: nil, + bookmarkDataIsStale: &stale + ) + if stale { + // re-save fresh bookmark + let freshBookmark = try fileURL.bookmarkData( + options: [.withSecurityScope], + includingResourceValuesForKeys: nil, + relativeTo: nil + ) + try freshBookmark.write(to: bookmarkURL) + } + guard fileURL.startAccessingSecurityScopedResource() else { + throw ADCCredentialsAccessError.accessFailed + } + defer { fileURL.stopAccessingSecurityScopedResource() } + do { + return try Data(contentsOf: fileURL) + } + catch { + throw ADCCredentialsAccessError.readFailed(error.localizedDescription) + } + } + catch { + if let localError = error as? ADCCredentialsAccessError { + throw localError + } + throw ADCCredentialsAccessError.invalidBookmark + } + } + + static func promptAndStoreBookmark() async throws -> Data { + return try await withCheckedThrowingContinuation { continuation in + DispatchQueue.main.async { + let panel = NSOpenPanel() + panel.allowedContentTypes = [.json] + panel.canChooseFiles = true + panel.canChooseDirectories = false + panel.allowsMultipleSelection = false + panel.title = "Select Google ADC Credentials" + panel.directoryURL = URL(fileURLWithPath: NSHomeDirectory()).appendingPathComponent(".config/gcloud") + + panel.begin { response in + guard response == .OK, let url = panel.url else { + continuation.resume(throwing: ADCCredentialsAccessError.userCancelled) + return + } + do { + let bookmarkData = try url.bookmarkData( + options: [.withSecurityScope], + includingResourceValuesForKeys: nil, + relativeTo: nil + ) + let bookmarkURL = storedBookmarkURL() + try bookmarkData.write(to: bookmarkURL) + guard url.startAccessingSecurityScopedResource() else { + throw ADCCredentialsAccessError.accessFailed + } + defer { url.stopAccessingSecurityScopedResource() } + let data = try Data(contentsOf: url) + continuation.resume(returning: data) + } + catch { + continuation.resume(throwing: ADCCredentialsAccessError.readFailed(error.localizedDescription)) + } + } + } + } + } +} diff --git a/macai/Utilities/APIHandlers/APIProtocol.swift b/macai/Utilities/APIHandlers/APIProtocol.swift index 75ad72e..f401cf1 100644 --- a/macai/Utilities/APIHandlers/APIProtocol.swift +++ b/macai/Utilities/APIHandlers/APIProtocol.swift @@ -41,6 +41,8 @@ protocol APIServiceConfiguration { var apiUrl: URL { get set } var apiKey: String { get set } var model: String { get set } + var gcpProjectId: String? { get set } + var gcpRegion: String? { get set } } struct AIModel: Codable, Identifiable { diff --git a/macai/Utilities/APIHandlers/APIServiceConfig.swift b/macai/Utilities/APIHandlers/APIServiceConfig.swift index 91bbf7d..d5e4cce 100644 --- a/macai/Utilities/APIHandlers/APIServiceConfig.swift +++ b/macai/Utilities/APIHandlers/APIServiceConfig.swift @@ -12,4 +12,6 @@ struct APIServiceConfig: APIServiceConfiguration, Codable { var apiUrl: URL var apiKey: String var model: String + var gcpProjectId: String? + var gcpRegion: String? } diff --git a/macai/Utilities/APIHandlers/APIServiceFactory.swift b/macai/Utilities/APIHandlers/APIServiceFactory.swift index a49a803..6a70461 100644 --- a/macai/Utilities/APIHandlers/APIServiceFactory.swift +++ b/macai/Utilities/APIHandlers/APIServiceFactory.swift @@ -45,6 +45,13 @@ class APIServiceFactory { return DeepseekHandler(config: config, session: session) case "openrouter": return OpenRouterHandler(config: config, session: session) + case "vertex": + // Route to appropriate handler based on model prefix + if config.model.lowercased().hasPrefix("claude") { + return VertexClaudeHandler(config: config, session: session) + } else { + return VertexGeminiHandler(config: config, session: session) + } default: fatalError("Unsupported API service: \(config.name)") } diff --git a/macai/Utilities/APIHandlers/GeminiHandler.swift b/macai/Utilities/APIHandlers/GeminiHandler.swift index 8ab4b75..a990a39 100644 --- a/macai/Utilities/APIHandlers/GeminiHandler.swift +++ b/macai/Utilities/APIHandlers/GeminiHandler.swift @@ -5,10 +5,10 @@ // Created by Renat on 07.02.2025. // -import AppKit -import CoreData import Foundation +// MARK: - Gemini-specific Models Response + private struct GeminiModelList: Decodable { let models: [GeminiModel] } @@ -21,109 +21,21 @@ private struct GeminiModel: Decodable { } } -private struct GeminiGenerateRequest: Encodable { - let contents: [GeminiContentRequest] - let systemInstruction: GeminiContentRequest? - let generationConfig: GeminiGenerationConfig? -} - -struct GeminiContentRequest: Encodable { - let role: String? - let parts: [GeminiPartRequest] -} - -struct GeminiPartRequest: Codable { - let text: String? - let inlineData: GeminiInlineData? - let thoughtSignature: String? - - init(text: String, thoughtSignature: String? = nil) { - self.text = text - self.inlineData = nil - self.thoughtSignature = thoughtSignature - } - - init(inlineData: GeminiInlineData, thoughtSignature: String? = nil) { - self.text = nil - self.inlineData = inlineData - self.thoughtSignature = thoughtSignature - } -} - -struct GeminiInlineData: Codable { - let mimeType: String - let data: String -} - -private struct GeminiGenerationConfig: Encodable { - let temperature: Float -} - -private struct GeminiGenerateResponse: Decodable { - let candidates: [GeminiCandidate]? -} - -private struct GeminiCandidate: Decodable { - let content: GeminiContentResponse? - let finishReason: String? -} - -private struct GeminiContentResponse: Decodable { - let parts: [GeminiPartResponse]? -} - -private struct GeminiPartResponse: Decodable { - let text: String? - let inlineData: GeminiInlineDataResponse? - let thoughtSignature: String? -} - -struct PartsEnvelope: Codable { - let serviceType: String - let parts: [GeminiPartRequest] -} - -private struct GeminiInlineDataResponse: Decodable { - let mimeType: String? - let data: String? -} - -private struct GeminiErrorEnvelope: Decodable { - let error: GeminiAPIError -} - -private struct GeminiAPIError: Decodable { - let code: Int? - let message: String - let status: String? -} - -private struct GeminiStreamParseResult { - let delta: String? - let finished: Bool -} +// MARK: - GeminiHandler (Public Gemini API) -class GeminiHandler: APIService { - let name: String - let baseURL: URL +class GeminiHandler: GeminiHandlerBase { private let apiKey: String - let model: String - private let session: URLSession private let modelsEndpoint: URL - private var activeDataTask: URLSessionDataTask? - private var activeStreamTask: Task? - private var lastResponseParts: [GeminiPartResponse]? - init(config: APIServiceConfiguration, session: URLSession) { - self.name = config.name - self.baseURL = config.apiUrl + override init(config: APIServiceConfiguration, session: URLSession) { self.apiKey = config.apiKey - self.model = config.model - self.session = session self.modelsEndpoint = GeminiHandler.normalizeModelsEndpoint(from: config.apiUrl) + super.init(config: config, session: session) } - func fetchModels() async throws -> [AIModel] { + // MARK: - Fetch Models + + override func fetchModels() async throws -> [AIModel] { guard var components = URLComponents(url: modelsEndpoint, resolvingAgainstBaseURL: false) else { throw APIError.unknown("Invalid Gemini models URL") } @@ -159,160 +71,9 @@ class GeminiHandler: APIService { } } - func sendMessage( - _ requestMessages: [[String: String]], - temperature: Float, - completion: @escaping (Result) -> Void - ) { - let requestResult = prepareRequest( - requestMessages: requestMessages, - temperature: temperature, - stream: false - ) - - switch requestResult { - case .failure(let error): - DispatchQueue.main.async { - completion(.failure(error)) - } - return - - case .success(let request): - activeDataTask?.cancel() - let task = session.dataTask(with: request) { data, response, error in - DispatchQueue.main.async { - self.activeDataTask = nil - let result = self.handleAPIResponse(response, data: data, error: error) - - switch result { - case .success(let payload): - guard let payload = payload else { - completion(.failure(.invalidResponse)) - return - } - - do { - let decoder = self.makeDecoder() - let response = try decoder.decode(GeminiGenerateResponse.self, from: payload) - self.lastResponseParts = response.candidates?.first?.content?.parts - var inlineDataCache: [String: String] = [:] - if let message = self.renderMessage(from: response, inlineDataCache: &inlineDataCache) { - completion(.success(message)) - } - else { - completion(.failure(.decodingFailed("Empty Gemini response"))) - } - } - catch { - if let envelope = try? self.makeDecoder().decode(GeminiErrorEnvelope.self, from: payload) { - completion(.failure(.serverError(envelope.error.message))) - } - else { - completion(.failure(.decodingFailed("Failed to decode Gemini response: \(error.localizedDescription)"))) - } - } - - case .failure(let error): - completion(.failure(error)) - } - } - } - activeDataTask = task - task.resume() - } - } - - func sendMessageStream(_ requestMessages: [[String: String]], temperature: Float) async throws - -> AsyncThrowingStream - { - let requestResult = prepareRequest( - requestMessages: requestMessages, - temperature: temperature, - stream: true - ) - - switch requestResult { - case .failure(let error): - throw error - - case .success(let request): - return AsyncThrowingStream { continuation in - let streamTask = Task { - defer { self.activeStreamTask = nil } - do { - let (stream, response) = try await session.bytes(for: request) - let responseCheck = self.handleAPIResponse(response, data: nil, error: nil) - - switch responseCheck { - case .failure(let error): - var errorData = Data() - for try await byte in stream { - errorData.append(byte) - } - - if let envelope = try? self.makeDecoder().decode(GeminiErrorEnvelope.self, from: errorData) { - continuation.finish(throwing: APIError.serverError(envelope.error.message)) - } - else { - let message = String(data: errorData, encoding: .utf8) ?? error.localizedDescription - continuation.finish(throwing: APIError.serverError(message)) - } - return - - case .success: - break - } - - let decoder = self.makeDecoder() - var aggregatedText = "" - var inlineDataCache: [String: String] = [:] - - for try await line in stream.lines { - if line.isEmpty { continue } - - if line.hasPrefix("data:") { - let index = line.index(line.startIndex, offsetBy: "data:".count) - let payload = String(line[index...]).trimmingCharacters(in: .whitespacesAndNewlines) - - if payload.isEmpty { continue } - - let result = try self.parseStreamPayload( - payload, - decoder: decoder, - aggregatedText: &aggregatedText, - inlineDataCache: &inlineDataCache - ) - - if let delta = result.delta, !delta.isEmpty { - continuation.yield(delta) - } - - if result.finished { - continuation.finish() - return - } - } - } - - continuation.finish() - } - catch let apiError as APIError { - continuation.finish(throwing: apiError) - } - catch { - continuation.finish(throwing: APIError.requestFailed(error)) - } - } - activeStreamTask?.cancel() - activeStreamTask = streamTask - continuation.onTermination = { _ in - streamTask.cancel() - } - } - } - } + // MARK: - Request Preparation - private func prepareRequest( + override func prepareRequest( requestMessages: [[String: String]], temperature: Float, stream: Bool @@ -355,55 +116,9 @@ class GeminiHandler: APIService { return .success(request) } - private func transformMessages(_ requestMessages: [[String: String]]) -> ( - GeminiContentRequest?, - [GeminiContentRequest] - ) { - var systemInstruction: GeminiContentRequest? - var contents: [GeminiContentRequest] = [] - let requiresThoughtSignatures = model.lowercased().contains("gemini-3") - - for message in requestMessages { - guard let role = message["role"], let content = message["content"] else { continue } - let storedParts = decodeStoredParts(from: message["message_parts"]) - - switch role { - case "system": - let text = Self.stripImagePlaceholders(from: content).trimmingCharacters(in: .whitespacesAndNewlines) - guard !text.isEmpty else { continue } - systemInstruction = GeminiContentRequest( - role: "system", - parts: [GeminiPartRequest(text: text)] - ) - case "assistant": - let parts = storedParts ?? buildParts(from: content, allowInlineData: true) - // For Gemini 3 models, we must include thought signatures; if missing, skip to avoid INVALID_ARGUMENT. - if requiresThoughtSignatures, !parts.contains(where: { $0.thoughtSignature != nil }) { - continue - } - guard !parts.isEmpty else { continue } - contents.append( - GeminiContentRequest( - role: "model", - parts: parts - ) - ) - default: - let parts = storedParts ?? buildParts(from: content, allowInlineData: true) - guard !parts.isEmpty else { continue } - contents.append( - GeminiContentRequest( - role: "user", - parts: parts - ) - ) - } - } + // MARK: - URL Building - return (systemInstruction, contents) - } - - private func buildRequestURL(stream: Bool) -> URL? { + override func buildRequestURL(stream: Bool) -> URL? { let action = stream ? ":streamGenerateContent" : ":generateContent" var base = modelsEndpoint.absoluteString @@ -423,111 +138,7 @@ class GeminiHandler: APIService { return components?.url } - func consumeLastResponseParts() -> [GeminiPartRequest]? { - defer { lastResponseParts = nil } - guard let parts = lastResponseParts else { return nil } - let requestParts: [GeminiPartRequest] = parts.compactMap { responsePart in - if let text = responsePart.text { - return GeminiPartRequest(text: text, thoughtSignature: responsePart.thoughtSignature) - } - if let mime = responsePart.inlineData?.mimeType, - let data = responsePart.inlineData?.data { - return GeminiPartRequest( - inlineData: GeminiInlineData(mimeType: mime, data: data), - thoughtSignature: responsePart.thoughtSignature - ) - } - return nil - } - return requestParts.isEmpty ? nil : requestParts - } - - private func handleAPIResponse(_ response: URLResponse?, data: Data?, error: Error?) -> Result { - if let error = error { - return .failure(.requestFailed(error)) - } - - guard let httpResponse = response as? HTTPURLResponse else { - return .failure(.invalidResponse) - } - - guard (200...299).contains(httpResponse.statusCode) else { - var message = "HTTP \(httpResponse.statusCode)" - - if let data = data { - if let decodedError = try? makeDecoder().decode(GeminiErrorEnvelope.self, from: data) { - message = decodedError.error.message - } - else if let raw = String(data: data, encoding: .utf8), !raw.isEmpty { - message = raw - } - } - - switch httpResponse.statusCode { - case 400: - return .failure(.serverError("Bad Request: \(message)")) - case 401, 403: - return .failure(.unauthorized) - case 429: - return .failure(.rateLimited) - case 500...599: - return .failure(.serverError("Gemini API Error: \(message)")) - default: - return .failure(.unknown(message)) - } - } - - return .success(data) - } - - private func renderMessage( - from response: GeminiGenerateResponse, - inlineDataCache: inout [String: String] - ) -> String? { - guard let candidate = response.candidates?.first, - let parts = candidate.content?.parts - else { - return nil - } - - var message = "" - - for part in parts { - if let text = part.text, !text.isEmpty { - message += text - } - - if let inline = part.inlineData, - let base64 = inline.data - { - let normalized = base64.trimmingCharacters(in: .whitespacesAndNewlines) - guard !normalized.isEmpty else { continue } - - if let cached = inlineDataCache[normalized] { - appendInlinePlaceholder(cached, to: &message) - continue - } - - if let placeholder = storeInlineImage(base64: normalized, mimeType: inline.mimeType) { - inlineDataCache[normalized] = placeholder - appendInlinePlaceholder(placeholder, to: &message) - } - } - } - - while message.hasSuffix("\n") { - message.removeLast() - } - - return message.isEmpty ? nil : message - } - - private func normalizedModelIdentifier() -> String { - if model.hasPrefix("models/") { - return model.replacingOccurrences(of: "models/", with: "") - } - return model - } + // MARK: - URL Normalization private static func normalizeModelsEndpoint(from url: URL) -> URL { var endpoint = url @@ -553,323 +164,4 @@ class GeminiHandler: APIService { return endpoint.appendingPathComponent("models") } - - private func makeDecoder() -> JSONDecoder { - let decoder = JSONDecoder() - decoder.keyDecodingStrategy = .convertFromSnakeCase - return decoder - } - - private func appendInlinePlaceholder(_ placeholder: String, to message: inout String) { - if !message.isEmpty, !message.hasSuffix("\n") { - message += "\n" - } - - message += placeholder - - if !message.hasSuffix("\n") { - message += "\n" - } - } - - private static func stripImagePlaceholders(from content: String) -> String { - let pattern = ".*?" - return content.replacingOccurrences(of: pattern, with: "", options: .regularExpression) - .trimmingCharacters(in: .whitespacesAndNewlines) - } - - private func parseStreamPayload( - _ payload: String, - decoder: JSONDecoder, - aggregatedText: inout String, - inlineDataCache: inout [String: String] - ) throws -> GeminiStreamParseResult { - if payload == "[DONE]" { - return GeminiStreamParseResult(delta: nil, finished: true) - } - - guard let data = payload.data(using: .utf8) else { - return GeminiStreamParseResult(delta: nil, finished: false) - } - - if let response = try? decoder.decode(GeminiGenerateResponse.self, from: data) { - lastResponseParts = response.candidates?.first?.content?.parts - let delta = extractDelta( - from: response, - aggregatedText: &aggregatedText, - inlineDataCache: &inlineDataCache - ) - let finished = shouldFinish(after: response) - return GeminiStreamParseResult(delta: delta, finished: finished) - } - - if let envelope = try? decoder.decode(GeminiErrorEnvelope.self, from: data) { - throw APIError.serverError(envelope.error.message) - } - - if let stringPayload = try? decoder.decode(String.self, from: data) { - let delta = mergeRawDelta(stringPayload, aggregatedText: &aggregatedText) - return GeminiStreamParseResult(delta: delta, finished: false) - } - - if let plain = String(data: data, encoding: .utf8), !plain.isEmpty { - let delta = mergeRawDelta(plain, aggregatedText: &aggregatedText) - return GeminiStreamParseResult(delta: delta, finished: false) - } - - return GeminiStreamParseResult(delta: nil, finished: false) - } - - private func buildParts(from content: String, allowInlineData: Bool) -> [GeminiPartRequest] { - var parts: [GeminiPartRequest] = [] - - let pattern = "(.*?)" - let regex = try? NSRegularExpression(pattern: pattern, options: []) - let nsString = content as NSString - let fullRange = NSRange(location: 0, length: nsString.length) - - var currentLocation = 0 - let matches = regex?.matches(in: content, options: [], range: fullRange) ?? [] - - for match in matches { - let matchRange = match.range - let textLength = matchRange.location - currentLocation - if textLength > 0 { - let textSegment = nsString.substring(with: NSRange(location: currentLocation, length: textLength)) - if !textSegment.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty { - parts.append(GeminiPartRequest(text: textSegment)) - } - } - - if allowInlineData, match.numberOfRanges > 1 { - let uuidRange = match.range(at: 1) - let uuidString = nsString.substring(with: uuidRange) - if let uuid = UUID(uuidString: uuidString), - let image = loadImageFromCoreData(uuid: uuid) - { - let base64 = image.data.base64EncodedString() - let inlineData = GeminiInlineData(mimeType: image.mimeType, data: base64) - parts.append(GeminiPartRequest(inlineData: inlineData)) - } - } - - currentLocation = matchRange.location + matchRange.length - } - - let remainingLength = nsString.length - currentLocation - if remainingLength > 0 { - let trailingText = nsString.substring(with: NSRange(location: currentLocation, length: remainingLength)) - if !trailingText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty { - parts.append(GeminiPartRequest(text: trailingText)) - } - } - - if parts.isEmpty { - let trimmed = content.trimmingCharacters(in: .whitespacesAndNewlines) - if !trimmed.isEmpty { - parts.append(GeminiPartRequest(text: content)) - } - } - - return parts - } - - private func decodeStoredParts(from base64: String?) -> [GeminiPartRequest]? { - guard let base64 = base64, let data = Data(base64Encoded: base64) else { return nil } - // Expect a vendor-tagged envelope - if let envelope = try? JSONDecoder().decode(PartsEnvelope.self, from: data), - envelope.serviceType.lowercased() == "gemini" { - return envelope.parts - } - // Legacy direct array encoding - if let parts = try? JSONDecoder().decode([GeminiPartRequest].self, from: data) { - return parts - } - return nil - } - - private func extractDelta( - from response: GeminiGenerateResponse, - aggregatedText: inout String, - inlineDataCache: inout [String: String] - ) -> String? { - guard let message = renderMessage(from: response, inlineDataCache: &inlineDataCache) else { - return nil - } - - return mergeRawDelta(message, aggregatedText: &aggregatedText) - } - - private func mergeRawDelta(_ fragment: String, aggregatedText: inout String) -> String? { - guard !fragment.isEmpty else { return nil } - - if aggregatedText.isEmpty { - aggregatedText = fragment - return fragment - } - - if fragment == aggregatedText { - return nil - } - - if fragment.hasPrefix(aggregatedText) { - let delta = String(fragment.dropFirst(aggregatedText.count)) - aggregatedText = fragment - return delta.isEmpty ? nil : delta - } - - let commonPrefix = fragment.commonPrefix(with: aggregatedText) - if !commonPrefix.isEmpty { - let delta = String(fragment.dropFirst(commonPrefix.count)) - aggregatedText = fragment - return delta.isEmpty ? nil : delta - } - - aggregatedText.append(fragment) - return fragment - } - - private func shouldFinish(after response: GeminiGenerateResponse) -> Bool { - guard let candidates = response.candidates else { return false } - for candidate in candidates { - if let reason = candidate.finishReason, isTerminalFinishReason(reason) { - return true - } - } - return false - } - - private func isTerminalFinishReason(_ reason: String) -> Bool { - switch reason.uppercased() { - case "STOP", "MAX_TOKENS", "SAFETY", "RECITATION", "BLOCKLIST", "OTHER": - return true - default: - return false - } - } - - private func storeInlineImage(base64: String, mimeType: String?) -> String? { - guard let data = Data(base64Encoded: base64) else { return nil } - let format = formatFromMimeType(mimeType) - - guard let uuid = saveImageData(data, format: format) else { - return nil - } - - return "\(uuid.uuidString)" - } - - private func saveImageData(_ data: Data, format: String) -> UUID? { - guard !data.isEmpty, let image = NSImage(data: data) else { return nil } - let uuid = UUID() - let context = PersistenceController.shared.container.viewContext - var saveSucceeded = false - - context.performAndWait { - let imageEntity = ImageEntity(context: context) - imageEntity.id = uuid - imageEntity.image = data - imageEntity.imageFormat = format - imageEntity.thumbnail = createThumbnailData(from: image) - - do { - try context.save() - saveSucceeded = true - } - catch { - print("Error saving generated image: \(error)") - context.rollback() - } - } - - return saveSucceeded ? uuid : nil - } - - private func createThumbnailData(from image: NSImage) -> Data? { - let thumbnailSize = CGFloat(AppConstants.thumbnailSize) - let originalSize = image.size - guard originalSize.width > 0, originalSize.height > 0 else { return nil } - - let aspectRatio = originalSize.width / originalSize.height - var targetSize = CGSize(width: thumbnailSize, height: thumbnailSize) - - if aspectRatio > 1 { - targetSize.height = thumbnailSize / aspectRatio - } - else { - targetSize.width = thumbnailSize * aspectRatio - } - - let thumbnail = NSImage(size: targetSize) - thumbnail.lockFocus() - NSGraphicsContext.current?.imageInterpolation = .high - image.draw(in: CGRect(origin: .zero, size: targetSize), from: .zero, operation: .copy, fraction: 1.0) - thumbnail.unlockFocus() - - guard let tiffData = thumbnail.tiffRepresentation, - let bitmap = NSBitmapImageRep(data: tiffData) - else { - return nil - } - - return bitmap.representation(using: .jpeg, properties: [.compressionFactor: 0.7]) - } - - private func formatFromMimeType(_ mimeType: String?) -> String { - switch mimeType?.lowercased() { - case "image/png": - return "png" - case "image/gif": - return "gif" - case "image/webp": - return "webp" - case "image/heic", "image/heif": - return "heic" - case "image/jpg": - return "jpeg" - default: - return "jpeg" - } - } - - private func loadImageFromCoreData(uuid: UUID) -> (data: Data, mimeType: String)? { - let viewContext = PersistenceController.shared.container.viewContext - - let fetchRequest: NSFetchRequest = ImageEntity.fetchRequest() - fetchRequest.predicate = NSPredicate(format: "id == %@", uuid as CVarArg) - fetchRequest.fetchLimit = 1 - - do { - let results = try viewContext.fetch(fetchRequest) - if let imageEntity = results.first, let imageData = imageEntity.image { - let format = imageEntity.imageFormat ?? "jpeg" - return (imageData, mimeTypeForImageFormat(format)) - } - } - catch { - print("Error fetching image from CoreData: \(error)") - } - - return nil - } - - private func mimeTypeForImageFormat(_ format: String) -> String { - switch format.lowercased() { - case "png": - return "image/png" - case "heic", "heif": - return "image/heic" - case "gif": - return "image/gif" - case "webp": - return "image/webp" - default: - return "image/jpeg" - } - } - - func cancelCurrentRequest() { - activeDataTask?.cancel() - activeStreamTask?.cancel() - } } diff --git a/macai/Utilities/APIHandlers/GeminiHandlerBase.swift b/macai/Utilities/APIHandlers/GeminiHandlerBase.swift new file mode 100644 index 0000000..b143219 --- /dev/null +++ b/macai/Utilities/APIHandlers/GeminiHandlerBase.swift @@ -0,0 +1,772 @@ +// +// GeminiHandlerBase.swift +// macai +// +// Created by Renat on 07.02.2025. +// + +import AppKit +import CoreData +import Foundation + +// MARK: - Shared Data Structures + +struct GeminiContentRequest: Encodable { + let role: String? + let parts: [GeminiPartRequest] +} + +struct GeminiPartRequest: Codable { + let text: String? + let inlineData: GeminiInlineData? + let thoughtSignature: String? + + init(text: String, thoughtSignature: String? = nil) { + self.text = text + self.inlineData = nil + self.thoughtSignature = thoughtSignature + } + + init(inlineData: GeminiInlineData, thoughtSignature: String? = nil) { + self.text = nil + self.inlineData = inlineData + self.thoughtSignature = thoughtSignature + } +} + +struct GeminiInlineData: Codable { + let mimeType: String + let data: String +} + +struct PartsEnvelope: Codable { + let serviceType: String + let parts: [GeminiPartRequest] +} + +// MARK: - Internal Response Structures + +struct GeminiGenerateRequest: Encodable { + let contents: [GeminiContentRequest] + let systemInstruction: GeminiContentRequest? + let generationConfig: GeminiGenerationConfig? +} + +struct GeminiGenerationConfig: Encodable { + let temperature: Float +} + +struct GeminiGenerateResponse: Decodable { + let candidates: [GeminiCandidate]? +} + +struct GeminiCandidate: Decodable { + let content: GeminiContentResponse? + let finishReason: String? +} + +struct GeminiContentResponse: Decodable { + let parts: [GeminiPartResponse]? +} + +struct GeminiPartResponse: Decodable { + let text: String? + let inlineData: GeminiInlineDataResponse? + let thoughtSignature: String? +} + +struct GeminiInlineDataResponse: Decodable { + let mimeType: String? + let data: String? +} + +struct GeminiErrorEnvelope: Decodable { + let error: GeminiAPIError +} + +struct GeminiAPIError: Decodable { + let code: Int? + let message: String + let status: String? +} + +struct GeminiStreamParseResult { + let delta: String? + let finished: Bool +} + +// MARK: - Base Handler Class + +class GeminiHandlerBase: APIService { + let name: String + let baseURL: URL + let model: String + let session: URLSession + var activeDataTask: URLSessionDataTask? + var activeStreamTask: Task? + var lastResponseParts: [GeminiPartResponse]? + + init(config: APIServiceConfiguration, session: URLSession) { + self.name = config.name + self.baseURL = config.apiUrl + self.model = config.model + self.session = session + } + + // MARK: - Abstract Methods (to be overridden by subclasses) + + func fetchModels() async throws -> [AIModel] { + fatalError("Subclasses must override fetchModels()") + } + + func prepareRequest( + requestMessages: [[String: String]], + temperature: Float, + stream: Bool + ) -> Result { + fatalError("Subclasses must override prepareRequest()") + } + + func buildRequestURL(stream: Bool) -> URL? { + fatalError("Subclasses must override buildRequestURL()") + } + + // MARK: - APIService Protocol Implementation + + func sendMessage( + _ requestMessages: [[String: String]], + temperature: Float, + completion: @escaping (Result) -> Void + ) { + let requestResult = prepareRequest( + requestMessages: requestMessages, + temperature: temperature, + stream: false + ) + + switch requestResult { + case .failure(let error): + DispatchQueue.main.async { + completion(.failure(error)) + } + return + + case .success(let request): + activeDataTask?.cancel() + let task = session.dataTask(with: request) { data, response, error in + DispatchQueue.main.async { + self.activeDataTask = nil + let result = self.handleAPIResponse(response, data: data, error: error) + + switch result { + case .success(let payload): + guard let payload = payload else { + completion(.failure(.invalidResponse)) + return + } + + do { + let decoder = self.makeDecoder() + let response = try decoder.decode(GeminiGenerateResponse.self, from: payload) + self.lastResponseParts = response.candidates?.first?.content?.parts + var inlineDataCache: [String: String] = [:] + if let message = self.renderMessage(from: response, inlineDataCache: &inlineDataCache) { + completion(.success(message)) + } + else { + completion(.failure(.decodingFailed("Empty Gemini response"))) + } + } + catch { + if let envelope = try? self.makeDecoder().decode(GeminiErrorEnvelope.self, from: payload) { + completion(.failure(.serverError(envelope.error.message))) + } + else { + completion(.failure(.decodingFailed("Failed to decode Gemini response: \(error.localizedDescription)"))) + } + } + + case .failure(let error): + completion(.failure(error)) + } + } + } + activeDataTask = task + task.resume() + } + } + + func sendMessageStream(_ requestMessages: [[String: String]], temperature: Float) async throws + -> AsyncThrowingStream + { + let requestResult = prepareRequest( + requestMessages: requestMessages, + temperature: temperature, + stream: true + ) + + switch requestResult { + case .failure(let error): + throw error + + case .success(let request): + return AsyncThrowingStream { continuation in + let streamTask = Task { + defer { self.activeStreamTask = nil } + do { + let (stream, response) = try await self.session.bytes(for: request) + let responseCheck = self.handleAPIResponse(response, data: nil, error: nil) + + switch responseCheck { + case .failure(let error): + var errorData = Data() + for try await byte in stream { + errorData.append(byte) + } + + if let envelope = try? self.makeDecoder().decode(GeminiErrorEnvelope.self, from: errorData) { + continuation.finish(throwing: APIError.serverError(envelope.error.message)) + } + else { + let message = String(data: errorData, encoding: .utf8) ?? error.localizedDescription + continuation.finish(throwing: APIError.serverError(message)) + } + return + + case .success: + break + } + + let decoder = self.makeDecoder() + var aggregatedText = "" + var inlineDataCache: [String: String] = [:] + + for try await line in stream.lines { + if line.isEmpty { continue } + + if line.hasPrefix("data:") { + let index = line.index(line.startIndex, offsetBy: "data:".count) + let payload = String(line[index...]).trimmingCharacters(in: .whitespacesAndNewlines) + + if payload.isEmpty { continue } + + let result = try self.parseStreamPayload( + payload, + decoder: decoder, + aggregatedText: &aggregatedText, + inlineDataCache: &inlineDataCache + ) + + if let delta = result.delta, !delta.isEmpty { + continuation.yield(delta) + } + + if result.finished { + continuation.finish() + return + } + } + } + + continuation.finish() + } + catch let apiError as APIError { + continuation.finish(throwing: apiError) + } + catch { + continuation.finish(throwing: APIError.requestFailed(error)) + } + } + activeStreamTask?.cancel() + activeStreamTask = streamTask + continuation.onTermination = { _ in + streamTask.cancel() + } + } + } + } + + func cancelCurrentRequest() { + activeDataTask?.cancel() + activeStreamTask?.cancel() + } + + // MARK: - Message Transformation + + func transformMessages(_ requestMessages: [[String: String]]) -> ( + GeminiContentRequest?, + [GeminiContentRequest] + ) { + var systemInstruction: GeminiContentRequest? + var contents: [GeminiContentRequest] = [] + let requiresThoughtSignatures = model.lowercased().contains("gemini-3") + + for message in requestMessages { + guard let role = message["role"], let content = message["content"] else { continue } + let storedParts = decodeStoredParts(from: message["message_parts"]) + + switch role { + case "system": + let text = Self.stripImagePlaceholders(from: content).trimmingCharacters(in: .whitespacesAndNewlines) + guard !text.isEmpty else { continue } + systemInstruction = GeminiContentRequest( + role: "system", + parts: [GeminiPartRequest(text: text)] + ) + case "assistant": + let parts = storedParts ?? buildParts(from: content, allowInlineData: true) + // For Gemini 3 models, we must include thought signatures; if missing, skip to avoid INVALID_ARGUMENT. + if requiresThoughtSignatures, !parts.contains(where: { $0.thoughtSignature != nil }) { + continue + } + guard !parts.isEmpty else { continue } + contents.append( + GeminiContentRequest( + role: "model", + parts: parts + ) + ) + default: + let parts = storedParts ?? buildParts(from: content, allowInlineData: true) + guard !parts.isEmpty else { continue } + contents.append( + GeminiContentRequest( + role: "user", + parts: parts + ) + ) + } + } + + return (systemInstruction, contents) + } + + func buildParts(from content: String, allowInlineData: Bool) -> [GeminiPartRequest] { + var parts: [GeminiPartRequest] = [] + + let pattern = "(.*?)" + let regex = try? NSRegularExpression(pattern: pattern, options: []) + let nsString = content as NSString + let fullRange = NSRange(location: 0, length: nsString.length) + + var currentLocation = 0 + let matches = regex?.matches(in: content, options: [], range: fullRange) ?? [] + + for match in matches { + let matchRange = match.range + let textLength = matchRange.location - currentLocation + if textLength > 0 { + let textSegment = nsString.substring(with: NSRange(location: currentLocation, length: textLength)) + if !textSegment.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty { + parts.append(GeminiPartRequest(text: textSegment)) + } + } + + if allowInlineData, match.numberOfRanges > 1 { + let uuidRange = match.range(at: 1) + let uuidString = nsString.substring(with: uuidRange) + if let uuid = UUID(uuidString: uuidString), + let image = loadImageFromCoreData(uuid: uuid) + { + let base64 = image.data.base64EncodedString() + let inlineData = GeminiInlineData(mimeType: image.mimeType, data: base64) + parts.append(GeminiPartRequest(inlineData: inlineData)) + } + } + + currentLocation = matchRange.location + matchRange.length + } + + let remainingLength = nsString.length - currentLocation + if remainingLength > 0 { + let trailingText = nsString.substring(with: NSRange(location: currentLocation, length: remainingLength)) + if !trailingText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty { + parts.append(GeminiPartRequest(text: trailingText)) + } + } + + if parts.isEmpty { + let trimmed = content.trimmingCharacters(in: .whitespacesAndNewlines) + if !trimmed.isEmpty { + parts.append(GeminiPartRequest(text: content)) + } + } + + return parts + } + + func decodeStoredParts(from base64: String?) -> [GeminiPartRequest]? { + guard let base64 = base64, let data = Data(base64Encoded: base64) else { return nil } + // Expect a vendor-tagged envelope + if let envelope = try? JSONDecoder().decode(PartsEnvelope.self, from: data), + envelope.serviceType.lowercased() == "gemini" || envelope.serviceType.lowercased() == "vertex" { + return envelope.parts + } + // Legacy direct array encoding + if let parts = try? JSONDecoder().decode([GeminiPartRequest].self, from: data) { + return parts + } + return nil + } + + static func stripImagePlaceholders(from content: String) -> String { + let pattern = ".*?" + return content.replacingOccurrences(of: pattern, with: "", options: .regularExpression) + .trimmingCharacters(in: .whitespacesAndNewlines) + } + + // MARK: - Response Handling + + func handleAPIResponse(_ response: URLResponse?, data: Data?, error: Error?) -> Result { + if let error = error { + return .failure(.requestFailed(error)) + } + + guard let httpResponse = response as? HTTPURLResponse else { + return .failure(.invalidResponse) + } + + guard (200...299).contains(httpResponse.statusCode) else { + var message = "HTTP \(httpResponse.statusCode)" + + if let data = data { + if let decodedError = try? makeDecoder().decode(GeminiErrorEnvelope.self, from: data) { + message = decodedError.error.message + } + else if let raw = String(data: data, encoding: .utf8), !raw.isEmpty { + message = raw + } + } + + switch httpResponse.statusCode { + case 400: + return .failure(.serverError("Bad Request: \(message)")) + case 401, 403: + return .failure(.unauthorized) + case 429: + return .failure(.rateLimited) + case 500...599: + return .failure(.serverError("Gemini API Error: \(message)")) + default: + return .failure(.unknown(message)) + } + } + + return .success(data) + } + + func renderMessage( + from response: GeminiGenerateResponse, + inlineDataCache: inout [String: String] + ) -> String? { + guard let candidate = response.candidates?.first, + let parts = candidate.content?.parts + else { + return nil + } + + var message = "" + + for part in parts { + if let text = part.text, !text.isEmpty { + message += text + } + + if let inline = part.inlineData, + let base64 = inline.data + { + let normalized = base64.trimmingCharacters(in: .whitespacesAndNewlines) + guard !normalized.isEmpty else { continue } + + if let cached = inlineDataCache[normalized] { + appendInlinePlaceholder(cached, to: &message) + continue + } + + if let placeholder = storeInlineImage(base64: normalized, mimeType: inline.mimeType) { + inlineDataCache[normalized] = placeholder + appendInlinePlaceholder(placeholder, to: &message) + } + } + } + + while message.hasSuffix("\n") { + message.removeLast() + } + + return message.isEmpty ? nil : message + } + + func parseStreamPayload( + _ payload: String, + decoder: JSONDecoder, + aggregatedText: inout String, + inlineDataCache: inout [String: String] + ) throws -> GeminiStreamParseResult { + if payload == "[DONE]" { + return GeminiStreamParseResult(delta: nil, finished: true) + } + + guard let data = payload.data(using: .utf8) else { + return GeminiStreamParseResult(delta: nil, finished: false) + } + + if let response = try? decoder.decode(GeminiGenerateResponse.self, from: data) { + lastResponseParts = response.candidates?.first?.content?.parts + let delta = extractDelta( + from: response, + aggregatedText: &aggregatedText, + inlineDataCache: &inlineDataCache + ) + let finished = shouldFinish(after: response) + return GeminiStreamParseResult(delta: delta, finished: finished) + } + + if let envelope = try? decoder.decode(GeminiErrorEnvelope.self, from: data) { + throw APIError.serverError(envelope.error.message) + } + + if let stringPayload = try? decoder.decode(String.self, from: data) { + let delta = mergeRawDelta(stringPayload, aggregatedText: &aggregatedText) + return GeminiStreamParseResult(delta: delta, finished: false) + } + + if let plain = String(data: data, encoding: .utf8), !plain.isEmpty { + let delta = mergeRawDelta(plain, aggregatedText: &aggregatedText) + return GeminiStreamParseResult(delta: delta, finished: false) + } + + return GeminiStreamParseResult(delta: nil, finished: false) + } + + func extractDelta( + from response: GeminiGenerateResponse, + aggregatedText: inout String, + inlineDataCache: inout [String: String] + ) -> String? { + guard let message = renderMessage(from: response, inlineDataCache: &inlineDataCache) else { + return nil + } + + return mergeRawDelta(message, aggregatedText: &aggregatedText) + } + + func mergeRawDelta(_ fragment: String, aggregatedText: inout String) -> String? { + guard !fragment.isEmpty else { return nil } + + if aggregatedText.isEmpty { + aggregatedText = fragment + return fragment + } + + if fragment == aggregatedText { + return nil + } + + if fragment.hasPrefix(aggregatedText) { + let delta = String(fragment.dropFirst(aggregatedText.count)) + aggregatedText = fragment + return delta.isEmpty ? nil : delta + } + + let commonPrefix = fragment.commonPrefix(with: aggregatedText) + if !commonPrefix.isEmpty { + let delta = String(fragment.dropFirst(commonPrefix.count)) + aggregatedText = fragment + return delta.isEmpty ? nil : delta + } + + aggregatedText.append(fragment) + return fragment + } + + func shouldFinish(after response: GeminiGenerateResponse) -> Bool { + guard let candidates = response.candidates else { return false } + for candidate in candidates { + if let reason = candidate.finishReason, isTerminalFinishReason(reason) { + return true + } + } + return false + } + + func isTerminalFinishReason(_ reason: String) -> Bool { + switch reason.uppercased() { + case "STOP", "MAX_TOKENS", "SAFETY", "RECITATION", "BLOCKLIST", "OTHER": + return true + default: + return false + } + } + + // MARK: - Response Parts + + func consumeLastResponseParts() -> [GeminiPartRequest]? { + defer { lastResponseParts = nil } + guard let parts = lastResponseParts else { return nil } + let requestParts: [GeminiPartRequest] = parts.compactMap { responsePart in + if let text = responsePart.text { + return GeminiPartRequest(text: text, thoughtSignature: responsePart.thoughtSignature) + } + if let mime = responsePart.inlineData?.mimeType, + let data = responsePart.inlineData?.data { + return GeminiPartRequest( + inlineData: GeminiInlineData(mimeType: mime, data: data), + thoughtSignature: responsePart.thoughtSignature + ) + } + return nil + } + return requestParts.isEmpty ? nil : requestParts + } + + // MARK: - Helpers + + func makeDecoder() -> JSONDecoder { + let decoder = JSONDecoder() + decoder.keyDecodingStrategy = .convertFromSnakeCase + return decoder + } + + func appendInlinePlaceholder(_ placeholder: String, to message: inout String) { + if !message.isEmpty, !message.hasSuffix("\n") { + message += "\n" + } + + message += placeholder + + if !message.hasSuffix("\n") { + message += "\n" + } + } + + func normalizedModelIdentifier() -> String { + if model.hasPrefix("models/") { + return model.replacingOccurrences(of: "models/", with: "") + } + return model + } + + // MARK: - Image Handling + + func storeInlineImage(base64: String, mimeType: String?) -> String? { + guard let data = Data(base64Encoded: base64) else { return nil } + let format = formatFromMimeType(mimeType) + + guard let uuid = saveImageData(data, format: format) else { + return nil + } + + return "\(uuid.uuidString)" + } + + func saveImageData(_ data: Data, format: String) -> UUID? { + guard !data.isEmpty, let image = NSImage(data: data) else { return nil } + let uuid = UUID() + let context = PersistenceController.shared.container.viewContext + var saveSucceeded = false + + context.performAndWait { + let imageEntity = ImageEntity(context: context) + imageEntity.id = uuid + imageEntity.image = data + imageEntity.imageFormat = format + imageEntity.thumbnail = createThumbnailData(from: image) + + do { + try context.save() + saveSucceeded = true + } + catch { + print("Error saving generated image: \(error)") + context.rollback() + } + } + + return saveSucceeded ? uuid : nil + } + + func createThumbnailData(from image: NSImage) -> Data? { + let thumbnailSize = CGFloat(AppConstants.thumbnailSize) + let originalSize = image.size + guard originalSize.width > 0, originalSize.height > 0 else { return nil } + + let aspectRatio = originalSize.width / originalSize.height + var targetSize = CGSize(width: thumbnailSize, height: thumbnailSize) + + if aspectRatio > 1 { + targetSize.height = thumbnailSize / aspectRatio + } + else { + targetSize.width = thumbnailSize * aspectRatio + } + + let thumbnail = NSImage(size: targetSize) + thumbnail.lockFocus() + NSGraphicsContext.current?.imageInterpolation = .high + image.draw(in: CGRect(origin: .zero, size: targetSize), from: .zero, operation: .copy, fraction: 1.0) + thumbnail.unlockFocus() + + guard let tiffData = thumbnail.tiffRepresentation, + let bitmap = NSBitmapImageRep(data: tiffData) + else { + return nil + } + + return bitmap.representation(using: .jpeg, properties: [.compressionFactor: 0.7]) + } + + func formatFromMimeType(_ mimeType: String?) -> String { + switch mimeType?.lowercased() { + case "image/png": + return "png" + case "image/gif": + return "gif" + case "image/webp": + return "webp" + case "image/heic", "image/heif": + return "heic" + case "image/jpg": + return "jpeg" + default: + return "jpeg" + } + } + + func loadImageFromCoreData(uuid: UUID) -> (data: Data, mimeType: String)? { + let viewContext = PersistenceController.shared.container.viewContext + + let fetchRequest: NSFetchRequest = ImageEntity.fetchRequest() + fetchRequest.predicate = NSPredicate(format: "id == %@", uuid as CVarArg) + fetchRequest.fetchLimit = 1 + + do { + let results = try viewContext.fetch(fetchRequest) + if let imageEntity = results.first, let imageData = imageEntity.image { + let format = imageEntity.imageFormat ?? "jpeg" + return (imageData, mimeTypeForImageFormat(format)) + } + } + catch { + print("Error fetching image from CoreData: \(error)") + } + + return nil + } + + func mimeTypeForImageFormat(_ format: String) -> String { + switch format.lowercased() { + case "png": + return "image/png" + case "heic", "heif": + return "image/heic" + case "gif": + return "image/gif" + case "webp": + return "image/webp" + default: + return "image/jpeg" + } + } +} + diff --git a/macai/Utilities/APIHandlers/VertexClaudeHandler.swift b/macai/Utilities/APIHandlers/VertexClaudeHandler.swift new file mode 100644 index 0000000..6625114 --- /dev/null +++ b/macai/Utilities/APIHandlers/VertexClaudeHandler.swift @@ -0,0 +1,334 @@ +// +// VertexClaudeHandler.swift +// macai +// +// Vertex AI handler for Claude models. +// + +import Foundation + +// MARK: - Vertex Claude Handler + +class VertexClaudeHandler: APIService { + let name: String + let baseURL: URL + private let model: String + private let session: URLSession + private let projectId: String + private let region: String + private var activeDataTask: Task? + private var activeStreamTask: Task? + + init(config: APIServiceConfiguration, session: URLSession) { + self.name = config.name + self.baseURL = config.apiUrl + self.model = config.model + self.session = session + self.projectId = config.gcpProjectId ?? "" + self.region = config.gcpRegion ?? "us-central1" + } + + // MARK: - Fetch Models + + func fetchModels() async throws -> [AIModel] { + // Vertex AI does not provide a stable v1 REST API endpoint for listing publisher models. + // Return the curated list of Claude models from configuration. + let allModels = AppConstants.defaultApiConfigurations["vertex"]?.models ?? [] + // Filter to only Claude models + let claudeModels = allModels.filter { $0.lowercased().hasPrefix("claude") } + return claudeModels.map { AIModel(id: $0) } + } + + // MARK: - Send Message (Non-streaming) + + func sendMessage( + _ requestMessages: [[String: String]], + temperature: Float, + completion: @escaping (Result) -> Void + ) { + activeDataTask?.cancel() + let task = Task { + defer { self.activeDataTask = nil } + do { + let request = try await prepareRequest( + requestMessages: requestMessages, + temperature: temperature, + stream: false + ) + + let (data, response) = try await session.data(for: request) + + await MainActor.run { + let result = VertexResponseHandler.handleAPIResponse(response, data: data, error: nil) + + switch result { + case .success(let payload): + guard let payload = payload else { + completion(.failure(.invalidResponse)) + return + } + + if let (messageContent, _) = self.parseClaudeJSONResponse(data: payload) { + completion(.success(messageContent)) + } + else { + // Try to decode error envelope + let decoder = JSONDecoder() + decoder.keyDecodingStrategy = .convertFromSnakeCase + if let envelope = try? decoder.decode(VertexErrorEnvelope.self, from: payload) { + completion(.failure(.serverError(envelope.error.message))) + } + else { + completion(.failure(.decodingFailed("Failed to parse Claude response"))) + } + } + + case .failure(let error): + completion(.failure(error)) + } + } + } + catch let apiError as APIError { + await MainActor.run { + completion(.failure(apiError)) + } + } + catch { + await MainActor.run { + completion(.failure(.requestFailed(error))) + } + } + } + activeDataTask = task + } + + // MARK: - Send Message Stream + + func sendMessageStream(_ requestMessages: [[String: String]], temperature: Float) async throws + -> AsyncThrowingStream + { + let request = try await prepareRequest( + requestMessages: requestMessages, + temperature: temperature, + stream: true + ) + + return AsyncThrowingStream { continuation in + let streamTask = Task { + defer { self.activeStreamTask = nil } + do { + let (stream, response) = try await self.session.bytes(for: request) + let responseCheck = VertexResponseHandler.handleAPIResponse(response, data: nil, error: nil) + + switch responseCheck { + case .failure(let error): + var errorData = Data() + for try await byte in stream { + errorData.append(byte) + } + + let decoder = JSONDecoder() + decoder.keyDecodingStrategy = .convertFromSnakeCase + if let envelope = try? decoder.decode(VertexErrorEnvelope.self, from: errorData) { + continuation.finish(throwing: APIError.serverError(envelope.error.message)) + } + else { + let message = String(data: errorData, encoding: .utf8) ?? error.localizedDescription + continuation.finish(throwing: APIError.serverError(message)) + } + return + + case .success: + break + } + + // Handle Claude streaming response + for try await line in stream.lines { + let (finished, error, content, _) = self.parseClaudeSSEEvent(line) + + if let error = error { + continuation.finish(throwing: APIError.decodingFailed(error.localizedDescription)) + break + } + + if let content = content, !content.isEmpty { + continuation.yield(content) + } + + if finished { + continuation.finish() + break + } + } + } + catch let apiError as APIError { + continuation.finish(throwing: apiError) + } + catch { + continuation.finish(throwing: APIError.requestFailed(error)) + } + } + activeStreamTask?.cancel() + activeStreamTask = streamTask + continuation.onTermination = { _ in + streamTask.cancel() + } + } + } + + // MARK: - Request Preparation + + private func prepareRequest( + requestMessages: [[String: String]], + temperature: Float, + stream: Bool + ) async throws -> URLRequest { + guard !projectId.isEmpty else { + throw APIError.unknown("GCP Project ID is required for Vertex AI") + } + + guard !region.isEmpty else { + throw APIError.unknown("GCP Region is required for Vertex AI") + } + + let accessToken = try await VertexTokenManager.shared.getAccessToken(session: session) + + guard let url = VertexURLBuilder.claudeURL( + region: region, + projectId: projectId, + model: model, + stream: stream + ) else { + throw APIError.unknown("Invalid Vertex AI Claude request URL") + } + + var request = URLRequest(url: url) + request.httpMethod = "POST" + request.timeoutInterval = AppConstants.requestTimeout + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + request.setValue("Bearer \(accessToken)", forHTTPHeaderField: "Authorization") + + if stream { + request.setValue("text/event-stream", forHTTPHeaderField: "Accept") + } + + // Use Anthropic Messages API format for Claude models + var systemMessage = "" + var updatedRequestMessages = requestMessages + + if let firstMessage = requestMessages.first, firstMessage["role"] == "system" { + systemMessage = firstMessage["content"] ?? "" + updatedRequestMessages.removeFirst() + } + + if updatedRequestMessages.isEmpty { + throw APIError.unknown("Claude request requires at least one user or assistant message") + } + + let jsonDict: [String: Any] = [ + "anthropic_version": "vertex-2023-10-16", + "messages": updatedRequestMessages, + "system": systemMessage, + "stream": stream, + "temperature": temperature, + "max_tokens": 4096, + ] + + request.httpBody = try JSONSerialization.data(withJSONObject: jsonDict, options: []) + + return request + } + + // MARK: - Claude Response Parsing + + private func parseClaudeJSONResponse(data: Data) -> (String, String)? { + do { + if let json = try JSONSerialization.jsonObject(with: data, options: []) as? [String: Any], + let role = json["role"] as? String, + let contentArray = json["content"] as? [[String: Any]] + { + let textContent = contentArray.compactMap { item -> String? in + if let type = item["type"] as? String, type == "text", + let text = item["text"] as? String + { + return text + } + return nil + }.joined(separator: "\n") + + if !textContent.isEmpty { + return (textContent, role) + } + } + } + catch { + print("Error parsing Claude JSON: \(error.localizedDescription)") + } + return nil + } + + private func parseClaudeSSEEvent(_ event: String) -> (Bool, Error?, String?, String?) { + var isFinished = false + var textContent = "" + var parseError: Error? + var jsonString: String? + + if event.hasPrefix("data: ") { + jsonString = event.replacingOccurrences(of: "data: ", with: "") + } + + guard let jsonString = jsonString else { + return (isFinished, parseError, nil, nil) + } + + guard let jsonData = jsonString.data(using: .utf8), + let json = try? JSONSerialization.jsonObject(with: jsonData, options: []) as? [String: Any] + else { + parseError = NSError( + domain: "SSEParsing", + code: 1, + userInfo: [NSLocalizedDescriptionKey: "Failed to parse JSON: \(jsonString)"] + ) + return (isFinished, parseError, nil, nil) + } + + if let eventType = json["type"] as? String { + switch eventType { + case "content_block_start": + if let contentBlock = json["content_block"] as? [String: Any], + let text = contentBlock["text"] as? String + { + textContent = text + } + case "content_block_delta": + if let delta = json["delta"] as? [String: Any], + let text = delta["text"] as? String + { + textContent += text + } + case "message_delta": + if let delta = json["delta"] as? [String: Any], + let stopReason = delta["stop_reason"] as? String + { + isFinished = stopReason == "end_turn" + } + case "message_stop": + isFinished = true + case "ping": + // Ignore ping events + break + default: + print("Unhandled Claude event type: \(eventType)") + } + } + return (isFinished, parseError, textContent.isEmpty ? nil : textContent, nil) + } + + // MARK: - Cancel + + func cancelCurrentRequest() { + activeDataTask?.cancel() + activeStreamTask?.cancel() + } +} + diff --git a/macai/Utilities/APIHandlers/VertexCommon.swift b/macai/Utilities/APIHandlers/VertexCommon.swift new file mode 100644 index 0000000..c7c57df --- /dev/null +++ b/macai/Utilities/APIHandlers/VertexCommon.swift @@ -0,0 +1,256 @@ +// +// VertexCommon.swift +// macai +// +// Shared utilities for Vertex AI handlers. +// + +import Foundation + +// MARK: - ADC Credentials + +struct ADCCredentials: Decodable { + let clientId: String + let clientSecret: String + let refreshToken: String + let type: String + + enum CodingKeys: String, CodingKey { + case clientId = "client_id" + case clientSecret = "client_secret" + case refreshToken = "refresh_token" + case type + } +} + +struct TokenResponse: Decodable { + let accessToken: String + let expiresIn: Int + let tokenType: String + + enum CodingKeys: String, CodingKey { + case accessToken = "access_token" + case expiresIn = "expires_in" + case tokenType = "token_type" + } +} + +// MARK: - Token Manager + +actor VertexTokenManager { + static let shared = VertexTokenManager() + + private var cachedToken: String? + private var tokenExpiry: Date? + + private let tokenURL = URL(string: "https://oauth2.googleapis.com/token")! + + func getAccessToken(session: URLSession) async throws -> String { + // Return cached token if still valid (with 5-minute buffer) + if let token = cachedToken, + let expiry = tokenExpiry, + Date() < expiry.addingTimeInterval(-300) + { + return token + } + + let _ = try await ensureADCCredentials() + // Read ADC credentials + let credentials = try await readADCCredentials() + + // Exchange refresh token for access token + let token = try await refreshAccessToken(credentials: credentials, session: session) + return token + } + + private func ensureADCCredentials() async throws -> Data { + do { + return try ADCCredentialsAccess.loadCredentialsData() + } + catch let error as ADCCredentialsAccessError { + switch error { + case .noBookmark: + return try await ADCCredentialsAccess.promptAndStoreBookmark() + default: + throw error + } + } + } + + private func readADCCredentials() async throws -> ADCCredentials { + do { + let data = try ADCCredentialsAccess.loadCredentialsData() + let credentials = try JSONDecoder().decode(ADCCredentials.self, from: data) + + guard credentials.type == "authorized_user" else { + throw APIError.unknown( + "ADC file is not an authorized_user type. Run 'gcloud auth application-default login'" + ) + } + + return credentials + } + catch let error as ADCCredentialsAccessError { + switch error { + case .noBookmark: + let data = try await ADCCredentialsAccess.promptAndStoreBookmark() + let credentials = try JSONDecoder().decode(ADCCredentials.self, from: data) + + guard credentials.type == "authorized_user" else { + throw APIError.unknown( + "ADC file is not an authorized_user type. Run 'gcloud auth application-default login'" + ) + } + + return credentials + default: + throw error + } + } + } + + private func refreshAccessToken(credentials: ADCCredentials, session: URLSession) async throws -> String { + var request = URLRequest(url: tokenURL) + request.httpMethod = "POST" + request.setValue("application/x-www-form-urlencoded", forHTTPHeaderField: "Content-Type") + + let body = [ + "client_id": credentials.clientId, + "client_secret": credentials.clientSecret, + "refresh_token": credentials.refreshToken, + "grant_type": "refresh_token", + ] + + request.httpBody = + body + .map { "\($0.key)=\($0.value.addingPercentEncoding(withAllowedCharacters: .urlQueryAllowed) ?? $0.value)" } + .joined(separator: "&") + .data(using: .utf8) + + let (data, response) = try await session.data(for: request) + + guard let httpResponse = response as? HTTPURLResponse else { + throw APIError.invalidResponse + } + + guard (200...299).contains(httpResponse.statusCode) else { + throw APIError.unauthorized + } + + let tokenResponse = try JSONDecoder().decode(TokenResponse.self, from: data) + + // Cache the token + cachedToken = tokenResponse.accessToken + tokenExpiry = Date().addingTimeInterval(TimeInterval(tokenResponse.expiresIn)) + + return tokenResponse.accessToken + } + + func clearCache() { + cachedToken = nil + tokenExpiry = nil + } +} + +// MARK: - Vertex AI Error Structures + +struct VertexErrorEnvelope: Decodable { + let error: VertexAPIError +} + +struct VertexAPIError: Decodable { + let code: Int? + let message: String + let status: String? +} + +// MARK: - URL Builder + +enum VertexURLBuilder { + /// Builds a Vertex AI URL for Gemini models + static func geminiURL( + region: String, + projectId: String, + model: String, + stream: Bool + ) -> URL? { + let normalizedModel = model.hasPrefix("models/") ? model.replacingOccurrences(of: "models/", with: "") : model + let action = stream ? ":streamGenerateContent" : ":generateContent" + + var urlString = + "https://\(region)-aiplatform.googleapis.com/v1/projects/\(projectId)/locations/\(region)/publishers/google/models/\(normalizedModel)\(action)" + + if stream { + urlString += "?alt=sse" + } + + return URL(string: urlString) + } + + /// Builds a Vertex AI URL for Claude models + static func claudeURL( + region: String, + projectId: String, + model: String, + stream: Bool + ) -> URL? { + let normalizedModel = model.hasPrefix("models/") ? model.replacingOccurrences(of: "models/", with: "") : model + let action = stream ? ":streamRawPredict" : ":rawPredict" + + let urlString = + "https://\(region)-aiplatform.googleapis.com/v1/projects/\(projectId)/locations/\(region)/publishers/anthropic/models/\(normalizedModel)\(action)" + + return URL(string: urlString) + } +} + +// MARK: - Shared Response Handling + +enum VertexResponseHandler { + static func handleAPIResponse( + _ response: URLResponse?, + data: Data?, + error: Error? + ) -> Result { + if let error = error { + return .failure(.requestFailed(error)) + } + + guard let httpResponse = response as? HTTPURLResponse else { + return .failure(.invalidResponse) + } + + guard (200...299).contains(httpResponse.statusCode) else { + var message = "HTTP \(httpResponse.statusCode)" + + if let data = data { + let decoder = JSONDecoder() + decoder.keyDecodingStrategy = .convertFromSnakeCase + if let decodedError = try? decoder.decode(VertexErrorEnvelope.self, from: data) { + message = decodedError.error.message + } + else if let raw = String(data: data, encoding: .utf8), !raw.isEmpty { + message = raw + } + } + + switch httpResponse.statusCode { + case 400: + return .failure(.serverError("Bad Request: \(message)")) + case 401, 403: + // Clear token cache on auth failure + Task { await VertexTokenManager.shared.clearCache() } + return .failure(.unauthorized) + case 429: + return .failure(.rateLimited) + case 500...599: + return .failure(.serverError("Vertex AI Error: \(message)")) + default: + return .failure(.unknown(message)) + } + } + + return .success(data) + } +} + diff --git a/macai/Utilities/APIHandlers/VertexGeminiHandler.swift b/macai/Utilities/APIHandlers/VertexGeminiHandler.swift new file mode 100644 index 0000000..4eed9b5 --- /dev/null +++ b/macai/Utilities/APIHandlers/VertexGeminiHandler.swift @@ -0,0 +1,278 @@ +// +// VertexGeminiHandler.swift +// macai +// +// Vertex AI handler for Gemini models. +// + +import AppKit +import CoreData +import Foundation + +// MARK: - Vertex Gemini Handler + +class VertexGeminiHandler: GeminiHandlerBase { + private let projectId: String + private let region: String + private var vertexActiveDataTask: Task? + + override init(config: APIServiceConfiguration, session: URLSession) { + self.projectId = config.gcpProjectId ?? "" + self.region = config.gcpRegion ?? "us-central1" + super.init(config: config, session: session) + } + + // MARK: - Fetch Models + + override func fetchModels() async throws -> [AIModel] { + // Vertex AI does not provide a stable v1 REST API endpoint for listing publisher models. + // Return the curated list of Gemini models from configuration. + let allModels = AppConstants.defaultApiConfigurations["vertex"]?.models ?? [] + // Filter to only Gemini models (exclude Claude models) + let geminiModels = allModels.filter { !$0.lowercased().hasPrefix("claude") } + return geminiModels.map { AIModel(id: $0) } + } + + // MARK: - Override sendMessage for async token handling + + override func sendMessage( + _ requestMessages: [[String: String]], + temperature: Float, + completion: @escaping (Result) -> Void + ) { + vertexActiveDataTask?.cancel() + let task = Task { + defer { self.vertexActiveDataTask = nil } + do { + let request = try await prepareRequestAsync( + requestMessages: requestMessages, + temperature: temperature, + stream: false + ) + + let (data, response) = try await session.data(for: request) + + await MainActor.run { + let result = VertexResponseHandler.handleAPIResponse(response, data: data, error: nil) + + switch result { + case .success(let payload): + guard let payload = payload else { + completion(.failure(.invalidResponse)) + return + } + + do { + let decoder = self.makeDecoder() + let response = try decoder.decode(GeminiGenerateResponse.self, from: payload) + self.lastResponseParts = response.candidates?.first?.content?.parts + var inlineDataCache: [String: String] = [:] + if let message = self.renderMessage(from: response, inlineDataCache: &inlineDataCache) { + completion(.success(message)) + } + else { + completion(.failure(.decodingFailed("Empty Vertex AI Gemini response"))) + } + } + catch { + if let envelope = try? self.makeDecoder().decode(VertexErrorEnvelope.self, from: payload) { + completion(.failure(.serverError(envelope.error.message))) + } + else { + completion( + .failure( + .decodingFailed( + "Failed to decode Vertex AI response: \(error.localizedDescription)" + ) + ) + ) + } + } + + case .failure(let error): + completion(.failure(error)) + } + } + } + catch let apiError as APIError { + await MainActor.run { + completion(.failure(apiError)) + } + } + catch { + await MainActor.run { + completion(.failure(.requestFailed(error))) + } + } + } + vertexActiveDataTask = task + } + + // MARK: - Override sendMessageStream for async token handling + + override func sendMessageStream(_ requestMessages: [[String: String]], temperature: Float) async throws + -> AsyncThrowingStream + { + let request = try await prepareRequestAsync( + requestMessages: requestMessages, + temperature: temperature, + stream: true + ) + + return AsyncThrowingStream { continuation in + let streamTask = Task { + defer { self.activeStreamTask = nil } + do { + let (stream, response) = try await self.session.bytes(for: request) + let responseCheck = VertexResponseHandler.handleAPIResponse(response, data: nil, error: nil) + + switch responseCheck { + case .failure(let error): + var errorData = Data() + for try await byte in stream { + errorData.append(byte) + } + + if let envelope = try? self.makeDecoder().decode(VertexErrorEnvelope.self, from: errorData) { + continuation.finish(throwing: APIError.serverError(envelope.error.message)) + } + else { + let message = String(data: errorData, encoding: .utf8) ?? error.localizedDescription + continuation.finish(throwing: APIError.serverError(message)) + } + return + + case .success: + break + } + + // Handle Gemini streaming response + let decoder = self.makeDecoder() + var aggregatedText = "" + var inlineDataCache: [String: String] = [:] + + for try await line in stream.lines { + if line.isEmpty { continue } + + if line.hasPrefix("data:") { + let index = line.index(line.startIndex, offsetBy: "data:".count) + let payload = String(line[index...]).trimmingCharacters(in: .whitespacesAndNewlines) + + if payload.isEmpty { continue } + + let result = try self.parseStreamPayload( + payload, + decoder: decoder, + aggregatedText: &aggregatedText, + inlineDataCache: &inlineDataCache + ) + + if let delta = result.delta, !delta.isEmpty { + continuation.yield(delta) + } + + if result.finished { + continuation.finish() + return + } + } + } + + continuation.finish() + } + catch let apiError as APIError { + continuation.finish(throwing: apiError) + } + catch { + continuation.finish(throwing: APIError.requestFailed(error)) + } + } + activeStreamTask?.cancel() + activeStreamTask = streamTask + continuation.onTermination = { _ in + streamTask.cancel() + } + } + } + + // MARK: - Synchronous prepareRequest (not used directly, required by base class) + + override func prepareRequest( + requestMessages: [[String: String]], + temperature: Float, + stream: Bool + ) -> Result { + // This method requires async for token retrieval, so we return an error + // The actual implementation uses prepareRequestAsync + return .failure(.unknown("Use sendMessage or sendMessageStream instead")) + } + + // MARK: - Async Request Preparation + + private func prepareRequestAsync( + requestMessages: [[String: String]], + temperature: Float, + stream: Bool + ) async throws -> URLRequest { + guard !projectId.isEmpty else { + throw APIError.unknown("GCP Project ID is required for Vertex AI") + } + + guard !region.isEmpty else { + throw APIError.unknown("GCP Region is required for Vertex AI") + } + + let accessToken = try await VertexTokenManager.shared.getAccessToken(session: session) + + guard let url = buildRequestURL(stream: stream) else { + throw APIError.unknown("Invalid Vertex AI request URL") + } + + var request = URLRequest(url: url) + request.httpMethod = "POST" + request.timeoutInterval = AppConstants.requestTimeout + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + request.setValue("Bearer \(accessToken)", forHTTPHeaderField: "Authorization") + + if stream { + request.setValue("text/event-stream", forHTTPHeaderField: "Accept") + } + + // Use Gemini format for Google models + let (systemInstruction, contents) = transformMessages(requestMessages) + + if contents.isEmpty { + throw APIError.unknown("Vertex AI request requires at least one user or assistant message") + } + + let body = GeminiGenerateRequest( + contents: contents, + systemInstruction: systemInstruction, + generationConfig: GeminiGenerationConfig(temperature: temperature) + ) + + let encoder = JSONEncoder() + request.httpBody = try encoder.encode(body) + + return request + } + + // MARK: - URL Building + + override func buildRequestURL(stream: Bool) -> URL? { + return VertexURLBuilder.geminiURL( + region: region, + projectId: projectId, + model: model, + stream: stream + ) + } + + // MARK: - Cancel + + override func cancelCurrentRequest() { + vertexActiveDataTask?.cancel() + super.cancelCurrentRequest() + } +} +