diff --git a/packages/responses-server/examples/streaming.js b/packages/responses-server/examples/streaming.js new file mode 100644 index 0000000000..2d342d67de --- /dev/null +++ b/packages/responses-server/examples/streaming.js @@ -0,0 +1,17 @@ +import { OpenAI } from "openai"; +const openai = new OpenAI({ baseURL: "http://localhost:3000/v1", apiKey: process.env.HF_TOKEN }); + +const stream = await openai.responses.create({ + model: "Qwen/Qwen2.5-VL-7B-Instruct", + input: [ + { + role: "user", + content: "Say 'double bubble bath' ten times fast.", + }, + ], + stream: true, +}); + +for await (const event of stream) { + console.log(event); +} diff --git a/packages/responses-server/src/routes/responses.ts b/packages/responses-server/src/routes/responses.ts index 7350b90bf0..ce3181545f 100644 --- a/packages/responses-server/src/routes/responses.ts +++ b/packages/responses-server/src/routes/responses.ts @@ -5,7 +5,12 @@ import { generateUniqueId } from "../lib/generateUniqueId.js"; import { InferenceClient } from "@huggingface/inference"; import type { ChatCompletionInputMessage, ChatCompletionInputMessageChunkType } from "@huggingface/tasks"; -import { type Response as OpenAIResponse } from "openai/resources/responses/responses"; +import type { + Response, + ResponseStreamEvent, + ResponseOutputItem, + ResponseContentPartAddedEvent, +} from "openai/resources/responses/responses"; export const postCreateResponse = async ( req: ValidatedRequest, @@ -33,27 +38,189 @@ export const postCreateResponse = async ( content: typeof item.content === "string" ? item.content - : item.content.map((content) => { - if (content.type === "input_image") { - return { - type: "image_url" as ChatCompletionInputMessageChunkType, - image_url: { - url: content.image_url, - }, - }; - } - // content.type must be "input_text" at this point - return { - type: "text" as ChatCompletionInputMessageChunkType, - text: content.text, - }; - }), + : item.content + .map((content) => { + switch (content.type) { + case "input_image": + return { + type: "image_url" as ChatCompletionInputMessageChunkType, + image_url: { + url: content.image_url, + }, + }; + case "output_text": + return { + type: "text" as ChatCompletionInputMessageChunkType, + text: content.text, + }; + case "refusal": + return undefined; + case "input_text": + return { + type: "text" as ChatCompletionInputMessageChunkType, + text: content.text, + }; + } + }) + .filter((item) => item !== undefined), })) ); } else { messages.push({ role: "user", content: req.body.input }); } + const payload = { + model: req.body.model, + messages: messages, + temperature: req.body.temperature, + top_p: req.body.top_p, + stream: req.body.stream, + }; + + const responseObject: Omit< + Response, + "incomplete_details" | "metadata" | "output_text" | "parallel_tool_calls" | "tool_choice" | "tools" + > = { + object: "response", + id: generateUniqueId("resp"), + status: "in_progress", + error: null, + instructions: req.body.instructions, + model: req.body.model, + temperature: req.body.temperature, + top_p: req.body.top_p, + created_at: new Date().getTime(), + output: [], + }; + + if (req.body.stream) { + res.setHeader("Content-Type", "text/event-stream"); + res.setHeader("Connection", "keep-alive"); + let sequenceNumber = 0; + + // Emit events in sequence + const emitEvent = (event: ResponseStreamEvent) => { + res.write(`data: ${JSON.stringify(event)}\n\n`); + }; + + try { + // Response created event + emitEvent({ + type: "response.created", + response: responseObject as Response, + sequence_number: sequenceNumber++, + }); + + // Response in progress event + emitEvent({ + type: "response.in_progress", + response: responseObject as Response, + sequence_number: sequenceNumber++, + }); + + const stream = client.chatCompletionStream(payload); + + const outputObject: ResponseOutputItem = { + id: generateUniqueId("msg"), + type: "message", + role: "assistant", + status: "in_progress", + content: [], + }; + responseObject.output = [outputObject]; + + // Response output item added event + emitEvent({ + type: "response.output_item.added", + output_index: 0, + item: outputObject, + sequence_number: sequenceNumber++, + }); + + // Response content part added event + const contentPart: ResponseContentPartAddedEvent["part"] = { + type: "output_text", + text: "", + annotations: [], + }; + outputObject.content.push(contentPart); + + emitEvent({ + type: "response.content_part.added", + item_id: outputObject.id, + output_index: 0, + content_index: 0, + part: contentPart, + sequence_number: sequenceNumber++, + }); + + for await (const chunk of stream) { + if (chunk.choices[0].delta.content) { + contentPart.text += chunk.choices[0].delta.content; + + // Response output text delta event + emitEvent({ + type: "response.output_text.delta", + item_id: outputObject.id, + output_index: 0, + content_index: 0, + delta: chunk.choices[0].delta.content, + sequence_number: sequenceNumber++, + }); + } + } + + // Response output text done event + emitEvent({ + type: "response.output_text.done", + item_id: outputObject.id, + output_index: 0, + content_index: 0, + text: contentPart.text, + sequence_number: sequenceNumber++, + }); + + // Response content part done event + emitEvent({ + type: "response.content_part.done", + item_id: outputObject.id, + output_index: 0, + content_index: 0, + part: contentPart, + sequence_number: sequenceNumber++, + }); + + // Response output item done event + outputObject.status = "completed"; + emitEvent({ + type: "response.output_item.done", + output_index: 0, + item: outputObject, + sequence_number: sequenceNumber++, + }); + + // Response completed event + responseObject.status = "completed"; + emitEvent({ + type: "response.completed", + response: responseObject as Response, + sequence_number: sequenceNumber++, + }); + } catch (streamError: any) { + console.error("Error in streaming chat completion:", streamError); + + emitEvent({ + type: "error", + code: null, + message: streamError.message || "An error occurred while streaming from inference server.", + param: null, + sequence_number: sequenceNumber++, + }); + } + res.end(); + return; + } + try { const chatCompletionResponse = await client.chatCompletion({ model: req.body.model, @@ -62,37 +229,24 @@ export const postCreateResponse = async ( top_p: req.body.top_p, }); - const responseObject: Omit< - OpenAIResponse, - "incomplete_details" | "metadata" | "output_text" | "parallel_tool_calls" | "tool_choice" | "tools" - > = { - object: "response", - id: generateUniqueId("resp"), - status: "completed", - error: null, - instructions: req.body.instructions, - model: req.body.model, - temperature: req.body.temperature, - top_p: req.body.top_p, - created_at: chatCompletionResponse.created, - output: chatCompletionResponse.choices[0].message.content - ? [ - { - id: generateUniqueId("msg"), - type: "message", - role: "assistant", - status: "completed", - content: [ - { - type: "output_text", - text: chatCompletionResponse.choices[0].message.content, - annotations: [], - }, - ], - }, - ] - : [], - }; + responseObject.status = "completed"; + responseObject.output = chatCompletionResponse.choices[0].message.content + ? [ + { + id: generateUniqueId("msg"), + type: "message", + role: "assistant", + status: "completed", + content: [ + { + type: "output_text", + text: chatCompletionResponse.choices[0].message.content, + annotations: [], + }, + ], + }, + ] + : []; res.json(responseObject); } catch (error) { diff --git a/packages/responses-server/src/schemas.ts b/packages/responses-server/src/schemas.ts index 4e47301aec..c1c8509257 100644 --- a/packages/responses-server/src/schemas.ts +++ b/packages/responses-server/src/schemas.ts @@ -4,65 +4,91 @@ import { z } from "zod"; * https://platform.openai.com/docs/api-reference/responses/create * commented out properties are not supported by the server */ + +const inputContentSchema = z.array( + z.union([ + z.object({ + type: z.literal("input_text"), + text: z.string(), + }), + z.object({ + type: z.literal("input_image"), + // file_id: z.string().nullable().default(null), + image_url: z.string(), + // detail: z.enum(["auto", "low", "high"]).default("auto"), + }), + // z.object({ + // type: z.literal("input_file"), + // file_data: z.string().nullable().default(null), + // file_id: z.string().nullable().default(null), + // filename: z.string().nullable().default(null), + // }), + ]) +); + export const createResponseParamsSchema = z.object({ // background: z.boolean().default(false), // include: input: z.union([ z.string(), z.array( - // z.union([ - z.object({ - content: z.union([ - z.string(), - z.array( + z.union([ + z.object({ + content: z.union([z.string(), inputContentSchema]), + role: z.enum(["user", "assistant", "system", "developer"]), + type: z.enum(["message"]).default("message"), + }), + z.object({ + role: z.enum(["user", "system", "developer"]), + status: z.enum(["in_progress", "completed", "incomplete"]).nullable().default(null), + content: inputContentSchema, + type: z.enum(["message"]).default("message"), + }), + z.object({ + id: z.string().optional(), + role: z.enum(["assistant"]), + status: z.enum(["in_progress", "completed", "incomplete"]).optional(), + type: z.enum(["message"]).default("message"), + content: z.array( z.union([ z.object({ - type: z.literal("input_text"), + type: z.literal("output_text"), text: z.string(), + annotations: z.array(z.object({})).optional(), // TODO: incomplete + logprobs: z.array(z.object({})).optional(), // TODO: incomplete }), z.object({ - type: z.literal("input_image"), - // file_id: z.string().nullable(), - image_url: z.string(), - // detail: z.enum(["auto", "low", "high"]).default("auto"), + type: z.literal("refusal"), + refusal: z.string(), }), - // z.object({ - // type: z.literal("input_file"), - // file_data: z.string().nullable(), - // file_id: z.string().nullable(), - // filename: z.string().nullable(), - // }), + // TODO: much more objects: File search tool call, Computer tool call, Computer tool call output, Web search tool call, Function tool call, Function tool call output, Reasoning, Image generation call, Code interpreter tool call, Local shell call, Local shell call output, MCP list tools, MCP approval request, MCP approval response, MCP tool call ]) ), - ]), - role: z.enum(["user", "assistant", "system", "developer"]), - type: z.enum(["message"]).default("message"), - }) - // z.object({}), // An item representing part of the context for the response to be generated by the model - // z.object({ - // id: z.string(), - // type: z.enum(["item_reference"]).default("item_reference"), - // }), - // ]) + }), + // z.object({ + // id: z.string(), + // type: z.enum(["item_reference"]).default("item_reference"), + // }), + ]) ), ]), - instructions: z.string().nullable(), - // max_output_tokens: z.number().min(0).nullable(), - // max_tool_calls: z.number().min(0).nullable(), - // metadata: z.record(z.string().max(64), z.string().max(512)).nullable(), // + 16 items max + instructions: z.string().nullable().default(null), + // max_output_tokens: z.number().min(0).nullable().default(null), + // max_tool_calls: z.number().min(0).nullable().default(null), + // metadata: z.record(z.string().max(64), z.string().max(512)).nullable().default(null), // + 16 items max model: z.string(), - // previous_response_id: z.string().nullable(), + // previous_response_id: z.string().nullable().default(null), // reasoning: z.object({ // effort: z.enum(["low", "medium", "high"]).default("medium"), - // summary: z.enum(["auto", "concise", "detailed"]).nullable(), + // summary: z.enum(["auto", "concise", "detailed"]).nullable().default(null), // }), // store: z.boolean().default(true), - // stream: z.boolean().default(false), + stream: z.boolean().default(false), temperature: z.number().min(0).max(2).default(1), // text: // tool_choice: // tools: - // top_logprobs: z.number().min(0).max(20).nullable(), + // top_logprobs: z.number().min(0).max(20).nullable().default(null), top_p: z.number().min(0).max(1).default(1), // truncation: z.enum(["auto", "disabled"]).default("disabled"), // user