diff --git a/README.md b/README.md index 29577bb..bba0db2 100644 --- a/README.md +++ b/README.md @@ -96,7 +96,7 @@ Without oh-my-zsh: ; Primary service configuration ; Set 'service' to match one of the defined sections below. [service] -service = groq_service +service = azure_openai_service ; Example configuration for a self-hosted Ollama service. [my_ollama] @@ -124,6 +124,14 @@ model = gemma2-9b-it api_type = mistral api_key = model = mistral-small-latest + +; Azure Open AI configuration +; Provide the 'api_key' and 'endpoint'. +[azure_openai_service] +api_type = azure_openai +api_key = api_key +endpoint = endpoint +deployment = deployment_name ``` In this configuration file, you can define multiple services with their own configurations. The required and optional parameters of the `api_type` are specified in `services/sevices.py`. Choose which service to use in the `[service]` section. diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..ec838c5 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +openai diff --git a/services/services.py b/services/services.py index aa54991..8ceaf06 100644 --- a/services/services.py +++ b/services/services.py @@ -61,6 +61,50 @@ def get_completion(self, full_command: str) -> str: return response.choices[0].message.content +class AzureOpenAIClient(BaseClient): + """ + config keys: + - api_type="azure_openai" + - api_key (required) + - endpoint (required): The Azure OpenAI endpoint. + - deployment (optional): defaults to "gpt-4o-mini" + - api_verion (optional): deafults to "2024-12-01-preview" + - temperature (optional): defaults to 1.0. + """ + + api_type = "azure_openai" + default_deployment = os.getenv("AZUREOPENAI_DEFAULT_DEPLOYMENT", "gpt-4o-mini") + default_api_version = "2024-12-01-preview" + + def __init__(self, config: dict): + try: + from openai import AzureOpenAI + except ImportError: + print( + "OpenAI library is not installed. Please install it using 'pip install openai'" + ) + sys.exit(1) + + self.config = config + self.config["deployment"] = self.config.get("deployment", self.default_deployment) + self.config["api_version"] = self.config.get("api_version", self.default_api_version) + self.client = AzureOpenAI( + api_version=self.config["api_version"], + azure_endpoint=self.config["endpoint"], + api_key=self.config["api_key"], + ) + + def get_completion(self, full_command: str) -> str: + response = self.client.chat.completions.create( + model=self.config["deployment"], + messages=[ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": full_command}, + ], + temperature=float(self.config.get("temperature", 1.0)), + ) + return response.choices[0].message.content + class GoogleGenAIClient(BaseClient): """ config keys: @@ -239,7 +283,7 @@ def get_completion(self, full_command: str) -> str: class ClientFactory: - api_types = [OpenAIClient.api_type, GoogleGenAIClient.api_type, GroqClient.api_type, MistralClient.api_type, AmazonBedrock.api_type] + api_types = [OpenAIClient.api_type, GoogleGenAIClient.api_type, GroqClient.api_type, MistralClient.api_type, AmazonBedrock.api_type, AzureOpenAIClient.api_type] @classmethod def create(cls): @@ -263,6 +307,8 @@ def create(cls): return MistralClient(config) case AmazonBedrock.api_type: return AmazonBedrock(config) + case AzureOpenAIClient.api_type: + return AzureOpenAIClient(config) case _: raise KeyError( f"Specified API type {api_type} is not one of the supported services {cls.api_types}"