1+ import json
12from typing import Dict , List , Optional , Any
23from collections .abc import Generator
34from openai .types .chat import ChatCompletion
@@ -58,6 +59,62 @@ def completion(self,
5859
5960 return response .json ()
6061
62+ class RestfulResponseClient ():
63+ def __init__ (self ,
64+ api_base : str ,
65+ api_key : str ,
66+ api_type : str = "responses" ,
67+ model_id : str = "o3" ,
68+ http_client = None ):
69+ self .api_base = api_base
70+ self .api_key = api_key
71+ self .api_type = api_type
72+ self .model_id = model_id
73+
74+ self .http_client = http_client
75+
76+ def completion (self ,
77+ model ,
78+ input ,
79+ tools ,
80+ ** kwargs ):
81+
82+ headers = {
83+ "app_key" : self .api_key ,
84+ "Content-Type" : "application/json"
85+ }
86+
87+ model = model .split ("/" )[- 1 ]
88+ data = {
89+ "model" : model ,
90+ "input" : input ,
91+ "tools" : tools ,
92+ "stream" : False ,
93+ }
94+
95+ # Add any additional kwargs to the data
96+ if kwargs :
97+ data .update (kwargs )
98+
99+ response = requests .post (
100+ f"{ self .api_base } /{ self .api_type } " ,
101+ json = data ,
102+ headers = headers ,
103+ )
104+
105+ response_text = response .text
106+ for line in response_text .split ('\n ' ):
107+ if line .strip ():
108+ try :
109+ json_line = line .strip ()
110+ print (json_line )
111+ if json_line .startswith ("data: " ) and "response.completed" in json_line :
112+ json_line = json_line .replace ("data: " , "" ).strip ()
113+ res = json .loads (json_line )
114+ return res
115+ except Exception as e :
116+ logger .error (f"Error parsing line: { line } , error: { e } " )
117+
61118
62119class RestfulTranscribeClient ():
63120 def __init__ (self ,
@@ -726,4 +783,226 @@ def __call__(self, *args, **kwargs) -> str:
726783 Call the model with the given arguments.
727784 This is a convenience method that calls `generate` with the same arguments.
728785 """
729- return self .generate (* args , ** kwargs )
786+ return self .generate (* args , ** kwargs )
787+
788+
789+ class RestfulResponseModel (ApiModel ):
790+ """This model connects to an OpenAI-compatible API server.
791+
792+ Parameters:
793+ model_id (`str`):
794+ The model identifier to use on the server (e.g. "gpt-3.5-turbo").
795+ api_base (`str`, *optional*):
796+ The base URL of the OpenAI-compatible API server.
797+ api_key (`str`, *optional*):
798+ The API key to use for authentication.
799+ organization (`str`, *optional*):
800+ The organization to use for the API request.
801+ project (`str`, *optional*):
802+ The project to use for the API request.
803+ client_kwargs (`dict[str, Any]`, *optional*):
804+ Additional keyword arguments to pass to the OpenAI client (like organization, project, max_retries etc.).
805+ custom_role_conversions (`dict[str, str]`, *optional*):
806+ Custom role conversion mapping to convert message roles in others.
807+ Useful for specific models that do not support specific message roles like "system".
808+ flatten_messages_as_text (`bool`, default `False`):
809+ Whether to flatten messages as text.
810+ **kwargs:
811+ Additional keyword arguments to pass to the OpenAI API.
812+ """
813+
814+ def __init__ (
815+ self ,
816+ model_id : str ,
817+ api_base : Optional [str ] = None ,
818+ api_type : str = "chat/completions" ,
819+ api_key : Optional [str ] = None ,
820+ custom_role_conversions : dict [str , str ] | None = None ,
821+ flatten_messages_as_text : bool = False ,
822+ http_client = None ,
823+ ** kwargs ,
824+ ):
825+ self .model_id = model_id
826+ self .api_base = api_base
827+ self .api_key = api_key
828+ self .api_type = api_type
829+ flatten_messages_as_text = (
830+ flatten_messages_as_text
831+ if flatten_messages_as_text is not None
832+ else model_id .startswith (("ollama" , "groq" , "cerebras" ))
833+ )
834+
835+ self .http_client = http_client
836+
837+ self .message_manager = MessageManager (model_id = model_id )
838+
839+ super ().__init__ (
840+ model_id = model_id ,
841+ custom_role_conversions = custom_role_conversions ,
842+ flatten_messages_as_text = flatten_messages_as_text ,
843+ ** kwargs ,
844+ )
845+
846+ def create_client (self ):
847+ return RestfulResponseClient (api_base = self .api_base ,
848+ api_key = self .api_key ,
849+ api_type = self .api_type ,
850+ model_id = self .model_id ,
851+ http_client = self .http_client )
852+
853+ def _prepare_completion_kwargs (
854+ self ,
855+ messages : list [ChatMessage ],
856+ stop_sequences : list [str ] | None = None ,
857+ response_format : dict [str , str ] | None = None ,
858+ tools_to_call_from : list [Any ] | None = None ,
859+ custom_role_conversions : dict [str , str ] | None = None ,
860+ convert_images_to_image_urls : bool = False ,
861+ tool_choice : str | dict | None = "required" , # Configurable tool_choice parameter
862+ ** kwargs ,
863+ ) -> dict [str , Any ]:
864+ """
865+ Prepare parameters required for model invocation, handling parameter priorities.
866+
867+ Parameter priority from high to low:
868+ 1. Explicitly passed kwargs
869+ 2. Specific parameters (stop_sequences, response_format, etc.)
870+ 3. Default values in self.kwargs
871+ """
872+ # Clean and standardize the message list
873+ flatten_messages_as_text = kwargs .pop ("flatten_messages_as_text" , self .flatten_messages_as_text )
874+ messages_as_dicts = self .message_manager .get_clean_message_list (
875+ messages ,
876+ role_conversions = custom_role_conversions or tool_role_conversions ,
877+ convert_images_to_image_urls = convert_images_to_image_urls ,
878+ flatten_messages_as_text = flatten_messages_as_text ,
879+ )
880+ # Use self.kwargs as the base configuration
881+ completion_kwargs = {
882+ ** self .kwargs ,
883+ "input" : messages_as_dicts ,
884+ }
885+
886+ # Handle specific parameters
887+ if stop_sequences is not None :
888+ completion_kwargs ["stop" ] = stop_sequences
889+ if response_format is not None :
890+ completion_kwargs ["response_format" ] = response_format
891+
892+ # Handle tools parameter
893+ if tools_to_call_from :
894+ tools_config = {
895+ "tools" : [self .message_manager .get_tool_json_schema (tool , model_id = self .model_id ) for tool in
896+ tools_to_call_from ],
897+ }
898+ if tool_choice is not None :
899+ tools_config ["tool_choice" ] = tool_choice
900+ completion_kwargs .update (tools_config )
901+
902+ # Finally, use the passed-in kwargs to override all settings
903+ completion_kwargs .update (kwargs )
904+
905+ completion_kwargs = self .message_manager .get_clean_completion_kwargs (completion_kwargs )
906+
907+ return completion_kwargs
908+
909+ def generate_stream (self ,
910+ messages : list [ChatMessage ],
911+ stop_sequences : list [str ] | None = None ,
912+ response_format : dict [str , str ] | None = None ,
913+ tools_to_call_from : list [Any ] | None = None ,
914+ ** kwargs ,
915+ )-> Generator [ChatMessageStreamDelta ]:
916+
917+ completion_kwargs = self ._prepare_completion_kwargs (
918+ messages = messages ,
919+ stop_sequences = stop_sequences ,
920+ response_format = response_format ,
921+ tools_to_call_from = tools_to_call_from ,
922+ model = self .model_id ,
923+ custom_role_conversions = self .custom_role_conversions ,
924+ convert_images_to_image_urls = True ,
925+ ** kwargs ,
926+ )
927+
928+ for event in self .client .completion (** completion_kwargs , stream = True , stream_options = {"include_usage" : True }):
929+ if getattr (event , "usage" , None ):
930+ self ._last_input_token_count = event .usage .prompt_tokens
931+ self ._last_output_token_count = event .usage .completion_tokens
932+ yield ChatMessageStreamDelta (
933+ content = "" ,
934+ token_usage = TokenUsage (
935+ input_tokens = event .usage .prompt_tokens ,
936+ output_tokens = event .usage .completion_tokens ,
937+ ),
938+ )
939+ if event .choices :
940+ choice = event .choices [0 ]
941+ if choice .delta :
942+ yield ChatMessageStreamDelta (
943+ content = choice .delta .content ,
944+ tool_calls = [
945+ ChatMessageToolCallStreamDelta (
946+ index = delta .index ,
947+ id = delta .id ,
948+ type = delta .type ,
949+ function = delta .function ,
950+ )
951+ for delta in choice .delta .tool_calls
952+ ]
953+ if choice .delta .tool_calls
954+ else None ,
955+ )
956+ else :
957+ if not getattr (choice , "finish_reason" , None ):
958+ raise ValueError (f"No content or tool calls in event: { event } " )
959+
960+
961+ async def generate (
962+ self ,
963+ messages : list [ChatMessage ],
964+ stop_sequences : list [str ] | None = None ,
965+ response_format : dict [str , str ] | None = None ,
966+ tools_to_call_from : list [Any ] | None = None ,
967+ ** kwargs ,
968+ ) -> ChatMessage :
969+
970+ completion_kwargs = self ._prepare_completion_kwargs (
971+ messages = messages ,
972+ stop_sequences = stop_sequences ,
973+ response_format = response_format ,
974+ tools_to_call_from = tools_to_call_from ,
975+ model = self .model_id ,
976+ convert_images_to_image_urls = True ,
977+ custom_role_conversions = self .custom_role_conversions ,
978+ ** kwargs ,
979+ )
980+
981+ completion_kwargs ['tools' ] = [
982+ {"type" : "web_search_preview" },
983+ ]
984+
985+ # Async call to the LiteLLM client for completion
986+ response = self .client .completion (** completion_kwargs )
987+ print (response )
988+ exit ()
989+
990+ response = ChatCompletion .model_validate (response )
991+
992+ self ._last_input_token_count = response .usage .prompt_tokens
993+ self ._last_output_token_count = response .usage .completion_tokens
994+ return ChatMessage .from_dict (
995+ response .choices [0 ].message .model_dump (include = {"role" , "content" , "tool_calls" }),
996+ raw = response ,
997+ token_usage = TokenUsage (
998+ input_tokens = response .usage .prompt_tokens ,
999+ output_tokens = response .usage .completion_tokens ,
1000+ ),
1001+ )
1002+
1003+ async def __call__ (self , * args , ** kwargs ) -> ChatMessage :
1004+ """
1005+ Call the model with the given arguments.
1006+ This is a convenience method that calls `generate` with the same arguments.
1007+ """
1008+ return await self .generate (* args , ** kwargs )
0 commit comments