Skip to content

Commit 9d56ad9

Browse files
committed
update gpt-5
1 parent 7f4a55c commit 9d56ad9

File tree

7 files changed

+150
-20
lines changed

7 files changed

+150
-20
lines changed

src/models/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
Model,
77
parse_json_if_needed,
88
agglomerate_stream_deltas,
9-
CODEAGENT_RESPONSE_FORMAT
9+
CODEAGENT_RESPONSE_FORMAT,
1010
)
1111
from .litellm import LiteLLMModel
1212
from .openaillm import OpenAIServerModel
1313
from .models import ModelManager
14+
from .message_manager import MessageManager
1415

1516
model_manager = ModelManager()
1617

@@ -23,4 +24,5 @@
2324
"parse_json_if_needed",
2425
"model_manager",
2526
"ModelManager",
27+
"MessageManager",
2628
]

src/models/litellm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(
3636
self,
3737
model_id: Optional[str] = None,
3838
api_base=None,
39+
api_type: str = "chat/completions",
3940
api_key=None,
4041
custom_role_conversions: dict[str, str] | None = None,
4142
flatten_messages_as_text: bool | None = None,
@@ -52,6 +53,7 @@ def __init__(
5253
model_id = "anthropic/claude-3-5-sonnet-20240620"
5354
self.model_id = model_id
5455
self.api_base = api_base
56+
self.api_type = api_type
5557
self.api_key = api_key
5658
flatten_messages_as_text = (
5759
flatten_messages_as_text

src/models/message_manager.py

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,16 @@
1919
]
2020

2121
class MessageManager():
22-
def __init__(self, model_id: str):
22+
def __init__(self, model_id: str, api_type: str = "chat/completions"):
2323
self.model_id = model_id
24+
self.api_type = api_type
2425

2526
def get_clean_message_list(self,
2627
message_list: list[ChatMessage],
2728
role_conversions: dict[MessageRole, MessageRole] | dict[str, str] = {},
2829
convert_images_to_image_urls: bool = False,
2930
flatten_messages_as_text: bool = False,
31+
api_type: str = "chat/completions",
3032
) -> list[dict[str, Any]]:
3133
"""
3234
Creates a list of messages to give as input to the LLM. These messages are dictionaries and chat template compatible with transformers LLM chat template.
@@ -38,6 +40,25 @@ def get_clean_message_list(self,
3840
convert_images_to_image_urls (`bool`, default `False`): Whether to convert images to image URLs.
3941
flatten_messages_as_text (`bool`, default `False`): Whether to flatten messages as text.
4042
"""
43+
api_type = api_type or self.api_type
44+
if api_type == "responses":
45+
return self._get_responses_message_list(
46+
message_list, role_conversions, convert_images_to_image_urls, flatten_messages_as_text
47+
)
48+
else:
49+
return self._get_chat_completions_message_list(
50+
message_list, role_conversions, convert_images_to_image_urls, flatten_messages_as_text
51+
)
52+
53+
def _get_chat_completions_message_list(self,
54+
message_list: list[ChatMessage],
55+
role_conversions: dict[MessageRole, MessageRole] | dict[str, str] = {},
56+
convert_images_to_image_urls: bool = False,
57+
flatten_messages_as_text: bool = False,
58+
) -> list[dict[str, Any]]:
59+
"""
60+
Creates a list of messages in chat completions format.
61+
"""
4162
output_message_list: list[dict[str, Any]] = []
4263
message_list = deepcopy(message_list) # Avoid modifying the original list
4364
for message in message_list:
@@ -87,6 +108,106 @@ def get_clean_message_list(self,
87108
)
88109
return output_message_list
89110

111+
def _get_responses_message_list(self,
112+
message_list: list[ChatMessage],
113+
role_conversions: dict[MessageRole, MessageRole] | dict[str, str] = {},
114+
convert_images_to_image_urls: bool = False,
115+
flatten_messages_as_text: bool = False,
116+
) -> list[dict[str, Any]]:
117+
"""
118+
Creates a list of messages in responses format (OpenAI responses API).
119+
"""
120+
output_message_list: list[dict[str, Any]] = []
121+
message_list = deepcopy(message_list) # Avoid modifying the original list
122+
123+
for message in message_list:
124+
role = message.role
125+
if role not in MessageRole.roles():
126+
raise ValueError(f"Incorrect role {role}, only {MessageRole.roles()} are supported for now.")
127+
128+
if role in role_conversions:
129+
message.role = role_conversions[role] # type: ignore
130+
131+
# Handle content processing
132+
if isinstance(message.content, list):
133+
# Process each content element
134+
processed_content = []
135+
for element in message.content:
136+
assert isinstance(element, dict), "Error: this element should be a dict:" + str(element)
137+
138+
if element["type"] == "image":
139+
assert not flatten_messages_as_text, f"Cannot use images with {flatten_messages_as_text=}"
140+
if convert_images_to_image_urls:
141+
processed_content.append({
142+
"type": "image_url",
143+
"image_url": {"url": make_image_url(encode_image_base64(element.pop("image")))},
144+
})
145+
else:
146+
processed_content.append({
147+
"type": "image",
148+
"image": encode_image_base64(element["image"])
149+
})
150+
elif element["type"] == "text":
151+
processed_content.append(element)
152+
else:
153+
processed_content.append(element)
154+
155+
content = processed_content
156+
else:
157+
# Handle string content
158+
if flatten_messages_as_text:
159+
content = message.content
160+
else:
161+
content = [{"type": "text", "text": message.content}] if message.content else []
162+
163+
# Handle tool calls for responses format
164+
tool_calls = None
165+
if message.tool_calls:
166+
tool_calls = []
167+
for tool_call in message.tool_calls:
168+
tool_calls.append({
169+
"id": tool_call.id,
170+
"type": tool_call.type,
171+
"function": {
172+
"name": tool_call.function.name,
173+
"arguments": tool_call.function.arguments,
174+
"description": tool_call.function.description
175+
}
176+
})
177+
178+
# Create message in responses format
179+
message_dict = {
180+
"role": message.role,
181+
"content": content,
182+
}
183+
184+
if tool_calls:
185+
message_dict["tool_calls"] = tool_calls
186+
187+
# Merge consecutive messages with same role
188+
if len(output_message_list) > 0 and message.role == output_message_list[-1]["role"]:
189+
if flatten_messages_as_text:
190+
if isinstance(content, list) and content and content[0]["type"] == "text":
191+
output_message_list[-1]["content"] += "\n" + content[0]["text"]
192+
else:
193+
output_message_list[-1]["content"] += "\n" + str(content)
194+
else:
195+
# Merge content lists
196+
if isinstance(output_message_list[-1]["content"], list) and isinstance(content, list):
197+
output_message_list[-1]["content"].extend(content)
198+
else:
199+
output_message_list[-1]["content"] = content
200+
201+
# Merge tool calls
202+
if tool_calls and "tool_calls" in output_message_list[-1]:
203+
output_message_list[-1]["tool_calls"].extend(tool_calls)
204+
elif tool_calls:
205+
output_message_list[-1]["tool_calls"] = tool_calls
206+
else:
207+
output_message_list.append(message_dict)
208+
209+
return output_message_list
210+
90211
def get_tool_json_schema(self,
91212
tool: Any,
92213
model_id: Optional[str] = None

src/models/models.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -154,26 +154,31 @@ def _register_openai_models(self, use_local_proxy: bool = False):
154154
# deep research
155155
model_name = "o3-deep-research"
156156
model_id = "o3-deep-research"
157-
model = RestfulResponseModel(
158-
api_base=self._check_local_api_base(local_api_base_name="SKYWORK_SHUBIAOBIAO_API_BASE",
159-
remote_api_base_name="OPENAI_API_BASE"),
160-
api_type="responses",
157+
client = AsyncOpenAI(
161158
api_key=api_key,
159+
base_url=self._check_local_api_base(local_api_base_name="SKYWORK_API_BASE",
160+
remote_api_base_name="SKYWORK_API_BASE"),
161+
http_client=ASYNC_HTTP_CLIENT,
162+
)
163+
model = LiteLLMModel(
162164
model_id=model_id,
163-
http_client=HTTP_CLIENT,
165+
http_client=client,
164166
custom_role_conversions=custom_role_conversions,
165167
)
166168
self.registed_models[model_name] = model
167-
168-
model_name = "o4-mini-deep-research"
169-
model_id = "o4-mini-deep-research"
170-
model = RestfulResponseModel(
171-
api_base=self._check_local_api_base(local_api_base_name="SKYWORK_SHUBIAOBIAO_API_BASE",
172-
remote_api_base_name="OPENAI_API_BASE"),
173-
api_type="responses",
169+
170+
# gpt-5
171+
model_name = "gpt-5"
172+
model_id = "openai/gpt-5"
173+
client = AsyncOpenAI(
174174
api_key=api_key,
175+
base_url=self._check_local_api_base(local_api_base_name="SKYWORK_AZURE_US_API_BASE",
176+
remote_api_base_name="OPENAI_API_BASE"),
177+
http_client=ASYNC_HTTP_CLIENT,
178+
)
179+
model = LiteLLMModel(
175180
model_id=model_id,
176-
http_client=HTTP_CLIENT,
181+
http_client=client,
177182
custom_role_conversions=custom_role_conversions,
178183
)
179184
self.registed_models[model_name] = model

src/models/openaillm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import warnings
2-
from typing import Dict, List, Optional, Any
3-
from copy import deepcopy
1+
from typing import Any
42
from collections.abc import Generator
53

64
from src.models.base import (ApiModel,

src/models/restful.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -984,8 +984,6 @@ async def generate(
984984

985985
# Async call to the LiteLLM client for completion
986986
response = self.client.completion(**completion_kwargs)
987-
print(response)
988-
exit()
989987

990988
response = ChatCompletion.model_validate(response)
991989

tests/test_models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ async def video_generation():
8989
messages = [
9090
ChatMessage(role="user", content="What is the capital of France?"),
9191
]
92+
93+
response = asyncio.run(model_manager.registed_models["gpt-5"](
94+
messages=messages,
95+
))
9296

9397
response = asyncio.run(model_manager.registed_models["o3-deep-research"](
9498
messages=messages,

0 commit comments

Comments
 (0)