|
7 | 7 |
|
8 | 8 | import openai |
9 | 9 | from openai import AsyncAzureOpenAI, AsyncOpenAI |
10 | | -from openai.types import CreateEmbeddingResponse |
11 | 10 | from openai.types.chat import ( |
12 | 11 | ChatCompletionMessageParam, |
13 | 12 | ChatCompletionSystemMessageParam, |
@@ -204,41 +203,6 @@ async def prompt_file( |
204 | 203 | user_prompt = self._load_prompt_file(user_prompt_file) |
205 | 204 | return await self.prompt(user_prompt, system_prompt, params, prompt_settings) |
206 | 205 |
|
207 | | - async def embeddings( |
208 | | - self, inputs: List[str], embedding_settings: Optional[Dict[str, Any]] = None |
209 | | - ) -> Dict[str, List[float]]: |
210 | | - settings = {"model": "text-embedding-ada-002-v2", **(embedding_settings or {})} |
211 | | - |
212 | | - embeddings = {} |
213 | | - |
214 | | - async def _run_single_embedding(input: str) -> int: |
215 | | - connection_errors = 0 |
216 | | - while True: |
217 | | - try: |
218 | | - result: CreateEmbeddingResponse = await self._client.embeddings.create( |
219 | | - input=[input], encoding_format="float", **settings |
220 | | - ) |
221 | | - embeddings[input] = result.data[0].embedding |
222 | | - return result.usage.prompt_tokens |
223 | | - except (openai.RateLimitError, openai.InternalServerError) as e: |
224 | | - logger.warning(f"Rate limit error on embeddings: {e}") |
225 | | - await asyncio.sleep(1) |
226 | | - except (openai.APIConnectionError, openai.APITimeoutError): |
227 | | - if connection_errors > 2: |
228 | | - if hasattr(self._config, "endpoint") and self._config.endpoint.startswith("http://localhost"): |
229 | | - logger.error("Azure OpenAI API unreachable - have failed to start a local proxy?") |
230 | | - raise |
231 | | - logger.warning("Connectivity error on embeddings, retrying...") |
232 | | - connection_errors += 1 |
233 | | - await asyncio.sleep(1) |
234 | | - |
235 | | - start_overall = time.time() |
236 | | - tokens = await asyncio.gather(*[_run_single_embedding(input) for input in inputs]) |
237 | | - elapsed = time.time() - start_overall |
238 | | - |
239 | | - logger.info(f"{len(inputs)} embeddings produced in {elapsed:.1f} seconds using {sum(tokens)} tokens.") |
240 | | - return embeddings |
241 | | - |
242 | 206 |
|
243 | 207 | class OpenAICacheClient(OpenAIClient): |
244 | 208 | def __init__(self, client_class: str, config: Dict[str, Any]) -> None: |
|
0 commit comments