diff --git a/promptimize/prompt_cases.py b/promptimize/prompt_cases.py index 6f61889..7e4ed6b 100644 --- a/promptimize/prompt_cases.py +++ b/promptimize/prompt_cases.py @@ -1,7 +1,7 @@ import os from typing import Any, Callable, List, Optional, Union -from langchain.llms import OpenAI +from langchain.llms import OpenAI, AzureOpenAI from langchain.callbacks import get_openai_callback from box import Box @@ -69,6 +69,15 @@ def get_prompt_executor(self): model_name = os.environ.get("OPENAI_MODEL") or "text-davinci-003" openai_api_key = os.environ.get("OPENAI_API_KEY") self.prompt_executor_kwargs = {"model_name": model_name} + if os.environ.get("OPENAI_API_TYPE") == "azure": + if not os.environ.get("AZURE_DEPLOYMENT_NAME"): + raise Exception( + "Environment variable with key name 'AZURE_DEPLOYMENT_NAME'" + "is required when OPEN_API_TYPE=='azure'." + ) + return AzureOpenAI( + model_name=model_name, deployment_name=os.environ.get("AZURE_DEPLOYMENT_NAME") + ) return OpenAI(model_name=model_name, openai_api_key=openai_api_key) def execute_prompt(self, prompt_str):