diff --git a/chainfury/components/openai/__init__.py b/chainfury/components/openai/__init__.py index db8a0c1..44955ce 100644 --- a/chainfury/components/openai/__init__.py +++ b/chainfury/components/openai/__init__.py @@ -1,6 +1,7 @@ import requests from pydantic import BaseModel from typing import Any, List, Union, Dict, Optional +from litellm import completion from chainfury import Secret, model_registry, exponential_backoff, Model, UnAuthException from chainfury.components.const import Env @@ -195,31 +196,25 @@ def openai_chat( messages = [x.dict(skip_defaults=True) for x in messages] def _fn(): - r = requests.post( - "https://api.openai.com/v1/chat/completions", - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {openai_api_key}", - }, - json={ - "model": model, - "messages": messages, - "max_tokens": max_tokens, - "temperature": temperature, - "top_p": top_p, - "n": n, - "stop": stop, - "presence_penalty": presence_penalty, - "frequency_penalty": frequency_penalty, - "logit_bias": logit_bias, - "user": user, - }, - ) - if r.status_code == 401: - raise UnAuthException(r.text) - if r.status_code != 200: - raise Exception(f"OpenAI API returned status code {r.status_code}: {r.text}") - return r.json() + try: + r = completion( + api_key = openai_api_key, + model= model, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + n= n, + stop=stop, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + user=user, + + ) + except Exception as e: + raise Exception(f"OpenAI API returned exception: {e} ") + return r return exponential_backoff(_fn, max_retries=retry_count, retry_delay=retry_delay) diff --git a/pyproject.toml b/pyproject.toml index c0266c0..2062fbd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ Jinja2 = "3.1.2" jinja2schema = "0.1.4" pydantic = "^1.10.7" requests = "^2.28.2" +litellm = { version = "0.1.531", optional = true } stability-sdk = { version = "0.8.3", optional = true } qdrant-client = { version = "1.3.1", optional = true }