diff --git a/packages/inference/README.md b/packages/inference/README.md index 0ea60b2be7..0690431c7d 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -63,6 +63,7 @@ Currently, we support the following providers: - [Cohere](https://cohere.com) - [Cerebras](https://cerebras.ai/) - [Groq](https://groq.com) +- [Wavespeed.ai](https://wavespeed.ai/) To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. The default value of the `provider` parameter is "auto", which will select the first of the providers available for the model, sorted by your preferred order in https://hf.co/settings/inference-providers. @@ -97,6 +98,7 @@ Only a subset of models are supported when requesting third-party providers. You - [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models) - [Groq supported models](https://console.groq.com/docs/models) - [Novita AI supported models](https://huggingface.co/api/partners/novita/models) +- [Wavespeed.ai supported models](https://huggingface.co/api/partners/wavespeed-ai/models) ❗**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type. This is not an issue for LLMs as everyone converged on the OpenAI API anyways, but can be more tricky for other tasks like "text-to-image" or "automatic-speech-recognition" where there exists no standard API. Let us know if any help is needed or if we can make things easier for you! diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index 06e692aa72..177ecf3230 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -48,6 +48,7 @@ import type { import * as Replicate from "../providers/replicate.js"; import * as Sambanova from "../providers/sambanova.js"; import * as Together from "../providers/together.js"; +import * as WavespeedAI from "../providers/wavespeed-ai.js"; import type { InferenceProvider, InferenceProviderOrPolicy, InferenceTask } from "../types.js"; import { InferenceClientInputError } from "../errors.js"; @@ -151,6 +152,11 @@ export const PROVIDERS: Record): Record { + const payload: Record = { + ...omit(params.args, ["inputs", "parameters"]), + ...params.args.parameters, + prompt: params.args.inputs, + }; + // Add LoRA support if adapter is specified in the mapping + if (params.mapping?.adapter === "lora") { + payload.loras = [ + { + path: params.mapping.hfModelId, + scale: 1, // Default scale value + }, + ]; + } + return payload; + } + + override async getResponse( + response: WaveSpeedAISubmitTaskResponse, + url?: string, + headers?: Record + ): Promise { + if (!headers) { + throw new InferenceClientInputError("Headers are required for WaveSpeed AI API calls"); + } + + const resultUrl = response.data.urls.get; + + // Poll for results until completion + while (true) { + const resultResponse = await fetch(resultUrl, { headers }); + + if (!resultResponse.ok) { + throw new InferenceClientProviderApiError( + "Failed to fetch response status from WaveSpeed AI API", + { url: resultUrl, method: "GET" }, + { + requestId: resultResponse.headers.get("x-request-id") ?? "", + status: resultResponse.status, + body: await resultResponse.text(), + } + ); + } + + const result: WaveSpeedAIResponse = await resultResponse.json(); + const taskResult = result.data; + + switch (taskResult.status) { + case "completed": { + // Get the media data from the first output URL + if (!taskResult.outputs?.[0]) { + throw new InferenceClientProviderOutputError( + "Received malformed response from WaveSpeed AI API: No output URL in completed response" + ); + } + const mediaResponse = await fetch(taskResult.outputs[0]); + if (!mediaResponse.ok) { + throw new InferenceClientProviderApiError( + "Failed to fetch generation output from WaveSpeed AI API", + { url: taskResult.outputs[0], method: "GET" }, + { + requestId: mediaResponse.headers.get("x-request-id") ?? "", + status: mediaResponse.status, + body: await mediaResponse.text(), + } + ); + } + return await mediaResponse.blob(); + } + case "failed": { + throw new InferenceClientProviderOutputError(taskResult.error || "Task failed"); + } + + default: { + // Wait before polling again + await delay(500); + continue; + } + } + } + } +} + +export class WavespeedAITextToImageTask extends WavespeedAITask implements TextToImageTaskHelper { + constructor() { + super(WAVESPEEDAI_API_BASE_URL); + } +} + +export class WavespeedAITextToVideoTask extends WavespeedAITask implements TextToVideoTaskHelper { + constructor() { + super(WAVESPEEDAI_API_BASE_URL); + } +} + +export class WavespeedAIImageToImageTask extends WavespeedAITask implements ImageToImageTaskHelper { + constructor() { + super(WAVESPEEDAI_API_BASE_URL); + } + + async preparePayloadAsync(args: ImageToImageArgs): Promise { + return { + ...args, + inputs: args.parameters?.prompt, + image: base64FromBytes( + new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await (args.inputs as Blob).arrayBuffer()) + ), + }; + } +} diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts index f48e9a011c..9c902fd6e8 100644 --- a/packages/inference/src/types.ts +++ b/packages/inference/src/types.ts @@ -54,6 +54,7 @@ export const INFERENCE_PROVIDERS = [ "replicate", "sambanova", "together", + "wavespeed-ai", ] as const; export const PROVIDERS_OR_POLICIES = [...INFERENCE_PROVIDERS, "auto"] as const; diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index e4b358b26b..0a2e8bf472 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -2107,4 +2107,116 @@ describe.skip("InferenceClient", () => { }, TIMEOUT ); + describe.concurrent( + "Wavespeed AI", + () => { + const client = new InferenceClient(env.HF_WAVESPEED_KEY ?? "dummy"); + + HARDCODED_MODEL_INFERENCE_MAPPING["wavespeed-ai"] = { + "black-forest-labs/FLUX.1-schnell": { + provider: "wavespeed-ai", + hfModelId: "wavespeed-ai/flux-schnell", + providerId: "wavespeed-ai/flux-schnell", + status: "live", + task: "text-to-image", + }, + "Wan-AI/Wan2.1-T2V-14B": { + provider: "wavespeed-ai", + hfModelId: "wavespeed-ai/wan-2.1/t2v-480p", + providerId: "wavespeed-ai/wan-2.1/t2v-480p", + status: "live", + task: "text-to-video", + }, + "HiDream-ai/HiDream-E1-Full": { + provider: "wavespeed-ai", + hfModelId: "wavespeed-ai/hidream-e1-full", + providerId: "wavespeed-ai/hidream-e1-full", + status: "live", + task: "image-to-image", + }, + "openfree/flux-chatgpt-ghibli-lora": { + provider: "wavespeed-ai", + hfModelId: "openfree/flux-chatgpt-ghibli-lora", + providerId: "wavespeed-ai/flux-dev-lora", + status: "live", + task: "text-to-image", + adapter: "lora", + adapterWeightsPath: "openfree/flux-chatgpt-ghibli-lora", + }, + "linoyts/yarn_art_Flux_LoRA": { + provider: "wavespeed-ai", + hfModelId: "linoyts/yarn_art_Flux_LoRA", + providerId: "wavespeed-ai/flux-dev-lora-ultra-fast", + status: "live", + task: "text-to-image", + adapter: "lora", + adapterWeightsPath: "linoyts/yarn_art_Flux_LoRA", + }, + }; + it(`textToImage - black-forest-labs/FLUX.1-schnell`, async () => { + const res = await client.textToImage({ + model: "black-forest-labs/FLUX.1-schnell", + provider: "wavespeed-ai", + inputs: + "Cute boy with a hat, exploring nature, holding a telescope, backpack, surrounded by flowers, cartoon style, vibrant colors.", + }); + expect(res).toBeInstanceOf(Blob); + }); + + it(`textToImage - openfree/flux-chatgpt-ghibli-lora`, async () => { + const res = await client.textToImage({ + model: "openfree/flux-chatgpt-ghibli-lora", + provider: "wavespeed-ai", + inputs: + "Cute boy with a hat, exploring nature, holding a telescope, backpack, surrounded by flowers, cartoon style, vibrant colors.", + }); + expect(res).toBeInstanceOf(Blob); + }); + + it(`textToImage - linoyts/yarn_art_Flux_LoRA`, async () => { + const res = await client.textToImage({ + model: "linoyts/yarn_art_Flux_LoRA", + provider: "wavespeed-ai", + inputs: + "Cute boy with a hat, exploring nature, holding a telescope, backpack, surrounded by flowers, cartoon style, vibrant colors.", + }); + expect(res).toBeInstanceOf(Blob); + }); + + it(`textToVideo - Wan-AI/Wan2.1-T2V-14B`, async () => { + const res = await client.textToVideo({ + model: "Wan-AI/Wan2.1-T2V-14B", + provider: "wavespeed-ai", + inputs: + "A cool street dancer, wearing a baggy hoodie and hip-hop pants, dancing in front of a graffiti wall, night neon background, quick camera cuts, urban trends.", + parameters: { + guidance_scale: 5, + num_inference_steps: 30, + seed: -1, + }, + duration: 5, + enable_safety_checker: true, + flow_shift: 2.9, + size: "480*832", + }); + expect(res).toBeInstanceOf(Blob); + }); + + it(`imageToImage - HiDream-ai/HiDream-E1-Full`, async () => { + const res = await client.imageToImage({ + model: "HiDream-ai/HiDream-E1-Full", + provider: "wavespeed-ai", + inputs: new Blob([readTestFile("cheetah.png")], { type: "image / png" }), + parameters: { + prompt: "The leopard chases its prey", + guidance_scale: 5, + num_inference_steps: 30, + seed: -1, + }, + }); + expect(res).toBeInstanceOf(Blob); + }); + }, + 60000 * 5 + ); });