Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 0 additions & 36 deletions lib/evagg/llm/aoai.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import openai
from openai import AsyncAzureOpenAI, AsyncOpenAI
from openai.types import CreateEmbeddingResponse
from openai.types.chat import (
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
Expand Down Expand Up @@ -204,41 +203,6 @@ async def prompt_file(
user_prompt = self._load_prompt_file(user_prompt_file)
return await self.prompt(user_prompt, system_prompt, params, prompt_settings)

async def embeddings(
self, inputs: List[str], embedding_settings: Optional[Dict[str, Any]] = None
) -> Dict[str, List[float]]:
settings = {"model": "text-embedding-ada-002-v2", **(embedding_settings or {})}

embeddings = {}

async def _run_single_embedding(input: str) -> int:
connection_errors = 0
while True:
try:
result: CreateEmbeddingResponse = await self._client.embeddings.create(
input=[input], encoding_format="float", **settings
)
embeddings[input] = result.data[0].embedding
return result.usage.prompt_tokens
except (openai.RateLimitError, openai.InternalServerError) as e:
logger.warning(f"Rate limit error on embeddings: {e}")
await asyncio.sleep(1)
except (openai.APIConnectionError, openai.APITimeoutError):
if connection_errors > 2:
if hasattr(self._config, "endpoint") and self._config.endpoint.startswith("http://localhost"):
logger.error("Azure OpenAI API unreachable - have failed to start a local proxy?")
raise
logger.warning("Connectivity error on embeddings, retrying...")
connection_errors += 1
await asyncio.sleep(1)

start_overall = time.time()
tokens = await asyncio.gather(*[_run_single_embedding(input) for input in inputs])
elapsed = time.time() - start_overall

logger.info(f"{len(inputs)} embeddings produced in {elapsed:.1f} seconds using {sum(tokens)} tokens.")
return embeddings


class OpenAICacheClient(OpenAIClient):
def __init__(self, client_class: str, config: Dict[str, Any]) -> None:
Expand Down
8 changes: 1 addition & 7 deletions lib/evagg/llm/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Protocol
from typing import Any, Dict, Optional, Protocol


class IPromptClient(Protocol):
Expand All @@ -21,9 +21,3 @@ async def prompt_file(
) -> str:
"""Get the response from a prompt with an input file."""
... # pragma: no cover

async def embeddings(
self, inputs: List[str], embedding_settings: Optional[Dict[str, Any]] = None
) -> Dict[str, List[float]]:
"""Get embeddings for the given inputs."""
... # pragma: no cover
24 changes: 0 additions & 24 deletions test/evagg/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,27 +44,3 @@ async def test_openai_client_prompt(mock_openai, test_file_contents) -> None:
temperature=1.5,
model="gpt-8",
)


@patch("lib.evagg.llm.aoai.AsyncAzureOpenAI", return_value=AsyncMock())
async def test_openai_client_embeddings(mock_openai) -> None:
embedding = MagicMock(data=[MagicMock(embedding=[0.4, 0.5, 0.6])], usage=MagicMock(prompt_tokens=10))
mock_openai.return_value.embeddings.create.return_value = embedding

inputs = [f"input_{i}" for i in range(1)]
client = OpenAIClient(
"AsyncAzureOpenAI",
{
"deployment": "gpt-8",
"endpoint": "https://ai",
"api_key": "test",
"api_version": "test",
"timeout": 60,
},
)
response = await client.embeddings(inputs)
mock_openai.assert_called_once_with(azure_endpoint="https://ai", api_key="test", api_version="test", timeout=60)
mock_openai.return_value.embeddings.create.assert_has_calls(
[call(input=[input], encoding_format="float", model="text-embedding-ada-002-v2") for input in inputs]
)
assert response == {input: [0.4, 0.5, 0.6] for input in inputs}