Skip to content

Commit 230d9d8

Browse files
Fix Tool Call API & Minor Change (#1080)
- Fix #1062 - Support new function call format (DeepSeek, Kimi-K2) - Make Tool Call API compatible to the latest OpenAI version --------- Co-authored-by: shihaobai <[email protected]>
1 parent 82df7a1 commit 230d9d8

File tree

6 files changed

+1133
-278
lines changed

6 files changed

+1133
-278
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/gpt_oss_fused_moe_weight_tp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,10 @@ def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_t
138138
topk_ids=topk_ids,
139139
inplace=True,
140140
use_fp8_w8a8=use_fp8_w8a8,
141-
w1_scale=w1_scale,
142-
w2_scale=w2_scale,
143141
w1_bias=self.w1_bias,
144142
w2_bias=self.w2_bias / self.tp_world_size_,
143+
w1_scale=w1_scale,
144+
w2_scale=w2_scale,
145145
layout="interleaved",
146146
alpha=self.alpha,
147147
limit=self.limit,

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -764,13 +764,13 @@ def fused_experts_impl(
764764
hidden_states: torch.Tensor,
765765
w1: torch.Tensor,
766766
w2: torch.Tensor,
767-
w1_bias: Optional[torch.Tensor],
768-
w2_bias: Optional[torch.Tensor],
769767
topk_weights: torch.Tensor,
770768
topk_ids: torch.Tensor,
771769
inplace: bool = False,
772770
use_fp8_w8a8: bool = False,
773771
use_int8_w8a16: bool = False,
772+
w1_bias: Optional[torch.Tensor] = None,
773+
w2_bias: Optional[torch.Tensor] = None,
774774
w1_scale: Optional[torch.Tensor] = None,
775775
w2_scale: Optional[torch.Tensor] = None,
776776
a1_scale: Optional[torch.Tensor] = None,
@@ -890,13 +890,13 @@ def inplace_fused_experts_impl(
890890
hidden_states: torch.Tensor,
891891
w1: torch.Tensor,
892892
w2: torch.Tensor,
893-
# optional bias for w1 and w2
894-
w1_bias: Optional[torch.Tensor],
895-
w2_bias: Optional[torch.Tensor],
896893
topk_weights: torch.Tensor,
897894
topk_ids: torch.Tensor,
898895
use_fp8_w8a8: bool = False,
899896
use_int8_w8a16: bool = False,
897+
# optional bias for w1 and w2
898+
w1_bias: Optional[torch.Tensor] = None,
899+
w2_bias: Optional[torch.Tensor] = None,
900900
w1_scale: Optional[torch.Tensor] = None,
901901
w2_scale: Optional[torch.Tensor] = None,
902902
a1_scale: Optional[torch.Tensor] = None,
@@ -909,13 +909,13 @@ def inplace_fused_experts_impl(
909909
hidden_states,
910910
w1,
911911
w2,
912-
w1_bias,
913-
w2_bias,
914912
topk_weights,
915913
topk_ids,
916914
True,
917915
use_fp8_w8a8,
918916
use_int8_w8a16,
917+
w1_bias,
918+
w2_bias,
919919
w1_scale,
920920
w2_scale,
921921
a1_scale,
@@ -930,13 +930,13 @@ def inplace_fused_experts_impl_fake(
930930
hidden_states: torch.Tensor,
931931
w1: torch.Tensor,
932932
w2: torch.Tensor,
933-
# optional bias for w1 and w2
934-
w1_bias: Optional[torch.Tensor],
935-
w2_bias: Optional[torch.Tensor],
936933
topk_weights: torch.Tensor,
937934
topk_ids: torch.Tensor,
938935
use_fp8_w8a8: bool = False,
939936
use_int8_w8a16: bool = False,
937+
# optional bias for w1 and w2
938+
w1_bias: Optional[torch.Tensor] = None,
939+
w2_bias: Optional[torch.Tensor] = None,
940940
w1_scale: Optional[torch.Tensor] = None,
941941
w2_scale: Optional[torch.Tensor] = None,
942942
a1_scale: Optional[torch.Tensor] = None,
@@ -960,13 +960,13 @@ def outplace_fused_experts_impl(
960960
hidden_states: torch.Tensor,
961961
w1: torch.Tensor,
962962
w2: torch.Tensor,
963-
# optional bias for w1 and w2
964-
w1_bias: Optional[torch.Tensor],
965-
w2_bias: Optional[torch.Tensor],
966963
topk_weights: torch.Tensor,
967964
topk_ids: torch.Tensor,
968965
use_fp8_w8a8: bool = False,
969966
use_int8_w8a16: bool = False,
967+
# optional bias for w1 and w2
968+
w1_bias: Optional[torch.Tensor] = None,
969+
w2_bias: Optional[torch.Tensor] = None,
970970
w1_scale: Optional[torch.Tensor] = None,
971971
w2_scale: Optional[torch.Tensor] = None,
972972
a1_scale: Optional[torch.Tensor] = None,
@@ -979,13 +979,13 @@ def outplace_fused_experts_impl(
979979
hidden_states,
980980
w1,
981981
w2,
982-
w1_bias,
983-
w2_bias,
984982
topk_weights,
985983
topk_ids,
986984
False,
987985
use_fp8_w8a8,
988986
use_int8_w8a16,
987+
w1_bias,
988+
w2_bias,
989989
w1_scale,
990990
w2_scale,
991991
a1_scale,
@@ -1051,12 +1051,12 @@ def fused_experts(
10511051
hidden_states,
10521052
w1,
10531053
w2,
1054-
w1_bias,
1055-
w2_bias,
10561054
topk_weights,
10571055
topk_ids,
10581056
use_fp8_w8a8,
10591057
use_int8_w8a16,
1058+
w1_bias,
1059+
w2_bias,
10601060
w1_scale,
10611061
w2_scale,
10621062
a1_scale,
@@ -1071,12 +1071,12 @@ def fused_experts(
10711071
hidden_states,
10721072
w1,
10731073
w2,
1074-
w1_bias,
1075-
w2_bias,
10761074
topk_weights,
10771075
topk_ids,
10781076
use_fp8_w8a8,
10791077
use_int8_w8a16,
1078+
w1_bias,
1079+
w2_bias,
10801080
w1_scale,
10811081
w2_scale,
10821082
a1_scale,

lightllm/server/api_models.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ class ChatCompletionRequest(BaseModel):
101101
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
102102
default="auto", examples=["none"]
103103
) # noqa
104+
parallel_tool_calls: Optional[bool] = True
104105

105106
# Additional parameters supported by LightLLM
106107
do_sample: Optional[bool] = False
@@ -122,11 +123,36 @@ class FunctionResponse(BaseModel):
122123
class ToolCall(BaseModel):
123124
"""Tool call response."""
124125

125-
id: str
126+
id: Optional[str] = None
127+
index: Optional[int] = None
126128
type: Literal["function"] = "function"
127129
function: FunctionResponse
128130

129131

132+
class ChatCompletionMessageGenericParam(BaseModel):
133+
role: Literal["system", "assistant", "tool", "function"]
134+
content: Union[str, List[MessageContent], None] = Field(default=None)
135+
tool_call_id: Optional[str] = None
136+
name: Optional[str] = None
137+
reasoning_content: Optional[str] = None
138+
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
139+
140+
@field_validator("role", mode="before")
141+
@classmethod
142+
def _normalize_role(cls, v):
143+
if isinstance(v, str):
144+
v_lower = v.lower()
145+
if v_lower not in {"system", "assistant", "tool", "function"}:
146+
raise ValueError(
147+
"'role' must be one of 'system', 'assistant', 'tool', or 'function' (case-insensitive)."
148+
)
149+
return v_lower
150+
raise ValueError("'role' must be a string")
151+
152+
153+
ChatCompletionMessageParam = Union[ChatCompletionMessageGenericParam, Message]
154+
155+
130156
class UsageInfo(BaseModel):
131157
prompt_tokens: int = 0
132158
completion_tokens: Optional[int] = 0

lightllm/server/api_openai.py

Lines changed: 74 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pickle
1010
import uuid
1111

12-
from .function_call_parser import TOOLS_TAG_LIST, FunctionCallParser
12+
from .function_call_parser import TOOLS_TAG_LIST, FunctionCallParser, ToolCallItem
1313
from .build_prompt import build_prompt, init_tokenizer
1414

1515
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -63,6 +63,52 @@ def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse
6363
return JSONResponse({"message": message}, status_code=status_code.value)
6464

6565

66+
def _process_tool_call_id(
67+
tool_call_parser,
68+
call_item: ToolCallItem,
69+
history_tool_calls_cnt: int,
70+
) -> str:
71+
"""Process for generating a new and unique `tool_call_id`"""
72+
if tool_call_parser != "kimi_k2":
73+
# A simple uuid is sufficient for all models except for Kimi-K2.
74+
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
75+
return tool_call_id
76+
else:
77+
# Align with Kimi-K2 format: functions.{name}:{index}
78+
# Kimi-K2 allows multiple tool_calls in one message;
79+
# SGLang sets call_item.tool_index to the *local* position inside that message.
80+
# Therefore, the index must be corrected by using
81+
# `history_tool_calls_cnt + call_item.tool_index` to ensure globally unique and properly ordered.
82+
tool_call_id = f"functions.{call_item.name}:{history_tool_calls_cnt+call_item.tool_index}"
83+
logger.debug(
84+
f"Process tool call idx, parser: {tool_call_parser}, \
85+
tool_call_id: {tool_call_id}, \
86+
history_cnt: {history_tool_calls_cnt}"
87+
)
88+
return tool_call_id
89+
90+
91+
def _get_history_tool_calls_cnt(request: ChatCompletionRequest) -> int:
92+
"""Counts the number of tool calls in the request's message history.
93+
94+
NOTE: This method is only useful for models that include self-increasing
95+
history tool call idx in tool calls id, such as kimi-k2
96+
97+
Args:
98+
request: The chat completion request object.
99+
100+
Returns:
101+
The total number of tool calls in the history, or 0 if not applicable.
102+
"""
103+
messages = getattr(request, "messages", [])
104+
idx = 0
105+
for msg in messages:
106+
if msg.role == "assistant":
107+
tool_calls = getattr(msg, "tool_calls", None)
108+
idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa
109+
return idx
110+
111+
66112
async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Request) -> Response:
67113
from .api_http import g_objs
68114

@@ -180,26 +226,31 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req
180226

181227
if tool_choice != "none" and any([i in text for i in TOOLS_TAG_LIST]):
182228
if finish_reason == "stop":
183-
finish_reason = "function_call"
229+
finish_reason = "tool_calls"
184230
try:
185231
# 为 tool_call_parser 提供默认值
186232
tool_parser = getattr(g_objs.args, "tool_call_parser", None) or "llama3"
187233
parser = FunctionCallParser(tools, tool_parser)
188234
full_normal_text, call_info_list = parser.parse_non_stream(text)
189-
tool_calls = [
190-
ToolCall(
191-
id=str(call_info.tool_index),
192-
function=FunctionResponse(name=call_info.name, arguments=call_info.parameters),
235+
tool_calls = []
236+
history_tool_calls_cnt = _get_history_tool_calls_cnt(request)
237+
for call_info in call_info_list:
238+
tool_id = _process_tool_call_id(tool_parser, call_info, history_tool_calls_cnt)
239+
tool_calls.append(
240+
ToolCall(
241+
id=tool_id,
242+
index=getattr(call_info, "tool_index", None),
243+
function=FunctionResponse(name=call_info.name, arguments=call_info.parameters),
244+
)
193245
)
194-
for call_info in call_info_list
195-
]
196246
except Exception as e:
197247
logger.error(f"Exception: {e}")
198248
return create_error_response(
199249
HTTPStatus.BAD_REQUEST,
200250
"Failed to parse fc related info to json format!",
201251
)
202-
252+
if finish_reason == "tool_calls":
253+
text = ""
203254
chat_message = ChatMessage(role="assistant", content=text, tool_calls=tool_calls)
204255
choice = ChatCompletionResponseChoice(
205256
index=i,
@@ -261,6 +312,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
261312
yield f"data: {chunk.model_dump_json()}\n\n"
262313

263314
# 2) if we found calls, we output them as separate chunk(s)
315+
history_tool_calls_cnt = _get_history_tool_calls_cnt(request)
264316
for call_item in calls:
265317
# transform call_item -> FunctionResponse + ToolCall
266318
if finish_reason == "stop":
@@ -278,17 +330,27 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
278330
remaining_call = expected_call.replace(actual_call, "", 1)
279331
call_item.parameters = remaining_call
280332

333+
if call_item.name:
334+
# First chunk: include ID and function name
335+
tool_call_id = _process_tool_call_id(tool_parser, call_item, history_tool_calls_cnt)
336+
function_name = call_item.name
337+
else:
338+
# Subsequent chunks: null ID and name for argument deltas
339+
tool_call_id = None
340+
function_name = None
341+
281342
tool_call = ToolCall(
282-
id=str(call_item.tool_index),
343+
id=tool_call_id,
344+
index=getattr(call_item, "tool_index", None),
283345
function=FunctionResponse(
284-
name=call_item.name,
346+
name=function_name,
285347
arguments=call_item.parameters,
286348
),
287349
)
288350
choice_data = ChatCompletionStreamResponseChoice(
289351
index=0,
290352
delta=DeltaMessage(role="assistant", tool_calls=[tool_call]),
291-
finish_reason="function_call",
353+
finish_reason="tool_calls",
292354
)
293355
chunk = ChatCompletionStreamResponse(
294356
id=group_request_id,

lightllm/server/core/objs/start_args_type.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ class StartArgs:
3030
mem_fraction: float = field(default=0.9)
3131
batch_max_tokens: Optional[int] = field(default=None)
3232
eos_id: List[int] = field(default_factory=list)
33-
tool_call_parser: Optional[str] = field(default=None, metadata={"choices": ["llama3", "qwen25", "mistral"]})
33+
tool_call_parser: Optional[str] = field(
34+
default=None, metadata={"choices": ["llama3", "qwen25", "mistral", "deepseekv3", "kimi_k2", "qwen"]}
35+
)
3436
running_max_req_size: int = field(default=1000)
3537
tp: int = field(default=1)
3638
dp: int = field(default=1)

0 commit comments

Comments
 (0)