Skip to content

Add Mixedbread AI support #130876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ static TransportVersion def(int id) {
public static final TransportVersion PROJECT_STATE_REGISTRY_RECORDS_DELETIONS = def(9_113_0_00);
public static final TransportVersion ESQL_SERIALIZE_TIMESERIES_FIELD_TYPE = def(9_114_0_00);
public static final TransportVersion ML_INFERENCE_IBM_WATSONX_COMPLETION_ADDED = def(9_115_0_00);
public static final TransportVersion ML_INFERENCE_MIXEDBREAD_ADDED = def(9_116_0_00);
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.mixedbread.embeddings.MixedbreadEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.mixedbread.embeddings.MixedbreadEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings;
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings;
Expand Down Expand Up @@ -164,6 +168,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
addIbmWatsonxNamedWritables(namedWriteables);
addGoogleVertexAiNamedWriteables(namedWriteables);
addMistralNamedWriteables(namedWriteables);
addMixedbreadNamedWriteables(namedWriteables);
addCustomElandWriteables(namedWriteables);
addAnthropicNamedWritables(namedWriteables);
addAmazonBedrockNamedWriteables(namedWriteables);
Expand Down Expand Up @@ -276,6 +281,34 @@ private static void addMistralNamedWriteables(List<NamedWriteableRegistry.Entry>
// note - no task settings for Mistral embeddings...
}

private static void addMixedbreadNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
MixedbreadEmbeddingsServiceSettings.NAME,
MixedbreadEmbeddingsServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
TaskSettings.class,
MixedbreadEmbeddingsTaskSettings.NAME,
MixedbreadEmbeddingsTaskSettings::new
)
);

namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
MixedbreadRerankServiceSettings.NAME,
MixedbreadRerankServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, MixedbreadRerankTaskSettings.NAME, MixedbreadRerankTaskSettings::new)
);
}

private static void addAzureAiStudioNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService;
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService;
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
import org.elasticsearch.xpack.inference.services.mixedbread.MixedbreadService;
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerService;
Expand Down Expand Up @@ -392,6 +393,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get()),
context -> new GoogleVertexAiService(httpFactory.get(), serviceComponents.get()),
context -> new MistralService(httpFactory.get(), serviceComponents.get()),
context -> new MixedbreadService(httpFactory.get(), serviceComponents.get()),
context -> new AnthropicService(httpFactory.get(), serviceComponents.get()),
context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get()),
context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get()),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.mixedbread;

public class MixedbreadConstants {
public static final String EMBEDDINGS_URI_PATH = "/v1/embeddings";
public static final String RERANK_URI_PATH = "/v1/rerank";

// common service settings fields
public static final String API_KEY_FIELD = "api_key";

// embeddings service and request settings
public static final String INPUT_FIELD = "input";

// rerank task settings fields
public static final String QUERY_FIELD = "query";

// embeddings task settings fields
public static final String USER_FIELD = "user";

// rerank task settings fields
public static final String RETURN_DOCUMENTS_FIELD = "return_documents";
public static final String TOP_K_FIELD = "top_k";

private MixedbreadConstants() {}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.mixedbread;

import org.elasticsearch.inference.EmptySecretSettings;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.SecretSettings;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;

import java.net.URI;
import java.net.URISyntaxException;
import java.util.Map;
import java.util.Objects;

/**
* Abstract class representing a Mixedbread model for inference.
* This class extends RateLimitGroupingModel and provides common functionality for Mixedbread models.
*/
public abstract class MixedbreadModel extends RateLimitGroupingModel {
protected String modelId;
protected URI uri;
protected RateLimitSettings rateLimitSettings;

public MixedbreadModel(MixedbreadModel model, TaskSettings taskSettings, RateLimitSettings rateLimitSettings) {
super(model, taskSettings);
this.rateLimitSettings = Objects.requireNonNull(rateLimitSettings);
}

/**
* Constructor for creating a MixedbreadModel with specified configurations and secrets.
*
* @param configurations the model configurations
* @param secrets the secret settings for the model
*/
protected MixedbreadModel(ModelConfigurations configurations, ModelSecrets secrets) {
super(configurations, secrets);
}

/**
* Constructor for creating a MixedbreadModel with specified model, service settings, and secret settings.
* @param model the model configurations
* @param serviceSettings the settings for the inference service
*/
protected MixedbreadModel(RateLimitGroupingModel model, ServiceSettings serviceSettings) {
super(model, serviceSettings);
}

public String model() {
return this.modelId;
}

public URI uri() {
return this.uri;
}

@Override
public RateLimitSettings rateLimitSettings() {
return this.rateLimitSettings;
}

@Override
public int rateLimitGroupingHash() {
return Objects.hash(modelId, uri, getSecretSettings());
}

// Needed for testing only
public void setURI(String newUri) {
try {
this.uri = new URI(newUri);
} catch (URISyntaxException e) {
// swallow any error
}
}

/**
* Retrieves the secret settings from the provided map of secrets.
* If the map is null or empty, it returns an instance of EmptySecretSettings.
* Caused by the fact that Mixedbread model doesn't have out of the box security settings and can be used witout authentication.
*
* @param secrets the map containing secret settings
* @return an instance of SecretSettings
*/
protected static SecretSettings retrieveSecretSettings(Map<String, Object> secrets) {
return (secrets != null && secrets.isEmpty()) ? EmptySecretSettings.INSTANCE : DefaultSecretSettings.fromMap(secrets);
}

@Override
public DefaultSecretSettings getSecretSettings() {
return (DefaultSecretSettings) super.getSecretSettings();
}
}
Loading