diff --git a/packages/responses-server/examples/function.js b/packages/responses-server/examples/function.js new file mode 100644 index 0000000000..26893d5449 --- /dev/null +++ b/packages/responses-server/examples/function.js @@ -0,0 +1,32 @@ +import OpenAI from "openai"; + +const openai = new OpenAI({ baseURL: "http://localhost:3000/v1", apiKey: process.env.HF_TOKEN }); + +const tools = [ + { + type: "function", + name: "get_current_weather", + description: "Get the current weather in a given location", + parameters: { + type: "object", + properties: { + location: { + type: "string", + description: "The city and state, e.g. San Francisco, CA", + }, + unit: { type: "string", enum: ["celsius", "fahrenheit"] }, + }, + required: ["location", "unit"], + }, + }, +]; + +const response = await openai.responses.create({ + model: "meta-llama/Llama-3.3-70B-Instruct", + provider: "cerebras", + tools: tools, + input: "What is the weather like in Boston today?", + tool_choice: "auto", +}); + +console.log(response); diff --git a/packages/responses-server/examples/function_streaming.js b/packages/responses-server/examples/function_streaming.js new file mode 100644 index 0000000000..3c6d557ef0 --- /dev/null +++ b/packages/responses-server/examples/function_streaming.js @@ -0,0 +1,33 @@ +import { OpenAI } from "openai"; + +const openai = new OpenAI({ baseURL: "http://localhost:3000/v1", apiKey: process.env.HF_TOKEN }); + +const tools = [ + { + type: "function", + name: "get_weather", + description: "Get current temperature for provided coordinates in celsius.", + parameters: { + type: "object", + properties: { + latitude: { type: "number" }, + longitude: { type: "number" }, + }, + required: ["latitude", "longitude"], + additionalProperties: false, + }, + strict: true, + }, +]; + +const stream = await openai.responses.create({ + model: "meta-llama/Llama-3.3-70B-Instruct", + provider: "cerebras", + input: [{ role: "user", content: "What's the weather like in Paris today?" }], + tools, + stream: true, +}); + +for await (const event of stream) { + console.log(event); +} diff --git a/packages/responses-server/src/routes/responses.ts b/packages/responses-server/src/routes/responses.ts index 663383df25..bccc4146e2 100644 --- a/packages/responses-server/src/routes/responses.ts +++ b/packages/responses-server/src/routes/responses.ts @@ -12,10 +12,18 @@ import type { import type { Response, ResponseStreamEvent, - ResponseOutputItem, ResponseContentPartAddedEvent, + ResponseOutputMessage, + ResponseFunctionToolCall, } from "openai/resources/responses/responses"; +class StreamingError extends Error { + constructor(message: string) { + super(message); + this.name = "StreamingError"; + } +} + export const postCreateResponse = async ( req: ValidatedRequest, res: ExpressResponse @@ -74,13 +82,13 @@ export const postCreateResponse = async ( } const payload: ChatCompletionInput = { + // main params model: req.body.model, provider: req.body.provider, messages: messages, - max_tokens: req.body.max_output_tokens === null ? undefined : req.body.max_output_tokens, - temperature: req.body.temperature, - top_p: req.body.top_p, stream: req.body.stream, + // options + max_tokens: req.body.max_output_tokens === null ? undefined : req.body.max_output_tokens, response_format: req.body.text?.format ? { type: req.body.text.format.type, @@ -95,12 +103,33 @@ export const postCreateResponse = async ( : undefined, } : undefined, + temperature: req.body.temperature, + tool_choice: + typeof req.body.tool_choice === "string" + ? req.body.tool_choice + : req.body.tool_choice + ? { + type: "function", + function: { + name: req.body.tool_choice.name, + }, + } + : undefined, + tools: req.body.tools + ? req.body.tools.map((tool) => ({ + type: tool.type, + function: { + name: tool.name, + parameters: tool.parameters, + description: tool.description, + strict: tool.strict, + }, + })) + : undefined, + top_p: req.body.top_p, }; - const responseObject: Omit< - Response, - "incomplete_details" | "output_text" | "parallel_tool_calls" | "tool_choice" | "tools" - > = { + const responseObject: Omit = { created_at: new Date().getTime(), error: null, id: generateUniqueId("resp"), @@ -110,7 +139,11 @@ export const postCreateResponse = async ( model: req.body.model, object: "response", output: [], + // parallel_tool_calls: req.body.parallel_tool_calls, status: "in_progress", + text: req.body.text, + tool_choice: req.body.tool_choice ?? "auto", + tools: req.body.tools ?? [], temperature: req.body.temperature, top_p: req.body.top_p, }; @@ -142,45 +175,62 @@ export const postCreateResponse = async ( const stream = client.chatCompletionStream(payload); - const outputObject: ResponseOutputItem = { - id: generateUniqueId("msg"), - type: "message", - role: "assistant", - status: "in_progress", - content: [], - }; - responseObject.output = [outputObject]; + for await (const chunk of stream) { + if (chunk.choices[0].delta.content) { + if (responseObject.output.length === 0) { + const outputObject: ResponseOutputMessage = { + id: generateUniqueId("msg"), + type: "message", + role: "assistant", + status: "in_progress", + content: [], + }; + responseObject.output = [outputObject]; - // Response output item added event - emitEvent({ - type: "response.output_item.added", - output_index: 0, - item: outputObject, - sequence_number: sequenceNumber++, - }); + // Response output item added event + emitEvent({ + type: "response.output_item.added", + output_index: 0, + item: outputObject, + sequence_number: sequenceNumber++, + }); + } - // Response content part added event - const contentPart: ResponseContentPartAddedEvent["part"] = { - type: "output_text", - text: "", - annotations: [], - }; - outputObject.content.push(contentPart); + const outputObject = responseObject.output.at(-1); + if (!outputObject || outputObject.type !== "message") { + throw new StreamingError("Not implemented: only single output item type is supported in streaming mode."); + } - emitEvent({ - type: "response.content_part.added", - item_id: outputObject.id, - output_index: 0, - content_index: 0, - part: contentPart, - sequence_number: sequenceNumber++, - }); + if (outputObject.content.length === 0) { + // Response content part added event + const contentPart: ResponseContentPartAddedEvent["part"] = { + type: "output_text", + text: "", + annotations: [], + }; + outputObject.content.push(contentPart); - for await (const chunk of stream) { - if (chunk.choices[0].delta.content) { - contentPart.text += chunk.choices[0].delta.content; + emitEvent({ + type: "response.content_part.added", + item_id: outputObject.id, + output_index: 0, + content_index: 0, + part: contentPart, + sequence_number: sequenceNumber++, + }); + } + + const contentPart = outputObject.content.at(-1); + if (!contentPart || contentPart.type !== "output_text") { + throw new StreamingError("Not implemented: only output_text is supported in streaming mode."); + } + + if (contentPart.type !== "output_text") { + throw new StreamingError("Not implemented: only output_text is supported in streaming mode."); + } - // Response output text delta event + // Add text delta + contentPart.text += chunk.choices[0].delta.content; emitEvent({ type: "response.output_text.delta", item_id: outputObject.id, @@ -189,37 +239,109 @@ export const postCreateResponse = async ( delta: chunk.choices[0].delta.content, sequence_number: sequenceNumber++, }); + } else if (chunk.choices[0].delta.tool_calls) { + if (chunk.choices[0].delta.tool_calls.length > 1) { + throw new StreamingError("Not implemented: only single tool call is supported in streaming mode."); + } + + if (responseObject.output.length === 0) { + if (!chunk.choices[0].delta.tool_calls[0].function.name) { + throw new StreamingError("Tool call function name is required."); + } + + const outputObject: ResponseFunctionToolCall = { + type: "function_call", + id: generateUniqueId("fc"), + call_id: chunk.choices[0].delta.tool_calls[0].id, + name: chunk.choices[0].delta.tool_calls[0].function.name, + arguments: "", + }; + responseObject.output = [outputObject]; + + // Response output item added event + emitEvent({ + type: "response.output_item.added", + output_index: 0, + item: outputObject, + sequence_number: sequenceNumber++, + }); + } + + const outputObject = responseObject.output.at(-1); + if (!outputObject || !outputObject.id || outputObject.type !== "function_call") { + throw new StreamingError("Not implemented: can only support single output item type in streaming mode."); + } + + outputObject.arguments += chunk.choices[0].delta.tool_calls[0].function.arguments; + emitEvent({ + type: "response.function_call_arguments.delta", + item_id: outputObject.id, + output_index: 0, + delta: chunk.choices[0].delta.tool_calls[0].function.arguments, + sequence_number: sequenceNumber++, + }); } } - // Response output text done event - emitEvent({ - type: "response.output_text.done", - item_id: outputObject.id, - output_index: 0, - content_index: 0, - text: contentPart.text, - sequence_number: sequenceNumber++, - }); + const lastOutputItem = responseObject.output.at(-1); - // Response content part done event - emitEvent({ - type: "response.content_part.done", - item_id: outputObject.id, - output_index: 0, - content_index: 0, - part: contentPart, - sequence_number: sequenceNumber++, - }); + if (lastOutputItem) { + if (lastOutputItem?.type === "message") { + const contentPart = lastOutputItem.content.at(-1); + if (contentPart?.type === "output_text") { + emitEvent({ + type: "response.output_text.done", + item_id: lastOutputItem.id, + output_index: responseObject.output.length - 1, + content_index: lastOutputItem.content.length - 1, + text: contentPart.text, + sequence_number: sequenceNumber++, + }); - // Response output item done event - outputObject.status = "completed"; - emitEvent({ - type: "response.output_item.done", - output_index: 0, - item: outputObject, - sequence_number: sequenceNumber++, - }); + emitEvent({ + type: "response.content_part.done", + item_id: lastOutputItem.id, + output_index: responseObject.output.length - 1, + content_index: lastOutputItem.content.length - 1, + part: contentPart, + sequence_number: sequenceNumber++, + }); + } else { + throw new StreamingError("Not implemented: only output_text is supported in streaming mode."); + } + + // Response output item done event + lastOutputItem.status = "completed"; + emitEvent({ + type: "response.output_item.done", + output_index: responseObject.output.length - 1, + item: lastOutputItem, + sequence_number: sequenceNumber++, + }); + } else if (lastOutputItem?.type === "function_call") { + if (!lastOutputItem.id) { + throw new StreamingError("Function call id is required."); + } + + emitEvent({ + type: "response.function_call_arguments.done", + item_id: lastOutputItem.id, + output_index: responseObject.output.length - 1, + arguments: lastOutputItem.arguments, + sequence_number: sequenceNumber++, + }); + + lastOutputItem.status = "completed"; + emitEvent({ + type: "response.output_item.done", + output_index: responseObject.output.length - 1, + item: lastOutputItem, + sequence_number: sequenceNumber++, + }); + } else { + throw new StreamingError("Not implemented: only message output is supported in streaming mode."); + } + } // Response completed event responseObject.status = "completed"; @@ -228,13 +350,25 @@ export const postCreateResponse = async ( response: responseObject as Response, sequence_number: sequenceNumber++, }); - } catch (streamError: any) { + } catch (streamError) { console.error("Error in streaming chat completion:", streamError); + let message = "An error occurred while streaming from inference server."; + if (streamError instanceof StreamingError) { + message = streamError.message; + } else if ( + typeof streamError === "object" && + streamError && + "message" in streamError && + typeof streamError.message === "string" + ) { + message = streamError.message; + } + emitEvent({ type: "error", code: null, - message: streamError.message || "An error occurred while streaming from inference server.", + message, param: null, sequence_number: sequenceNumber++, }); @@ -263,7 +397,16 @@ export const postCreateResponse = async ( ], }, ] - : []; + : chatCompletionResponse.choices[0].message.tool_calls + ? chatCompletionResponse.choices[0].message.tool_calls.map((toolCall) => ({ + type: "function_call", + id: generateUniqueId("fc"), + call_id: toolCall.id, + name: toolCall.function.name, + arguments: toolCall.function.arguments, + status: "completed", + })) + : []; res.json(responseObject); } catch (error) { diff --git a/packages/responses-server/src/schemas.ts b/packages/responses-server/src/schemas.ts index 65b437c671..427716c10e 100644 --- a/packages/responses-server/src/schemas.ts +++ b/packages/responses-server/src/schemas.ts @@ -83,6 +83,7 @@ export const createResponseParamsSchema = z.object({ .nullable() .default(null), model: z.string(), + // parallel_tool_calls: z.boolean().default(true), // TODO: how to handle this if chat completion doesn't? provider: z.string().optional(), // previous_response_id: z.string().nullable().default(null), // reasoning: z.object({ @@ -114,8 +115,27 @@ export const createResponseParamsSchema = z.object({ ]), }) .optional(), - // tool_choice: - // tools: + tool_choice: z + .union([ + z.enum(["auto", "none", "required"]), + z.object({ + type: z.enum(["function"]), + name: z.string(), + }), + // TODO: also hosted tool and MCP tool + ]) + .optional(), + tools: z + .array( + z.object({ + name: z.string(), + parameters: z.record(z.any()), + strict: z.boolean().default(true), + type: z.enum(["function"]), + description: z.string().optional(), + }) + ) + .optional(), // top_logprobs: z.number().min(0).max(20).nullable().default(null), top_p: z.number().min(0).max(1).default(1), // truncation: z.enum(["auto", "disabled"]).default("disabled"),