From 268fbe76be8c873389da0efce22048e9dcf29930 Mon Sep 17 00:00:00 2001 From: Dwayne Thomson Date: Thu, 5 Oct 2023 17:19:19 -0400 Subject: [PATCH 1/5] support for Azure OpenAI --- promptimize/prompt_cases.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/promptimize/prompt_cases.py b/promptimize/prompt_cases.py index 6f61889..0967009 100644 --- a/promptimize/prompt_cases.py +++ b/promptimize/prompt_cases.py @@ -1,7 +1,7 @@ import os from typing import Any, Callable, List, Optional, Union -from langchain.llms import OpenAI +from langchain.llms import OpenAI, AzureOpenAI from langchain.callbacks import get_openai_callback from box import Box @@ -69,6 +69,8 @@ def get_prompt_executor(self): model_name = os.environ.get("OPENAI_MODEL") or "text-davinci-003" openai_api_key = os.environ.get("OPENAI_API_KEY") self.prompt_executor_kwargs = {"model_name": model_name} + if os.environ.get('OPENAI_API_TYPE') == 'azure': + return AzureOpenAI(model_name=model_name, deployment_name=os.environ.get('AZURE_DEPLOYMENT_NAME')) return OpenAI(model_name=model_name, openai_api_key=openai_api_key) def execute_prompt(self, prompt_str): From efd7ee9398f28977ec3d54245e7d648cbe657505 Mon Sep 17 00:00:00 2001 From: Dwayne Thomson Date: Thu, 5 Oct 2023 17:26:15 -0400 Subject: [PATCH 2/5] added validation --- promptimize/prompt_cases.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/promptimize/prompt_cases.py b/promptimize/prompt_cases.py index 0967009..91f0f20 100644 --- a/promptimize/prompt_cases.py +++ b/promptimize/prompt_cases.py @@ -70,6 +70,8 @@ def get_prompt_executor(self): openai_api_key = os.environ.get("OPENAI_API_KEY") self.prompt_executor_kwargs = {"model_name": model_name} if os.environ.get('OPENAI_API_TYPE') == 'azure': + if not os.environ.get("AZURE_DEPLOYMENT_NAME"): + raise Exception("Environment variable with key name 'AZURE_DEPLOYMENT_NAME' is required when OPEN_API_TYPE=='azure'.") return AzureOpenAI(model_name=model_name, deployment_name=os.environ.get('AZURE_DEPLOYMENT_NAME')) return OpenAI(model_name=model_name, openai_api_key=openai_api_key) From 1bd8682d71d58a4d4001733615a4cf67becf5da5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Oct 2023 21:40:39 +0000 Subject: [PATCH 3/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- promptimize/prompt_cases.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/promptimize/prompt_cases.py b/promptimize/prompt_cases.py index 91f0f20..f2b5059 100644 --- a/promptimize/prompt_cases.py +++ b/promptimize/prompt_cases.py @@ -69,10 +69,14 @@ def get_prompt_executor(self): model_name = os.environ.get("OPENAI_MODEL") or "text-davinci-003" openai_api_key = os.environ.get("OPENAI_API_KEY") self.prompt_executor_kwargs = {"model_name": model_name} - if os.environ.get('OPENAI_API_TYPE') == 'azure': + if os.environ.get("OPENAI_API_TYPE") == "azure": if not os.environ.get("AZURE_DEPLOYMENT_NAME"): - raise Exception("Environment variable with key name 'AZURE_DEPLOYMENT_NAME' is required when OPEN_API_TYPE=='azure'.") - return AzureOpenAI(model_name=model_name, deployment_name=os.environ.get('AZURE_DEPLOYMENT_NAME')) + raise Exception( + "Environment variable with key name 'AZURE_DEPLOYMENT_NAME' is required when OPEN_API_TYPE=='azure'." + ) + return AzureOpenAI( + model_name=model_name, deployment_name=os.environ.get("AZURE_DEPLOYMENT_NAME") + ) return OpenAI(model_name=model_name, openai_api_key=openai_api_key) def execute_prompt(self, prompt_str): From 148ecdd6c120c69fb5eb7f5e2050ff2b0600389a Mon Sep 17 00:00:00 2001 From: Dwayne Thomson Date: Thu, 5 Oct 2023 17:45:45 -0400 Subject: [PATCH 4/5] shorten line to pass upstrem validation --- promptimize/prompt_cases.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/promptimize/prompt_cases.py b/promptimize/prompt_cases.py index f2b5059..dcb8b93 100644 --- a/promptimize/prompt_cases.py +++ b/promptimize/prompt_cases.py @@ -72,7 +72,8 @@ def get_prompt_executor(self): if os.environ.get("OPENAI_API_TYPE") == "azure": if not os.environ.get("AZURE_DEPLOYMENT_NAME"): raise Exception( - "Environment variable with key name 'AZURE_DEPLOYMENT_NAME' is required when OPEN_API_TYPE=='azure'." + "Environment variable with key name 'AZURE_DEPLOYMENT_NAME'"\ + "is required when OPEN_API_TYPE=='azure'." ) return AzureOpenAI( model_name=model_name, deployment_name=os.environ.get("AZURE_DEPLOYMENT_NAME") From c6c2e267f79d7fcacc7c2fc5d4bced7446ccc2a2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Oct 2023 21:45:55 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- promptimize/prompt_cases.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/promptimize/prompt_cases.py b/promptimize/prompt_cases.py index dcb8b93..7e4ed6b 100644 --- a/promptimize/prompt_cases.py +++ b/promptimize/prompt_cases.py @@ -72,7 +72,7 @@ def get_prompt_executor(self): if os.environ.get("OPENAI_API_TYPE") == "azure": if not os.environ.get("AZURE_DEPLOYMENT_NAME"): raise Exception( - "Environment variable with key name 'AZURE_DEPLOYMENT_NAME'"\ + "Environment variable with key name 'AZURE_DEPLOYMENT_NAME'" "is required when OPEN_API_TYPE=='azure'." ) return AzureOpenAI(