diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java index 8e40bba8b32f7..1eb530ac1bb9e 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java @@ -350,7 +350,8 @@ private ElasticInferenceService createElasticInferenceService() { createWithEmptySettings(threadPool), ElasticInferenceServiceSettingsTests.create(gatewayUrl), modelRegistry, - new ElasticInferenceServiceAuthorizationRequestHandler(gatewayUrl, threadPool) + new ElasticInferenceServiceAuthorizationRequestHandler(gatewayUrl, threadPool), + mockClusterServiceEmpty() ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index de31f9d6cefc8..b729857c91f81 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -311,7 +311,8 @@ public Collection createComponents(PluginServices services) { serviceComponents.get(), inferenceServiceSettings, modelRegistry.get(), - authorizationHandler + authorizationHandler, + context ), context -> new SageMakerService( new SageMakerModelBuilder(sageMakerSchemas), @@ -321,7 +322,8 @@ public Collection createComponents(PluginServices services) { ), sageMakerSchemas, services.threadPool(), - sageMakerConfigurations::getOrCompute + sageMakerConfigurations::getOrCompute, + context ) ) ); @@ -383,24 +385,24 @@ public void loadExtensions(ExtensionLoader loader) { public List getInferenceServiceFactories() { return List.of( - context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get()), - context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get()), - context -> new OpenAiService(httpFactory.get(), serviceComponents.get()), - context -> new CohereService(httpFactory.get(), serviceComponents.get()), - context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get()), - context -> new AzureAiStudioService(httpFactory.get(), serviceComponents.get()), - context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get()), - context -> new GoogleVertexAiService(httpFactory.get(), serviceComponents.get()), - context -> new MistralService(httpFactory.get(), serviceComponents.get()), - context -> new AnthropicService(httpFactory.get(), serviceComponents.get()), - context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get()), - context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get()), - context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()), - context -> new JinaAIService(httpFactory.get(), serviceComponents.get()), - context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()), - context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()), + context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get(), context), + context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get(), context), + context -> new OpenAiService(httpFactory.get(), serviceComponents.get(), context), + context -> new CohereService(httpFactory.get(), serviceComponents.get(), context), + context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get(), context), + context -> new AzureAiStudioService(httpFactory.get(), serviceComponents.get(), context), + context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get(), context), + context -> new GoogleVertexAiService(httpFactory.get(), serviceComponents.get(), context), + context -> new MistralService(httpFactory.get(), serviceComponents.get(), context), + context -> new AnthropicService(httpFactory.get(), serviceComponents.get(), context), + context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get(), context), + context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get(), context), + context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get(), context), + context -> new JinaAIService(httpFactory.get(), serviceComponents.get(), context), + context -> new VoyageAIService(httpFactory.get(), serviceComponents.get(), context), + context -> new DeepSeekService(httpFactory.get(), serviceComponents.get(), context), ElasticsearchInternalService::new, - context -> new CustomService(httpFactory.get(), serviceComponents.get()) + context -> new CustomService(httpFactory.get(), serviceComponents.get(), context) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index ff8ae6fd5aac3..5074749c1cd9f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -9,6 +9,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Nullable; @@ -42,11 +43,13 @@ public abstract class SenderService implements InferenceService { protected static final Set COMPLETION_ONLY = EnumSet.of(TaskType.COMPLETION); private final Sender sender; private final ServiceComponents serviceComponents; + private final ClusterService clusterService; - public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { + public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { Objects.requireNonNull(factory); sender = factory.createSender(); this.serviceComponents = Objects.requireNonNull(serviceComponents); + this.clusterService = Objects.requireNonNull(clusterService); } public Sender getSender() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index 7897317736c72..da608779fee0a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -19,6 +20,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -85,8 +87,20 @@ public class AlibabaCloudSearchService extends SenderService { InputType.INTERNAL_SEARCH ); - public AlibabaCloudSearchService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public AlibabaCloudSearchService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public AlibabaCloudSearchService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + ClusterService clusterService + ) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index 591607953ea1a..c2b0ae8e69c37 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; @@ -20,6 +21,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -93,9 +95,19 @@ public class AmazonBedrockService extends SenderService { public AmazonBedrockService( HttpRequestSender.Factory httpSenderFactory, AmazonBedrockRequestSender.Factory amazonBedrockFactory, - ServiceComponents serviceComponents + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context ) { - super(httpSenderFactory, serviceComponents); + this(httpSenderFactory, amazonBedrockFactory, serviceComponents, context.clusterService()); + } + + public AmazonBedrockService( + HttpRequestSender.Factory httpSenderFactory, + AmazonBedrockRequestSender.Factory amazonBedrockFactory, + ServiceComponents serviceComponents, + ClusterService clusterService + ) { + super(httpSenderFactory, serviceComponents, clusterService); this.amazonBedrockSender = amazonBedrockFactory.createSender(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java index 791518ccc9168..8cf5446f8b6d5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java @@ -11,12 +11,14 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -58,8 +60,16 @@ public class AnthropicService extends SenderService { private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.COMPLETION); - public AnthropicService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public AnthropicService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public AnthropicService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index 17c7cbd6bdf0e..4a5a8be8b6633 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; @@ -19,6 +20,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -84,8 +86,16 @@ public class AzureAiStudioService extends SenderService { InputType.INTERNAL_SEARCH ); - public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public AzureAiStudioService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index e9ff97c1ba725..3d9a3dd516a2d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -18,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -69,8 +71,16 @@ public class AzureOpenAiService extends SenderService { private static final String SERVICE_NAME = "Azure OpenAI"; private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION); - public AzureOpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public AzureOpenAiService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public AzureOpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index c2f1221763165..fb6c630bd60c9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -19,6 +20,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -84,8 +86,16 @@ public class CohereService extends SenderService { // The reason it needs to be done here is that the batching logic needs to hold state but the *RequestManagers are instantiated // on every request - public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public CohereService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index 4e81d37ead3ad..5f5078affa9d3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -19,6 +20,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -74,8 +76,16 @@ public class CustomService extends SenderService { TaskType.COMPLETION ); - public CustomService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public CustomService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public CustomService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java index 56719199e094f..8a77efbd604d2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java @@ -10,12 +10,14 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -58,8 +60,16 @@ public class DeepSeekService extends SenderService { ); private static final EnumSet SUPPORTED_TASK_TYPES_FOR_STREAMING = EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); - public DeepSeekService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public DeepSeekService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public DeepSeekService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 36712ed922e95..58e964bb5c25f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; @@ -22,6 +23,7 @@ import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.MinimalServiceSettings; @@ -139,9 +141,28 @@ public ElasticInferenceService( ServiceComponents serviceComponents, ElasticInferenceServiceSettings elasticInferenceServiceSettings, ModelRegistry modelRegistry, - ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler + ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, + InferenceServiceExtension.InferenceServiceFactoryContext context ) { - super(factory, serviceComponents); + this( + factory, + serviceComponents, + elasticInferenceServiceSettings, + modelRegistry, + authorizationRequestHandler, + context.clusterService() + ); + } + + public ElasticInferenceService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + ElasticInferenceServiceSettings elasticInferenceServiceSettings, + ModelRegistry modelRegistry, + ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, + ClusterService clusterService + ) { + super(factory, serviceComponents, clusterService); this.elasticInferenceServiceComponents = new ElasticInferenceServiceComponents( elasticInferenceServiceSettings.getElasticInferenceServiceUrl() ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index 9841ea64370c3..4c8997f35555b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; @@ -19,6 +20,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -82,8 +84,16 @@ public class GoogleAiStudioService extends SenderService { InputType.INTERNAL_SEARCH ); - public GoogleAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public GoogleAiStudioService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public GoogleAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 3b59e999125e5..2c2c667cd6eee 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -18,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -97,8 +99,16 @@ public Set supportedStreamingTasks() { return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION); } - public GoogleVertexAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public GoogleVertexAiService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public GoogleVertexAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java index b0d40b41914d5..325f88c8904a3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java @@ -8,9 +8,11 @@ package org.elasticsearch.xpack.inference.services.huggingface; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -44,8 +46,16 @@ public abstract class HuggingFaceBaseService extends SenderService { */ static final int EMBEDDING_MAX_BATCH_SIZE = 20; - public HuggingFaceBaseService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public HuggingFaceBaseService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public HuggingFaceBaseService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index d10fb77290c6b..bc64e832d182a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -11,10 +11,12 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -71,8 +73,16 @@ public class HuggingFaceService extends HuggingFaceBaseService { OpenAiChatCompletionResponseEntity::fromResponse ); - public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public HuggingFaceService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index e61995aac91f3..5f9288bb99c24 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -11,12 +11,14 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -57,8 +59,16 @@ public class HuggingFaceElserService extends HuggingFaceBaseService { private static final String SERVICE_NAME = "Hugging Face ELSER"; private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING); - public HuggingFaceElserService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public HuggingFaceElserService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public HuggingFaceElserService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index 9bc63be1f9e7e..9617bff0d3f3d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -18,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -83,8 +85,16 @@ public class IbmWatsonxService extends SenderService { OpenAiChatCompletionResponseEntity::fromResponse ); - public IbmWatsonxService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public IbmWatsonxService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public IbmWatsonxService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index c2e88cb6cdc7c..00e1aede95a2b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -18,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -76,8 +78,16 @@ public class JinaAIService extends SenderService { InputType.INTERNAL_SEARCH ); - public JinaAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public JinaAIService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public JinaAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index b11feb117d761..3048847ea90d7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -18,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -84,8 +86,16 @@ public class MistralService extends SenderService { OpenAiChatCompletionResponseEntity::fromResponse ); - public MistralService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public MistralService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public MistralService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index edff1dfc08cba..b9e9e34c44736 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -18,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -91,8 +93,16 @@ public class OpenAiService extends SenderService { OpenAiChatCompletionResponseEntity::fromResponse ); - public OpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public OpenAiService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public OpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java index aafd6c46857fc..653c4288263f9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java @@ -12,6 +12,7 @@ import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.CheckedSupplier; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -20,6 +21,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -37,6 +39,7 @@ import java.util.EnumSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import static org.elasticsearch.core.Strings.format; @@ -55,13 +58,26 @@ public class SageMakerService implements InferenceService { private final SageMakerSchemas schemas; private final ThreadPool threadPool; private final LazyInitializable configuration; + private final ClusterService clusterService; public SageMakerService( SageMakerModelBuilder modelBuilder, SageMakerClient client, SageMakerSchemas schemas, ThreadPool threadPool, - CheckedSupplier, RuntimeException> configurationMap + CheckedSupplier, RuntimeException> configurationMap, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(modelBuilder, client, schemas, threadPool, configurationMap, context.clusterService()); + } + + public SageMakerService( + SageMakerModelBuilder modelBuilder, + SageMakerClient client, + SageMakerSchemas schemas, + ThreadPool threadPool, + CheckedSupplier, RuntimeException> configurationMap, + ClusterService clusterService ) { this.modelBuilder = modelBuilder; this.client = client; @@ -74,6 +90,7 @@ public SageMakerService( .setConfigurations(configurationMap.get()) .build() ); + this.clusterService = Objects.requireNonNull(clusterService); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java index 0ffec057dc2b4..9698ee4c0d4bb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; @@ -18,6 +19,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -96,8 +98,16 @@ public class VoyageAIService extends SenderService { InputType.INTERNAL_SEARCH ); - public VoyageAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + public VoyageAIService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public VoyageAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java index 5d7a6a149f941..7457859a64603 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; @@ -36,6 +37,7 @@ import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -64,7 +66,7 @@ public void testStart_InitializesTheSender() throws IOException { var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); - try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool))) { + try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.start(mock(Model.class), listener); @@ -84,7 +86,7 @@ public void testStart_CallingStartTwiceKeepsSameSenderReference() throws IOExcep var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); - try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool))) { + try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.start(mock(Model.class), listener); listener.actionGet(TIMEOUT); @@ -102,8 +104,8 @@ public void testStart_CallingStartTwiceKeepsSameSenderReference() throws IOExcep } private static final class TestSenderService extends SenderService { - TestSenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + TestSenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index 8fbbd33d569e4..f0258e9f66ed5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -91,7 +91,13 @@ public void shutdown() throws IOException { } public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { ActionListener modelVerificationListener = ActionListener.wrap(model -> { assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); @@ -116,7 +122,13 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws IOException } public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { ActionListener modelVerificationListener = ActionListener.wrap(model -> { assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); @@ -143,7 +155,13 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsP } public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { ActionListener modelVerificationListener = ActionListener.wrap(model -> { assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); @@ -169,7 +187,13 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsN } public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { var model = service.parsePersistedConfig( "id", TaskType.TEXT_EMBEDDING, @@ -190,7 +214,13 @@ public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSetting } public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { var model = service.parsePersistedConfig( "id", TaskType.TEXT_EMBEDDING, @@ -210,7 +240,13 @@ public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSetting } public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { var persistedConfig = getPersistedConfigMap( AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"), AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), @@ -235,7 +271,13 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChun } public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { var persistedConfig = getPersistedConfigMap( AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"), AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), @@ -262,7 +304,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChun public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = OpenAiChatCompletionModelTests.createCompletionModel( randomAlphaOfLength(10), randomAlphaOfLength(10), @@ -279,7 +321,7 @@ public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IO public void testUpdateModelWithEmbeddingDetails_UpdatesEmbeddingSizeAndSimilarity() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var model = AlibabaCloudSearchEmbeddingsModelTests.createModel( randomAlphaOfLength(10), @@ -316,7 +358,7 @@ public void testInfer_ThrowsValidationErrorForInvalidInputType_TextEmbedding() t taskSettingsMap, secretSettingsMap ); - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -360,7 +402,7 @@ public void testInfer_ThrowsValidationExceptionForInvalidInputType_SparseEmbeddi taskSettingsMap, secretSettingsMap ); - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -404,7 +446,7 @@ public void testInfer_ThrowsValidationErrorForInvalidRerankParams() throws IOExc taskSettingsMap, secretSettingsMap ); - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -452,7 +494,7 @@ private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettin var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = createModelForTaskType(taskType, chunkingSettings); PlainActionFuture> listener = new PlainActionFuture<>(); @@ -482,7 +524,13 @@ private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettin @SuppressWarnings("checkstyle:LineLength") public void testGetConfiguration() throws Exception { - try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + try ( + var service = new AlibabaCloudSearchService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { String content = XContentHelper.stripWhitespace( """ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index a014f27e7f0cc..c3b1cab4b4e0a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -959,7 +959,14 @@ public void testInfer_ThrowsErrorWhenModelIsNotAmazonBedrockModel() throws IOExc ); var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try ( + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -1007,7 +1014,12 @@ public void testInfer_SendsRequest_ForTitanEmbeddingsModel() throws IOException ); try ( - var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool)); + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender() ) { var results = new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F }))); @@ -1042,7 +1054,14 @@ public void testInfer_SendsRequest_ForCohereEmbeddingsModel() throws IOException mockClusterServiceEmpty() ); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try ( + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { var results = new TextEmbeddingFloatResults( List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F })) @@ -1088,7 +1107,14 @@ public void testInfer_SendsRequest_ForChatCompletionModel() throws IOException { mockClusterServiceEmpty() ); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try ( + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { var mockResults = new ChatCompletionResults(List.of(new ChatCompletionResults.Result("test result"))); requestSender.enqueue(mockResults); @@ -1132,7 +1158,14 @@ public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IO mockClusterServiceEmpty() ); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try ( + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { var model = AmazonBedrockChatCompletionModelTests.createModel( randomAlphaOfLength(10), randomAlphaOfLength(10), @@ -1166,7 +1199,14 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si mockClusterServiceEmpty() ); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try ( + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { var embeddingSize = randomNonNegativeInt(); var provider = randomFrom(AmazonBedrockProvider.values()); var model = AmazonBedrockEmbeddingsModelTests.createModel( @@ -1205,7 +1245,12 @@ public void testInfer_UnauthorizedResponse() throws IOException { ); try ( - var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool)); + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender() ) { requestSender.enqueue( @@ -1240,7 +1285,7 @@ public void testInfer_UnauthorizedResponse() throws IOException { } public void testSupportsStreaming() throws IOException { - try (var service = new AmazonBedrockService(mock(), mock(), createWithEmptySettings(mock()))) { + try (var service = new AmazonBedrockService(mock(), mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -1284,7 +1329,14 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep mockClusterServiceEmpty() ); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try ( + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { { var mockResults1 = new TextEmbeddingFloatResults( @@ -1345,7 +1397,12 @@ private AmazonBedrockService createAmazonBedrockService() { ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), mockClusterServiceEmpty() ); - return new AmazonBedrockService(mock(HttpRequestSender.Factory.class), amazonBedrockFactory, createWithEmptySettings(threadPool)); + return new AmazonBedrockService( + mock(HttpRequestSender.Factory.class), + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java index a3f0b01901009..9111866d29c88 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java @@ -453,7 +453,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new AnthropicService(factory, createWithEmptySettings(threadPool))) { + try (var service = new AnthropicService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -486,7 +486,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException public void testInfer_SendsCompletionRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AnthropicService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AnthropicService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { "id": "msg_01XzZQmG41BMGe5NZ5p2vEWb", @@ -579,7 +579,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamChatCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AnthropicService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AnthropicService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = AnthropicChatCompletionModelTests.createChatCompletionModel( getUrl(webServer), "secret", @@ -679,13 +679,13 @@ public void testGetConfiguration() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new AnthropicService(mock(), createWithEmptySettings(mock()))) { + try (var service = new AnthropicService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } } private AnthropicService createServiceWithMockSender() { - return new AnthropicService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new AnthropicService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index fee2adcf664ec..3383762a9f332 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -1073,7 +1073,7 @@ public void testParsePersistedConfig_WithoutSecretsCreatesRerankModel() throws I public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = AzureAiStudioChatCompletionModelTests.createModel( randomAlphaOfLength(10), randomAlphaOfLength(10), @@ -1098,7 +1098,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var model = AzureAiStudioEmbeddingsModelTests.createModel( randomAlphaOfLength(10), @@ -1124,7 +1124,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testUpdateModelWithChatCompletionDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = AzureAiStudioEmbeddingsModelTests.createModel( randomAlphaOfLength(10), randomAlphaOfLength(10), @@ -1152,7 +1152,7 @@ public void testUpdateModelWithChatCompletionDetails_NonNullSimilarityInOriginal private void testUpdateModelWithChatCompletionDetails_Successful(Integer maxNewTokens) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = AzureAiStudioChatCompletionModelTests.createModel( randomAlphaOfLength(10), randomAlphaOfLength(10), @@ -1185,7 +1185,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureAiStudioModel() throws IOExc var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new AzureAiStudioService(factory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -1223,7 +1223,7 @@ public void testInfer_ThrowsValidationErrorForInvalidInputType() throws IOExcept var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new AzureAiStudioService(factory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -1293,7 +1293,7 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1379,7 +1379,7 @@ private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOExcep public void testInfer_WithChatCompletionModel() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testChatCompletionResultJson)); var model = AzureAiStudioChatCompletionModelTests.createModel( @@ -1416,7 +1416,7 @@ public void testInfer_WithChatCompletionModel() throws IOException { public void testInfer_WithRerankModel() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testRerankTokenResponseJson)); var model = AzureAiStudioRerankModelTests.createModel( @@ -1457,7 +1457,7 @@ public void testInfer_WithRerankModel() throws IOException { public void testInfer_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1534,7 +1534,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamChatCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = AzureAiStudioChatCompletionModelTests.createModel( "id", getUrl(webServer), @@ -1666,7 +1666,7 @@ public void testGetConfiguration() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new AzureAiStudioService(mock(), createWithEmptySettings(mock()))) { + try (var service = new AzureAiStudioService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -1675,7 +1675,11 @@ public void testSupportsStreaming() throws IOException { // ---------------------------------------------------------------- private AzureAiStudioService createService() { - return new AzureAiStudioService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new AzureAiStudioService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index de2e9ae9a21b8..f3d65c5589169 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -752,7 +752,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureOpenAiModel() throws IOExcep var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new AzureOpenAiService(factory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -785,7 +785,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureOpenAiModel() throws IOExcep public void testInfer_SendsRequest() throws IOException, URISyntaxException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -844,7 +844,7 @@ public void testInfer_SendsRequest() throws IOException, URISyntaxException { public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = AzureOpenAiCompletionModelTests.createModelWithRandomValues(); assertThrows( ElasticsearchStatusException.class, @@ -864,7 +864,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var model = AzureOpenAiEmbeddingsModelTests.createModel( randomAlphaOfLength(10), @@ -891,7 +891,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testInfer_UnauthorisedResponse() throws IOException, URISyntaxException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -952,7 +952,7 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException, URISyn private void testChunkedInfer(AzureOpenAiEmbeddingsModel model) throws IOException, URISyntaxException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1065,7 +1065,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamChatCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = AzureOpenAiCompletionModelTests.createCompletionModel( "resource", "deployment", @@ -1209,14 +1209,18 @@ public void testGetConfiguration() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new AzureOpenAiService(mock(), createWithEmptySettings(mock()))) { + try (var service = new AzureOpenAiService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } } private AzureOpenAiService createAzureOpenAiService() { - return new AzureOpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new AzureOpenAiService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index 52e4f904a4de0..8f189baa33b20 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -779,7 +779,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotCohereModel() throws IOException var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new CohereService(factory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -812,7 +812,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotCohereModel() throws IOException public void testInfer_SendsRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -886,7 +886,7 @@ public void testInfer_SendsRequest() throws IOException { public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = CohereCompletionModelTests.createModel(randomAlphaOfLength(10), randomAlphaOfLength(10), randomAlphaOfLength(10)); assertThrows( ElasticsearchStatusException.class, @@ -906,7 +906,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var embeddingType = randomFrom(CohereEmbeddingType.values()); var model = CohereEmbeddingsModelTests.createModel( @@ -933,7 +933,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testInfer_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -975,7 +975,7 @@ public void testInfer_UnauthorisedResponse() throws IOException { public void testInfer_SetsInputTypeToIngest_FromInferParameter_WhenTaskSettingsAreEmpty() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1051,7 +1051,7 @@ public void testInfer_SetsInputTypeToIngestFromInferParameter_WhenModelSettingIs throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1125,7 +1125,7 @@ public void testInfer_SetsInputTypeToIngestFromInferParameter_WhenModelSettingIs public void testInfer_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest_v1API() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1200,7 +1200,7 @@ public void testInfer_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspec public void testInfer_DefaultsInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest_v2API() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1297,7 +1297,7 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { // Batching will call the service with 2 inputs String responseJson = """ @@ -1387,7 +1387,7 @@ private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { public void testChunkedInfer_BatchesCalls_Bytes() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { // Batching will call the service with 2 inputs String responseJson = """ @@ -1507,7 +1507,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamChatCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = CohereCompletionModelTests.createModel(getUrl(webServer), "secret", "model"); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -1591,7 +1591,7 @@ public void testGetConfiguration() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new CohereService(mock(), createWithEmptySettings(mock()))) { + try (var service = new CohereService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -1632,7 +1632,7 @@ private Map getRequestConfigMap(Map serviceSetti } private CohereService createCohereService() { - return new CohereService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new CohereService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java index cc1bb4471c0a9..a707030a34189 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java @@ -53,6 +53,7 @@ import static org.elasticsearch.xpack.inference.Utils.TIMEOUT; import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; @@ -148,7 +149,7 @@ private static void assertCompletionModel(Model model) { public static SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - return new CustomService(senderFactory, createWithEmptySettings(threadPool)); + return new CustomService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } private static Map createServiceSettingsMap(TaskType taskType) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java index af38ee38e1eff..908451b8e681f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java @@ -360,7 +360,8 @@ public void testDoChunkedInferAlwaysFails() throws IOException { private DeepSeekService createService() { return new DeepSeekService( HttpRequestSenderTests.createSenderFactory(threadPool, clientManager), - createWithEmptySettings(threadPool) + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 6ce484954d3ce..94d1e064648ff 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -1427,7 +1427,8 @@ private ElasticInferenceService createServiceWithMockSender(ElasticInferenceServ createWithEmptySettings(threadPool), new ElasticInferenceServiceSettings(Settings.EMPTY), modelRegistry, - mockAuthHandler + mockAuthHandler, + mockClusterServiceEmpty() ); } @@ -1456,7 +1457,8 @@ private ElasticInferenceService createService( createWithEmptySettings(threadPool), ElasticInferenceServiceSettingsTests.create(elasticInferenceServiceURL), modelRegistry, - mockAuthHandler + mockAuthHandler, + mockClusterServiceEmpty() ); } @@ -1469,7 +1471,8 @@ private ElasticInferenceService createServiceWithAuthHandler( createWithEmptySettings(threadPool), ElasticInferenceServiceSettingsTests.create(elasticInferenceServiceURL), modelRegistry, - new ElasticInferenceServiceAuthorizationRequestHandler(elasticInferenceServiceURL, threadPool) + new ElasticInferenceServiceAuthorizationRequestHandler(elasticInferenceServiceURL, threadPool), + mockClusterServiceEmpty() ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index 41175581df1cf..435ea9de5911b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -658,7 +658,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotGoogleAiStudioModel() throws IOEx var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new GoogleAiStudioService(factory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -696,7 +696,7 @@ public void testInfer_ThrowsValidationErrorWhenInputTypeIsSpecifiedForModelThatD var model = GoogleAiStudioEmbeddingsModelTests.createModel("model", getUrl(webServer), "secret"); - try (var service = new GoogleAiStudioService(factory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -730,7 +730,7 @@ public void testInfer_ThrowsValidationErrorWhenInputTypeIsSpecifiedForModelThatD public void testInfer_SendsCompletionRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { "candidates": [ @@ -818,7 +818,7 @@ public void testInfer_SendsEmbeddingsRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { "embeddings": [ @@ -897,7 +897,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { "embeddings": [ @@ -998,7 +998,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed public void testInfer_ResourceNotFound() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1033,7 +1033,7 @@ public void testInfer_ResourceNotFound() throws IOException { public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = GoogleAiStudioCompletionModelTests.createModel(randomAlphaOfLength(10), randomAlphaOfLength(10)); assertThrows( ElasticsearchStatusException.class, @@ -1052,7 +1052,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var model = GoogleAiStudioEmbeddingsModelTests.createModel( randomAlphaOfLength(10), @@ -1124,7 +1124,7 @@ public void testGetConfiguration() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new GoogleAiStudioService(mock(), createWithEmptySettings(mock()))) { + try (var service = new GoogleAiStudioService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -1171,6 +1171,10 @@ private Map getRequestConfigMap( } private GoogleAiStudioService createGoogleAiStudioService() { - return new GoogleAiStudioService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new GoogleAiStudioService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java index 99a09b983787d..26fd076e72462 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java @@ -1043,7 +1043,7 @@ public void testGetConfiguration() throws Exception { private GoogleVertexAiService createGoogleVertexAiService() { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - return new GoogleVertexAiService(senderFactory, createWithEmptySettings(threadPool)); + return new GoogleVertexAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java index 3be4b72c1237f..2cdf3f5263751 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java @@ -29,6 +29,7 @@ import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.CoreMatchers.is; import static org.mockito.Mockito.mock; @@ -92,7 +93,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotHuggingFaceModel() throws IOExcep private static final class TestService extends HuggingFaceService { TestService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + super(factory, serviceComponents, mockClusterServiceEmpty()); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java index 814d533129439..93156d4331263 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java @@ -81,7 +81,7 @@ public void shutdown() throws IOException { public void testChunkedInfer_CallsInfer_Elser_ConvertsFloatResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceElserService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceElserService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ [ @@ -137,7 +137,8 @@ public void testGetConfiguration() throws Exception { try ( var service = new HuggingFaceElserService( HttpRequestSenderTests.createSenderFactory(threadPool, clientManager), - createWithEmptySettings(threadPool) + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() ) ) { String content = XContentHelper.stripWhitespace(""" diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index e2850910ac64a..c770672c5d5f2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -258,7 +258,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws var mockModel = getInvalidModel("model_id", "service_name", TaskType.CHAT_COMPLETION); - try (var service = new HuggingFaceService(factory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -328,7 +328,7 @@ public void testUnifiedCompletionInfer() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = HuggingFaceChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( @@ -357,7 +357,7 @@ public void testUnifiedCompletionNonStreamingError() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = HuggingFaceChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); var latch = new CountDownLatch(1); service.unifiedCompletionInfer( @@ -486,7 +486,7 @@ public void testUnifiedCompletionMalformedError() throws Exception { private void testStreamError(String expectedResponse) throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = HuggingFaceChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( @@ -548,7 +548,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = HuggingFaceChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model"); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -621,7 +621,7 @@ public void testInfer_StreamRequestRetry() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new HuggingFaceService(mock(), createWithEmptySettings(mock()))) { + try (var service = new HuggingFaceService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -1009,7 +1009,7 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInTaskSetti public void testInfer_SendsEmbeddingsRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1060,7 +1060,7 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret"); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -1087,7 +1087,7 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { public void testInfer_SendsElserRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ [ @@ -1139,7 +1139,7 @@ public void testInfer_SendsElserRequest() throws IOException { public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = HuggingFaceElserModelTests.createModel(randomAlphaOfLength(10), randomAlphaOfLength(10)); assertThrows( ElasticsearchStatusException.class, @@ -1158,7 +1158,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var model = HuggingFaceEmbeddingsModelTests.createModel( randomAlphaOfLength(10), @@ -1179,7 +1179,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1233,7 +1233,7 @@ public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() th public void testChunkedInfer() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ [ @@ -1340,7 +1340,11 @@ public void testGetConfiguration() throws Exception { } private HuggingFaceService createHuggingFaceService() { - return new HuggingFaceService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new HuggingFaceService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ); } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index 3295ecfd4ece5..ddc62b5a412b9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -597,7 +597,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotIbmWatsonxModel() throws IOExcept var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new IbmWatsonxService(factory, createWithEmptySettings(threadPool))) { + try (var service = new IbmWatsonxService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -635,7 +635,7 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { var model = IbmWatsonxEmbeddingsModelTests.createModel(modelId, projectId, URI.create(url), apiVersion, apiKey, getUrl(webServer)); - try (var service = new IbmWatsonxService(factory, createWithEmptySettings(threadPool))) { + try (var service = new IbmWatsonxService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( @@ -1018,12 +1018,12 @@ private Map getRequestConfigMap( } private IbmWatsonxService createIbmWatsonxService() { - return new IbmWatsonxService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new IbmWatsonxService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } private static class IbmWatsonxServiceWithoutAuth extends IbmWatsonxService { IbmWatsonxServiceWithoutAuth(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); + super(factory, serviceComponents, mockClusterServiceEmpty()); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index eca76bc1a702a..d36c574e0aa99 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -778,7 +778,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotJinaAIModel() throws IOException var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new JinaAIService(factory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -819,7 +819,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var embeddingType = randomFrom(JinaAIEmbeddingType.values()); var model = JinaAIEmbeddingsModelTests.createModel( @@ -846,7 +846,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testInfer_Embedding_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -889,7 +889,7 @@ public void testInfer_Embedding_UnauthorisedResponse() throws IOException { public void testInfer_Rerank_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -923,7 +923,7 @@ public void testInfer_Rerank_UnauthorisedResponse() throws IOException { public void testInfer_Embedding_Get_Response_Ingest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -994,7 +994,7 @@ public void testInfer_Embedding_Get_Response_Ingest() throws IOException { public void testInfer_Embedding_Get_Response_Search() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1065,7 +1065,7 @@ public void testInfer_Embedding_Get_Response_Search() throws IOException { public void testInfer_Embedding_Get_Response_clustering() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ {"model":"jina-clip-v2","object":"list","usage":{"total_tokens":5,"prompt_tokens":5}, @@ -1120,7 +1120,7 @@ public void testInfer_Embedding_Get_Response_clustering() throws IOException { public void testInfer_Embedding_Get_Response_NullInputType() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1210,7 +1210,7 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_NoTopN() throws IOEx """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = JinaAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", null, false); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1295,7 +1295,7 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_TopN() throws IOExce """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = JinaAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", 3, false); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1392,7 +1392,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocumentsNull_NoTopN() throws IO """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = JinaAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", null, null); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1475,7 +1475,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = JinaAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", 3, true); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1540,7 +1540,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept public void testInfer_Embedding_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1637,7 +1637,7 @@ public void test_Embedding_ChunkedInfer_ChunkingSettingsNotSet() throws IOExcept private void test_Embedding_ChunkedInfer_BatchesCalls(JinaAIEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { // Batching will call the service with 2 input String responseJson = """ @@ -1800,7 +1800,7 @@ public void testGetConfiguration() throws Exception { } public void testDoesNotSupportsStreaming() throws IOException { - try (var service = new JinaAIService(mock(), createWithEmptySettings(mock()))) { + try (var service = new JinaAIService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertFalse(service.canStream(TaskType.COMPLETION)); assertFalse(service.canStream(TaskType.ANY)); } @@ -1841,7 +1841,7 @@ private Map getRequestConfigMap(Map serviceSetti } private JinaAIService createJinaAIService() { - return new JinaAIService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new JinaAIService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index 4ba9b8aa24394..8e170b25393e6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -249,7 +249,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws var mockModel = getInvalidModel("model_id", "service_name", TaskType.CHAT_COMPLETION); - try (var service = new MistralService(factory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -308,7 +308,7 @@ public void testUnifiedCompletionInfer() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = MistralChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( @@ -353,7 +353,7 @@ public void testUnifiedCompletionNonStreamingNotFoundError() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = MistralChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); var latch = new CountDownLatch(1); service.unifiedCompletionInfer( @@ -421,7 +421,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = MistralChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model"); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -459,7 +459,7 @@ public void testInfer_StreamRequest_ErrorResponse() { } public void testSupportsStreaming() throws IOException { - try (var service = new MistralService(mock(), createWithEmptySettings(mock()))) { + try (var service = new MistralService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -942,7 +942,7 @@ public void testParsePersistedConfig_WithoutSecretsCreatesAnEmbeddingsModelWhenC public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = new Model(ModelConfigurationsTests.createRandomInstance()); assertThrows( @@ -962,7 +962,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var model = MistralEmbeddingModelTests.createModel( randomAlphaOfLength(10), @@ -990,7 +990,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotMistralEmbeddingsModel() throws I var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new MistralService(factory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -1028,7 +1028,7 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { var model = MistralEmbeddingModelTests.createModel("id", "mistral-embed", "apikey", null, null, null, null); - try (var service = new MistralService(factory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( @@ -1086,7 +1086,7 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException { public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1173,7 +1173,7 @@ public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException { public void testInfer_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1276,7 +1276,7 @@ public void testGetConfiguration() throws Exception { // ---------------------------------------------------------------- private MistralService createService() { - return new MistralService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new MistralService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index c19eb664e88ac..83455861198d3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -847,7 +847,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotOpenAiModel() throws IOException var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -885,7 +885,7 @@ public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException { var model = OpenAiEmbeddingsModelTests.createModel(getUrl(webServer), "org", "secret", "model", "user"); - try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( ValidationException.class, @@ -924,7 +924,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid() throws IOException { var mockModel = getInvalidModel("model_id", "service_name", TaskType.SPARSE_EMBEDDING); - try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -965,7 +965,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws var mockModel = getInvalidModel("model_id", "service_name", TaskType.CHAT_COMPLETION); - try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -1003,7 +1003,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws public void testInfer_SendsRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1099,7 +1099,7 @@ public void testUnifiedCompletionInfer() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = OpenAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "org", "secret", "model", "user"); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( @@ -1132,7 +1132,7 @@ public void testUnifiedCompletionError() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = OpenAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "org", "secret", "model", "user"); var latch = new CountDownLatch(1); service.unifiedCompletionInfer( @@ -1189,7 +1189,7 @@ public void testMidStreamUnifiedCompletionError() throws Exception { private void testStreamError(String expectedResponse) throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = OpenAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "org", "secret", "model", "user"); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( @@ -1267,7 +1267,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceEventsAssertion streamCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var model = OpenAiChatCompletionModelTests.createCompletionModel(getUrl(webServer), "org", "secret", "model", "user"); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -1344,7 +1344,7 @@ public void testInfer_StreamRequestRetry() throws Exception { } public void testSupportsStreaming() throws IOException { - try (var service = new OpenAiService(mock(), createWithEmptySettings(mock()))) { + try (var service = new OpenAiService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); } @@ -1400,7 +1400,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testInfer_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1485,7 +1485,7 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { private void testChunkedInfer(OpenAiEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { // response with 2 embeddings String responseJson = """ @@ -1656,6 +1656,6 @@ public void testGetConfiguration() throws Exception { } private OpenAiService createOpenAiService() { - return new OpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new OpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java index d7d9473f18084..bf883a6345398 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java @@ -47,6 +47,7 @@ import static org.elasticsearch.action.support.ActionTestUtils.assertNoSuccessListener; import static org.elasticsearch.core.TimeValue.THIRTY_SECONDS; import static org.elasticsearch.xpack.core.inference.action.UnifiedCompletionRequestTests.randomUnifiedCompletionRequest; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -84,7 +85,7 @@ public void init() { ThreadPool threadPool = mock(); when(threadPool.executor(anyString())).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE); when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); - sageMakerService = new SageMakerService(modelBuilder, client, schemas, threadPool, Map::of); + sageMakerService = new SageMakerService(modelBuilder, client, schemas, threadPool, Map::of, mockClusterServiceEmpty()); } public void testSupportedTaskTypes() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java index 8602621e9eb78..72a3b530ab647 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -718,7 +718,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotVoyageAIModel() throws IOExceptio var mockModel = getInvalidModel("model_id", "service_name"); - try (var service = new VoyageAIService(factory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); service.infer( mockModel, @@ -763,7 +763,7 @@ public void testInfer_ThrowsValidationErrorForInvalidInputType() throws IOExcept "voyage-3-large" ); - try (var service = new VoyageAIService(factory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( @@ -806,7 +806,7 @@ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { var embeddingSize = randomNonNegativeInt(); var model = VoyageAIEmbeddingsModelTests.createModel( randomAlphaOfLength(10), @@ -831,7 +831,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si public void testInfer_Embedding_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -873,7 +873,7 @@ public void testInfer_Embedding_UnauthorisedResponse() throws IOException { public void testInfer_Rerank_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -907,7 +907,7 @@ public void testInfer_Rerank_UnauthorisedResponse() throws IOException { public void testInfer_Embedding_Get_Response_Ingest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -989,7 +989,7 @@ public void testInfer_Embedding_Get_Response_Ingest() throws IOException { public void testInfer_Embedding_Get_Response_Search() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1071,7 +1071,7 @@ public void testInfer_Embedding_Get_Response_Search() throws IOException { public void testInfer_Embedding_Get_Response_NullInputType() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1163,7 +1163,7 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_NoTopN() throws IOEx """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", null, false, false); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1251,7 +1251,7 @@ public void testInfer_Rerank_Get_Response_NoReturnDocuments_TopN() throws IOExce """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", 3, false, false); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1345,7 +1345,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocumentsNull_NoTopN() throws IO """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", null, null, null); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1423,7 +1423,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept """; var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = VoyageAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", 3, true, true); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1490,7 +1490,7 @@ public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOExcept public void testInfer_Embedding_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { String responseJson = """ { @@ -1599,7 +1599,7 @@ public void test_Embedding_ChunkedInfer_ChunkingSettingsNotSet() throws IOExcept private void test_Embedding_ChunkedInfer_BatchesCalls(VoyageAIEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { // Batching will call the service with 2 input String responseJson = """ @@ -1745,7 +1745,7 @@ public void testGetConfiguration() throws Exception { } public void testDoesNotSupportsStreaming() throws IOException { - try (var service = new VoyageAIService(mock(), createWithEmptySettings(mock()))) { + try (var service = new VoyageAIService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { assertFalse(service.canStream(TaskType.COMPLETION)); assertFalse(service.canStream(TaskType.ANY)); } @@ -1786,7 +1786,7 @@ private Map getRequestConfigMap(Map serviceSetti } private VoyageAIService createVoyageAIService() { - return new VoyageAIService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + return new VoyageAIService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } }