diff --git a/src/typeagent/aitools/utils.py b/src/typeagent/aitools/utils.py index adba401a..eb21f9ec 100644 --- a/src/typeagent/aitools/utils.py +++ b/src/typeagent/aitools/utils.py @@ -11,7 +11,6 @@ import sys import time -import black import colorama import typechat @@ -57,6 +56,8 @@ def format_code(text: str, line_width=None) -> str: NOTE: The text must be a valid Python expression or code block. """ + import black + if line_width is None: # Use the terminal width, but cap it to 200 characters. line_width = min(200, shutil.get_terminal_size().columns) @@ -197,7 +198,11 @@ def parse_azure_endpoint( f"{endpoint_envvar}={azure_endpoint} doesn't contain valid api-version field" ) - return azure_endpoint, m.group(1) + # Strip query string — AsyncAzureOpenAI expects a clean base URL and + # receives api_version as a separate parameter. + clean_endpoint = azure_endpoint.split("?", 1)[0] + + return clean_endpoint, m.group(1) def get_azure_api_key(azure_api_key: str) -> str: diff --git a/src/typeagent/knowpro/answers.py b/src/typeagent/knowpro/answers.py index 2c300506..58a536ed 100644 --- a/src/typeagent/knowpro/answers.py +++ b/src/typeagent/knowpro/answers.py @@ -5,8 +5,6 @@ from dataclasses import dataclass from typing import Any -import black - import typechat from .answer_context_schema import AnswerContext, RelevantKnowledge, RelevantMessage @@ -127,6 +125,8 @@ def create_question_prompt(question: str) -> str: def create_context_prompt(context: AnswerContext) -> str: # TODO: Use a more compact representation of the context than JSON. + import black + prompt = [ "[ANSWER CONTEXT]", "===", diff --git a/tests/test_utils.py b/tests/test_utils.py index 5966af61..7f806f74 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -67,7 +67,7 @@ def test_api_version_after_question_mark( ) endpoint, version = utils.parse_azure_endpoint("TEST_ENDPOINT") assert version == "2025-01-01-preview" - assert endpoint.startswith("https://") + assert endpoint == "https://myhost.openai.azure.com/openai/deployments/gpt-4" def test_api_version_after_ampersand(self, monkeypatch: pytest.MonkeyPatch) -> None: """api-version preceded by & (not the first query parameter).""" @@ -84,6 +84,44 @@ def test_missing_env_var_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: with pytest.raises(RuntimeError, match="not found"): utils.parse_azure_endpoint("NONEXISTENT_ENDPOINT") + def test_query_string_stripped_from_endpoint( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Returned endpoint should not contain query string parameters.""" + monkeypatch.setenv( + "TEST_ENDPOINT", + "https://myhost.openai.azure.com?api-version=2024-06-01", + ) + endpoint, version = utils.parse_azure_endpoint("TEST_ENDPOINT") + assert endpoint == "https://myhost.openai.azure.com" + assert version == "2024-06-01" + + def test_query_string_stripped_with_path( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Query string stripped even when endpoint includes a path.""" + monkeypatch.setenv( + "TEST_ENDPOINT", + "https://myhost.openai.azure.com/openai/deployments/gpt-4?api-version=2025-01-01-preview", + ) + endpoint, version = utils.parse_azure_endpoint("TEST_ENDPOINT") + assert endpoint == "https://myhost.openai.azure.com/openai/deployments/gpt-4" + assert "?" not in endpoint + assert version == "2025-01-01-preview" + + def test_query_string_stripped_multiple_params( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """All query parameters stripped, not just api-version.""" + monkeypatch.setenv( + "TEST_ENDPOINT", + "https://myhost.openai.azure.com?foo=bar&api-version=2024-06-01", + ) + endpoint, version = utils.parse_azure_endpoint("TEST_ENDPOINT") + assert endpoint == "https://myhost.openai.azure.com" + assert "foo" not in endpoint + assert version == "2024-06-01" + def test_no_api_version_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: """RuntimeError when the endpoint has no api-version field.""" monkeypatch.setenv(