From 4371bbd24249f5b38d95035d31c47d87b9aac884 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Wed, 4 Jun 2025 00:35:07 +0300 Subject: [PATCH 1/7] Refactor response handlers to improve error handling and streamline mid-stream error processing --- .../http/retry/BaseResponseHandler.java | 231 +++++++++++++++++- ...eUnifiedChatCompletionResponseHandler.java | 147 ++++++++--- ...iUnifiedChatCompletionResponseHandler.java | 103 ++++---- ...gingFaceChatCompletionResponseHandler.java | 91 ++++--- ...lUnifiedChatCompletionResponseHandler.java | 29 +-- ...iUnifiedChatCompletionResponseHandler.java | 126 ++++++---- .../openai/OpenAiCompletionPayload.java | 12 +- 7 files changed, 534 insertions(+), 205 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java index 56e994be86eb4..43954a5af8b91 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java @@ -12,10 +12,12 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import java.util.Locale; import java.util.Objects; import java.util.function.Function; @@ -34,17 +36,23 @@ public abstract class BaseResponseHandler implements ResponseHandler { public static final String SERVER_ERROR_OBJECT = "Received an error response"; public static final String BAD_REQUEST = "Received a bad request status code"; public static final String METHOD_NOT_ALLOWED = "Received a method not allowed status code"; + protected static final String ERROR_TYPE = "error"; + protected static final String STREAM_ERROR = "stream_error"; protected final String requestType; protected final ResponseParser parseFunction; private final Function errorParseFunction; private final boolean canHandleStreamingResponses; - public BaseResponseHandler(String requestType, ResponseParser parseFunction, Function errorParseFunction) { + protected BaseResponseHandler( + String requestType, + ResponseParser parseFunction, + Function errorParseFunction + ) { this(requestType, parseFunction, errorParseFunction, false); } - public BaseResponseHandler( + protected BaseResponseHandler( String requestType, ResponseParser parseFunction, Function errorParseFunction, @@ -109,19 +117,230 @@ private void checkForErrorObject(Request request, HttpResult result) { } protected Exception buildError(String message, Request request, HttpResult result) { - var errorEntityMsg = errorParseFunction.apply(result); - return buildError(message, request, result, errorEntityMsg); + var errorResponse = errorParseFunction.apply(result); + return buildError(message, request, result, errorResponse); } protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { var responseStatusCode = result.response().getStatusLine().getStatusCode(); return new ElasticsearchStatusException( - errorMessage(message, request, result, errorResponse, responseStatusCode), + errorMessage(message, request, errorResponse, responseStatusCode), toRestStatus(responseStatusCode) ); } - protected String errorMessage(String message, Request request, HttpResult result, ErrorResponse errorResponse, int statusCode) { + /** + * Builds an error for a streaming request with a custom error type. + * This method is used when an error response is received from the external service. + * Only streaming requests support this format, and it should be used when the error response. + * + * @param message the error message to include in the exception + * @param request the request that caused the error + * @param result the HTTP result containing the error response + * @param errorResponse the parsed error response from the HTTP result + * @param errorResponseClass the class of the expected error response type + * @return an instance of {@link UnifiedChatCompletionException} with details from the error response + */ + protected UnifiedChatCompletionException buildChatCompletionError( + String message, + Request request, + HttpResult result, + ErrorResponse errorResponse, + Class errorResponseClass + ) { + assert request.isStreaming() : "Only streaming requests support this format"; + var statusCode = result.response().getStatusLine().getStatusCode(); + var errorMessage = errorMessage(message, request, errorResponse, statusCode); + var restStatus = toRestStatus(statusCode); + + return buildChatCompletionError(errorResponse, errorMessage, restStatus, errorResponseClass); + } + + /** + * Builds a {@link UnifiedChatCompletionException} for a streaming request. + * This method is used when an error response is received from the external service. + * Only streaming requests should use this method. + * + * @param errorResponse the error response parsed from the HTTP result + * @param errorMessage the error message to include in the exception + * @param restStatus the REST status code of the response + * @param errorResponseClass the class of the expected error response type + * @return an instance of {@link UnifiedChatCompletionException} with details from the error response + */ + protected UnifiedChatCompletionException buildChatCompletionError( + ErrorResponse errorResponse, + String errorMessage, + RestStatus restStatus, + Class errorResponseClass + ) { + if (errorResponseClass.isInstance(errorResponse)) { + return buildProviderSpecificChatCompletionError(errorResponse, errorMessage, restStatus); + } else { + return buildDefaultChatCompletionError(errorResponse, errorMessage, restStatus); + } + } + + /** + * Builds a custom {@link UnifiedChatCompletionException} for a streaming request. + * This method is called when a specific error response is found in the HTTP result. + * It must be implemented by subclasses to handle specific error response formats. + * Only streaming requests should use this method. + * + * @param errorResponse the error response parsed from the HTTP result + * @param errorMessage the error message to include in the exception + * @param restStatus the REST status code of the response + * @return an instance of {@link UnifiedChatCompletionException} with details from the error response + */ + protected UnifiedChatCompletionException buildProviderSpecificChatCompletionError( + ErrorResponse errorResponse, + String errorMessage, + RestStatus restStatus + ) { + throw new UnsupportedOperationException( + "Custom error handling is not implemented. Please override buildProviderSpecificChatCompletionError method." + ); + } + + /** + * Builds a default {@link UnifiedChatCompletionException} for a streaming request. + * This method is used when an error response is received but no specific error handling is implemented. + * Only streaming requests should use this method. + * + * @param errorResponse the error response parsed from the HTTP result + * @param errorMessage the error message to include in the exception + * @param restStatus the REST status code of the response + * @return an instance of {@link UnifiedChatCompletionException} with details from the error response + */ + protected UnifiedChatCompletionException buildDefaultChatCompletionError( + ErrorResponse errorResponse, + String errorMessage, + RestStatus restStatus + ) { + return new UnifiedChatCompletionException( + restStatus, + errorMessage, + createErrorType(errorResponse), + restStatus.name().toLowerCase(Locale.ROOT) + ); + } + + /** + * Builds a mid-stream error for a streaming request. + * This method is used when an error occurs while processing a streaming response. + * It must be implemented by subclasses to handle specific error response formats. + * Only streaming requests should use this method. + * + * @param inferenceEntityId the ID of the inference entity + * @param message the error message + * @param e the exception that caused the error, can be null + * @return a {@link UnifiedChatCompletionException} representing the mid-stream error + */ + public UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) { + throw new UnsupportedOperationException( + "Mid-stream error handling is not implemented. Please override buildMidStreamChatCompletionError method." + ); + } + + /** + * Builds a mid-stream error for a streaming request with a custom error type. + * This method is used when an error occurs while processing a streaming response and allows for custom error handling. + * Only streaming requests should use this method. + * + * @param inferenceEntityId the ID of the inference entity + * @param message the error message + * @param e the exception that caused the error, can be null + * @param errorResponseClass the class of the expected error response type + * @return a {@link UnifiedChatCompletionException} representing the mid-stream error + */ + protected UnifiedChatCompletionException buildMidStreamChatCompletionError( + String inferenceEntityId, + String message, + Exception e, + Class errorResponseClass + ) { + // Extract the error response from the message using the provided method + var errorResponse = extractMidStreamChatCompletionErrorResponse(message); + // Check if the error response matches the expected type + if (errorResponseClass.isInstance(errorResponse)) { + // If it matches, we can build a custom mid-stream error exception + return buildProviderSpecificMidStreamChatCompletionError(inferenceEntityId, errorResponse); + } else if (e != null) { + // If the error response does not match, we can still return an exception based on the original throwable + return UnifiedChatCompletionException.fromThrowable(e); + } else { + // If no specific error response is found, we return a default mid-stream error + return buildDefaultMidStreamChatCompletionError(inferenceEntityId, errorResponse); + } + } + + /** + * Builds a custom mid-stream {@link UnifiedChatCompletionException} for a streaming request. + * This method is called when a specific error response is found in the message. + * It must be implemented by subclasses to handle specific error response formats. + * Only streaming requests should use this method. + * + * @param inferenceEntityId the ID of the inference entity + * @param errorResponse the error response parsed from the message + * @return an instance of {@link UnifiedChatCompletionException} with details from the error response + */ + protected UnifiedChatCompletionException buildProviderSpecificMidStreamChatCompletionError( + String inferenceEntityId, + ErrorResponse errorResponse + ) { + throw new UnsupportedOperationException( + "Mid-stream error handling is not implemented for this response handler. " + + "Please override buildProviderSpecificMidStreamChatCompletionError method." + ); + } + + /** + * Builds a default mid-stream error for a streaming request. + * This method is used when no specific error response is found in the message. + * Only streaming requests should use this method. + * + * @param inferenceEntityId the ID of the inference entity + * @param errorResponse the error response extracted from the message + * @return a {@link UnifiedChatCompletionException} representing the default mid-stream error + */ + protected UnifiedChatCompletionException buildDefaultMidStreamChatCompletionError( + String inferenceEntityId, + ErrorResponse errorResponse + ) { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, inferenceEntityId), + createErrorType(errorResponse), + STREAM_ERROR + ); + } + + /** + * Extracts the mid-stream error response from the message. + * This method is used to parse the error response from a streaming message. + * It must be implemented by subclasses to handle specific error response formats. + * Only streaming requests should use this method. + * + * @param message the message containing the error response + * @return an {@link ErrorResponse} object representing the mid-stream error + */ + protected ErrorResponse extractMidStreamChatCompletionErrorResponse(String message) { + throw new UnsupportedOperationException( + "Mid-stream error extraction is not implemented. Please override extractMidStreamChatCompletionErrorResponse method." + ); + } + + /** + * Creates a string representation of the error type based on the provided ErrorResponse. + * This method is used to generate a human-readable error type for logging or exception messages. + * + * @param errorResponse the ErrorResponse object + * @return a string representing the error type + */ + protected static String createErrorType(ErrorResponse errorResponse) { + return errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown"; + } + + protected String errorMessage(String message, Request request, ErrorResponse errorResponse, int statusCode) { return (errorResponse == null || errorResponse.errorStructureFound() == false || Strings.isNullOrEmpty(errorResponse.getErrorMessage())) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java index ad40d43b3af3b..346b6a4f9026f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java @@ -25,6 +25,10 @@ import static org.elasticsearch.core.Strings.format; +/** + * Handles streaming chat completion responses and error parsing for Elastic Inference Service endpoints. + * This handler is designed to work with the unified Elastic Inference Service chat completion API. + */ public class ElasticInferenceServiceUnifiedChatCompletionResponseHandler extends ElasticInferenceServiceResponseHandler { public ElasticInferenceServiceUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { super(requestType, parseFunction, true); @@ -34,53 +38,128 @@ public ElasticInferenceServiceUnifiedChatCompletionResponseHandler(String reques public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser()); // EIS uses the unified API spec - var openAiProcessor = new OpenAiUnifiedStreamingProcessor((m, e) -> buildMidStreamError(request, m, e)); + var openAiProcessor = new OpenAiUnifiedStreamingProcessor( + (m, e) -> buildMidStreamChatCompletionError(request.getInferenceEntityId(), m, e) + ); flow.subscribe(serverSentEventProcessor); serverSentEventProcessor.subscribe(openAiProcessor); return new StreamingUnifiedChatCompletionResults(openAiProcessor); } + /** + * Builds an error for the Elastic Inference Service. + * This method is called when an error response is received from the service. + * + * @param message The error message to include in the exception. + * @param request The request that caused the error. + * @param result The HTTP result containing the error response. + * @param errorResponse The parsed error response from the service. + * @return An instance of {@link Exception} representing the error. + */ @Override - protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { - assert request.isStreaming() : "Only streaming requests support this format"; - var responseStatusCode = result.response().getStatusLine().getStatusCode(); - if (request.isStreaming()) { - var restStatus = toRestStatus(responseStatusCode); - return new UnifiedChatCompletionException( - restStatus, - errorMessage(message, request, result, errorResponse, responseStatusCode), - "error", - restStatus.name().toLowerCase(Locale.ROOT) - ); - } else { - return super.buildError(message, request, result, errorResponse); - } + protected UnifiedChatCompletionException buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { + return buildChatCompletionError(message, request, result, errorResponse, ErrorResponse.class); + } + + /** + * Builds a custom {@link UnifiedChatCompletionException} for the Elastic Inference Service. + * This method is called when an error response is received from the service. + * + * @param errorResponse The error response received from the service. + * @param errorMessage The error message to include in the exception. + * @param restStatus The HTTP status of the error response. + * @param errorResponseClass The class of the error response. + * @return An instance of {@link UnifiedChatCompletionException} with details from the error response. + */ + @Override + protected UnifiedChatCompletionException buildChatCompletionError( + ErrorResponse errorResponse, + String errorMessage, + RestStatus restStatus, + Class errorResponseClass + ) { + return new UnifiedChatCompletionException(restStatus, errorMessage, ERROR_TYPE, restStatus.name().toLowerCase(Locale.ROOT)); } - private static Exception buildMidStreamError(Request request, String message, Exception e) { - var errorResponse = ElasticInferenceServiceErrorResponseEntity.fromString(message); + /** + * Builds a mid-stream error for the Elastic Inference Service. + * This method is called when an error occurs during the streaming process. + * + * @param inferenceEntityId The ID of the inference entity. + * @param message The error message received from the service. + * @param e The exception that occurred, if any. + * @return An instance of {@link UnifiedChatCompletionException} representing the mid-stream error. + */ + @Override + public UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) { + var errorResponse = extractMidStreamChatCompletionErrorResponse(message); + // Check if the error response contains a specific structure if (errorResponse.errorStructureFound()) { - return new UnifiedChatCompletionException( - RestStatus.INTERNAL_SERVER_ERROR, - format( - "%s for request from inference entity id [%s]. Error message: [%s]", - SERVER_ERROR_OBJECT, - request.getInferenceEntityId(), - errorResponse.getErrorMessage() - ), - "error", - "stream_error" - ); + return buildProviderSpecificMidStreamChatCompletionError(inferenceEntityId, errorResponse); } else if (e != null) { return UnifiedChatCompletionException.fromThrowable(e); } else { - return new UnifiedChatCompletionException( - RestStatus.INTERNAL_SERVER_ERROR, - format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()), - "error", - "stream_error" - ); + return buildDefaultMidStreamChatCompletionError(inferenceEntityId, errorResponse); } } + + /** + * Extracts the error response from the message. This method is specific to the Elastic Inference Service + * and should parse the message according to its error response format. + * + * @param message The message containing the error response. + * @return An instance of {@link ErrorResponse} parsed from the message. + */ + @Override + protected ErrorResponse extractMidStreamChatCompletionErrorResponse(String message) { + return ElasticInferenceServiceErrorResponseEntity.fromString(message); + } + + /** + * Builds a custom mid-stream {@link UnifiedChatCompletionException} for the Elastic Inference Service. + * This method is called when a specific error response structure is found in the message. + * + * @param inferenceEntityId The ID of the inference entity. + * @param errorResponse The error response parsed from the message. + * @return An instance of {@link UnifiedChatCompletionException} with details from the error response. + */ + @Override + protected UnifiedChatCompletionException buildProviderSpecificMidStreamChatCompletionError( + String inferenceEntityId, + ErrorResponse errorResponse + ) { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format( + "%s for request from inference entity id [%s]. Error message: [%s]", + SERVER_ERROR_OBJECT, + inferenceEntityId, + errorResponse.getErrorMessage() + ), + ERROR_TYPE, + STREAM_ERROR + ); + } + + /** + * Builds a default mid-stream {@link UnifiedChatCompletionException} for the Elastic Inference Service. + * This method is called when specific error response structure is NOT found in the message. + * + * @param inferenceEntityId The ID of the inference entity. + * @param errorResponse The error response parsed from the message. + * @return An instance of {@link UnifiedChatCompletionException} with a generic error message. + */ + @Override + protected UnifiedChatCompletionException buildDefaultMidStreamChatCompletionError( + String inferenceEntityId, + ErrorResponse errorResponse + ) { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, inferenceEntityId), + ERROR_TYPE, + STREAM_ERROR + ); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java index 8c355c9f67f18..29c5c910abf75 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java @@ -10,8 +10,6 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.logging.LogManager; -import org.elasticsearch.logging.Logger; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; @@ -29,13 +27,16 @@ import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor; import java.nio.charset.StandardCharsets; -import java.util.Locale; import java.util.Objects; import java.util.Optional; import java.util.concurrent.Flow; import static org.elasticsearch.core.Strings.format; +/** + * Handles streaming chat completion responses and error parsing for Google Vertex AI inference endpoints. + * This handler is designed to work with the unified Google Vertex AI chat completion API. + */ public class GoogleVertexAiUnifiedChatCompletionResponseHandler extends GoogleVertexAiResponseHandler { private static final String ERROR_FIELD = "error"; @@ -54,7 +55,9 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher buildMidStreamError(request, m, e)); + var googleVertexAiProcessor = new GoogleVertexAiUnifiedStreamingProcessor( + (m, e) -> buildMidStreamChatCompletionError(request.getInferenceEntityId(), m, e) + ); flow.subscribe(serverSentEventProcessor); serverSentEventProcessor.subscribe(googleVertexAiProcessor); @@ -62,57 +65,57 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher, Void> ERROR_PARSER = new ConstructingObjectParser<>( "google_vertex_ai_error_wrapper", true, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java index 8dffd612db5c8..d3cbffb10f203 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java @@ -41,55 +41,54 @@ public HuggingFaceChatCompletionResponseHandler(String requestType, ResponsePars } @Override - protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { - assert request.isStreaming() : "Only streaming requests support this format"; - var responseStatusCode = result.response().getStatusLine().getStatusCode(); - if (request.isStreaming()) { - var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode); - var restStatus = toRestStatus(responseStatusCode); - return errorResponse instanceof HuggingFaceErrorResponseEntity - ? new UnifiedChatCompletionException( - restStatus, - errorMessage, - HUGGING_FACE_ERROR, - restStatus.name().toLowerCase(Locale.ROOT) - ) - : new UnifiedChatCompletionException( - restStatus, - errorMessage, - createErrorType(errorResponse), - restStatus.name().toLowerCase(Locale.ROOT) - ); - } else { - return super.buildError(message, request, result, errorResponse); - } + protected UnifiedChatCompletionException buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { + return buildChatCompletionError(message, request, result, errorResponse, HuggingFaceErrorResponseEntity.class); } @Override - protected Exception buildMidStreamError(Request request, String message, Exception e) { - var errorResponse = StreamingHuggingFaceErrorResponseEntity.fromString(message); - if (errorResponse instanceof StreamingHuggingFaceErrorResponseEntity streamingHuggingFaceErrorResponseEntity) { - return new UnifiedChatCompletionException( - RestStatus.INTERNAL_SERVER_ERROR, - format( - "%s for request from inference entity id [%s]. Error message: [%s]", - SERVER_ERROR_OBJECT, - request.getInferenceEntityId(), - errorResponse.getErrorMessage() - ), - HUGGING_FACE_ERROR, - extractErrorCode(streamingHuggingFaceErrorResponseEntity) - ); - } else if (e != null) { - return UnifiedChatCompletionException.fromThrowable(e); - } else { - return new UnifiedChatCompletionException( - RestStatus.INTERNAL_SERVER_ERROR, - format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()), - createErrorType(errorResponse), - "stream_error" - ); - } + protected UnifiedChatCompletionException buildProviderSpecificChatCompletionError( + ErrorResponse errorResponse, + String errorMessage, + RestStatus restStatus + ) { + return new UnifiedChatCompletionException(restStatus, errorMessage, HUGGING_FACE_ERROR, restStatus.name().toLowerCase(Locale.ROOT)); + } + + /** + * Builds an error for mid-stream responses from Hugging Face. + * This method is called when an error response is received during streaming operations. + * + * @param inferenceEntityId The ID of the inference entity that made the request. + * @param message The error message to include in the exception. + * @param e The exception that occurred. + * @return An instance of {@link UnifiedChatCompletionException} representing the error. + */ + @Override + public UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) { + return buildMidStreamChatCompletionError(inferenceEntityId, message, e, StreamingHuggingFaceErrorResponseEntity.class); + } + + @Override + protected UnifiedChatCompletionException buildProviderSpecificMidStreamChatCompletionError( + String inferenceEntityId, + ErrorResponse errorResponse + ) { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format( + "%s for request from inference entity id [%s]. Error message: [%s]", + SERVER_ERROR_OBJECT, + inferenceEntityId, + errorResponse.getErrorMessage() + ), + HUGGING_FACE_ERROR, + extractErrorCode((StreamingHuggingFaceErrorResponseEntity) errorResponse) + ); + } + + @Override + protected ErrorResponse extractMidStreamChatCompletionErrorResponse(String message) { + return StreamingHuggingFaceErrorResponseEntity.fromString(message); } private static String extractErrorCode(StreamingHuggingFaceErrorResponseEntity streamingHuggingFaceErrorResponseEntity) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java index a9d6df687fe99..bb7ee509fa3dc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.services.mistral; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; @@ -30,22 +31,16 @@ public MistralUnifiedChatCompletionResponseHandler(String requestType, ResponseP } @Override - protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { - assert request.isStreaming() : "Only streaming requests support this format"; - var responseStatusCode = result.response().getStatusLine().getStatusCode(); - if (request.isStreaming()) { - var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode); - var restStatus = toRestStatus(responseStatusCode); - return errorResponse instanceof MistralErrorResponse - ? new UnifiedChatCompletionException(restStatus, errorMessage, MISTRAL_ERROR, restStatus.name().toLowerCase(Locale.ROOT)) - : new UnifiedChatCompletionException( - restStatus, - errorMessage, - createErrorType(errorResponse), - restStatus.name().toLowerCase(Locale.ROOT) - ); - } else { - return super.buildError(message, request, result, errorResponse); - } + protected UnifiedChatCompletionException buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { + return buildChatCompletionError(message, request, result, errorResponse, MistralErrorResponse.class); + } + + @Override + protected UnifiedChatCompletionException buildProviderSpecificChatCompletionError( + ErrorResponse errorResponse, + String errorMessage, + RestStatus restStatus + ) { + return new UnifiedChatCompletionException(restStatus, errorMessage, MISTRAL_ERROR, restStatus.name().toLowerCase(Locale.ROOT)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java index e1a0117c7bcca..7e779645c77d9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java @@ -19,7 +19,6 @@ import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor; import org.elasticsearch.xpack.inference.external.response.streaming.StreamingErrorResponse; -import java.util.Locale; import java.util.concurrent.Flow; import java.util.function.Function; @@ -45,64 +44,95 @@ public OpenAiUnifiedChatCompletionResponseHandler( @Override public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser()); - var openAiProcessor = new OpenAiUnifiedStreamingProcessor((m, e) -> buildMidStreamError(request, m, e)); + var openAiProcessor = new OpenAiUnifiedStreamingProcessor( + (m, e) -> buildMidStreamChatCompletionError(request.getInferenceEntityId(), m, e) + ); flow.subscribe(serverSentEventProcessor); serverSentEventProcessor.subscribe(openAiProcessor); return new StreamingUnifiedChatCompletionResults(openAiProcessor); } @Override - protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { - assert request.isStreaming() : "Only streaming requests support this format"; - var responseStatusCode = result.response().getStatusLine().getStatusCode(); - if (request.isStreaming()) { - var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode); - var restStatus = toRestStatus(responseStatusCode); - return errorResponse instanceof StreamingErrorResponse oer - ? new UnifiedChatCompletionException(restStatus, errorMessage, oer.type(), oer.code(), oer.param()) - : new UnifiedChatCompletionException( - restStatus, - errorMessage, - createErrorType(errorResponse), - restStatus.name().toLowerCase(Locale.ROOT) - ); - } else { - return super.buildError(message, request, result, errorResponse); - } + protected UnifiedChatCompletionException buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { + return buildChatCompletionError(message, request, result, errorResponse, StreamingErrorResponse.class); } - protected static String createErrorType(ErrorResponse errorResponse) { - return errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown"; + /** + * Builds a custom {@link UnifiedChatCompletionException} for OpenAI inference endpoints. + * This method is called when an error response is received. + * + * @param errorResponse the parsed error response from the service + * @param errorMessage the error message received + * @param restStatus the HTTP status code of the error + * @return an instance of {@link UnifiedChatCompletionException} with details from the error response + */ + @Override + protected UnifiedChatCompletionException buildProviderSpecificChatCompletionError( + ErrorResponse errorResponse, + String errorMessage, + RestStatus restStatus + ) { + var streamingError = (StreamingErrorResponse) errorResponse; + return new UnifiedChatCompletionException( + restStatus, + errorMessage, + streamingError.type(), + streamingError.code(), + streamingError.param() + ); + } + + /** + * Builds a custom mid-stream {@link UnifiedChatCompletionException} for OpenAI inference endpoints. + * This method is called when an error response is received during streaming. + * + * @param inferenceEntityId the ID of the inference entity + * @param message the error message received during streaming + * @param e the exception that occurred + * @return an instance of {@link UnifiedChatCompletionException} with details from the error response + */ + @Override + public UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) { + // Use the custom type StreamingErrorResponse for mid-stream errors + return buildMidStreamChatCompletionError(inferenceEntityId, message, e, StreamingErrorResponse.class); } - protected Exception buildMidStreamError(Request request, String message, Exception e) { - return buildMidStreamError(request.getInferenceEntityId(), message, e); + /** + * Extracts the mid-stream error response from the message. + * + * @param message the message containing the error response + * @return the extracted {@link ErrorResponse} + */ + @Override + protected ErrorResponse extractMidStreamChatCompletionErrorResponse(String message) { + return StreamingErrorResponse.fromString(message); } - public static UnifiedChatCompletionException buildMidStreamError(String inferenceEntityId, String message, Exception e) { - var errorResponse = StreamingErrorResponse.fromString(message); - if (errorResponse instanceof StreamingErrorResponse oer) { - return new UnifiedChatCompletionException( - RestStatus.INTERNAL_SERVER_ERROR, - format( - "%s for request from inference entity id [%s]. Error message: [%s]", - SERVER_ERROR_OBJECT, - inferenceEntityId, - errorResponse.getErrorMessage() - ), - oer.type(), - oer.code(), - oer.param() - ); - } else if (e != null) { - return UnifiedChatCompletionException.fromThrowable(e); - } else { - return new UnifiedChatCompletionException( - RestStatus.INTERNAL_SERVER_ERROR, - format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, inferenceEntityId), - createErrorType(errorResponse), - "stream_error" - ); - } + /** + * Builds a custom mid-stream {@link UnifiedChatCompletionException} for OpenAI inference endpoints. + * This method is called when an error response is received during streaming. + * + * @param inferenceEntityId the ID of the inference entity + * @param errorResponse the parsed error response from the service + * @return an instance of {@link UnifiedChatCompletionException} with details from the error response + */ + @Override + protected UnifiedChatCompletionException buildProviderSpecificMidStreamChatCompletionError( + String inferenceEntityId, + ErrorResponse errorResponse + ) { + var streamingError = (StreamingErrorResponse) errorResponse; + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format( + "%s for request from inference entity id [%s]. Error message: [%s]", + SERVER_ERROR_OBJECT, + inferenceEntityId, + streamingError.getErrorMessage() + ), + streamingError.type(), + streamingError.code(), + streamingError.param() + ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiCompletionPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiCompletionPayload.java index 64b42f00d2d5b..d7e506d690610 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiCompletionPayload.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiCompletionPayload.java @@ -22,7 +22,6 @@ import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; -import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; @@ -40,6 +39,11 @@ import java.util.Map; import java.util.stream.Stream; +/** + * Handles chat completion requests and responses for OpenAI models in SageMaker. + * This class implements the SageMakerStreamSchemaPayload interface to provide + * the necessary methods for handling OpenAI chat completions. + */ public class OpenAiCompletionPayload implements SageMakerStreamSchemaPayload { private static final XContent jsonXContent = JsonXContent.jsonXContent; @@ -50,7 +54,7 @@ public class OpenAiCompletionPayload implements SageMakerStreamSchemaPayload { private static final String USER_FIELD = "user"; private static final String USER_ROLE = "user"; private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens"; - private static final ResponseHandler ERROR_HANDLER = new OpenAiUnifiedChatCompletionResponseHandler( + private static final OpenAiUnifiedChatCompletionResponseHandler ERROR_HANDLER = new OpenAiUnifiedChatCompletionResponseHandler( "sagemaker openai chat completion", ((request, result) -> { assert false : "do not call this"; @@ -88,12 +92,12 @@ public StreamingUnifiedChatCompletionResults.Results chatCompletionResponseBody( var serverSentEvents = serverSentEvents(response); var results = serverSentEvents.flatMap(event -> { if ("error".equals(event.type())) { - throw OpenAiUnifiedChatCompletionResponseHandler.buildMidStreamError(model.getInferenceEntityId(), event.data(), null); + throw ERROR_HANDLER.buildMidStreamChatCompletionError(model.getInferenceEntityId(), event.data(), null); } else { try { return OpenAiUnifiedStreamingProcessor.parse(parserConfig, event); } catch (Exception e) { - throw OpenAiUnifiedChatCompletionResponseHandler.buildMidStreamError(model.getInferenceEntityId(), event.data(), e); + throw ERROR_HANDLER.buildMidStreamChatCompletionError(model.getInferenceEntityId(), event.data(), e); } } }) From 569e351b69ded5be27c80bfc37097df024b1be72 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Mon, 7 Jul 2025 19:18:17 +0300 Subject: [PATCH 2/7] Refactor error handling in streaming chat completion responses --- ...nUnifiedChatCompletionResponseHandler.java | 30 ++++++++----------- .../elastic/ElasticCompletionPayload.java | 6 +++- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonUnifiedChatCompletionResponseHandler.java index 41b82bbf2cd02..1135315440e23 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonUnifiedChatCompletionResponseHandler.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.services.ibmwatsonx; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; @@ -30,22 +31,17 @@ public IbmWatsonUnifiedChatCompletionResponseHandler(String requestType, Respons } @Override - protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { - assert request.isStreaming() : "Only streaming requests support this format"; - var responseStatusCode = result.response().getStatusLine().getStatusCode(); - if (request.isStreaming()) { - var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode); - var restStatus = toRestStatus(responseStatusCode); - return errorResponse instanceof IbmWatsonxErrorResponseEntity - ? new UnifiedChatCompletionException(restStatus, errorMessage, WATSONX_ERROR, restStatus.name().toLowerCase(Locale.ROOT)) - : new UnifiedChatCompletionException( - restStatus, - errorMessage, - createErrorType(errorResponse), - restStatus.name().toLowerCase(Locale.ROOT) - ); - } else { - return super.buildError(message, request, result, errorResponse); - } + protected UnifiedChatCompletionException buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { + return buildChatCompletionError(message, request, result, errorResponse, IbmWatsonxErrorResponseEntity.class); } + + @Override + protected UnifiedChatCompletionException buildProviderSpecificChatCompletionError( + ErrorResponse errorResponse, + String errorMessage, + RestStatus restStatus + ) { + return new UnifiedChatCompletionException(restStatus, errorMessage, WATSONX_ERROR, restStatus.name().toLowerCase(Locale.ROOT)); + } + } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticCompletionPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticCompletionPayload.java index a8c4f7c57796b..74722fc410835 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticCompletionPayload.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticCompletionPayload.java @@ -102,7 +102,11 @@ public StreamingUnifiedChatCompletionResults.Results chatCompletionResponseBody( ); return new StreamingUnifiedChatCompletionResults.Results(results); } catch (Exception e) { - throw OpenAiUnifiedChatCompletionResponseHandler.buildMidStreamError(model.getInferenceEntityId(), responseData, e); + throw new OpenAiUnifiedChatCompletionResponseHandler(null, null).buildMidStreamChatCompletionError( + model.getInferenceEntityId(), + responseData, + e + ); } } From 0a09f00c0692c872eadb0b6baed58128db5f3d75 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Mon, 7 Jul 2025 20:27:37 +0300 Subject: [PATCH 3/7] Refactor mid-stream error handling in response handlers --- .../http/retry/BaseResponseHandler.java | 17 ----------------- ...iceUnifiedChatCompletionResponseHandler.java | 3 +-- ...xAiUnifiedChatCompletionResponseHandler.java | 3 +-- ...nAiUnifiedChatCompletionResponseHandler.java | 1 - .../elastic/ElasticCompletionPayload.java | 17 ++++++++++------- 5 files changed, 12 insertions(+), 29 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java index 43954a5af8b91..0900080019f79 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java @@ -224,23 +224,6 @@ protected UnifiedChatCompletionException buildDefaultChatCompletionError( ); } - /** - * Builds a mid-stream error for a streaming request. - * This method is used when an error occurs while processing a streaming response. - * It must be implemented by subclasses to handle specific error response formats. - * Only streaming requests should use this method. - * - * @param inferenceEntityId the ID of the inference entity - * @param message the error message - * @param e the exception that caused the error, can be null - * @return a {@link UnifiedChatCompletionException} representing the mid-stream error - */ - public UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) { - throw new UnsupportedOperationException( - "Mid-stream error handling is not implemented. Please override buildMidStreamChatCompletionError method." - ); - } - /** * Builds a mid-stream error for a streaming request with a custom error type. * This method is used when an error occurs while processing a streaming response and allows for custom error handling. diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java index 346b6a4f9026f..7bdaa45184010 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java @@ -91,8 +91,7 @@ protected UnifiedChatCompletionException buildChatCompletionError( * @param e The exception that occurred, if any. * @return An instance of {@link UnifiedChatCompletionException} representing the mid-stream error. */ - @Override - public UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) { + private UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) { var errorResponse = extractMidStreamChatCompletionErrorResponse(message); // Check if the error response contains a specific structure if (errorResponse.errorStructureFound()) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java index 7049c4f192e1e..79fb1239662b8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java @@ -83,8 +83,7 @@ protected UnifiedChatCompletionException buildProviderSpecificChatCompletionErro ); } - @Override - public UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) { + private UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) { return buildMidStreamChatCompletionError(inferenceEntityId, message, e, GoogleVertexAiErrorResponse.class); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java index 7e779645c77d9..3ba3478940bee 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java @@ -91,7 +91,6 @@ protected UnifiedChatCompletionException buildProviderSpecificChatCompletionErro * @param e the exception that occurred * @return an instance of {@link UnifiedChatCompletionException} with details from the error response */ - @Override public UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) { // Use the custom type StreamingErrorResponse for mid-stream errors return buildMidStreamChatCompletionError(inferenceEntityId, message, e, StreamingErrorResponse.class); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticCompletionPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticCompletionPayload.java index 74722fc410835..1d59b6e03c9e0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticCompletionPayload.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticCompletionPayload.java @@ -40,7 +40,14 @@ * Each chunk should be in a valid JSON format, as that is the format the Elastic API uses. */ public class ElasticCompletionPayload implements SageMakerStreamSchemaPayload, ElasticPayload { - private static final XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler( + private static final OpenAiUnifiedChatCompletionResponseHandler ERROR_HANDLER = new OpenAiUnifiedChatCompletionResponseHandler( + "sagemaker openai chat completion", + ((request, result) -> { + assert false : "do not call this"; + throw new UnsupportedOperationException("SageMaker should not call this object's response parser."); + }) + ); + private static final XContentParserConfiguration PARSER_CONFIGURATION = XContentParserConfiguration.EMPTY.withDeprecationHandler( LoggingDeprecationHandler.INSTANCE ); @@ -94,7 +101,7 @@ public SdkBytes chatCompletionRequestBytes(SageMakerModel model, UnifiedCompleti public StreamingUnifiedChatCompletionResults.Results chatCompletionResponseBody(SageMakerModel model, SdkBytes response) { var responseData = response.asUtf8String(); try { - var results = OpenAiUnifiedStreamingProcessor.parse(parserConfig, responseData) + var results = OpenAiUnifiedStreamingProcessor.parse(PARSER_CONFIGURATION, responseData) .collect( () -> new ArrayDeque(), ArrayDeque::offer, @@ -102,11 +109,7 @@ public StreamingUnifiedChatCompletionResults.Results chatCompletionResponseBody( ); return new StreamingUnifiedChatCompletionResults.Results(results); } catch (Exception e) { - throw new OpenAiUnifiedChatCompletionResponseHandler(null, null).buildMidStreamChatCompletionError( - model.getInferenceEntityId(), - responseData, - e - ); + throw ERROR_HANDLER.buildMidStreamChatCompletionError(model.getInferenceEntityId(), responseData, e); } } From f4582f39ab6bb33c66f1283b0098151efd84b61e Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Tue, 8 Jul 2025 17:32:13 +0300 Subject: [PATCH 4/7] Refactor error handling in streaming response handlers to use functional interfaces for improved flexibility --- .../http/retry/BaseResponseHandler.java | 102 +++++------------ .../retry/ChatCompletionErrorBuilder.java | 33 ++++++ ...eUnifiedChatCompletionResponseHandler.java | 104 +++++------------- ...iUnifiedChatCompletionResponseHandler.java | 33 +++--- ...gingFaceChatCompletionResponseHandler.java | 29 +++-- ...nUnifiedChatCompletionResponseHandler.java | 13 ++- ...lUnifiedChatCompletionResponseHandler.java | 12 +- ...iUnifiedChatCompletionResponseHandler.java | 35 +++--- 8 files changed, 158 insertions(+), 203 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ChatCompletionErrorBuilder.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java index 0900080019f79..f103a914e1b47 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java @@ -19,7 +19,9 @@ import java.util.Locale; import java.util.Objects; +import java.util.function.BiFunction; import java.util.function.Function; +import java.util.function.Supplier; import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody; @@ -124,7 +126,7 @@ protected Exception buildError(String message, Request request, HttpResult resul protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { var responseStatusCode = result.response().getStatusLine().getStatusCode(); return new ElasticsearchStatusException( - errorMessage(message, request, errorResponse, responseStatusCode), + extractErrorMessage(message, request, errorResponse, responseStatusCode), toRestStatus(responseStatusCode) ); } @@ -138,7 +140,8 @@ protected Exception buildError(String message, Request request, HttpResult resul * @param request the request that caused the error * @param result the HTTP result containing the error response * @param errorResponse the parsed error response from the HTTP result - * @param errorResponseClass the class of the expected error response type + * @param errorResponseClassSupplier the supplier that provides the class of the expected error response type + * @param chatCompletionErrorBuilder the builder for creating provider-specific chat completion errors * @return an instance of {@link UnifiedChatCompletionException} with details from the error response */ protected UnifiedChatCompletionException buildChatCompletionError( @@ -146,14 +149,15 @@ protected UnifiedChatCompletionException buildChatCompletionError( Request request, HttpResult result, ErrorResponse errorResponse, - Class errorResponseClass + Supplier> errorResponseClassSupplier, + ChatCompletionErrorBuilder chatCompletionErrorBuilder ) { assert request.isStreaming() : "Only streaming requests support this format"; var statusCode = result.response().getStatusLine().getStatusCode(); - var errorMessage = errorMessage(message, request, errorResponse, statusCode); + var errorMessage = extractErrorMessage(message, request, errorResponse, statusCode); var restStatus = toRestStatus(statusCode); - return buildChatCompletionError(errorResponse, errorMessage, restStatus, errorResponseClass); + return buildChatCompletionError(errorResponse, errorMessage, restStatus, errorResponseClassSupplier, chatCompletionErrorBuilder); } /** @@ -164,43 +168,24 @@ protected UnifiedChatCompletionException buildChatCompletionError( * @param errorResponse the error response parsed from the HTTP result * @param errorMessage the error message to include in the exception * @param restStatus the REST status code of the response - * @param errorResponseClass the class of the expected error response type + * @param errorResponseClassSupplier the supplier that provides the class of the expected error response type + * @param chatCompletionErrorBuilder the builder for creating provider-specific chat completion errors * @return an instance of {@link UnifiedChatCompletionException} with details from the error response */ protected UnifiedChatCompletionException buildChatCompletionError( ErrorResponse errorResponse, String errorMessage, RestStatus restStatus, - Class errorResponseClass + Supplier> errorResponseClassSupplier, + ChatCompletionErrorBuilder chatCompletionErrorBuilder ) { - if (errorResponseClass.isInstance(errorResponse)) { - return buildProviderSpecificChatCompletionError(errorResponse, errorMessage, restStatus); + if (errorResponseClassSupplier.get().isInstance(errorResponse)) { + return chatCompletionErrorBuilder.buildProviderSpecificChatCompletionError(errorResponse, errorMessage, restStatus); } else { return buildDefaultChatCompletionError(errorResponse, errorMessage, restStatus); } } - /** - * Builds a custom {@link UnifiedChatCompletionException} for a streaming request. - * This method is called when a specific error response is found in the HTTP result. - * It must be implemented by subclasses to handle specific error response formats. - * Only streaming requests should use this method. - * - * @param errorResponse the error response parsed from the HTTP result - * @param errorMessage the error message to include in the exception - * @param restStatus the REST status code of the response - * @return an instance of {@link UnifiedChatCompletionException} with details from the error response - */ - protected UnifiedChatCompletionException buildProviderSpecificChatCompletionError( - ErrorResponse errorResponse, - String errorMessage, - RestStatus restStatus - ) { - throw new UnsupportedOperationException( - "Custom error handling is not implemented. Please override buildProviderSpecificChatCompletionError method." - ); - } - /** * Builds a default {@link UnifiedChatCompletionException} for a streaming request. * This method is used when an error response is received but no specific error handling is implemented. @@ -211,7 +196,7 @@ protected UnifiedChatCompletionException buildProviderSpecificChatCompletionErro * @param restStatus the REST status code of the response * @return an instance of {@link UnifiedChatCompletionException} with details from the error response */ - protected UnifiedChatCompletionException buildDefaultChatCompletionError( + private static UnifiedChatCompletionException buildDefaultChatCompletionError( ErrorResponse errorResponse, String errorMessage, RestStatus restStatus @@ -232,21 +217,25 @@ protected UnifiedChatCompletionException buildDefaultChatCompletionError( * @param inferenceEntityId the ID of the inference entity * @param message the error message * @param e the exception that caused the error, can be null - * @param errorResponseClass the class of the expected error response type + * @param errorResponseClassSupplier a supplier that provides the class of the expected error response type + * @param specificErrorBuilder a function that builds a specific error based on the inference entity ID and error response + * @param midStreamErrorExtractor a function that extracts the mid-stream error response from the message * @return a {@link UnifiedChatCompletionException} representing the mid-stream error */ protected UnifiedChatCompletionException buildMidStreamChatCompletionError( String inferenceEntityId, String message, Exception e, - Class errorResponseClass + Supplier> errorResponseClassSupplier, + BiFunction specificErrorBuilder, + Function midStreamErrorExtractor ) { // Extract the error response from the message using the provided method - var errorResponse = extractMidStreamChatCompletionErrorResponse(message); + var errorResponse = midStreamErrorExtractor.apply(message); // Check if the error response matches the expected type - if (errorResponseClass.isInstance(errorResponse)) { + if (errorResponseClassSupplier.get().isInstance(errorResponse)) { // If it matches, we can build a custom mid-stream error exception - return buildProviderSpecificMidStreamChatCompletionError(inferenceEntityId, errorResponse); + return specificErrorBuilder.apply(inferenceEntityId, errorResponse); } else if (e != null) { // If the error response does not match, we can still return an exception based on the original throwable return UnifiedChatCompletionException.fromThrowable(e); @@ -256,26 +245,6 @@ protected UnifiedChatCompletionException buildMidStreamChatCompletionError( } } - /** - * Builds a custom mid-stream {@link UnifiedChatCompletionException} for a streaming request. - * This method is called when a specific error response is found in the message. - * It must be implemented by subclasses to handle specific error response formats. - * Only streaming requests should use this method. - * - * @param inferenceEntityId the ID of the inference entity - * @param errorResponse the error response parsed from the message - * @return an instance of {@link UnifiedChatCompletionException} with details from the error response - */ - protected UnifiedChatCompletionException buildProviderSpecificMidStreamChatCompletionError( - String inferenceEntityId, - ErrorResponse errorResponse - ) { - throw new UnsupportedOperationException( - "Mid-stream error handling is not implemented for this response handler. " - + "Please override buildProviderSpecificMidStreamChatCompletionError method." - ); - } - /** * Builds a default mid-stream error for a streaming request. * This method is used when no specific error response is found in the message. @@ -285,7 +254,7 @@ protected UnifiedChatCompletionException buildProviderSpecificMidStreamChatCompl * @param errorResponse the error response extracted from the message * @return a {@link UnifiedChatCompletionException} representing the default mid-stream error */ - protected UnifiedChatCompletionException buildDefaultMidStreamChatCompletionError( + protected static UnifiedChatCompletionException buildDefaultMidStreamChatCompletionError( String inferenceEntityId, ErrorResponse errorResponse ) { @@ -297,21 +266,6 @@ protected UnifiedChatCompletionException buildDefaultMidStreamChatCompletionErro ); } - /** - * Extracts the mid-stream error response from the message. - * This method is used to parse the error response from a streaming message. - * It must be implemented by subclasses to handle specific error response formats. - * Only streaming requests should use this method. - * - * @param message the message containing the error response - * @return an {@link ErrorResponse} object representing the mid-stream error - */ - protected ErrorResponse extractMidStreamChatCompletionErrorResponse(String message) { - throw new UnsupportedOperationException( - "Mid-stream error extraction is not implemented. Please override extractMidStreamChatCompletionErrorResponse method." - ); - } - /** * Creates a string representation of the error type based on the provided ErrorResponse. * This method is used to generate a human-readable error type for logging or exception messages. @@ -319,11 +273,11 @@ protected ErrorResponse extractMidStreamChatCompletionErrorResponse(String messa * @param errorResponse the ErrorResponse object * @return a string representing the error type */ - protected static String createErrorType(ErrorResponse errorResponse) { + private static String createErrorType(ErrorResponse errorResponse) { return errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown"; } - protected String errorMessage(String message, Request request, ErrorResponse errorResponse, int statusCode) { + private static String extractErrorMessage(String message, Request request, ErrorResponse errorResponse, int statusCode) { return (errorResponse == null || errorResponse.errorStructureFound() == false || Strings.isNullOrEmpty(errorResponse.getErrorMessage())) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ChatCompletionErrorBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ChatCompletionErrorBuilder.java new file mode 100644 index 0000000000000..c083fd9b4a82b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ChatCompletionErrorBuilder.java @@ -0,0 +1,33 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.retry; + +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; + +/** + * Functional interface for building provider-specific chat completion errors. + * This interface is used to create exceptions that are specific to the chat completion service being used. + */ +@FunctionalInterface +public interface ChatCompletionErrorBuilder { + + /** + * Builds a provider-specific chat completion error based on the given parameters. + * + * @param errorResponse The error response received from the service. + * @param errorMessage A custom error message to include in the exception. + * @param restStatus The HTTP status code associated with the error. + * @return An instance of {@link UnifiedChatCompletionException} representing the error. + */ + UnifiedChatCompletionException buildProviderSpecificChatCompletionError( + ErrorResponse errorResponse, + String errorMessage, + RestStatus restStatus + ); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java index 7bdaa45184010..4edd1cef8e05a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java @@ -59,27 +59,22 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher ErrorResponse.class, + ElasticInferenceServiceUnifiedChatCompletionResponseHandler::buildProviderSpecificChatCompletionError + ); } - /** - * Builds a custom {@link UnifiedChatCompletionException} for the Elastic Inference Service. - * This method is called when an error response is received from the service. - * - * @param errorResponse The error response received from the service. - * @param errorMessage The error message to include in the exception. - * @param restStatus The HTTP status of the error response. - * @param errorResponseClass The class of the error response. - * @return An instance of {@link UnifiedChatCompletionException} with details from the error response. - */ - @Override - protected UnifiedChatCompletionException buildChatCompletionError( - ErrorResponse errorResponse, - String errorMessage, - RestStatus restStatus, - Class errorResponseClass + private static UnifiedChatCompletionException buildProviderSpecificChatCompletionError( + ErrorResponse response, + String message, + RestStatus restStatus ) { - return new UnifiedChatCompletionException(restStatus, errorMessage, ERROR_TYPE, restStatus.name().toLowerCase(Locale.ROOT)); + return new UnifiedChatCompletionException(restStatus, message, ERROR_TYPE, restStatus.name().toLowerCase(Locale.ROOT)); } /** @@ -92,73 +87,24 @@ protected UnifiedChatCompletionException buildChatCompletionError( * @return An instance of {@link UnifiedChatCompletionException} representing the mid-stream error. */ private UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) { - var errorResponse = extractMidStreamChatCompletionErrorResponse(message); + var errorResponse = ElasticInferenceServiceErrorResponseEntity.fromString(message); // Check if the error response contains a specific structure if (errorResponse.errorStructureFound()) { - return buildProviderSpecificMidStreamChatCompletionError(inferenceEntityId, errorResponse); + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format( + "%s for request from inference entity id [%s]. Error message: [%s]", + SERVER_ERROR_OBJECT, + inferenceEntityId, + errorResponse.getErrorMessage() + ), + ERROR_TYPE, + STREAM_ERROR + ); } else if (e != null) { return UnifiedChatCompletionException.fromThrowable(e); } else { return buildDefaultMidStreamChatCompletionError(inferenceEntityId, errorResponse); } } - - /** - * Extracts the error response from the message. This method is specific to the Elastic Inference Service - * and should parse the message according to its error response format. - * - * @param message The message containing the error response. - * @return An instance of {@link ErrorResponse} parsed from the message. - */ - @Override - protected ErrorResponse extractMidStreamChatCompletionErrorResponse(String message) { - return ElasticInferenceServiceErrorResponseEntity.fromString(message); - } - - /** - * Builds a custom mid-stream {@link UnifiedChatCompletionException} for the Elastic Inference Service. - * This method is called when a specific error response structure is found in the message. - * - * @param inferenceEntityId The ID of the inference entity. - * @param errorResponse The error response parsed from the message. - * @return An instance of {@link UnifiedChatCompletionException} with details from the error response. - */ - @Override - protected UnifiedChatCompletionException buildProviderSpecificMidStreamChatCompletionError( - String inferenceEntityId, - ErrorResponse errorResponse - ) { - return new UnifiedChatCompletionException( - RestStatus.INTERNAL_SERVER_ERROR, - format( - "%s for request from inference entity id [%s]. Error message: [%s]", - SERVER_ERROR_OBJECT, - inferenceEntityId, - errorResponse.getErrorMessage() - ), - ERROR_TYPE, - STREAM_ERROR - ); - } - - /** - * Builds a default mid-stream {@link UnifiedChatCompletionException} for the Elastic Inference Service. - * This method is called when specific error response structure is NOT found in the message. - * - * @param inferenceEntityId The ID of the inference entity. - * @param errorResponse The error response parsed from the message. - * @return An instance of {@link UnifiedChatCompletionException} with a generic error message. - */ - @Override - protected UnifiedChatCompletionException buildDefaultMidStreamChatCompletionError( - String inferenceEntityId, - ErrorResponse errorResponse - ) { - return new UnifiedChatCompletionException( - RestStatus.INTERNAL_SERVER_ERROR, - format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, inferenceEntityId), - ERROR_TYPE, - STREAM_ERROR - ); - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java index 79fb1239662b8..b60e074cf3b10 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java @@ -54,7 +54,14 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher buildMidStreamChatCompletionError(request.getInferenceEntityId(), m, e) + (message, exception) -> buildMidStreamChatCompletionError( + request.getInferenceEntityId(), + message, + exception, + () -> GoogleVertexAiErrorResponse.class, + GoogleVertexAiUnifiedChatCompletionResponseHandler::buildProviderSpecificMidStreamChatCompletionError, + GoogleVertexAiErrorResponse::fromString + ) ); flow.subscribe(serverSentEventProcessor); @@ -64,11 +71,17 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher GoogleVertexAiErrorResponse.class, + GoogleVertexAiUnifiedChatCompletionResponseHandler::buildProviderSpecificChatCompletionError + ); } - @Override - protected UnifiedChatCompletionException buildProviderSpecificChatCompletionError( + private static UnifiedChatCompletionException buildProviderSpecificChatCompletionError( ErrorResponse errorResponse, String errorMessage, RestStatus restStatus @@ -83,12 +96,7 @@ protected UnifiedChatCompletionException buildProviderSpecificChatCompletionErro ); } - private UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) { - return buildMidStreamChatCompletionError(inferenceEntityId, message, e, GoogleVertexAiErrorResponse.class); - } - - @Override - protected UnifiedChatCompletionException buildProviderSpecificMidStreamChatCompletionError( + private static UnifiedChatCompletionException buildProviderSpecificMidStreamChatCompletionError( String inferenceEntityId, ErrorResponse errorResponse ) { @@ -107,11 +115,6 @@ protected UnifiedChatCompletionException buildProviderSpecificMidStreamChatCompl ); } - @Override - protected ErrorResponse extractMidStreamChatCompletionErrorResponse(String message) { - return GoogleVertexAiErrorResponse.fromString(message); - } - public static class GoogleVertexAiErrorResponse extends ErrorResponse { private static final ConstructingObjectParser, Void> ERROR_PARSER = new ConstructingObjectParser<>( "google_vertex_ai_error_wrapper", diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java index d3cbffb10f203..3ac2c616adb8b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java @@ -42,11 +42,17 @@ public HuggingFaceChatCompletionResponseHandler(String requestType, ResponsePars @Override protected UnifiedChatCompletionException buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { - return buildChatCompletionError(message, request, result, errorResponse, HuggingFaceErrorResponseEntity.class); + return buildChatCompletionError( + message, + request, + result, + errorResponse, + () -> HuggingFaceErrorResponseEntity.class, + HuggingFaceChatCompletionResponseHandler::buildProviderSpecificChatCompletionError + ); } - @Override - protected UnifiedChatCompletionException buildProviderSpecificChatCompletionError( + private static UnifiedChatCompletionException buildProviderSpecificChatCompletionError( ErrorResponse errorResponse, String errorMessage, RestStatus restStatus @@ -65,11 +71,17 @@ protected UnifiedChatCompletionException buildProviderSpecificChatCompletionErro */ @Override public UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) { - return buildMidStreamChatCompletionError(inferenceEntityId, message, e, StreamingHuggingFaceErrorResponseEntity.class); + return buildMidStreamChatCompletionError( + inferenceEntityId, + message, + e, + () -> StreamingHuggingFaceErrorResponseEntity.class, + HuggingFaceChatCompletionResponseHandler::buildProviderSpecificMidStreamChatCompletionError, + StreamingHuggingFaceErrorResponseEntity::fromString + ); } - @Override - protected UnifiedChatCompletionException buildProviderSpecificMidStreamChatCompletionError( + private static UnifiedChatCompletionException buildProviderSpecificMidStreamChatCompletionError( String inferenceEntityId, ErrorResponse errorResponse ) { @@ -86,11 +98,6 @@ protected UnifiedChatCompletionException buildProviderSpecificMidStreamChatCompl ); } - @Override - protected ErrorResponse extractMidStreamChatCompletionErrorResponse(String message) { - return StreamingHuggingFaceErrorResponseEntity.fromString(message); - } - private static String extractErrorCode(StreamingHuggingFaceErrorResponseEntity streamingHuggingFaceErrorResponseEntity) { return streamingHuggingFaceErrorResponseEntity.httpStatusCode() != null ? String.valueOf(streamingHuggingFaceErrorResponseEntity.httpStatusCode()) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonUnifiedChatCompletionResponseHandler.java index 1135315440e23..0b18b4c6e7035 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonUnifiedChatCompletionResponseHandler.java @@ -32,16 +32,21 @@ public IbmWatsonUnifiedChatCompletionResponseHandler(String requestType, Respons @Override protected UnifiedChatCompletionException buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { - return buildChatCompletionError(message, request, result, errorResponse, IbmWatsonxErrorResponseEntity.class); + return buildChatCompletionError( + message, + request, + result, + errorResponse, + () -> IbmWatsonxErrorResponseEntity.class, + IbmWatsonUnifiedChatCompletionResponseHandler::buildProviderSpecificChatCompletionError + ); } - @Override - protected UnifiedChatCompletionException buildProviderSpecificChatCompletionError( + private static UnifiedChatCompletionException buildProviderSpecificChatCompletionError( ErrorResponse errorResponse, String errorMessage, RestStatus restStatus ) { return new UnifiedChatCompletionException(restStatus, errorMessage, WATSONX_ERROR, restStatus.name().toLowerCase(Locale.ROOT)); } - } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java index bb7ee509fa3dc..aea2308dc4aa2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java @@ -32,11 +32,17 @@ public MistralUnifiedChatCompletionResponseHandler(String requestType, ResponseP @Override protected UnifiedChatCompletionException buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { - return buildChatCompletionError(message, request, result, errorResponse, MistralErrorResponse.class); + return buildChatCompletionError( + message, + request, + result, + errorResponse, + () -> MistralErrorResponse.class, + MistralUnifiedChatCompletionResponseHandler::buildProviderSpecificChatCompletionError + ); } - @Override - protected UnifiedChatCompletionException buildProviderSpecificChatCompletionError( + private static UnifiedChatCompletionException buildProviderSpecificChatCompletionError( ErrorResponse errorResponse, String errorMessage, RestStatus restStatus diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java index 3ba3478940bee..8f9d707581b4e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java @@ -54,7 +54,14 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher StreamingErrorResponse.class, + OpenAiUnifiedChatCompletionResponseHandler::buildProviderSpecificChatCompletionError + ); } /** @@ -66,8 +73,7 @@ protected UnifiedChatCompletionException buildError(String message, Request requ * @param restStatus the HTTP status code of the error * @return an instance of {@link UnifiedChatCompletionException} with details from the error response */ - @Override - protected UnifiedChatCompletionException buildProviderSpecificChatCompletionError( + private static UnifiedChatCompletionException buildProviderSpecificChatCompletionError( ErrorResponse errorResponse, String errorMessage, RestStatus restStatus @@ -93,18 +99,14 @@ protected UnifiedChatCompletionException buildProviderSpecificChatCompletionErro */ public UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) { // Use the custom type StreamingErrorResponse for mid-stream errors - return buildMidStreamChatCompletionError(inferenceEntityId, message, e, StreamingErrorResponse.class); - } - - /** - * Extracts the mid-stream error response from the message. - * - * @param message the message containing the error response - * @return the extracted {@link ErrorResponse} - */ - @Override - protected ErrorResponse extractMidStreamChatCompletionErrorResponse(String message) { - return StreamingErrorResponse.fromString(message); + return buildMidStreamChatCompletionError( + inferenceEntityId, + message, + e, + () -> StreamingErrorResponse.class, + OpenAiUnifiedChatCompletionResponseHandler::buildProviderSpecificMidStreamChatCompletionError, + StreamingErrorResponse::fromString + ); } /** @@ -115,8 +117,7 @@ protected ErrorResponse extractMidStreamChatCompletionErrorResponse(String messa * @param errorResponse the parsed error response from the service * @return an instance of {@link UnifiedChatCompletionException} with details from the error response */ - @Override - protected UnifiedChatCompletionException buildProviderSpecificMidStreamChatCompletionError( + private static UnifiedChatCompletionException buildProviderSpecificMidStreamChatCompletionError( String inferenceEntityId, ErrorResponse errorResponse ) { From 8ff0bb6623633e5f2d6ad0fbcf76da3a4718d601 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Tue, 8 Jul 2025 21:18:06 +0300 Subject: [PATCH 5/7] Refactor error handling in BaseResponseHandler to check for error structure before type matching --- .../inference/external/http/retry/BaseResponseHandler.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java index f103a914e1b47..52e243f3a99e3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java @@ -179,7 +179,7 @@ protected UnifiedChatCompletionException buildChatCompletionError( Supplier> errorResponseClassSupplier, ChatCompletionErrorBuilder chatCompletionErrorBuilder ) { - if (errorResponseClassSupplier.get().isInstance(errorResponse)) { + if (errorResponse.errorStructureFound() && errorResponseClassSupplier.get().isInstance(errorResponse)) { return chatCompletionErrorBuilder.buildProviderSpecificChatCompletionError(errorResponse, errorMessage, restStatus); } else { return buildDefaultChatCompletionError(errorResponse, errorMessage, restStatus); @@ -233,7 +233,7 @@ protected UnifiedChatCompletionException buildMidStreamChatCompletionError( // Extract the error response from the message using the provided method var errorResponse = midStreamErrorExtractor.apply(message); // Check if the error response matches the expected type - if (errorResponseClassSupplier.get().isInstance(errorResponse)) { + if (errorResponse.errorStructureFound() && errorResponseClassSupplier.get().isInstance(errorResponse)) { // If it matches, we can build a custom mid-stream error exception return specificErrorBuilder.apply(inferenceEntityId, errorResponse); } else if (e != null) { From a7320f49193ac04ecf2fd5f31daf64d3fa6d310a Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Thu, 10 Jul 2025 00:34:39 +0300 Subject: [PATCH 6/7] Refactor error handling in response classes to use ChatCompletionErrorResponse for improved consistency and maintainability --- .../http/retry/BaseResponseHandler.java | 54 +++--------- ...iedChatCompletionExceptionConvertible.java | 16 ++++ ...iedChatCompletionExceptionConvertible.java | 17 ++++ ...StreamingChatCompletionErrorResponse.java} | 48 +++++++++-- ...eUnifiedChatCompletionResponseHandler.java | 26 +++--- ...iUnifiedChatCompletionResponseHandler.java | 74 +++++++---------- ...gingFaceChatCompletionResponseHandler.java | 83 ++++++------------- .../HuggingFaceErrorResponseEntity.java | 13 ++- ...nUnifiedChatCompletionResponseHandler.java | 28 ------- .../IbmWatsonxErrorResponseEntity.java | 13 ++- ...lUnifiedChatCompletionResponseHandler.java | 28 ------- .../response/MistralErrorResponse.java | 13 ++- ...iUnifiedChatCompletionResponseHandler.java | 76 +---------------- 13 files changed, 188 insertions(+), 301 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/MidStreamUnifiedChatCompletionExceptionConvertible.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/UnifiedChatCompletionExceptionConvertible.java rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/{StreamingErrorResponse.java => OpenAiStreamingChatCompletionErrorResponse.java} (67%) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java index 52e243f3a99e3..74cfe37c8a3aa 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java @@ -19,9 +19,7 @@ import java.util.Locale; import java.util.Objects; -import java.util.function.BiFunction; import java.util.function.Function; -import java.util.function.Supplier; import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody; @@ -38,7 +36,6 @@ public abstract class BaseResponseHandler implements ResponseHandler { public static final String SERVER_ERROR_OBJECT = "Received an error response"; public static final String BAD_REQUEST = "Received a bad request status code"; public static final String METHOD_NOT_ALLOWED = "Received a method not allowed status code"; - protected static final String ERROR_TYPE = "error"; protected static final String STREAM_ERROR = "stream_error"; protected final String requestType; @@ -140,47 +137,22 @@ protected Exception buildError(String message, Request request, HttpResult resul * @param request the request that caused the error * @param result the HTTP result containing the error response * @param errorResponse the parsed error response from the HTTP result - * @param errorResponseClassSupplier the supplier that provides the class of the expected error response type - * @param chatCompletionErrorBuilder the builder for creating provider-specific chat completion errors * @return an instance of {@link UnifiedChatCompletionException} with details from the error response */ protected UnifiedChatCompletionException buildChatCompletionError( String message, Request request, HttpResult result, - ErrorResponse errorResponse, - Supplier> errorResponseClassSupplier, - ChatCompletionErrorBuilder chatCompletionErrorBuilder + ErrorResponse errorResponse ) { assert request.isStreaming() : "Only streaming requests support this format"; var statusCode = result.response().getStatusLine().getStatusCode(); var errorMessage = extractErrorMessage(message, request, errorResponse, statusCode); var restStatus = toRestStatus(statusCode); - return buildChatCompletionError(errorResponse, errorMessage, restStatus, errorResponseClassSupplier, chatCompletionErrorBuilder); - } - - /** - * Builds a {@link UnifiedChatCompletionException} for a streaming request. - * This method is used when an error response is received from the external service. - * Only streaming requests should use this method. - * - * @param errorResponse the error response parsed from the HTTP result - * @param errorMessage the error message to include in the exception - * @param restStatus the REST status code of the response - * @param errorResponseClassSupplier the supplier that provides the class of the expected error response type - * @param chatCompletionErrorBuilder the builder for creating provider-specific chat completion errors - * @return an instance of {@link UnifiedChatCompletionException} with details from the error response - */ - protected UnifiedChatCompletionException buildChatCompletionError( - ErrorResponse errorResponse, - String errorMessage, - RestStatus restStatus, - Supplier> errorResponseClassSupplier, - ChatCompletionErrorBuilder chatCompletionErrorBuilder - ) { - if (errorResponse.errorStructureFound() && errorResponseClassSupplier.get().isInstance(errorResponse)) { - return chatCompletionErrorBuilder.buildProviderSpecificChatCompletionError(errorResponse, errorMessage, restStatus); + if (errorResponse.errorStructureFound() + && errorResponse instanceof UnifiedChatCompletionExceptionConvertible chatCompletionExceptionConvertible) { + return chatCompletionExceptionConvertible.toUnifiedChatCompletionException(errorMessage, restStatus); } else { return buildDefaultChatCompletionError(errorResponse, errorMessage, restStatus); } @@ -196,7 +168,7 @@ protected UnifiedChatCompletionException buildChatCompletionError( * @param restStatus the REST status code of the response * @return an instance of {@link UnifiedChatCompletionException} with details from the error response */ - private static UnifiedChatCompletionException buildDefaultChatCompletionError( + protected static UnifiedChatCompletionException buildDefaultChatCompletionError( ErrorResponse errorResponse, String errorMessage, RestStatus restStatus @@ -217,8 +189,6 @@ private static UnifiedChatCompletionException buildDefaultChatCompletionError( * @param inferenceEntityId the ID of the inference entity * @param message the error message * @param e the exception that caused the error, can be null - * @param errorResponseClassSupplier a supplier that provides the class of the expected error response type - * @param specificErrorBuilder a function that builds a specific error based on the inference entity ID and error response * @param midStreamErrorExtractor a function that extracts the mid-stream error response from the message * @return a {@link UnifiedChatCompletionException} representing the mid-stream error */ @@ -226,22 +196,20 @@ protected UnifiedChatCompletionException buildMidStreamChatCompletionError( String inferenceEntityId, String message, Exception e, - Supplier> errorResponseClassSupplier, - BiFunction specificErrorBuilder, Function midStreamErrorExtractor ) { // Extract the error response from the message using the provided method - var errorResponse = midStreamErrorExtractor.apply(message); + var error = midStreamErrorExtractor.apply(message); // Check if the error response matches the expected type - if (errorResponse.errorStructureFound() && errorResponseClassSupplier.get().isInstance(errorResponse)) { + if (error.errorStructureFound() && error instanceof MidStreamUnifiedChatCompletionExceptionConvertible midStreamError) { // If it matches, we can build a custom mid-stream error exception - return specificErrorBuilder.apply(inferenceEntityId, errorResponse); + return midStreamError.toUnifiedChatCompletionException(inferenceEntityId); } else if (e != null) { // If the error response does not match, we can still return an exception based on the original throwable return UnifiedChatCompletionException.fromThrowable(e); } else { // If no specific error response is found, we return a default mid-stream error - return buildDefaultMidStreamChatCompletionError(inferenceEntityId, errorResponse); + return buildDefaultMidStreamChatCompletionError(inferenceEntityId, error); } } @@ -277,7 +245,7 @@ private static String createErrorType(ErrorResponse errorResponse) { return errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown"; } - private static String extractErrorMessage(String message, Request request, ErrorResponse errorResponse, int statusCode) { + protected static String extractErrorMessage(String message, Request request, ErrorResponse errorResponse, int statusCode) { return (errorResponse == null || errorResponse.errorStructureFound() == false || Strings.isNullOrEmpty(errorResponse.getErrorMessage())) @@ -291,7 +259,7 @@ private static String extractErrorMessage(String message, Request request, Error ); } - public static RestStatus toRestStatus(int statusCode) { + protected static RestStatus toRestStatus(int statusCode) { RestStatus code = null; if (statusCode < 500) { code = RestStatus.fromCode(statusCode); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/MidStreamUnifiedChatCompletionExceptionConvertible.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/MidStreamUnifiedChatCompletionExceptionConvertible.java new file mode 100644 index 0000000000000..60a5a99c6cfed --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/MidStreamUnifiedChatCompletionExceptionConvertible.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.retry; + +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; + +public interface MidStreamUnifiedChatCompletionExceptionConvertible { + + UnifiedChatCompletionException toUnifiedChatCompletionException(String inferenceEntityId); + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/UnifiedChatCompletionExceptionConvertible.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/UnifiedChatCompletionExceptionConvertible.java new file mode 100644 index 0000000000000..672e9f9b2ded5 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/UnifiedChatCompletionExceptionConvertible.java @@ -0,0 +1,17 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.retry; + +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; + +public interface UnifiedChatCompletionExceptionConvertible { + + UnifiedChatCompletionException toUnifiedChatCompletionException(String errorMessage, RestStatus restStatus); + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/StreamingErrorResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/OpenAiStreamingChatCompletionErrorResponse.java similarity index 67% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/StreamingErrorResponse.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/OpenAiStreamingChatCompletionErrorResponse.java index 93e1d6388f357..bca3b12dbae51 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/StreamingErrorResponse.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/OpenAiStreamingChatCompletionErrorResponse.java @@ -8,19 +8,26 @@ package org.elasticsearch.xpack.inference.external.response.streaming; import org.elasticsearch.core.Nullable; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.MidStreamUnifiedChatCompletionExceptionConvertible; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionExceptionConvertible; import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity; import java.util.Objects; import java.util.Optional; +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler.SERVER_ERROR_OBJECT; + /** * Represents an error response from a streaming inference service. * This class extends {@link ErrorResponse} and provides additional fields @@ -38,17 +45,21 @@ * * TODO: {@link ErrorMessageResponseEntity} is nearly identical to this, but doesn't parse as many fields. We must remove the duplication. */ -public class StreamingErrorResponse extends ErrorResponse { +public class OpenAiStreamingChatCompletionErrorResponse extends ErrorResponse + implements + UnifiedChatCompletionExceptionConvertible, + MidStreamUnifiedChatCompletionExceptionConvertible { private static final ConstructingObjectParser, Void> ERROR_PARSER = new ConstructingObjectParser<>( "streaming_error", true, - args -> Optional.ofNullable((StreamingErrorResponse) args[0]) - ); - private static final ConstructingObjectParser ERROR_BODY_PARSER = new ConstructingObjectParser<>( - "streaming_error", - true, - args -> new StreamingErrorResponse((String) args[0], (String) args[1], (String) args[2], (String) args[3]) + args -> Optional.ofNullable((OpenAiStreamingChatCompletionErrorResponse) args[0]) ); + private static final ConstructingObjectParser ERROR_BODY_PARSER = + new ConstructingObjectParser<>( + "streaming_error", + true, + args -> new OpenAiStreamingChatCompletionErrorResponse((String) args[0], (String) args[1], (String) args[2], (String) args[3]) + ); static { ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("message")); @@ -105,13 +116,34 @@ public static ErrorResponse fromString(String response) { private final String param; private final String type; - StreamingErrorResponse(String errorMessage, @Nullable String code, @Nullable String param, String type) { + OpenAiStreamingChatCompletionErrorResponse(String errorMessage, @Nullable String code, @Nullable String param, String type) { super(errorMessage); this.code = code; this.param = param; this.type = Objects.requireNonNull(type); } + @Override + public UnifiedChatCompletionException toUnifiedChatCompletionException(String errorMessage, RestStatus restStatus) { + return new UnifiedChatCompletionException(restStatus, errorMessage, this.type(), this.code(), this.param()); + } + + @Override + public UnifiedChatCompletionException toUnifiedChatCompletionException(String inferenceEntityId) { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format( + "%s for request from inference entity id [%s]. Error message: [%s]", + SERVER_ERROR_OBJECT, + inferenceEntityId, + this.getErrorMessage() + ), + this.type(), + this.code(), + this.param() + ); + } + @Nullable public String code() { return code; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java index 4edd1cef8e05a..0a59681186977 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java @@ -30,6 +30,8 @@ * This handler is designed to work with the unified Elastic Inference Service chat completion API. */ public class ElasticInferenceServiceUnifiedChatCompletionResponseHandler extends ElasticInferenceServiceResponseHandler { + private static final String ERROR_TYPE = "error"; + public ElasticInferenceServiceUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { super(requestType, parseFunction, true); } @@ -59,22 +61,16 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher ErrorResponse.class, - ElasticInferenceServiceUnifiedChatCompletionResponseHandler::buildProviderSpecificChatCompletionError - ); - } + assert request.isStreaming() : "Only streaming requests support this format"; + var statusCode = result.response().getStatusLine().getStatusCode(); + var errorMessage = extractErrorMessage(message, request, errorResponse, statusCode); + var restStatus = toRestStatus(statusCode); - private static UnifiedChatCompletionException buildProviderSpecificChatCompletionError( - ErrorResponse response, - String message, - RestStatus restStatus - ) { - return new UnifiedChatCompletionException(restStatus, message, ERROR_TYPE, restStatus.name().toLowerCase(Locale.ROOT)); + if (errorResponse.errorStructureFound()) { + return new UnifiedChatCompletionException(restStatus, errorMessage, ERROR_TYPE, restStatus.name().toLowerCase(Locale.ROOT)); + } else { + return buildDefaultChatCompletionError(errorResponse, errorMessage, restStatus); + } } /** diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java index b60e074cf3b10..d7c13c730cf15 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java @@ -21,6 +21,8 @@ import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.MidStreamUnifiedChatCompletionExceptionConvertible; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionExceptionConvertible; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor; @@ -58,8 +60,6 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher GoogleVertexAiErrorResponse.class, - GoogleVertexAiUnifiedChatCompletionResponseHandler::buildProviderSpecificMidStreamChatCompletionError, GoogleVertexAiErrorResponse::fromString ) ); @@ -71,51 +71,13 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher GoogleVertexAiErrorResponse.class, - GoogleVertexAiUnifiedChatCompletionResponseHandler::buildProviderSpecificChatCompletionError - ); - } - - private static UnifiedChatCompletionException buildProviderSpecificChatCompletionError( - ErrorResponse errorResponse, - String errorMessage, - RestStatus restStatus - ) { - var vertexAIErrorResponse = (GoogleVertexAiErrorResponse) errorResponse; - return new UnifiedChatCompletionException( - restStatus, - errorMessage, - vertexAIErrorResponse.status(), - String.valueOf(vertexAIErrorResponse.code()), - null - ); - } - - private static UnifiedChatCompletionException buildProviderSpecificMidStreamChatCompletionError( - String inferenceEntityId, - ErrorResponse errorResponse - ) { - var vertexAIErrorResponse = (GoogleVertexAiErrorResponse) errorResponse; - return new UnifiedChatCompletionException( - RestStatus.INTERNAL_SERVER_ERROR, - format( - "%s for request from inference entity id [%s]. Error message: [%s]", - SERVER_ERROR_OBJECT, - inferenceEntityId, - errorResponse.getErrorMessage() - ), - vertexAIErrorResponse.status(), - String.valueOf(vertexAIErrorResponse.code()), - null - ); + return buildChatCompletionError(message, request, result, errorResponse); } - public static class GoogleVertexAiErrorResponse extends ErrorResponse { + public static class GoogleVertexAiErrorResponse extends ErrorResponse + implements + UnifiedChatCompletionExceptionConvertible, + MidStreamUnifiedChatCompletionExceptionConvertible { private static final ConstructingObjectParser, Void> ERROR_PARSER = new ConstructingObjectParser<>( "google_vertex_ai_error_wrapper", true, @@ -174,6 +136,28 @@ static ErrorResponse fromString(String response) { this.status = status; } + @Override + public UnifiedChatCompletionException toUnifiedChatCompletionException(String errorMessage, RestStatus restStatus) { + return new UnifiedChatCompletionException(restStatus, errorMessage, this.status(), String.valueOf(this.code()), null); + } + + @Override + public UnifiedChatCompletionException toUnifiedChatCompletionException(String inferenceEntityId) { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format( + "%s for request from inference entity id [%s]. Error message: [%s]", + SERVER_ERROR_OBJECT, + inferenceEntityId, + this.getErrorMessage() + ), + this.status(), + String.valueOf(this.code()), + null + ); + + } + public int code() { return code; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java index 3ac2c616adb8b..c36e46c450b70 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java @@ -16,14 +16,12 @@ import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; -import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.MidStreamUnifiedChatCompletionExceptionConvertible; import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; -import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceErrorResponseEntity; import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; -import java.util.Locale; import java.util.Optional; import static org.elasticsearch.core.Strings.format; @@ -40,26 +38,6 @@ public HuggingFaceChatCompletionResponseHandler(String requestType, ResponsePars super(requestType, parseFunction, HuggingFaceErrorResponseEntity::fromResponse); } - @Override - protected UnifiedChatCompletionException buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { - return buildChatCompletionError( - message, - request, - result, - errorResponse, - () -> HuggingFaceErrorResponseEntity.class, - HuggingFaceChatCompletionResponseHandler::buildProviderSpecificChatCompletionError - ); - } - - private static UnifiedChatCompletionException buildProviderSpecificChatCompletionError( - ErrorResponse errorResponse, - String errorMessage, - RestStatus restStatus - ) { - return new UnifiedChatCompletionException(restStatus, errorMessage, HUGGING_FACE_ERROR, restStatus.name().toLowerCase(Locale.ROOT)); - } - /** * Builds an error for mid-stream responses from Hugging Face. * This method is called when an error response is received during streaming operations. @@ -71,37 +49,7 @@ private static UnifiedChatCompletionException buildProviderSpecificChatCompletio */ @Override public UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) { - return buildMidStreamChatCompletionError( - inferenceEntityId, - message, - e, - () -> StreamingHuggingFaceErrorResponseEntity.class, - HuggingFaceChatCompletionResponseHandler::buildProviderSpecificMidStreamChatCompletionError, - StreamingHuggingFaceErrorResponseEntity::fromString - ); - } - - private static UnifiedChatCompletionException buildProviderSpecificMidStreamChatCompletionError( - String inferenceEntityId, - ErrorResponse errorResponse - ) { - return new UnifiedChatCompletionException( - RestStatus.INTERNAL_SERVER_ERROR, - format( - "%s for request from inference entity id [%s]. Error message: [%s]", - SERVER_ERROR_OBJECT, - inferenceEntityId, - errorResponse.getErrorMessage() - ), - HUGGING_FACE_ERROR, - extractErrorCode((StreamingHuggingFaceErrorResponseEntity) errorResponse) - ); - } - - private static String extractErrorCode(StreamingHuggingFaceErrorResponseEntity streamingHuggingFaceErrorResponseEntity) { - return streamingHuggingFaceErrorResponseEntity.httpStatusCode() != null - ? String.valueOf(streamingHuggingFaceErrorResponseEntity.httpStatusCode()) - : null; + return buildMidStreamChatCompletionError(inferenceEntityId, message, e, HuggingFaceStreamingErrorResponseEntity::fromString); } /** @@ -116,17 +64,19 @@ private static String extractErrorCode(StreamingHuggingFaceErrorResponseEntity s * } * */ - private static class StreamingHuggingFaceErrorResponseEntity extends ErrorResponse { + private static class HuggingFaceStreamingErrorResponseEntity extends ErrorResponse + implements + MidStreamUnifiedChatCompletionExceptionConvertible { private static final ConstructingObjectParser, Void> ERROR_PARSER = new ConstructingObjectParser<>( HUGGING_FACE_ERROR, true, - args -> Optional.ofNullable((StreamingHuggingFaceErrorResponseEntity) args[0]) + args -> Optional.ofNullable((HuggingFaceStreamingErrorResponseEntity) args[0]) ); - private static final ConstructingObjectParser ERROR_BODY_PARSER = + private static final ConstructingObjectParser ERROR_BODY_PARSER = new ConstructingObjectParser<>( HUGGING_FACE_ERROR, true, - args -> new StreamingHuggingFaceErrorResponseEntity(args[0] != null ? (String) args[0] : "unknown", (Integer) args[1]) + args -> new HuggingFaceStreamingErrorResponseEntity(args[0] != null ? (String) args[0] : "unknown", (Integer) args[1]) ); static { @@ -163,7 +113,7 @@ private static ErrorResponse fromString(String response) { @Nullable private final Integer httpStatusCode; - StreamingHuggingFaceErrorResponseEntity(String errorMessage, @Nullable Integer httpStatusCode) { + HuggingFaceStreamingErrorResponseEntity(String errorMessage, @Nullable Integer httpStatusCode) { super(errorMessage); this.httpStatusCode = httpStatusCode; } @@ -173,5 +123,20 @@ public Integer httpStatusCode() { return httpStatusCode; } + @Override + public UnifiedChatCompletionException toUnifiedChatCompletionException(String inferenceEntityId) { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format( + "%s for request from inference entity id [%s]. Error message: [%s]", + SERVER_ERROR_OBJECT, + inferenceEntityId, + this.getErrorMessage() + ), + HUGGING_FACE_ERROR, + this.httpStatusCode() != null ? String.valueOf(this.httpStatusCode()) : null + ); + + } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceErrorResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceErrorResponseEntity.java index d30a60c341a58..f79d9b9e69079 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceErrorResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceErrorResponseEntity.java @@ -7,14 +7,20 @@ package org.elasticsearch.xpack.inference.services.huggingface.response; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionExceptionConvertible; -public class HuggingFaceErrorResponseEntity extends ErrorResponse { +import java.util.Locale; + +public class HuggingFaceErrorResponseEntity extends ErrorResponse implements UnifiedChatCompletionExceptionConvertible { + private static final String HUGGING_FACE_ERROR = "hugging_face_error"; public HuggingFaceErrorResponseEntity(String message) { super(message); @@ -52,4 +58,9 @@ public static ErrorResponse fromResponse(HttpResult response) { return ErrorResponse.UNDEFINED_ERROR; } + + @Override + public UnifiedChatCompletionException toUnifiedChatCompletionException(String errorMessage, RestStatus restStatus) { + return new UnifiedChatCompletionException(restStatus, errorMessage, HUGGING_FACE_ERROR, restStatus.name().toLowerCase(Locale.ROOT)); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonUnifiedChatCompletionResponseHandler.java index 0b18b4c6e7035..6eeac36dd4370 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonUnifiedChatCompletionResponseHandler.java @@ -7,46 +7,18 @@ package org.elasticsearch.xpack.inference.services.ibmwatsonx; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; -import org.elasticsearch.xpack.inference.external.http.HttpResult; -import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; -import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.ibmwatsonx.response.IbmWatsonxErrorResponseEntity; import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; -import java.util.Locale; - /** * Handles streaming chat completion responses and error parsing for Watsonx inference endpoints. * Adapts the OpenAI handler to support Watsonx's error schema. */ public class IbmWatsonUnifiedChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler { - private static final String WATSONX_ERROR = "watsonx_error"; - public IbmWatsonUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { super(requestType, parseFunction, IbmWatsonxErrorResponseEntity::fromResponse); } - @Override - protected UnifiedChatCompletionException buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { - return buildChatCompletionError( - message, - request, - result, - errorResponse, - () -> IbmWatsonxErrorResponseEntity.class, - IbmWatsonUnifiedChatCompletionResponseHandler::buildProviderSpecificChatCompletionError - ); - } - - private static UnifiedChatCompletionException buildProviderSpecificChatCompletionError( - ErrorResponse errorResponse, - String errorMessage, - RestStatus restStatus - ) { - return new UnifiedChatCompletionException(restStatus, errorMessage, WATSONX_ERROR, restStatus.name().toLowerCase(Locale.ROOT)); - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxErrorResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxErrorResponseEntity.java index 012283d54be89..d4f24ac07651a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxErrorResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxErrorResponseEntity.java @@ -7,17 +7,23 @@ package org.elasticsearch.xpack.inference.services.ibmwatsonx.response; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionExceptionConvertible; +import java.util.Locale; import java.util.Map; import java.util.Objects; -public class IbmWatsonxErrorResponseEntity extends ErrorResponse { +public class IbmWatsonxErrorResponseEntity extends ErrorResponse implements UnifiedChatCompletionExceptionConvertible { + + private static final String WATSONX_ERROR = "watsonx_error"; private IbmWatsonxErrorResponseEntity(String errorMessage) { super(errorMessage); @@ -41,4 +47,9 @@ public static ErrorResponse fromResponse(HttpResult response) { return ErrorResponse.UNDEFINED_ERROR; } + + @Override + public UnifiedChatCompletionException toUnifiedChatCompletionException(String errorMessage, RestStatus restStatus) { + return new UnifiedChatCompletionException(restStatus, errorMessage, WATSONX_ERROR, restStatus.name().toLowerCase(Locale.ROOT)); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java index aea2308dc4aa2..9cf705b9ca9da 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java @@ -7,46 +7,18 @@ package org.elasticsearch.xpack.inference.services.mistral; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; -import org.elasticsearch.xpack.inference.external.http.HttpResult; -import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; -import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.mistral.response.MistralErrorResponse; import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; -import java.util.Locale; - /** * Handles streaming chat completion responses and error parsing for Mistral inference endpoints. * Adapts the OpenAI handler to support Mistral's error schema. */ public class MistralUnifiedChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler { - private static final String MISTRAL_ERROR = "mistral_error"; - public MistralUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { super(requestType, parseFunction, MistralErrorResponse::fromResponse); } - @Override - protected UnifiedChatCompletionException buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { - return buildChatCompletionError( - message, - request, - result, - errorResponse, - () -> MistralErrorResponse.class, - MistralUnifiedChatCompletionResponseHandler::buildProviderSpecificChatCompletionError - ); - } - - private static UnifiedChatCompletionException buildProviderSpecificChatCompletionError( - ErrorResponse errorResponse, - String errorMessage, - RestStatus restStatus - ) { - return new UnifiedChatCompletionException(restStatus, errorMessage, MISTRAL_ERROR, restStatus.name().toLowerCase(Locale.ROOT)); - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/response/MistralErrorResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/response/MistralErrorResponse.java index 02dfb746fae53..5a6f1260f6ec0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/response/MistralErrorResponse.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/response/MistralErrorResponse.java @@ -7,10 +7,14 @@ package org.elasticsearch.xpack.inference.services.mistral.response; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionExceptionConvertible; import java.nio.charset.StandardCharsets; +import java.util.Locale; /** * Represents an error response entity for Mistral inference services. @@ -65,12 +69,19 @@ * } * */ -public class MistralErrorResponse extends ErrorResponse { +public class MistralErrorResponse extends ErrorResponse implements UnifiedChatCompletionExceptionConvertible { + + private static final String MISTRAL_ERROR = "mistral_error"; public MistralErrorResponse(String message) { super(message); } + @Override + public UnifiedChatCompletionException toUnifiedChatCompletionException(String errorMessage, RestStatus restStatus) { + return new UnifiedChatCompletionException(restStatus, errorMessage, MISTRAL_ERROR, restStatus.name().toLowerCase(Locale.ROOT)); + } + /** * Creates an ErrorResponse from the given HttpResult. * Attempts to read the body as a UTF-8 string and constructs a MistralErrorResponseEntity. diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java index 8f9d707581b4e..4c3588443501c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java @@ -8,29 +8,26 @@ package org.elasticsearch.xpack.inference.services.openai; import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.response.streaming.OpenAiStreamingChatCompletionErrorResponse; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor; -import org.elasticsearch.xpack.inference.external.response.streaming.StreamingErrorResponse; import java.util.concurrent.Flow; import java.util.function.Function; -import static org.elasticsearch.core.Strings.format; - /** * Handles streaming chat completion responses and error parsing for OpenAI inference endpoints. * This handler is designed to work with the unified OpenAI chat completion API. */ public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatCompletionResponseHandler { public OpenAiUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { - super(requestType, parseFunction, StreamingErrorResponse::fromResponse); + super(requestType, parseFunction, OpenAiStreamingChatCompletionErrorResponse::fromResponse); } public OpenAiUnifiedChatCompletionResponseHandler( @@ -54,38 +51,7 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher StreamingErrorResponse.class, - OpenAiUnifiedChatCompletionResponseHandler::buildProviderSpecificChatCompletionError - ); - } - - /** - * Builds a custom {@link UnifiedChatCompletionException} for OpenAI inference endpoints. - * This method is called when an error response is received. - * - * @param errorResponse the parsed error response from the service - * @param errorMessage the error message received - * @param restStatus the HTTP status code of the error - * @return an instance of {@link UnifiedChatCompletionException} with details from the error response - */ - private static UnifiedChatCompletionException buildProviderSpecificChatCompletionError( - ErrorResponse errorResponse, - String errorMessage, - RestStatus restStatus - ) { - var streamingError = (StreamingErrorResponse) errorResponse; - return new UnifiedChatCompletionException( - restStatus, - errorMessage, - streamingError.type(), - streamingError.code(), - streamingError.param() - ); + return buildChatCompletionError(message, request, result, errorResponse); } /** @@ -99,40 +65,6 @@ private static UnifiedChatCompletionException buildProviderSpecificChatCompletio */ public UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) { // Use the custom type StreamingErrorResponse for mid-stream errors - return buildMidStreamChatCompletionError( - inferenceEntityId, - message, - e, - () -> StreamingErrorResponse.class, - OpenAiUnifiedChatCompletionResponseHandler::buildProviderSpecificMidStreamChatCompletionError, - StreamingErrorResponse::fromString - ); - } - - /** - * Builds a custom mid-stream {@link UnifiedChatCompletionException} for OpenAI inference endpoints. - * This method is called when an error response is received during streaming. - * - * @param inferenceEntityId the ID of the inference entity - * @param errorResponse the parsed error response from the service - * @return an instance of {@link UnifiedChatCompletionException} with details from the error response - */ - private static UnifiedChatCompletionException buildProviderSpecificMidStreamChatCompletionError( - String inferenceEntityId, - ErrorResponse errorResponse - ) { - var streamingError = (StreamingErrorResponse) errorResponse; - return new UnifiedChatCompletionException( - RestStatus.INTERNAL_SERVER_ERROR, - format( - "%s for request from inference entity id [%s]. Error message: [%s]", - SERVER_ERROR_OBJECT, - inferenceEntityId, - streamingError.getErrorMessage() - ), - streamingError.type(), - streamingError.code(), - streamingError.param() - ); + return buildMidStreamChatCompletionError(inferenceEntityId, message, e, OpenAiStreamingChatCompletionErrorResponse::fromString); } } From 8cc7d72f2994ffe70e5af25bc9fb32e4f149d07b Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Thu, 10 Jul 2025 01:21:38 +0300 Subject: [PATCH 7/7] Remove redundant class --- .../retry/ChatCompletionErrorBuilder.java | 33 ------------------- 1 file changed, 33 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ChatCompletionErrorBuilder.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ChatCompletionErrorBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ChatCompletionErrorBuilder.java deleted file mode 100644 index c083fd9b4a82b..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ChatCompletionErrorBuilder.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.http.retry; - -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; - -/** - * Functional interface for building provider-specific chat completion errors. - * This interface is used to create exceptions that are specific to the chat completion service being used. - */ -@FunctionalInterface -public interface ChatCompletionErrorBuilder { - - /** - * Builds a provider-specific chat completion error based on the given parameters. - * - * @param errorResponse The error response received from the service. - * @param errorMessage A custom error message to include in the exception. - * @param restStatus The HTTP status code associated with the error. - * @return An instance of {@link UnifiedChatCompletionException} representing the error. - */ - UnifiedChatCompletionException buildProviderSpecificChatCompletionError( - ErrorResponse errorResponse, - String errorMessage, - RestStatus restStatus - ); -}