diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 486804a1f57ee..13be296bd46da 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -329,6 +329,7 @@ static TransportVersion def(int id) { public static final TransportVersion PROJECT_STATE_REGISTRY_RECORDS_DELETIONS = def(9_113_0_00); public static final TransportVersion ESQL_SERIALIZE_TIMESERIES_FIELD_TYPE = def(9_114_0_00); public static final TransportVersion ML_INFERENCE_IBM_WATSONX_COMPLETION_ADDED = def(9_115_0_00); + public static final TransportVersion ML_INFERENCE_MIXEDBREAD_ADDED = def(9_116_0_00); /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index c347fa1dca4ce..c126a304459f0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -106,6 +106,10 @@ import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings; import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.mixedbread.embeddings.MixedbreadEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.mixedbread.embeddings.MixedbreadEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankServiceSettings; +import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankTaskSettings; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings; import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings; @@ -164,6 +168,7 @@ public static List getNamedWriteables() { addIbmWatsonxNamedWritables(namedWriteables); addGoogleVertexAiNamedWriteables(namedWriteables); addMistralNamedWriteables(namedWriteables); + addMixedbreadNamedWriteables(namedWriteables); addCustomElandWriteables(namedWriteables); addAnthropicNamedWritables(namedWriteables); addAmazonBedrockNamedWriteables(namedWriteables); @@ -276,6 +281,34 @@ private static void addMistralNamedWriteables(List // note - no task settings for Mistral embeddings... } + private static void addMixedbreadNamedWriteables(List namedWriteables) { + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + MixedbreadEmbeddingsServiceSettings.NAME, + MixedbreadEmbeddingsServiceSettings::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + TaskSettings.class, + MixedbreadEmbeddingsTaskSettings.NAME, + MixedbreadEmbeddingsTaskSettings::new + ) + ); + + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + MixedbreadRerankServiceSettings.NAME, + MixedbreadRerankServiceSettings::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry(TaskSettings.class, MixedbreadRerankTaskSettings.NAME, MixedbreadRerankTaskSettings::new) + ); + } + private static void addAzureAiStudioNamedWriteables(List namedWriteables) { namedWriteables.add( new NamedWriteableRegistry.Entry( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index de31f9d6cefc8..fb7a89341757f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -133,6 +133,7 @@ import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService; import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService; import org.elasticsearch.xpack.inference.services.mistral.MistralService; +import org.elasticsearch.xpack.inference.services.mixedbread.MixedbreadService; import org.elasticsearch.xpack.inference.services.openai.OpenAiService; import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient; import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerService; @@ -392,6 +393,7 @@ public List getInferenceServiceFactories() { context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get()), context -> new GoogleVertexAiService(httpFactory.get(), serviceComponents.get()), context -> new MistralService(httpFactory.get(), serviceComponents.get()), + context -> new MixedbreadService(httpFactory.get(), serviceComponents.get()), context -> new AnthropicService(httpFactory.get(), serviceComponents.get()), context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get()), context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get()), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadConstants.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadConstants.java new file mode 100644 index 0000000000000..9df2beef04872 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadConstants.java @@ -0,0 +1,31 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread; + +public class MixedbreadConstants { + public static final String EMBEDDINGS_URI_PATH = "/v1/embeddings"; + public static final String RERANK_URI_PATH = "/v1/rerank"; + + // common service settings fields + public static final String API_KEY_FIELD = "api_key"; + + // embeddings service and request settings + public static final String INPUT_FIELD = "input"; + + // rerank task settings fields + public static final String QUERY_FIELD = "query"; + + // embeddings task settings fields + public static final String USER_FIELD = "user"; + + // rerank task settings fields + public static final String RETURN_DOCUMENTS_FIELD = "return_documents"; + public static final String TOP_K_FIELD = "top_k"; + + private MixedbreadConstants() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadModel.java new file mode 100644 index 0000000000000..7b107e25b34e1 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadModel.java @@ -0,0 +1,101 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread; + +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Map; +import java.util.Objects; + +/** + * Abstract class representing a Mixedbread model for inference. + * This class extends RateLimitGroupingModel and provides common functionality for Mixedbread models. + */ +public abstract class MixedbreadModel extends RateLimitGroupingModel { + protected String modelId; + protected URI uri; + protected RateLimitSettings rateLimitSettings; + + public MixedbreadModel(MixedbreadModel model, TaskSettings taskSettings, RateLimitSettings rateLimitSettings) { + super(model, taskSettings); + this.rateLimitSettings = Objects.requireNonNull(rateLimitSettings); + } + + /** + * Constructor for creating a MixedbreadModel with specified configurations and secrets. + * + * @param configurations the model configurations + * @param secrets the secret settings for the model + */ + protected MixedbreadModel(ModelConfigurations configurations, ModelSecrets secrets) { + super(configurations, secrets); + } + + /** + * Constructor for creating a MixedbreadModel with specified model, service settings, and secret settings. + * @param model the model configurations + * @param serviceSettings the settings for the inference service + */ + protected MixedbreadModel(RateLimitGroupingModel model, ServiceSettings serviceSettings) { + super(model, serviceSettings); + } + + public String model() { + return this.modelId; + } + + public URI uri() { + return this.uri; + } + + @Override + public RateLimitSettings rateLimitSettings() { + return this.rateLimitSettings; + } + + @Override + public int rateLimitGroupingHash() { + return Objects.hash(modelId, uri, getSecretSettings()); + } + + // Needed for testing only + public void setURI(String newUri) { + try { + this.uri = new URI(newUri); + } catch (URISyntaxException e) { + // swallow any error + } + } + + /** + * Retrieves the secret settings from the provided map of secrets. + * If the map is null or empty, it returns an instance of EmptySecretSettings. + * Caused by the fact that Mixedbread model doesn't have out of the box security settings and can be used witout authentication. + * + * @param secrets the map containing secret settings + * @return an instance of SecretSettings + */ + protected static SecretSettings retrieveSecretSettings(Map secrets) { + return (secrets != null && secrets.isEmpty()) ? EmptySecretSettings.INSTANCE : DefaultSecretSettings.fromMap(secrets); + } + + @Override + public DefaultSecretSettings getSecretSettings() { + return (DefaultSecretSettings) super.getSecretSettings(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadService.java new file mode 100644 index 0000000000000..9e59298d22e14 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadService.java @@ -0,0 +1,385 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.util.LazyInitializable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; +import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.mixedbread.action.MixedbreadActionCreator; +import org.elasticsearch.xpack.inference.services.mixedbread.embeddings.MixedbreadEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.mixedbread.embeddings.MixedbreadEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankModel; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; + +public class MixedbreadService extends SenderService { + public static final String NAME = "mixedbread"; + private static final String SERVICE_NAME = "Mixedbread"; + /** + * The optimal batch size depends on the hardware the model is deployed on. + * For Mixedbread use a conservatively small max batch size as it is + * unknown how the model is deployed + */ + static final int EMBEDDING_MAX_BATCH_SIZE = 20; + private static final EnumSet SUPPORTED_TASK_TYPES = EnumSet.of(TEXT_EMBEDDING); + + /** + * Constructor for creating a MixedbreadService with specified HTTP request sender factory and service components. + * + * @param factory the factory to create HTTP request senders + * @param serviceComponents the components required for the inference service + */ + public MixedbreadService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { + super(factory, serviceComponents); + } + + @Override + protected void doInfer( + Model model, + InferenceInputs inputs, + Map taskSettings, + TimeValue timeout, + ActionListener listener + ) { + var actionCreator = new MixedbreadActionCreator(getSender(), getServiceComponents()); + + if (model instanceof MixedbreadEmbeddingsModel mixedbreadEmbeddingsModel) { + mixedbreadEmbeddingsModel.accept(actionCreator).execute(inputs, timeout, listener); + } else if (model instanceof MixedbreadRerankModel mixedbreadRerankModel) { + mixedbreadRerankModel.accept(actionCreator, taskSettings).execute(inputs, timeout, listener); + } else { + listener.onFailure(createInvalidModelException(model)); + } + } + + @Override + protected void validateInputType(InputType inputType, Model model, ValidationException validationException) { + ServiceUtils.validateInputTypeIsUnspecifiedOrInternal(inputType, validationException); + } + + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + + /** + * Creates a MixedbreadModel based on the provided parameters. + * + * @param inferenceId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param serviceSettings the settings for the inference service + * @param taskSettings the settings specific to the task + * @param chunkingSettings the settings for chunking, if applicable + * @param secretSettings the secret settings for the model, such as API keys or tokens + * @param failureMessage the message to use in case of failure + * @param context the context for parsing configuration settings + * @return a new instance of MixedbreadModel based on the provided parameters + */ + protected MixedbreadModel createModel( + String inferenceId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + ChunkingSettings chunkingSettings, + Map secretSettings, + String failureMessage, + ConfigurationParseContext context + ) { + switch (taskType) { + case TEXT_EMBEDDING: + return new MixedbreadEmbeddingsModel( + inferenceId, + taskType, + NAME, + serviceSettings, + taskSettings, + chunkingSettings, + secretSettings, + context + ); + default: + throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + } + } + + @Override + public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { + if (model instanceof MixedbreadEmbeddingsModel embeddingsModel) { + var serviceSettings = embeddingsModel.getServiceSettings(); + var updatedServiceSettings = new MixedbreadEmbeddingsServiceSettings( + serviceSettings.modelId(), + serviceSettings.uri(), + embeddingSize, + serviceSettings.rateLimitSettings() + ); + + return new MixedbreadEmbeddingsModel(embeddingsModel, updatedServiceSettings); + } else { + throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass()); + } + } + + @Override + protected void doChunkedInfer( + Model model, + EmbeddingsInput inputs, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener> listener + ) { + if (model instanceof MixedbreadEmbeddingsModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + var MixedbreadModel = (MixedbreadEmbeddingsModel) model; + var actionCreator = new MixedbreadActionCreator(getSender(), getServiceComponents()); + + List batchedRequests = new EmbeddingRequestChunker<>( + inputs.getInputs(), + EMBEDDING_MAX_BATCH_SIZE, + MixedbreadModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + + for (var request : batchedRequests) { + var action = MixedbreadModel.accept(actionCreator); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + } + } + + @Override + public Set supportedStreamingTasks() { + return null; // EnumSet.of(COMPLETION, CHAT_COMPLETION); + } + + @Override + public InferenceServiceConfiguration getConfiguration() { + return Configuration.get(); + } + + @Override + public EnumSet supportedTaskTypes() { + return SUPPORTED_TASK_TYPES; + } + + @Override + public String name() { + return NAME; + } + + @Override + public void parseRequestConfig( + String modelId, + TaskType taskType, + Map config, + ActionListener parsedModelListener + ) { + try { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } + + MixedbreadModel model = createModel( + modelId, + taskType, + serviceSettingsMap, + taskSettingsMap, + chunkingSettings, + serviceSettingsMap, + TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), + ConfigurationParseContext.REQUEST + ); + + throwIfNotEmptyMap(config, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); + + parsedModelListener.onResponse(model); + } catch (Exception e) { + parsedModelListener.onFailure(e); + } + } + + @Override + public Model parsePersistedConfigWithSecrets( + String modelId, + TaskType taskType, + Map config, + Map secrets + ) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + + return createModelFromPersistent( + modelId, + taskType, + serviceSettingsMap, + taskSettingsMap, + chunkingSettings, + secretSettingsMap, + parsePersistedConfigErrorMsg(modelId, NAME) + ); + } + + private MixedbreadModel createModelFromPersistent( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + ChunkingSettings chunkingSettings, + Map secretSettings, + String failureMessage + ) { + return createModel( + inferenceEntityId, + taskType, + serviceSettings, + taskSettings, + chunkingSettings, + secretSettings, + failureMessage, + ConfigurationParseContext.PERSISTENT + ); + } + + @Override + public Model parsePersistedConfig(String modelId, TaskType taskType, Map config) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + + return createModelFromPersistent( + modelId, + taskType, + serviceSettingsMap, + taskSettingsMap, + chunkingSettings, + null, + parsePersistedConfigErrorMsg(modelId, NAME) + ); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_MIXEDBREAD_ADDED; + } + + /** + * Configuration class for the Mixedbread inference service. + * It provides the settings and configurations required for the service. + */ + public static class Configuration { + public static InferenceServiceConfiguration get() { + return CONFIGURATION.getOrCompute(); + } + + private Configuration() {} + + private static final LazyInitializable CONFIGURATION = new LazyInitializable<>( + () -> { + var configurationMap = new HashMap(); + + configurationMap.put( + URL, + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription("The URL endpoint to use for the requests.") + .setLabel("URL") + .setRequired(true) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ); + configurationMap.put( + MODEL_ID, + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription( + "Refer to the Mixedbread models documentation for the list of available models." + ) + .setLabel("Model") + .setRequired(true) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ); + configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); + configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); + + return new InferenceServiceConfiguration.Builder().setService(NAME) + .setName(SERVICE_NAME) + .setTaskTypes(SUPPORTED_TASK_TYPES) + .setConfigurations(configurationMap) + .build(); + } + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/action/MixedbreadActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/action/MixedbreadActionCreator.java new file mode 100644 index 0000000000000..9415f88f6d297 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/action/MixedbreadActionCreator.java @@ -0,0 +1,117 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.action; + +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.mixedbread.embeddings.MixedbreadEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.mixedbread.embeddings.MixedbreadEmbeddingsResponseHandler; +import org.elasticsearch.xpack.inference.services.mixedbread.request.MixedbreadEmbeddingsRequest; +import org.elasticsearch.xpack.inference.services.mixedbread.request.MixedbreadRerankRequest; +import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankModel; +import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankResponseHandler; +import org.elasticsearch.xpack.inference.services.mixedbread.response.MixedbreadEmbeddingsResponseEntity; +import org.elasticsearch.xpack.inference.services.mixedbread.response.MixedbreadRerankResponseEntity; + +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.common.Truncator.truncate; + +public class MixedbreadActionCreator implements MixedbreadVisitor { + private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE = "Failed to send Mixedbread %s request from inference entity id [%s]"; + private static final String INVALID_REQUEST_TYPE_MESSAGE = "Invalid request type: expected Mixedbread %s request but got %s"; + + private static final ResponseHandler EMBEDDINGS_HANDLER = new MixedbreadEmbeddingsResponseHandler( + "mixedbread text embedding", + MixedbreadEmbeddingsResponseEntity::fromResponse + ); + + private static final ResponseHandler RERANK_HANDLER = new MixedbreadRerankResponseHandler("mixedbread rerank", (request, response) -> { + if ((request instanceof MixedbreadRerankRequest) == false) { + var errorMessage = format( + INVALID_REQUEST_TYPE_MESSAGE, + "RERANK", + request != null ? request.getClass().getSimpleName() : "null" + ); + throw new IllegalArgumentException(errorMessage); + } + return MixedbreadRerankResponseEntity.fromResponse((MixedbreadRerankRequest) request, response); + }); + + private final Sender sender; + private final ServiceComponents serviceComponents; + + /** + * Constructs a new MixedbreadActionCreator with the specified sender and service components. + * + * @param sender the sender to use for executing actions + * @param serviceComponents the service components providing necessary services + */ + public MixedbreadActionCreator(Sender sender, ServiceComponents serviceComponents) { + this.sender = Objects.requireNonNull(sender); + this.serviceComponents = Objects.requireNonNull(serviceComponents); + } + + @Override + public ExecutableAction create(MixedbreadEmbeddingsModel model) { + var manager = new GenericRequestManager<>( + serviceComponents.threadPool(), + model, + EMBEDDINGS_HANDLER, + embeddingsInput -> new MixedbreadEmbeddingsRequest( + serviceComponents.truncator(), + truncate(embeddingsInput.getStringInputs(), model.getServiceSettings().maxInputTokens()), + model + ), + EmbeddingsInput.class + ); + + var errorMessage = buildErrorMessage(TaskType.TEXT_EMBEDDING, model.getInferenceEntityId()); + return new SenderExecutableAction(sender, manager, errorMessage); + } + + @Override + public ExecutableAction create(MixedbreadRerankModel model, Map taskSettings) { + var overriddenModel = MixedbreadRerankModel.of(model, taskSettings); + var manager = new GenericRequestManager<>( + serviceComponents.threadPool(), + overriddenModel, + RERANK_HANDLER, + inputs -> new MixedbreadRerankRequest( + model, + inputs.getQuery(), + inputs.getChunks(), + inputs.getReturnDocuments(), + inputs.getTopN() + ), + QueryAndDocsInputs.class + ); + var errorMessage = buildErrorMessage(TaskType.RERANK, model.getInferenceEntityId()); + return new SenderExecutableAction(sender, manager, errorMessage); + } + + /** + * Builds an error message for failed requests. + * + * @param requestType the type of request that failed + * @param inferenceId the inference entity ID associated with the request + * @return a formatted error message + */ + public static String buildErrorMessage(TaskType requestType, String inferenceId) { + return format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, requestType.toString(), inferenceId); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/action/MixedbreadVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/action/MixedbreadVisitor.java new file mode 100644 index 0000000000000..0b49fa34eb91a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/action/MixedbreadVisitor.java @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.action; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.mixedbread.embeddings.MixedbreadEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankModel; + +import java.util.Map; + +public interface MixedbreadVisitor { + ExecutableAction create(MixedbreadEmbeddingsModel embeddingsModel); + + ExecutableAction create(MixedbreadRerankModel rerankModel, Map taskSettings); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/embeddings/MixedbreadEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/embeddings/MixedbreadEmbeddingsModel.java new file mode 100644 index 0000000000000..438234877e3f8 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/embeddings/MixedbreadEmbeddingsModel.java @@ -0,0 +1,120 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.embeddings; + +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.mixedbread.MixedbreadModel; +import org.elasticsearch.xpack.inference.services.mixedbread.action.MixedbreadVisitor; + +import java.util.Map; + +public class MixedbreadEmbeddingsModel extends MixedbreadModel { + + /** + * Constructor for creating a MixedbreadEmbeddingsModel with specified parameters. + * + * @param inferenceEntityId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param service the name of the inference service + * @param serviceSettings the settings for the inference service, specific to embeddings + * @param secrets the secret settings for the model, such as API keys or tokens + * @param context the context for parsing configuration settings + */ + public MixedbreadEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + ChunkingSettings chunkingSettings, + Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + MixedbreadEmbeddingsServiceSettings.fromMap(serviceSettings, context), + EmptyTaskSettings.INSTANCE, // no task settings for Mixedbread embeddings + chunkingSettings, + retrieveSecretSettings(secrets) + ); + } + + /** + * Constructor for creating a MixedbreadEmbeddingsModel with specified parameters. + * + * @param model the base MixedbreadEmbeddingsModel to copy properties from + * @param serviceSettings the settings for the inference service, specific to embeddings + */ + public MixedbreadEmbeddingsModel(MixedbreadEmbeddingsModel model, MixedbreadEmbeddingsServiceSettings serviceSettings) { + super(model, serviceSettings); + setPropertiesFromServiceSettings(serviceSettings); + } + + /** + * Sets properties from the provided MixedbreadEmbeddingsServiceSettings. + * + * @param serviceSettings the service settings to extract properties from + */ + private void setPropertiesFromServiceSettings(MixedbreadEmbeddingsServiceSettings serviceSettings) { + this.modelId = serviceSettings.modelId(); + this.uri = serviceSettings.uri(); + this.rateLimitSettings = serviceSettings.rateLimitSettings(); + } + + /** + * Constructor for creating a MixedbreadEmbeddingsModel with specified parameters. + * + * @param inferenceEntityId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param service the name of the inference service + * @param serviceSettings the settings for the inference service, specific to embeddings + * @param taskSettings the task settings for the model + * @param chunkingSettings the chunking settings for processing input data + * @param secrets the secret settings for the model, such as API keys or tokens + */ + public MixedbreadEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + MixedbreadEmbeddingsServiceSettings serviceSettings, + TaskSettings taskSettings, + ChunkingSettings chunkingSettings, + SecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE, chunkingSettings), + new ModelSecrets(secrets) + ); + setPropertiesFromServiceSettings(serviceSettings); + } + + @Override + public MixedbreadEmbeddingsServiceSettings getServiceSettings() { + return (MixedbreadEmbeddingsServiceSettings) super.getServiceSettings(); + } + + /** + * Accepts a visitor to create an executable action for this Mixedbread embeddings model. + * + * @param creator the visitor that creates the executable action + * @return an ExecutableAction representing the Mixedbread embeddings model + */ + public ExecutableAction accept(MixedbreadVisitor creator) { + return creator.create(this); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/embeddings/MixedbreadEmbeddingsRequestTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/embeddings/MixedbreadEmbeddingsRequestTaskSettings.java new file mode 100644 index 0000000000000..d6429cb4c5bef --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/embeddings/MixedbreadEmbeddingsRequestTaskSettings.java @@ -0,0 +1,44 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.embeddings; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; + +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; +import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.USER_FIELD; + +public record MixedbreadEmbeddingsRequestTaskSettings(@Nullable String user) { + public static final MixedbreadEmbeddingsRequestTaskSettings EMPTY_SETTINGS = new MixedbreadEmbeddingsRequestTaskSettings(null); + + /** + * Extracts the task settings from a map. All settings are considered optional and the absence of a setting + * does not throw an error. + * + * @param map the settings received from a request + * @return a {@link MixedbreadEmbeddingsRequestTaskSettings} + */ + public static MixedbreadEmbeddingsRequestTaskSettings fromMap(Map map) { + if (map.isEmpty()) { + return MixedbreadEmbeddingsRequestTaskSettings.EMPTY_SETTINGS; + } + + ValidationException validationException = new ValidationException(); + + String user = extractOptionalString(map, USER_FIELD, ModelConfigurations.TASK_SETTINGS, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new MixedbreadEmbeddingsRequestTaskSettings(user); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/embeddings/MixedbreadEmbeddingsResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/embeddings/MixedbreadEmbeddingsResponseHandler.java new file mode 100644 index 0000000000000..a0fcb80515f3c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/embeddings/MixedbreadEmbeddingsResponseHandler.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.embeddings; + +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.services.mixedbread.response.MixedbreadErrorResponse; +import org.elasticsearch.xpack.inference.services.openai.OpenAiResponseHandler; + +public class MixedbreadEmbeddingsResponseHandler extends OpenAiResponseHandler { + /** + * Constructs a new MixedbreadEmbeddingsResponseHandler with the specified request type and response parser. + * + * @param requestType the type of request this handler will process + * @param parseFunction the function to parse the response + */ + public MixedbreadEmbeddingsResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, MixedbreadErrorResponse::fromResponse, false); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/embeddings/MixedbreadEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/embeddings/MixedbreadEmbeddingsServiceSettings.java new file mode 100644 index 0000000000000..6e1a5f203a348 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/embeddings/MixedbreadEmbeddingsServiceSettings.java @@ -0,0 +1,211 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.mixedbread.MixedbreadService; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings.extractUri; + +public class MixedbreadEmbeddingsServiceSettings extends FilteredXContentObject implements ServiceSettings { + public static final String NAME = "mixedbread_embeddings_service_settings"; + // There is no default rate limit for Mixedbread, so we set a reasonable default of 3000 requests per minute + protected static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000); + + private final String modelId; + private final URI uri; + private final Integer dimensions; + private final RateLimitSettings rateLimitSettings; + + /** + * Creates a new instance of MixedbreadEmbeddingsServiceSettings from a map of settings. + * + * @param map the map containing the settings + * @param context the context for parsing configuration settings + * @return a new instance of MixedbreadEmbeddingsServiceSettings + * @throws ValidationException if any required fields are missing or invalid + */ + public static MixedbreadEmbeddingsServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + var model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + var uri = extractUri(map, URL, validationException); + var dimensions = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException); + var rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + MixedbreadService.NAME, + context + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new MixedbreadEmbeddingsServiceSettings(model, uri, dimensions, rateLimitSettings); + } + + /** + * Constructs a new MixedbreadEmbeddingsServiceSettings from a StreamInput. + * + * @param in the StreamInput to read from + * @throws IOException if an I/O error occurs during reading + */ + public MixedbreadEmbeddingsServiceSettings(StreamInput in) throws IOException { + this.modelId = in.readString(); + this.uri = createUri(in.readString()); + this.dimensions = in.readOptionalVInt(); + this.rateLimitSettings = new RateLimitSettings(in); + } + + /** + * Constructs a new MixedbreadEmbeddingsServiceSettings with the specified parameters. + * + * @param modelId the identifier for the model + * @param uri the URI of the Mixedbread service + * @param dimensions the number of dimensions for the embeddings, can be null + * @param rateLimitSettings the rate limit settings for the service, can be null + */ + public MixedbreadEmbeddingsServiceSettings( + String modelId, + URI uri, + @Nullable Integer dimensions, + @Nullable RateLimitSettings rateLimitSettings + ) { + this.modelId = modelId; + this.uri = uri; + this.dimensions = dimensions; + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + /** + * Constructs a new MixedbreadEmbeddingsServiceSettings with the specified parameters. + * + * @param modelId the identifier for the model + * @param url the URL of the Mixedbread service + * @param dimensions the number of dimensions for the embeddings, can be null + * @param rateLimitSettings the rate limit settings for the service, can be null + */ + public MixedbreadEmbeddingsServiceSettings( + String modelId, + String url, + @Nullable Integer dimensions, + @Nullable RateLimitSettings rateLimitSettings + ) { + this(modelId, createUri(url), dimensions, rateLimitSettings); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_MIXEDBREAD_ADDED; + } + + @Override + public String modelId() { + return this.modelId; + } + + public URI uri() { + return this.uri; + } + + @Override + public Integer dimensions() { + return this.dimensions; + } + + @Override + public DenseVectorFieldMapper.ElementType elementType() { + return DenseVectorFieldMapper.ElementType.FLOAT; + } + + /** + * Returns the rate limit settings for this service. + * + * @return the rate limit settings, never null + */ + public RateLimitSettings rateLimitSettings() { + return this.rateLimitSettings; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + out.writeString(uri.toString()); + out.writeOptionalVInt(dimensions); + rateLimitSettings.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + toXContentFragmentOfExposedFields(builder, params); + builder.endObject(); + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.field(MODEL_ID, modelId); + builder.field(URL, uri.toString()); + + if (dimensions != null) { + builder.field(DIMENSIONS, dimensions); + } + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MixedbreadEmbeddingsServiceSettings that = (MixedbreadEmbeddingsServiceSettings) o; + return Objects.equals(modelId, that.modelId) + && Objects.equals(uri, that.uri) + && Objects.equals(dimensions, that.dimensions) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, uri, dimensions, rateLimitSettings); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/embeddings/MixedbreadEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/embeddings/MixedbreadEmbeddingsTaskSettings.java new file mode 100644 index 0000000000000..93a9bdd805f95 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/embeddings/MixedbreadEmbeddingsTaskSettings.java @@ -0,0 +1,122 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; +import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.USER_FIELD; + +public class MixedbreadEmbeddingsTaskSettings implements TaskSettings { + public static final String NAME = "mixedbread_embeddings_task_settings"; + + public static MixedbreadEmbeddingsTaskSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + + String user = extractOptionalString(map, USER_FIELD, ModelConfigurations.TASK_SETTINGS, validationException); + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new MixedbreadEmbeddingsTaskSettings(user); + } + + /** + * Creates a new {@link MixedbreadEmbeddingsTaskSettings} object by overriding the values in originalSettings with the ones + * passed in via requestSettings if the fields are not null. + * + * @param originalSettings the original {@link MixedbreadEmbeddingsTaskSettings} from the inference entity configuration from storage + * @param requestSettings the {@link MixedbreadEmbeddingsRequestTaskSettings} from the request + * @return a new {@link MixedbreadEmbeddingsTaskSettings} + */ + public static MixedbreadEmbeddingsTaskSettings of( + MixedbreadEmbeddingsTaskSettings originalSettings, + MixedbreadEmbeddingsRequestTaskSettings requestSettings + ) { + var userToUse = requestSettings.user() == null ? originalSettings.user : requestSettings.user(); + return new MixedbreadEmbeddingsTaskSettings(userToUse); + } + + public MixedbreadEmbeddingsTaskSettings(@Nullable String user) { + this.user = user; + } + + public MixedbreadEmbeddingsTaskSettings(StreamInput in) throws IOException { + this.user = in.readOptionalString(); + } + + private final String user; + + public String user() { + return this.user; + } + + @Override + public boolean isEmpty() { + return user == null || user.isEmpty(); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.V_8_14_0; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(this.user); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (user != null) { + builder.field(USER_FIELD, user); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MixedbreadEmbeddingsTaskSettings that = (MixedbreadEmbeddingsTaskSettings) o; + return Objects.equals(user, that.user); + } + + @Override + public int hashCode() { + return Objects.hashCode(user); + } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + MixedbreadEmbeddingsRequestTaskSettings requestSettings = MixedbreadEmbeddingsRequestTaskSettings.fromMap( + new HashMap<>(newSettings) + ); + return MixedbreadEmbeddingsTaskSettings.of(this, requestSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/request/MixedbreadEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/request/MixedbreadEmbeddingsRequest.java new file mode 100644 index 0000000000000..d07f09e94dee7 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/request/MixedbreadEmbeddingsRequest.java @@ -0,0 +1,92 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.request; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.mixedbread.embeddings.MixedbreadEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.net.URI; +import java.nio.charset.StandardCharsets; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +public class MixedbreadEmbeddingsRequest implements Request { + private final URI uri; + private final MixedbreadEmbeddingsModel model; + private final String inferenceEntityId; + private final Truncator.TruncationResult truncationResult; + private final Truncator truncator; + + /** + * Constructs a new MixedbreadEmbeddingsRequest with the specified truncator, input, and model. + * + * @param truncator the truncator to handle input truncation + * @param input the input to be truncated + * @param model the Mixedbread embeddings model to be used for the request + */ + public MixedbreadEmbeddingsRequest(Truncator truncator, Truncator.TruncationResult input, MixedbreadEmbeddingsModel model) { + this.uri = model.uri(); + this.model = model; + this.inferenceEntityId = model.getInferenceEntityId(); + this.truncator = truncator; + this.truncationResult = input; + } + + /** + * Returns the URI for this request. + * + * @return the URI of the Mixedbread embeddings model + */ + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(this.uri); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new MixedbreadEmbeddingsRequestEntity(model.model(), truncationResult.input())) + .getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters()); + if (model.getSecretSettings() instanceof DefaultSecretSettings) { + var secretSettings = model.getSecretSettings(); + httpPost.setHeader(createAuthBearerHeader(secretSettings.apiKey())); + } + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public URI getURI() { + return uri; + } + + @Override + public Request truncate() { + var truncatedInput = truncator.truncate(truncationResult.input()); + return new MixedbreadEmbeddingsRequest(truncator, truncatedInput, model); + } + + @Override + public boolean[] getTruncationInfo() { + return truncationResult.truncated().clone(); + } + + @Override + public String getInferenceEntityId() { + return inferenceEntityId; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/request/MixedbreadEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/request/MixedbreadEmbeddingsRequestEntity.java new file mode 100644 index 0000000000000..1987f4e1cc8e7 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/request/MixedbreadEmbeddingsRequestEntity.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.request; + +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.ENCODING_FORMAT_FIELD; + +/** + * MixedbreadEmbeddingsRequestEntity is responsible for creating the request entity for Mixedbread embeddings. + * It implements ToXContentObject to allow serialization to XContent format. + */ +public record MixedbreadEmbeddingsRequestEntity(String model, List input) implements ToXContentObject { + + public static final String INPUT_FIELD = "input"; + public static final String MODEL_FIELD = "model"; + + /** + * Constructs a MixedbreadEmbeddingsRequestEntity with the specified model ID and input. + * + * @param model the ID of the model to use for embeddings + * @param input the list of input to generate embeddings for + */ + public MixedbreadEmbeddingsRequestEntity { + Objects.requireNonNull(model); + Objects.requireNonNull(input); + } + + /** + * Constructs a MixedbreadEmbeddingsRequestEntity with the specified model ID and a single content string. + */ + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.field(MODEL_FIELD, model); + builder.field(INPUT_FIELD, input); + builder.field(ENCODING_FORMAT_FIELD, "float"); + + builder.endObject(); + + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/request/MixedbreadRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/request/MixedbreadRequest.java new file mode 100644 index 0000000000000..37c571419e73c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/request/MixedbreadRequest.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.request; + +import org.apache.http.client.methods.HttpEntityEnclosingRequestBase; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.mixedbread.MixedbreadModel; + +import java.net.URI; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; +import static org.elasticsearch.xpack.inference.services.azureaistudio.request.AzureAiStudioRequestFields.API_KEY_HEADER; + +public abstract class MixedbreadRequest implements Request { + protected final URI uri; + protected final String inferenceEntityId; + + protected MixedbreadRequest(MixedbreadModel model) { + this.uri = model.uri(); + this.inferenceEntityId = model.getInferenceEntityId(); + } + + protected void setAuthHeader(HttpEntityEnclosingRequestBase request, MixedbreadModel model) { + var apiKey = model.getSecretSettings().apiKey(); + request.setHeader(API_KEY_HEADER, apiKey.toString()); + request.setHeader(createAuthBearerHeader(apiKey)); + } + + @Override + public URI getURI() { + return this.uri; + } + + @Override + public String getInferenceEntityId() { + return this.inferenceEntityId; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/request/MixedbreadRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/request/MixedbreadRerankRequest.java new file mode 100644 index 0000000000000..e7014f8f59c0e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/request/MixedbreadRerankRequest.java @@ -0,0 +1,79 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.request; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankModel; + +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; + +public class MixedbreadRerankRequest extends MixedbreadRequest { + private final String query; + private final List input; + private final Boolean returnDocuments; + private final Integer topN; + private final MixedbreadRerankModel rerankModel; + + public MixedbreadRerankRequest( + MixedbreadRerankModel model, + String query, + List input, + @Nullable Boolean returnDocuments, + @Nullable Integer topN + ) { + super(model); + this.rerankModel = Objects.requireNonNull(model); + this.query = query; + this.input = Objects.requireNonNull(input); + this.returnDocuments = returnDocuments; + this.topN = topN; + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(this.uri); + + ByteArrayEntity byteEntity = new ByteArrayEntity(Strings.toString(createRequestEntity()).getBytes(StandardCharsets.UTF_8)); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + setAuthHeader(httpPost, rerankModel); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public Request truncate() { + // Not applicable for rerank, only used in text embedding requests + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // Not applicable for rerank, only used in text embedding requests + return null; + } + + private MixedbreadRerankRequestEntity createRequestEntity() { + var taskSettings = rerankModel.getTaskSettings(); + return new MixedbreadRerankRequestEntity(rerankModel.model(), query, input, topN, returnDocuments, taskSettings); + } + + public Integer getTopN() { + return topN != null ? topN : rerankModel.getTaskSettings().topK(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/request/MixedbreadRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/request/MixedbreadRerankRequestEntity.java new file mode 100644 index 0000000000000..c5cd3806619a6 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/request/MixedbreadRerankRequestEntity.java @@ -0,0 +1,63 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.request; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankTaskSettings; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.mixedbread.MixedbreadConstants.INPUT_FIELD; +import static org.elasticsearch.xpack.inference.services.mixedbread.MixedbreadConstants.QUERY_FIELD; +import static org.elasticsearch.xpack.inference.services.mixedbread.MixedbreadConstants.RETURN_DOCUMENTS_FIELD; +import static org.elasticsearch.xpack.inference.services.mixedbread.MixedbreadConstants.TOP_K_FIELD; +import static org.elasticsearch.xpack.inference.services.mixedbread.request.MixedbreadEmbeddingsRequestEntity.MODEL_FIELD; + +public record MixedbreadRerankRequestEntity( + String model, + String query, + List input, + @Nullable Integer topN, + @Nullable Boolean returnDocuments, + MixedbreadRerankTaskSettings taskSettings +) implements ToXContentObject { + + public MixedbreadRerankRequestEntity { + Objects.requireNonNull(model); + Objects.requireNonNull(query); + Objects.requireNonNull(input); + Objects.requireNonNull(taskSettings); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.field(MODEL_FIELD, model); + builder.field(QUERY_FIELD, query); + builder.field(INPUT_FIELD, input); + + if (topN != null) { + builder.field(TOP_K_FIELD, topN); + } else if (taskSettings.topK() != null) { + builder.field(TOP_K_FIELD, taskSettings.topK()); + } + + if (returnDocuments != null) { + builder.field(RETURN_DOCUMENTS_FIELD, returnDocuments); + } else if (taskSettings.returnDocuments() != null) { + builder.field(RETURN_DOCUMENTS_FIELD, taskSettings.returnDocuments()); + } + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankModel.java new file mode 100644 index 0000000000000..1e0fee86b59a5 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankModel.java @@ -0,0 +1,92 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.rerank; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.mixedbread.MixedbreadModel; +import org.elasticsearch.xpack.inference.services.mixedbread.action.MixedbreadVisitor; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.util.Map; + +public class MixedbreadRerankModel extends MixedbreadModel { + public static MixedbreadRerankModel of(MixedbreadRerankModel model, Map taskSettings) { + if (taskSettings == null || taskSettings.isEmpty()) { + return model; + } + + final var requestTaskSettings = MixedbreadRerankRequestTaskSettings.fromMap(taskSettings); + final var taskSettingToUse = MixedbreadRerankTaskSettings.of(model.getTaskSettings(), requestTaskSettings); + + return new MixedbreadRerankModel(model, taskSettingToUse); + } + + public MixedbreadRerankModel( + String inferenceEntityId, + TaskType taskType, + String service, + MixedbreadRerankServiceSettings serviceSettings, + MixedbreadRerankTaskSettings taskSettings, + DefaultSecretSettings secrets + ) { + super(new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secrets)); + } + + public MixedbreadRerankModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + MixedbreadRerankServiceSettings.fromMap(serviceSettings, context), + MixedbreadRerankTaskSettings.fromMap(taskSettings), + DefaultSecretSettings.fromMap(secrets) + ); + } + + public MixedbreadRerankModel(MixedbreadRerankModel model, MixedbreadRerankTaskSettings taskSettings) { + super(model, taskSettings, model.getServiceSettings().rateLimitSettings()); + } + + @Override + public MixedbreadRerankServiceSettings getServiceSettings() { + return (MixedbreadRerankServiceSettings) super.getServiceSettings(); + } + + @Override + public MixedbreadRerankTaskSettings getTaskSettings() { + return (MixedbreadRerankTaskSettings) super.getTaskSettings(); + } + + @Override + public DefaultSecretSettings getSecretSettings() { + return super.getSecretSettings(); + } + + /** + * Accepts a visitor to create an executable action for this Mixedbread embeddings model. + * + * @param creator the visitor that creates the executable action + * @return an ExecutableAction representing the Mixedbread embeddings model + */ + public ExecutableAction accept(MixedbreadVisitor creator, Map taskSettings) { + return creator.create(this, taskSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankRequestTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankRequestTaskSettings.java new file mode 100644 index 0000000000000..4f589abc6e368 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankRequestTaskSettings.java @@ -0,0 +1,48 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.rerank; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; + +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.mixedbread.MixedbreadConstants.RETURN_DOCUMENTS_FIELD; +import static org.elasticsearch.xpack.inference.services.mixedbread.MixedbreadConstants.TOP_K_FIELD; + +public record MixedbreadRerankRequestTaskSettings(@Nullable Boolean returnDocuments, @Nullable Integer topN) { + + public static final MixedbreadRerankRequestTaskSettings EMPTY_SETTINGS = new MixedbreadRerankRequestTaskSettings(null, null); + + /** + * Extracts the task settings from a map. All settings are considered optional and the absence of a setting + * does not throw an error. + * + * @param map the settings received from a request + * @return a {@link MixedbreadRerankRequestTaskSettings} + */ + public static MixedbreadRerankRequestTaskSettings fromMap(Map map) { + if (map.isEmpty()) { + return MixedbreadRerankRequestTaskSettings.EMPTY_SETTINGS; + } + + final var validationException = new ValidationException(); + + final var returnDocuments = extractOptionalBoolean(map, RETURN_DOCUMENTS_FIELD, validationException); + final var topN = extractOptionalPositiveInteger(map, TOP_K_FIELD, ModelConfigurations.TASK_SETTINGS, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new MixedbreadRerankRequestTaskSettings(returnDocuments, topN); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankResponseHandler.java new file mode 100644 index 0000000000000..90393addf24d4 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankResponseHandler.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.rerank; + +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.services.mixedbread.response.MixedbreadErrorResponse; +import org.elasticsearch.xpack.inference.services.openai.OpenAiResponseHandler; + +public class MixedbreadRerankResponseHandler extends OpenAiResponseHandler { + /** + * Constructs a new MixedbreadEmbeddingsResponseHandler with the specified request type and response parser. + * + * @param requestType the type of request this handler will process + * @param parseFunction the function to parse the response + */ + public MixedbreadRerankResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, MixedbreadErrorResponse::fromResponse, false); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankServiceSettings.java new file mode 100644 index 0000000000000..5fad859376147 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankServiceSettings.java @@ -0,0 +1,167 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService; +import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; + +public class MixedbreadRerankServiceSettings extends FilteredXContentObject implements ServiceSettings { + public static final String NAME = "mixedbread_rerank_service_settings"; + // There is no default rate limit for Mixedbread, so we set a reasonable default of 3000 requests per minute + protected static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000); + + /** + * Creates a new instance of MixedbreadRerankServiceSettings from a map of settings. + * + * @param map the map containing the settings + * @param context the context for parsing configuration settings + * @return a new instance of MixedbreadRerankServiceSettings + * @throws ValidationException if any required fields are missing or invalid + */ + public static MixedbreadRerankServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + var uri = HuggingFaceServiceSettings.extractUri(map, URL, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + HuggingFaceService.NAME, + context + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + return new MixedbreadRerankServiceSettings(uri, rateLimitSettings); + } + + private final URI uri; + private final RateLimitSettings rateLimitSettings; + + public MixedbreadRerankServiceSettings(String url) { + uri = createUri(url); + rateLimitSettings = DEFAULT_RATE_LIMIT_SETTINGS; + } + + MixedbreadRerankServiceSettings(URI uri, @Nullable RateLimitSettings rateLimitSettings) { + this.uri = Objects.requireNonNull(uri); + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + /** + * Constructs a new MixedbreadEmbeddingsServiceSettings from a StreamInput. + * + * @param in the StreamInput to read from + * @throws IOException if an I/O error occurs during reading + */ + public MixedbreadRerankServiceSettings(StreamInput in) throws IOException { + this.uri = createUri(in.readString()); + this.rateLimitSettings = new RateLimitSettings(in); + } + + /** + * Constructs a new MixedbreadRerankServiceSettings with the specified parameters. + * + * @param modelId the identifier for the model + * @param uri the URI of the Mixedbread service + * @param rateLimitSettings the rate limit settings for the service, can be null + */ + public MixedbreadRerankServiceSettings(String modelId, URI uri, @Nullable RateLimitSettings rateLimitSettings) { + this.uri = uri; + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + public URI uri() { + return this.uri; + } + + /** + * Returns the rate limit settings for this service. + * + * @return the rate limit settings, never null + */ + public RateLimitSettings rateLimitSettings() { + return this.rateLimitSettings; + } + + @Override + public String modelId() { + return null; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + toXContentFragmentOfExposedFields(builder, params); + builder.endObject(); + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.field(URL, uri.toString()); + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_MIXEDBREAD_ADDED; + } + + @Override + public boolean supportsVersion(TransportVersion version) { + return version.onOrAfter(TransportVersions.ML_INFERENCE_HUGGING_FACE_RERANK_ADDED) + || version.isPatchFrom(TransportVersions.ML_INFERENCE_HUGGING_FACE_RERANK_ADDED_8_19); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(uri.toString()); + rateLimitSettings.writeTo(out); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MixedbreadRerankServiceSettings that = (MixedbreadRerankServiceSettings) o; + return Objects.equals(uri, that.uri) && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(uri, rateLimitSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankTaskSettings.java new file mode 100644 index 0000000000000..1f70a79fd39de --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankTaskSettings.java @@ -0,0 +1,150 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.mixedbread.MixedbreadConstants.RETURN_DOCUMENTS_FIELD; +import static org.elasticsearch.xpack.inference.services.mixedbread.MixedbreadConstants.TOP_K_FIELD; + +public class MixedbreadRerankTaskSettings implements TaskSettings { + public static final String NAME = "mixedbread_rerank_task_settings"; + + public static MixedbreadRerankTaskSettings fromMap(Map map) { + final var validationException = new ValidationException(); + + final var returnDocuments = extractOptionalBoolean(map, RETURN_DOCUMENTS_FIELD, validationException); + final var topN = extractOptionalPositiveInteger(map, TOP_K_FIELD, ModelConfigurations.TASK_SETTINGS, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new MixedbreadRerankTaskSettings(returnDocuments, topN); + } + + /** + * Creates a new {@link MixedbreadRerankTaskSettings} object by overriding the values in originalSettings with the ones + * passed in via requestSettings if the fields are not null. + * @param originalSettings the original {@link MixedbreadRerankTaskSettings} from the inference entity configuration from storage + * @param requestSettings the {@link MixedbreadRerankRequestTaskSettings} from the request + * @return a new {@link MixedbreadRerankTaskSettings} + */ + public static MixedbreadRerankTaskSettings of( + MixedbreadRerankTaskSettings originalSettings, + MixedbreadRerankRequestTaskSettings requestSettings + ) { + + final var returnDocuments = requestSettings.returnDocuments() == null + ? originalSettings.returnDocuments() + : requestSettings.returnDocuments(); + final var topK = requestSettings.topN() == null ? originalSettings.topK() : requestSettings.topN(); + + return new MixedbreadRerankTaskSettings(returnDocuments, topK); + } + + public MixedbreadRerankTaskSettings(@Nullable Boolean returnDocuments, @Nullable Integer topK) { + this.returnDocuments = returnDocuments; + this.topK = topK; + } + + public MixedbreadRerankTaskSettings(StreamInput in) throws IOException { + this.returnDocuments = in.readOptionalBoolean(); + this.topK = in.readOptionalVInt(); + } + + private final Boolean returnDocuments; + private final Integer topK; + + public Boolean returnDocuments() { + return returnDocuments; + } + + public Integer topK() { + return topK; + } + + public boolean areAnyParametersAvailable() { + return returnDocuments != null && topK != null; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_MIXEDBREAD_ADDED; + } + + @Override + public boolean isEmpty() { + return returnDocuments == null && topK == null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalBoolean(returnDocuments); + out.writeOptionalVInt(topK); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + if (returnDocuments != null) { + builder.field(RETURN_DOCUMENTS_FIELD, returnDocuments); + } + if (topK != null) { + builder.field(TOP_K_FIELD, topK); + } + + builder.endObject(); + return builder; + } + + @Override + public String toString() { + return "MixedbreadRerankTaskSettings{" + ", returnDocuments=" + returnDocuments + ", topN=" + topK + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MixedbreadRerankTaskSettings that = (MixedbreadRerankTaskSettings) o; + return Objects.equals(returnDocuments, that.returnDocuments) && Objects.equals(topK, that.topK); + } + + @Override + public int hashCode() { + return Objects.hash(returnDocuments, topK); + } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + MixedbreadRerankRequestTaskSettings requestSettings = MixedbreadRerankRequestTaskSettings.fromMap(new HashMap<>(newSettings)); + return of(this, requestSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/response/MixedbreadEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/response/MixedbreadEmbeddingsResponseEntity.java new file mode 100644 index 0000000000000..70dbbab3fda11 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/response/MixedbreadEmbeddingsResponseEntity.java @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.response; + +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceEmbeddingsResponseEntity; + +import java.io.IOException; + +public class MixedbreadEmbeddingsResponseEntity { + /** + * Parses the response from a Mixedbread embeddings request and returns the results. + * + * @param request the original request that was sent + * @param response the HTTP result containing the response data + * @return an InferenceServiceResults object containing the parsed results + * @throws IOException if there is an error parsing the response + */ + public static InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException { + // expected response type is the same as the HuggingFace Embeddings + return HuggingFaceEmbeddingsResponseEntity.fromResponse(request, response); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/response/MixedbreadErrorResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/response/MixedbreadErrorResponse.java new file mode 100644 index 0000000000000..3e4d4d14cf7d8 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/response/MixedbreadErrorResponse.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.response; + +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; + +import java.nio.charset.StandardCharsets; + +public class MixedbreadErrorResponse extends ErrorResponse { + public MixedbreadErrorResponse(String message) { + super(message); + } + + public static ErrorResponse fromResponse(HttpResult response) { + try { + String errorMessage = new String(response.body(), StandardCharsets.UTF_8); + return new MixedbreadErrorResponse(errorMessage); + } catch (Exception e) { + // swallow the error + } + return ErrorResponse.UNDEFINED_ERROR; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/response/MixedbreadRerankResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/response/MixedbreadRerankResponseEntity.java new file mode 100644 index 0000000000000..f71f2c9542e23 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/response/MixedbreadRerankResponseEntity.java @@ -0,0 +1,82 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.response; + +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.services.mixedbread.request.MixedbreadRerankRequest; + +import java.io.IOException; +import java.util.Comparator; +import java.util.List; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; + +public class MixedbreadRerankResponseEntity { + /** + * Parses the response from a Mixedbread embeddings request and returns the results. + * + * @param request the original request that was sent + * @param response the HTTP result containing the response data + * @return an InferenceServiceResults object containing the parsed results + * @throws IOException if there is an error parsing the response + */ + public static InferenceServiceResults fromResponse(MixedbreadRerankRequest request, HttpResult response) throws IOException { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + moveToFirstToken(jsonParser); + var rankedDocs = doParse(jsonParser); + var rankedDocsByRelevanceStream = rankedDocs.stream() + .sorted(Comparator.comparingDouble(RankedDocsResults.RankedDoc::relevanceScore).reversed()); + var rankedDocStreamTopN = request.getTopN() == null + ? rankedDocsByRelevanceStream + : rankedDocsByRelevanceStream.limit(request.getTopN()); + return new RankedDocsResults(rankedDocStreamTopN.toList()); + } + } + + private static List doParse(XContentParser parser) throws IOException { + return parseList(parser, (listParser, index) -> { + var parsedRankedDoc = RankedDocEntry.parse(parser); + return new RankedDocsResults.RankedDoc(parsedRankedDoc.index, parsedRankedDoc.score, parsedRankedDoc.text); + }); + } + + private record RankedDocEntry(Integer index, Float score, @Nullable String text) { + + private static final ParseField TEXT = new ParseField("text"); + private static final ParseField SCORE = new ParseField("score"); + private static final ParseField INDEX = new ParseField("index"); + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "mixedbread_rerank_response", + true, + args -> new RankedDocEntry((int) args[0], (float) args[1], (String) args[2]) + ); + + static { + PARSER.declareInt(ConstructingObjectParser.constructorArg(), INDEX); + PARSER.declareFloat(ConstructingObjectParser.constructorArg(), SCORE); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), TEXT); + } + + public static RankedDocEntry parse(XContentParser parser) { + return PARSER.apply(parser, null); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/request/MixedbreadEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/request/MixedbreadEmbeddingsRequestEntityTests.java new file mode 100644 index 0000000000000..19953667f5883 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/request/MixedbreadEmbeddingsRequestEntityTests.java @@ -0,0 +1,32 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.request; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.hamcrest.CoreMatchers; + +import java.io.IOException; +import java.util.List; + +public class MixedbreadEmbeddingsRequestEntityTests extends ESTestCase { + + public void testXContent_WritesModelInputAndFormat() throws IOException { + var entity = new MixedbreadEmbeddingsRequestEntity("mixedbread-embed", List.of("abc")); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, CoreMatchers.is(""" + {"model":"mixedbread-embed","input":["abc"],"encoding_format":"float"}""")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/request/MixedbreadRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/request/MixedbreadRerankRequestEntityTests.java new file mode 100644 index 0000000000000..5555c72318d52 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/request/MixedbreadRerankRequestEntityTests.java @@ -0,0 +1,155 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.request; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankTaskSettings; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; + +public class MixedbreadRerankRequestEntityTests extends ESTestCase { + + public void testXContent_SingleRequest_WritesAllFieldsIfDefined() throws IOException { + var entity = new MixedbreadRerankRequestEntity( + "model", + "query", + List.of("abc"), + 2, + Boolean.TRUE, + new MixedbreadRerankTaskSettings(null, null) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "input": [ + "abc" + ], + "top_k": 2, + "return_documents": true + } + """)); + } + + public void testXContent_SingleRequest_WritesMinimalFields() throws IOException { + var entity = new MixedbreadRerankRequestEntity( + "model", + "query", + List.of("abc"), + null, + null, + new MixedbreadRerankTaskSettings(null, null) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "input": [ + "abc" + ] + } + """)); + } + + public void testXContent_SingleRequest_OverridesTopKField() throws IOException { + var entity = new MixedbreadRerankRequestEntity( + "model", + "query", + List.of("abc"), + null, + null, + new MixedbreadRerankTaskSettings(null, 2) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "input": [ + "abc" + ], + "top_k": 2 + } + """)); + } + + public void testXContent_SingleRequest_OverridesReturnDocumentsField() throws IOException { + var entity = new MixedbreadRerankRequestEntity( + "model", + "query", + List.of("abc"), + null, + null, + new MixedbreadRerankTaskSettings(Boolean.TRUE, null) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "input": [ + "abc" + ], + "return_documents": true + } + """)); + } + + public void testXContent_MultipleRequests_WritesAllFieldsIfDefined() throws IOException { + var entity = new MixedbreadRerankRequestEntity( + "model", + "query", + List.of("abc", "def"), + 2, + Boolean.TRUE, + new MixedbreadRerankTaskSettings(null, null) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "input": [ + "abc", + "def" + ], + "top_k": 2, + "return_documents": true + } + """)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/response/MixedbreadErrorResponseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/response/MixedbreadErrorResponseTests.java new file mode 100644 index 0000000000000..2147674fd0f52 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/response/MixedbreadErrorResponseTests.java @@ -0,0 +1,32 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.response; + +import org.apache.http.HttpResponse; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.nio.charset.StandardCharsets; + +import static org.mockito.Mockito.mock; + +public class MixedbreadErrorResponseTests extends ESTestCase { + public static final String ERROR_RESPONSE_JSON = """ + { + "error": "A valid user token is required" + } + """; + + public void testFromResponse() { + var errorResponse = MixedbreadErrorResponse.fromResponse( + new HttpResult(mock(HttpResponse.class), ERROR_RESPONSE_JSON.getBytes(StandardCharsets.UTF_8)) + ); + assertNotNull(errorResponse); + assertEquals(ERROR_RESPONSE_JSON, errorResponse.getErrorMessage()); + } +}