From 7b0797f19f68e9dbebabc10a19421fc497fad93c Mon Sep 17 00:00:00 2001 From: Arseny Yankovski Date: Wed, 14 May 2025 17:29:13 +0200 Subject: [PATCH 1/4] added dat1.co as provider --- packages/inference/README.md | 2 + .../inference/src/lib/getProviderHelper.ts | 6 + packages/inference/src/providers/dat1.ts | 118 ++++++++++++++++++ packages/inference/src/types.ts | 1 + 4 files changed, 127 insertions(+) create mode 100644 packages/inference/src/providers/dat1.ts diff --git a/packages/inference/README.md b/packages/inference/README.md index 55cff9429c..c7370c11d6 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -62,6 +62,7 @@ Currently, we support the following providers: - [Cohere](https://cohere.com) - [Cerebras](https://cerebras.ai/) - [Groq](https://groq.com) +- [Dat1](https://dat1.co) To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token. ```ts @@ -91,6 +92,7 @@ Only a subset of models are supported when requesting third-party providers. You - [Together supported models](https://huggingface.co/api/partners/together/models) - [Cohere supported models](https://huggingface.co/api/partners/cohere/models) - [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models) +- [Dat1 supported models](https://huggingface.co/api/partners/dat1/models) - [Groq supported models](https://console.groq.com/docs/models) - [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending) diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index 4e9e3ddbe5..8870d4bfcd 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -47,6 +47,7 @@ import type { import * as Replicate from "../providers/replicate"; import * as Sambanova from "../providers/sambanova"; import * as Together from "../providers/together"; +import * as Dat1 from "../providers/dat1"; import type { InferenceProvider, InferenceTask } from "../types"; export const PROVIDERS: Record>> = { @@ -59,6 +60,11 @@ export const PROVIDERS: Record Together model ID here: + * + * https://huggingface.co/api/partners/together/models + * + * This is a publicly available mapping. + * + * If you want to try to run inference for a new model locally before it's registered on huggingface.co, + * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes. + * + * - If you work at Together and want to update this mapping, please use the model mapping API we provide on huggingface.co + * - If you're a community member and want to add a new supported HF model to Together, please open an issue on the present repo + * and we will tag Together team members. + * + * Thanks! + */ +import type { ChatCompletionOutput, TextGenerationOutput, TextGenerationOutputFinishReason } from "@huggingface/tasks"; +import { InferenceOutputError } from "../lib/InferenceOutputError"; +import type { BodyParams } from "../types"; +import { omit } from "../utils/omit"; +import { + BaseConversationalTask, + BaseTextGenerationTask, + TaskProviderHelper, + type TextToImageTaskHelper, +} from "./providerHelper"; + +const DAT1_API_BASE_URL = "https://api.dat1.co/api/v1/hf"; + +interface Dat1TextCompletionOutput extends Omit { + choices: Array<{ + text: string; + finish_reason: TextGenerationOutputFinishReason; + seed: number; + logprobs: unknown; + index: number; + }>; +} + +interface Dat1Base64ImageGeneration { + data: Array<{ + b64_json: string; + }>; +} + +export class Dat1ConversationalTask extends BaseConversationalTask { + constructor() { + super("dat1", DAT1_API_BASE_URL); + } +} + +export class Dat1TextGenerationTask extends BaseTextGenerationTask { + constructor() { + super("dat1", DAT1_API_BASE_URL); + } + + override preparePayload(params: BodyParams): Record { + return { + model: params.model, + ...params.args, + prompt: params.args.inputs, + }; + } + + override async getResponse(response: Dat1TextCompletionOutput): Promise { + if ( + typeof response === "object" && + "choices" in response && + Array.isArray(response?.choices) && + typeof response?.model === "string" + ) { + const completion = response.choices[0]; + return { + generated_text: completion.text, + }; + } + throw new InferenceOutputError("Expected Dat1 text generation response format"); + } +} + +export class Dat1TextToImageTask extends TaskProviderHelper implements TextToImageTaskHelper { + constructor() { + super("dat1", DAT1_API_BASE_URL); + } + + makeRoute(): string { + return "v1/images/generations"; + } + + preparePayload(params: BodyParams): Record { + return { + ...omit(params.args, ["inputs", "parameters"]), + ...(params.args.parameters as Record), + prompt: params.args.inputs, + response_format: "base64", + model: params.model, + }; + } + + async getResponse(response: Dat1Base64ImageGeneration, outputType?: "url" | "blob"): Promise { + if ( + typeof response === "object" && + "data" in response && + Array.isArray(response.data) && + response.data.length > 0 && + "b64_json" in response.data[0] && + typeof response.data[0].b64_json === "string" + ) { + const base64Data = response.data[0].b64_json; + if (outputType === "url") { + return `data:image/jpeg;base64,${base64Data}`; + } + return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob()); + } + + throw new InferenceOutputError("Expected Dat1 text-to-image response format"); + } +} diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts index a6df1ba3e4..c328f65923 100644 --- a/packages/inference/src/types.ts +++ b/packages/inference/src/types.ts @@ -41,6 +41,7 @@ export const INFERENCE_PROVIDERS = [ "black-forest-labs", "cerebras", "cohere", + "dat1", "fal-ai", "featherless-ai", "fireworks-ai", From 0a857bc8a4a587f52e5bd404c7e5d5d8b372c3d2 Mon Sep 17 00:00:00 2001 From: Arseny Yankovski Date: Sun, 18 May 2025 12:41:10 +0200 Subject: [PATCH 2/4] added dat1.co as provider --- .../inference/src/lib/getProviderHelper.ts | 1 - packages/inference/src/providers/consts.ts | 1 + packages/inference/src/providers/dat1.ts | 45 ++----------- .../inference/test/InferenceClient.spec.ts | 65 ++++++++++++++++++- 4 files changed, 67 insertions(+), 45 deletions(-) diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index 8870d4bfcd..7196ded342 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -63,7 +63,6 @@ export const PROVIDERS: Record { - choices: Array<{ - text: string; - finish_reason: TextGenerationOutputFinishReason; - seed: number; - logprobs: unknown; - index: number; - }>; -} - interface Dat1Base64ImageGeneration { data: Array<{ b64_json: string; @@ -47,34 +35,9 @@ export class Dat1ConversationalTask extends BaseConversationalTask { constructor() { super("dat1", DAT1_API_BASE_URL); } -} - -export class Dat1TextGenerationTask extends BaseTextGenerationTask { - constructor() { - super("dat1", DAT1_API_BASE_URL); - } - - override preparePayload(params: BodyParams): Record { - return { - model: params.model, - ...params.args, - prompt: params.args.inputs, - }; - } - override async getResponse(response: Dat1TextCompletionOutput): Promise { - if ( - typeof response === "object" && - "choices" in response && - Array.isArray(response?.choices) && - typeof response?.model === "string" - ) { - const completion = response.choices[0]; - return { - generated_text: completion.text, - }; - } - throw new InferenceOutputError("Expected Dat1 text generation response format"); + override makeRoute(): string { + return "/chat/completions"; } } @@ -83,8 +46,8 @@ export class Dat1TextToImageTask extends TaskProviderHelper implements TextToIma super("dat1", DAT1_API_BASE_URL); } - makeRoute(): string { - return "v1/images/generations"; + override makeRoute(): string { + return "/images/generations"; } preparePayload(params: BodyParams): Record { diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index 3365f38cfd..da68107055 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -22,17 +22,17 @@ if (!env.HF_TOKEN) { console.warn("Set HF_TOKEN in the env to run the tests for better rate limits"); } -describe.skip("InferenceClient", () => { +describe("InferenceClient", () => { // Individual tests can be ran without providing an api key, however running all tests without an api key will result in rate limiting error. - describe("backward compatibility", () => { + describe.skip("backward compatibility", () => { it("works with old HfInference name", async () => { const hf = new HfInference(env.HF_TOKEN); expect("fillMask" in hf).toBe(true); }); }); - describe.concurrent( + describe.skip( "HF Inference", () => { const hf = new InferenceClient(env.HF_TOKEN); @@ -989,6 +989,65 @@ describe.skip("InferenceClient", () => { TIMEOUT ); + describe.concurrent( + "dat1", + () => { + const client = new InferenceClient(env.DAT1_KEY ?? "dummy"); + + HARDCODED_MODEL_INFERENCE_MAPPING["dat1"] = { + "unsloth/Llama-3.2-3B-Instruct-GGUF": { + hfModelId: "unsloth/Llama-3.2-3B-Instruct-GGUF", + providerId: "unsloth-Llama-32-3B-Instruct-GGUF", + status: "live", + task: "conversational", + }, + "Kwai-Kolors/Kolors": { + hfModelId: "Kwai-Kolors/Kolors", + providerId: "Kwai-Kolors-Kolors", + status: "live", + task: "text-to-image", + }, + }; + + it("chatCompletion", async () => { + const res = await client.chatCompletion({ + model: "unsloth/Llama-3.2-3B-Instruct-GGUF", + provider: "dat1", + messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], + }); + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(completion).toContain("two"); + } + }); + + it("chatCompletion stream", async () => { + const stream = client.chatCompletionStream({ + model: "unsloth/Llama-3.2-3B-Instruct-GGUF", + provider: "dat1", + messages: [{ role: "user", content: "Complete the equation 1 + 1 = , just the answer" }], + }) as AsyncGenerator; + let out = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + out += chunk.choices[0].delta.content; + } + } + expect(out).toContain("2"); + }); + + it("textToImage", async () => { + const res = await client.textToImage({ + model: "Kwai-Kolors/Kolors", + provider: "dat1", + inputs: "award winning high resolution photo of a giant tortoise", + }); + expect(res).toBeInstanceOf(Blob); + }); + }, + TIMEOUT + ) + /** * Compatibility with third-party Inference Providers */ From 7c3bf1ab1a1b63553b33ea9525b88e4ec0fb25dc Mon Sep 17 00:00:00 2001 From: Arseny Yankovski Date: Sun, 18 May 2025 12:46:50 +0200 Subject: [PATCH 3/4] added dat1.co as provider --- packages/inference/test/InferenceClient.spec.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index da68107055..3c9f16926c 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -22,17 +22,17 @@ if (!env.HF_TOKEN) { console.warn("Set HF_TOKEN in the env to run the tests for better rate limits"); } -describe("InferenceClient", () => { +describe.skip("InferenceClient", () => { // Individual tests can be ran without providing an api key, however running all tests without an api key will result in rate limiting error. - describe.skip("backward compatibility", () => { + describe("backward compatibility", () => { it("works with old HfInference name", async () => { const hf = new HfInference(env.HF_TOKEN); expect("fillMask" in hf).toBe(true); }); }); - describe.skip( + describe( "HF Inference", () => { const hf = new InferenceClient(env.HF_TOKEN); From 31fa0aa8d43132a47d6718caf75116841d9c2fb7 Mon Sep 17 00:00:00 2001 From: Arseny Yankovski Date: Sun, 18 May 2025 12:47:22 +0200 Subject: [PATCH 4/4] added dat1.co as provider --- packages/inference/test/InferenceClient.spec.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index 3c9f16926c..134cb51090 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -32,7 +32,7 @@ describe.skip("InferenceClient", () => { }); }); - describe( + describe.concurrent( "HF Inference", () => { const hf = new InferenceClient(env.HF_TOKEN);