diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutionState.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutionState.java index 55f1f49f68c21..48303be617286 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutionState.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutionState.java @@ -27,8 +27,8 @@ public class BulkInferenceExecutionState { private final Map bufferedResponses; private final AtomicBoolean finished = new AtomicBoolean(false); - public BulkInferenceExecutionState(int bufferSize) { - this.bufferedResponses = new ConcurrentHashMap<>(bufferSize); + public BulkInferenceExecutionState() { + this.bufferedResponses = new ConcurrentHashMap<>(); } /** @@ -125,7 +125,7 @@ public void addFailure(Exception e) { * Indicates whether the entire bulk execution is marked as finished and all responses have been successfully persisted. */ public boolean finished() { - return finished.get() && getMaxSeqNo() == getPersistedCheckpoint(); + return hasFailure() || (finished.get() && getMaxSeqNo() == getPersistedCheckpoint()); } /** diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutor.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutor.java index 257799962dda7..1dfedd55a39fe 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutor.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutor.java @@ -8,253 +8,344 @@ package org.elasticsearch.xpack.esql.inference.bulk; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.common.util.concurrent.AbstractRunnable; +import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.esql.inference.InferenceRunner; -import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; import java.util.ArrayList; import java.util.List; -import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.BlockingQueue; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ExecutorService; import java.util.concurrent.Semaphore; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Consumer; + +import static org.elasticsearch.xpack.esql.plugin.EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME; /** * Executes a sequence of inference requests in bulk with throttling and concurrency control. */ public class BulkInferenceExecutor { - private final ThrottledInferenceRunner throttledInferenceRunner; - private final BulkInferenceExecutionConfig bulkExecutionConfig; + private final InferenceRunner inferenceRunner; + private final Semaphore permits; + private final ExecutorService executorService; + + /** + * Custom concurrent queue that prevents duplicate bulk requests from being queued. + *

+ * This queue implementation ensures fairness among multiple concurrent bulk operations + * by preventing the same bulk request from being queued multiple times. It uses a + * backing concurrent set to track which requests are already queued. + *

+ */ + private final Queue pendingBulkRequests = new ConcurrentLinkedQueue<>() { + private final Set requests = ConcurrentCollections.newConcurrentSet(); + + @Override + public boolean offer(BulkInferenceRequest bulkInferenceRequest) { + synchronized (requests) { + if (requests.add(bulkInferenceRequest)) { + return super.offer(bulkInferenceRequest); + } + return false; // Already exists, don't add duplicate + } + } + + @Override + public BulkInferenceRequest poll() { + synchronized (requests) { + BulkInferenceRequest request = super.poll(); + if (request != null) { + requests.remove(request); + } + return request; + } + } + }; /** * Constructs a new {@code BulkInferenceExecutor}. * - * @param inferenceRunner The inference runner used to execute individual inference requests. - * @param threadPool The thread pool for executing inference tasks. - * @param bulkExecutionConfig Configuration options (throttling and concurrency limits). + * @param inferenceRunner The inference runner used to execute individual inference requests. + * @param threadPool The thread pool for executing inference tasks. + * @param bulkExecutionConfig Configuration options (throttling and concurrency limits). */ public BulkInferenceExecutor(InferenceRunner inferenceRunner, ThreadPool threadPool, BulkInferenceExecutionConfig bulkExecutionConfig) { - this.throttledInferenceRunner = ThrottledInferenceRunner.create(inferenceRunner, executorService(threadPool), bulkExecutionConfig); - this.bulkExecutionConfig = bulkExecutionConfig; + this.inferenceRunner = inferenceRunner; + this.permits = new Semaphore(bulkExecutionConfig.maxOutstandingRequests()); + this.executorService = threadPool.executor(ESQL_WORKER_THREAD_POOL_NAME); } /** - * Executes the provided bulk inference requests. - *

- * Each request is sent to the {@link ThrottledInferenceRunner} to be executed. - * The final listener is notified with all successful responses once all requests are completed. + * Executes multiple inference requests in bulk and collects all responses. * - * @param requests An iterator over the inference requests to be executed. - * @param listener A listener notified with the complete list of responses or a failure. + * @param requests An iterator over the inference requests to execute + * @param listener Called with the list of all responses in request order */ public void execute(BulkInferenceRequestIterator requests, ActionListener> listener) { + List responses = new ArrayList<>(); + execute(requests, responses::add, ActionListener.wrap(ignored -> listener.onResponse(responses), listener::onFailure)); + } + + /** + * Executes multiple inference requests in bulk with streaming response handling. + *

+ * This method orchestrates the entire bulk inference process: + * 1. Creates execution state to track progress and responses + * 2. Sets up response handling pipeline + * 3. Initiates asynchronous request processing + *

+ * + * @param requests An iterator over the inference requests to execute + * @param responseConsumer Called for each successful inference response as they complete + * @param completionListener Called when all requests are complete or if any error occurs + */ + public void execute( + BulkInferenceRequestIterator requests, + Consumer responseConsumer, + ActionListener completionListener + ) { if (requests.hasNext() == false) { - listener.onResponse(List.of()); + completionListener.onResponse(null); return; } - final BulkInferenceExecutionState bulkExecutionState = new BulkInferenceExecutionState( - bulkExecutionConfig.maxOutstandingRequests() - ); - final ResponseHandler responseHandler = new ResponseHandler(bulkExecutionState, listener, requests.estimatedSize()); - - while (bulkExecutionState.finished() == false && requests.hasNext()) { - InferenceAction.Request request = requests.next(); - long seqNo = bulkExecutionState.generateSeqNo(); - - if (requests.hasNext() == false) { - bulkExecutionState.finish(); - } - - ActionListener inferenceResponseListener = ActionListener.runAfter( - ActionListener.wrap( - r -> bulkExecutionState.onInferenceResponse(seqNo, r), - e -> bulkExecutionState.onInferenceException(seqNo, e) - ), - responseHandler::persistPendingResponses - ); - - if (request == null) { - inferenceResponseListener.onResponse(null); - } else { - throttledInferenceRunner.doInference(request, inferenceResponseListener); - } - } + new BulkInferenceRequest(requests, responseConsumer, completionListener).executePendingRequests(); } /** - * Handles collection and delivery of inference responses once they are complete. + * Encapsulates the execution state and logic for a single bulk inference operation. + *

+ * This inner class manages the complete lifecycle of a bulk inference request, including: + * - Request iteration and permit-based concurrency control + * - Asynchronous execution with hybrid recursion strategy + * - Response collection and ordering via execution state + * - Error handling and completion notification + *

+ *

+ * Each BulkInferenceRequest instance represents one bulk operation that may contain + * multiple individual inference requests. Multiple BulkInferenceRequest instances + * can execute concurrently, with fairness ensured through the pending queue mechanism. + *

*/ - private static class ResponseHandler { - private final List responses; - private final ActionListener> listener; - private final BulkInferenceExecutionState bulkExecutionState; + private class BulkInferenceRequest { + private final BulkInferenceRequestIterator requests; + private final Consumer responseConsumer; + private final ActionListener completionListener; + + private final BulkInferenceExecutionState executionState = new BulkInferenceExecutionState(); private final AtomicBoolean responseSent = new AtomicBoolean(false); - private ResponseHandler( - BulkInferenceExecutionState bulkExecutionState, - ActionListener> listener, - int estimatedSize + BulkInferenceRequest( + BulkInferenceRequestIterator requests, + Consumer responseConsumer, + ActionListener completionListener ) { - this.listener = listener; - this.bulkExecutionState = bulkExecutionState; - this.responses = new ArrayList<>(estimatedSize); + this.requests = requests; + this.responseConsumer = responseConsumer; + this.completionListener = completionListener; } /** - * Persists all buffered responses that can be delivered in order, and sends the final response if all requests are finished. + * Attempts to poll the next request from the iterator and acquire a permit for execution. + *

+ * Because multiple threads may call this concurrently via async callbacks, this method is synchronized to ensure thread-safe access + * to the request iterator. + *

+ * + * @return A BulkRequestItem if a request and permit are available, null otherwise */ - public synchronized void persistPendingResponses() { - long persistedSeqNo = bulkExecutionState.getPersistedCheckpoint(); - - while (persistedSeqNo < bulkExecutionState.getProcessedCheckpoint()) { - persistedSeqNo++; - if (bulkExecutionState.hasFailure() == false) { - try { - InferenceAction.Response response = bulkExecutionState.fetchBufferedResponse(persistedSeqNo); - responses.add(response); - } catch (Exception e) { - bulkExecutionState.addFailure(e); - } + private BulkRequestItem pollPendingRequest() { + synchronized (requests) { + if (requests.hasNext()) { + return new BulkRequestItem(executionState.generateSeqNo(), requests.next()); } - bulkExecutionState.markSeqNoAsPersisted(persistedSeqNo); } - sendResponseOnCompletion(); + return null; } /** - * Sends the final response or failure once all inference tasks have completed. + * Main execution loop that processes inference requests asynchronously with hybrid recursion strategy. + *

+ * This method implements a continuation-based asynchronous pattern with the following features: + * - Queue-based fairness: Multiple bulk requests can be queued and processed fairly + * - Permit-based concurrency control: Limits concurrent inference requests using semaphores + * - Hybrid recursion strategy: Uses direct recursion for performance up to 100 levels, + * then switches to executor-based continuation to prevent stack overflow + * - Duplicate prevention: Custom queue prevents the same bulk request from being queued multiple times + *

+ *

+ * Execution flow: + * 1. Attempts to acquire a permit for concurrent execution + * 2. If no permit available, queues this bulk request for later execution + * 3. Polls for the next available request from the iterator + * 4. If no requests available, schedules the next queued bulk request + * 5. Executes the request asynchronously with proper continuation handling + * 6. Uses hybrid recursion: direct calls up to 100 levels, executor-based beyond that + *

+ *

+ * The loop terminates when: + * - No more requests are available and no permits can be acquired + * - The bulk execution is marked as finished (due to completion or failure) + * - An unrecoverable error occurs during processing + *

*/ - private void sendResponseOnCompletion() { - if (bulkExecutionState.finished() && responseSent.compareAndSet(false, true)) { - if (bulkExecutionState.hasFailure() == false) { - try { - listener.onResponse(responses); + private void executePendingRequests() { + executePendingRequests(0); + } + + private void executePendingRequests(int recursionDepth) { + try { + while (executionState.finished() == false) { + if (permits.tryAcquire() == false) { + if (requests.hasNext()) { + pendingBulkRequests.add(this); + } return; - } catch (Exception e) { - bulkExecutionState.addFailure(e); - } - } + } else { + BulkRequestItem bulkRequestItem = pollPendingRequest(); - listener.onFailure(bulkExecutionState.getFailure()); - } - } - } + if (bulkRequestItem == null) { + // No more requests available + // Release the permit we didn't used and stop processing + permits.release(); - /** - * Manages throttled inference tasks execution. - */ - private static class ThrottledInferenceRunner { - private final InferenceRunner inferenceRunner; - private final ExecutorService executorService; - private final BlockingQueue pendingRequestsQueue; - private final Semaphore permits; - - private ThrottledInferenceRunner(InferenceRunner inferenceRunner, ExecutorService executorService, int maxRunningTasks) { - this.executorService = executorService; - this.permits = new Semaphore(maxRunningTasks); - this.inferenceRunner = inferenceRunner; - this.pendingRequestsQueue = new ArrayBlockingQueue<>(maxRunningTasks); - } + // Check if another bulk request is pending for execution. + BulkInferenceRequest nexBulkRequest = pendingBulkRequests.poll(); - /** - * Creates a new {@code ThrottledInferenceRunner} with the specified configuration. - * - * @param inferenceRunner TThe inference runner used to execute individual inference requests. - * @param executorService The executor used for asynchronous execution. - * @param bulkExecutionConfig Configuration options (throttling and concurrency limits). - */ - public static ThrottledInferenceRunner create( - InferenceRunner inferenceRunner, - ExecutorService executorService, - BulkInferenceExecutionConfig bulkExecutionConfig - ) { - return new ThrottledInferenceRunner(inferenceRunner, executorService, bulkExecutionConfig.maxOutstandingRequests()); - } + while (nexBulkRequest == this) { + nexBulkRequest = pendingBulkRequests.poll(); + } - /** - * Schedules the inference task for execution. If a permit is available, the task runs immediately; otherwise, it is queued. - * - * @param request The inference request. - * @param listener The listener to notify on response or failure. - */ - public void doInference(InferenceAction.Request request, ActionListener listener) { - enqueueTask(request, listener); - executePendingRequests(); - } + if (nexBulkRequest != null) { + executorService.execute(nexBulkRequest::executePendingRequests); + } - /** - * Attempts to execute as many pending inference tasks as possible, limited by available permits. - */ - private void executePendingRequests() { - while (permits.tryAcquire()) { - AbstractRunnable task = pendingRequestsQueue.poll(); + return; + } - if (task == null) { - permits.release(); - return; - } + if (requests.hasNext() == false) { + // This is the last request - mark bulk execution as finished + // to prevent further processing attempts + executionState.finish(); + } - try { - executorService.execute(task); - } catch (Exception e) { - task.onFailure(e); - permits.release(); + final ActionListener inferenceResponseListener = ActionListener.runAfter( + ActionListener.wrap( + r -> executionState.onInferenceResponse(bulkRequestItem.seqNo(), r), + e -> executionState.onInferenceException(bulkRequestItem.seqNo(), e) + ), + () -> { + // Release the permit we used + permits.release(); + + try { + synchronized (executionState) { + persistPendingResponses(); + } + + if (executionState.finished() && responseSent.compareAndSet(false, true)) { + onBulkCompletion(); + } + + if (responseSent.get()) { + // Response has already been sent + // No need to continue processing this bulk. + // Check if another bulk request is pending for execution. + BulkInferenceRequest nexBulkRequest = pendingBulkRequests.poll(); + if (nexBulkRequest != null) { + executorService.execute(nexBulkRequest::executePendingRequests); + } + return; + } + if (executionState.finished() == false) { + // Execute any pending requests if any + if (recursionDepth > 100) { + executorService.execute(this::executePendingRequests); + } else { + this.executePendingRequests(recursionDepth + 1); + } + } + } catch (Exception e) { + if (responseSent.compareAndSet(false, true)) { + completionListener.onFailure(e); + } + } + } + ); + + // Handle null requests (edge case in some iterators) + if (bulkRequestItem.request() == null) { + inferenceResponseListener.onResponse(null); + return; + } + + // Execute the inference request with proper origin context + inferenceRunner.doInference(bulkRequestItem.request(), inferenceResponseListener); + } } + } catch (Exception e) { + executionState.addFailure(e); } } /** - * Add an inference task to the queue. - * - * @param request The inference request. - * * @param listener The listener to notify on response or failure. + * Processes and delivers buffered responses in order, ensuring proper sequencing. + *

+ * This method is synchronized to ensure thread-safe access to the execution state + * and prevent concurrent response processing which could cause ordering issues. + * Processing stops immediately if a failure is detected to implement fail-fast behavior. + *

*/ - private void enqueueTask(InferenceAction.Request request, ActionListener listener) { - try { - pendingRequestsQueue.put(createTask(request, listener)); - } catch (Exception e) { - listener.onFailure(new IllegalStateException("An error occurred while adding the inference request to the queue", e)); + private void persistPendingResponses() { + long persistedSeqNo = executionState.getPersistedCheckpoint(); + + while (persistedSeqNo < executionState.getProcessedCheckpoint()) { + persistedSeqNo++; + if (executionState.hasFailure() == false) { + try { + InferenceAction.Response response = executionState.fetchBufferedResponse(persistedSeqNo); + responseConsumer.accept(response); + } catch (Exception e) { + executionState.addFailure(e); + } + } + executionState.markSeqNoAsPersisted(persistedSeqNo); } } /** - * Wraps an inference request into an {@link AbstractRunnable} that releases its permit on completion and triggers any remaining - * queued tasks. - * - * @param request The inference request. - * @param listener The listener to notify on completion. - * @return A runnable task encapsulating the request. + * Call the completion listener when all requests have completed. */ - private AbstractRunnable createTask(InferenceAction.Request request, ActionListener listener) { - final ActionListener completionListener = ActionListener.runAfter(listener, () -> { - permits.release(); - executePendingRequests(); - }); - - return new AbstractRunnable() { - @Override - protected void doRun() { - try { - inferenceRunner.doInference(request, completionListener); - } catch (Throwable e) { - listener.onFailure(new RuntimeException("Unexpected failure while running inference", e)); - } + private void onBulkCompletion() { + if (executionState.hasFailure() == false) { + try { + completionListener.onResponse(null); + return; + } catch (Exception e) { + executionState.addFailure(e); } + } - @Override - public void onFailure(Exception e) { - completionListener.onFailure(e); - } - }; + completionListener.onFailure(executionState.getFailure()); } } - private static ExecutorService executorService(ThreadPool threadPool) { - return threadPool.executor(EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME); + /** + * Encapsulates an inference request with its associated sequence number. + *

+ * The sequence number is used for ordering responses and tracking completion + * in the bulk execution state. + *

+ * + * @param seqNo Unique sequence number for this request in the bulk operation + * @param request The actual inference request to execute + */ + private record BulkRequestItem(long seqNo, InferenceAction.Request request) { + } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutorTests.java index 7e44c681c6fc4..03d07adbd87cd 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutorTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutorTests.java @@ -26,6 +26,8 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import static org.hamcrest.Matchers.allOf; @@ -61,7 +63,7 @@ public void shutdownThreadPool() { } public void testSuccessfulExecution() throws Exception { - List requests = randomInferenceRequestList(between(1, 1000)); + List requests = randomInferenceRequestList(between(1, 1_000)); List responses = randomInferenceResponseList(requests.size()); InferenceRunner inferenceRunner = mockInferenceRunner(invocation -> { @@ -141,6 +143,35 @@ public void testInferenceRunnerSometimesFails() throws Exception { }); } + public void testParallelBulkExecution() throws Exception { + int batches = between(50, 100); + CountDownLatch latch = new CountDownLatch(batches); + + for (int i = 0; i < batches; i++) { + runWithRandomDelay(() -> { + List requests = randomInferenceRequestList(between(1, 1_000)); + List responses = randomInferenceResponseList(requests.size()); + + InferenceRunner inferenceRunner = mockInferenceRunner(invocation -> { + runWithRandomDelay(() -> { + ActionListener l = invocation.getArgument(1); + l.onResponse(responses.get(requests.indexOf(invocation.getArgument(0, InferenceAction.Request.class)))); + }); + return null; + }); + + ActionListener> listener = ActionListener.wrap(r -> { + assertThat(r, equalTo(responses)); + latch.countDown(); + }, ESTestCase::fail); + + bulkExecutor(inferenceRunner).execute(requestIterator(requests), listener); + }); + } + + latch.await(10, TimeUnit.SECONDS); + } + private BulkInferenceExecutor bulkExecutor(InferenceRunner inferenceRunner) { return new BulkInferenceExecutor(inferenceRunner, threadPool, randomBulkExecutionConfig()); } @@ -195,11 +226,7 @@ private void runWithRandomDelay(Runnable runnable) { if (randomBoolean()) { runnable.run(); } else { - threadPool.schedule( - runnable, - TimeValue.timeValueNanos(between(1, 1_000)), - threadPool.executor(EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME) - ); + threadPool.schedule(runnable, TimeValue.timeValueNanos(between(1, 1_000)), threadPool.generic()); } } }