diff --git a/global_methods.py b/global_methods.py index 7ada78d..4ff6eb5 100644 --- a/global_methods.py +++ b/global_methods.py @@ -23,6 +23,12 @@ def set_gemini_key(): def set_openai_key(): openai.api_key = os.environ['OPENAI_API_KEY'] + if openai.api_type == "azure" and openai.azure_ad_token_provider is None: + import azure.identity + openai.azure_ad_token_provider = azure.identity.get_bearer_token_provider( + azure.identity.DefaultAzureCredential(), + "https://cognitiveservices.azure.com/.default", + ) def run_json_trials(query, num_gen=1, num_tokens_request=1000,