diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java index 56e994be86eb4..74cfe37c8a3aa 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java @@ -12,10 +12,12 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import java.util.Locale; import java.util.Objects; import java.util.function.Function; @@ -34,17 +36,22 @@ public abstract class BaseResponseHandler implements ResponseHandler { public static final String SERVER_ERROR_OBJECT = "Received an error response"; public static final String BAD_REQUEST = "Received a bad request status code"; public static final String METHOD_NOT_ALLOWED = "Received a method not allowed status code"; + protected static final String STREAM_ERROR = "stream_error"; protected final String requestType; protected final ResponseParser parseFunction; private final Function errorParseFunction; private final boolean canHandleStreamingResponses; - public BaseResponseHandler(String requestType, ResponseParser parseFunction, Function errorParseFunction) { + protected BaseResponseHandler( + String requestType, + ResponseParser parseFunction, + Function errorParseFunction + ) { this(requestType, parseFunction, errorParseFunction, false); } - public BaseResponseHandler( + protected BaseResponseHandler( String requestType, ResponseParser parseFunction, Function errorParseFunction, @@ -109,19 +116,136 @@ private void checkForErrorObject(Request request, HttpResult result) { } protected Exception buildError(String message, Request request, HttpResult result) { - var errorEntityMsg = errorParseFunction.apply(result); - return buildError(message, request, result, errorEntityMsg); + var errorResponse = errorParseFunction.apply(result); + return buildError(message, request, result, errorResponse); } protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { var responseStatusCode = result.response().getStatusLine().getStatusCode(); return new ElasticsearchStatusException( - errorMessage(message, request, result, errorResponse, responseStatusCode), + extractErrorMessage(message, request, errorResponse, responseStatusCode), toRestStatus(responseStatusCode) ); } - protected String errorMessage(String message, Request request, HttpResult result, ErrorResponse errorResponse, int statusCode) { + /** + * Builds an error for a streaming request with a custom error type. + * This method is used when an error response is received from the external service. + * Only streaming requests support this format, and it should be used when the error response. + * + * @param message the error message to include in the exception + * @param request the request that caused the error + * @param result the HTTP result containing the error response + * @param errorResponse the parsed error response from the HTTP result + * @return an instance of {@link UnifiedChatCompletionException} with details from the error response + */ + protected UnifiedChatCompletionException buildChatCompletionError( + String message, + Request request, + HttpResult result, + ErrorResponse errorResponse + ) { + assert request.isStreaming() : "Only streaming requests support this format"; + var statusCode = result.response().getStatusLine().getStatusCode(); + var errorMessage = extractErrorMessage(message, request, errorResponse, statusCode); + var restStatus = toRestStatus(statusCode); + + if (errorResponse.errorStructureFound() + && errorResponse instanceof UnifiedChatCompletionExceptionConvertible chatCompletionExceptionConvertible) { + return chatCompletionExceptionConvertible.toUnifiedChatCompletionException(errorMessage, restStatus); + } else { + return buildDefaultChatCompletionError(errorResponse, errorMessage, restStatus); + } + } + + /** + * Builds a default {@link UnifiedChatCompletionException} for a streaming request. + * This method is used when an error response is received but no specific error handling is implemented. + * Only streaming requests should use this method. + * + * @param errorResponse the error response parsed from the HTTP result + * @param errorMessage the error message to include in the exception + * @param restStatus the REST status code of the response + * @return an instance of {@link UnifiedChatCompletionException} with details from the error response + */ + protected static UnifiedChatCompletionException buildDefaultChatCompletionError( + ErrorResponse errorResponse, + String errorMessage, + RestStatus restStatus + ) { + return new UnifiedChatCompletionException( + restStatus, + errorMessage, + createErrorType(errorResponse), + restStatus.name().toLowerCase(Locale.ROOT) + ); + } + + /** + * Builds a mid-stream error for a streaming request with a custom error type. + * This method is used when an error occurs while processing a streaming response and allows for custom error handling. + * Only streaming requests should use this method. + * + * @param inferenceEntityId the ID of the inference entity + * @param message the error message + * @param e the exception that caused the error, can be null + * @param midStreamErrorExtractor a function that extracts the mid-stream error response from the message + * @return a {@link UnifiedChatCompletionException} representing the mid-stream error + */ + protected UnifiedChatCompletionException buildMidStreamChatCompletionError( + String inferenceEntityId, + String message, + Exception e, + Function midStreamErrorExtractor + ) { + // Extract the error response from the message using the provided method + var error = midStreamErrorExtractor.apply(message); + // Check if the error response matches the expected type + if (error.errorStructureFound() && error instanceof MidStreamUnifiedChatCompletionExceptionConvertible midStreamError) { + // If it matches, we can build a custom mid-stream error exception + return midStreamError.toUnifiedChatCompletionException(inferenceEntityId); + } else if (e != null) { + // If the error response does not match, we can still return an exception based on the original throwable + return UnifiedChatCompletionException.fromThrowable(e); + } else { + // If no specific error response is found, we return a default mid-stream error + return buildDefaultMidStreamChatCompletionError(inferenceEntityId, error); + } + } + + /** + * Builds a default mid-stream error for a streaming request. + * This method is used when no specific error response is found in the message. + * Only streaming requests should use this method. + * + * @param inferenceEntityId the ID of the inference entity + * @param errorResponse the error response extracted from the message + * @return a {@link UnifiedChatCompletionException} representing the default mid-stream error + */ + protected static UnifiedChatCompletionException buildDefaultMidStreamChatCompletionError( + String inferenceEntityId, + ErrorResponse errorResponse + ) { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, inferenceEntityId), + createErrorType(errorResponse), + STREAM_ERROR + ); + } + + /** + * Creates a string representation of the error type based on the provided ErrorResponse. + * This method is used to generate a human-readable error type for logging or exception messages. + * + * @param errorResponse the ErrorResponse object + * @return a string representing the error type + */ + private static String createErrorType(ErrorResponse errorResponse) { + return errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown"; + } + + protected static String extractErrorMessage(String message, Request request, ErrorResponse errorResponse, int statusCode) { return (errorResponse == null || errorResponse.errorStructureFound() == false || Strings.isNullOrEmpty(errorResponse.getErrorMessage())) @@ -135,7 +259,7 @@ protected String errorMessage(String message, Request request, HttpResult result ); } - public static RestStatus toRestStatus(int statusCode) { + protected static RestStatus toRestStatus(int statusCode) { RestStatus code = null; if (statusCode < 500) { code = RestStatus.fromCode(statusCode); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/MidStreamUnifiedChatCompletionExceptionConvertible.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/MidStreamUnifiedChatCompletionExceptionConvertible.java new file mode 100644 index 0000000000000..60a5a99c6cfed --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/MidStreamUnifiedChatCompletionExceptionConvertible.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.retry; + +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; + +public interface MidStreamUnifiedChatCompletionExceptionConvertible { + + UnifiedChatCompletionException toUnifiedChatCompletionException(String inferenceEntityId); + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/UnifiedChatCompletionExceptionConvertible.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/UnifiedChatCompletionExceptionConvertible.java new file mode 100644 index 0000000000000..672e9f9b2ded5 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/UnifiedChatCompletionExceptionConvertible.java @@ -0,0 +1,17 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.retry; + +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; + +public interface UnifiedChatCompletionExceptionConvertible { + + UnifiedChatCompletionException toUnifiedChatCompletionException(String errorMessage, RestStatus restStatus); + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/StreamingErrorResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/OpenAiStreamingChatCompletionErrorResponse.java similarity index 67% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/StreamingErrorResponse.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/OpenAiStreamingChatCompletionErrorResponse.java index 93e1d6388f357..bca3b12dbae51 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/StreamingErrorResponse.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/OpenAiStreamingChatCompletionErrorResponse.java @@ -8,19 +8,26 @@ package org.elasticsearch.xpack.inference.external.response.streaming; import org.elasticsearch.core.Nullable; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.MidStreamUnifiedChatCompletionExceptionConvertible; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionExceptionConvertible; import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity; import java.util.Objects; import java.util.Optional; +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler.SERVER_ERROR_OBJECT; + /** * Represents an error response from a streaming inference service. * This class extends {@link ErrorResponse} and provides additional fields @@ -38,17 +45,21 @@ * * TODO: {@link ErrorMessageResponseEntity} is nearly identical to this, but doesn't parse as many fields. We must remove the duplication. */ -public class StreamingErrorResponse extends ErrorResponse { +public class OpenAiStreamingChatCompletionErrorResponse extends ErrorResponse + implements + UnifiedChatCompletionExceptionConvertible, + MidStreamUnifiedChatCompletionExceptionConvertible { private static final ConstructingObjectParser, Void> ERROR_PARSER = new ConstructingObjectParser<>( "streaming_error", true, - args -> Optional.ofNullable((StreamingErrorResponse) args[0]) - ); - private static final ConstructingObjectParser ERROR_BODY_PARSER = new ConstructingObjectParser<>( - "streaming_error", - true, - args -> new StreamingErrorResponse((String) args[0], (String) args[1], (String) args[2], (String) args[3]) + args -> Optional.ofNullable((OpenAiStreamingChatCompletionErrorResponse) args[0]) ); + private static final ConstructingObjectParser ERROR_BODY_PARSER = + new ConstructingObjectParser<>( + "streaming_error", + true, + args -> new OpenAiStreamingChatCompletionErrorResponse((String) args[0], (String) args[1], (String) args[2], (String) args[3]) + ); static { ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("message")); @@ -105,13 +116,34 @@ public static ErrorResponse fromString(String response) { private final String param; private final String type; - StreamingErrorResponse(String errorMessage, @Nullable String code, @Nullable String param, String type) { + OpenAiStreamingChatCompletionErrorResponse(String errorMessage, @Nullable String code, @Nullable String param, String type) { super(errorMessage); this.code = code; this.param = param; this.type = Objects.requireNonNull(type); } + @Override + public UnifiedChatCompletionException toUnifiedChatCompletionException(String errorMessage, RestStatus restStatus) { + return new UnifiedChatCompletionException(restStatus, errorMessage, this.type(), this.code(), this.param()); + } + + @Override + public UnifiedChatCompletionException toUnifiedChatCompletionException(String inferenceEntityId) { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format( + "%s for request from inference entity id [%s]. Error message: [%s]", + SERVER_ERROR_OBJECT, + inferenceEntityId, + this.getErrorMessage() + ), + this.type(), + this.code(), + this.param() + ); + } + @Nullable public String code() { return code; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java index ad40d43b3af3b..0a59681186977 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java @@ -25,7 +25,13 @@ import static org.elasticsearch.core.Strings.format; +/** + * Handles streaming chat completion responses and error parsing for Elastic Inference Service endpoints. + * This handler is designed to work with the unified Elastic Inference Service chat completion API. + */ public class ElasticInferenceServiceUnifiedChatCompletionResponseHandler extends ElasticInferenceServiceResponseHandler { + private static final String ERROR_TYPE = "error"; + public ElasticInferenceServiceUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { super(requestType, parseFunction, true); } @@ -34,53 +40,67 @@ public ElasticInferenceServiceUnifiedChatCompletionResponseHandler(String reques public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser()); // EIS uses the unified API spec - var openAiProcessor = new OpenAiUnifiedStreamingProcessor((m, e) -> buildMidStreamError(request, m, e)); + var openAiProcessor = new OpenAiUnifiedStreamingProcessor( + (m, e) -> buildMidStreamChatCompletionError(request.getInferenceEntityId(), m, e) + ); flow.subscribe(serverSentEventProcessor); serverSentEventProcessor.subscribe(openAiProcessor); return new StreamingUnifiedChatCompletionResults(openAiProcessor); } + /** + * Builds an error for the Elastic Inference Service. + * This method is called when an error response is received from the service. + * + * @param message The error message to include in the exception. + * @param request The request that caused the error. + * @param result The HTTP result containing the error response. + * @param errorResponse The parsed error response from the service. + * @return An instance of {@link Exception} representing the error. + */ @Override - protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { + protected UnifiedChatCompletionException buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { assert request.isStreaming() : "Only streaming requests support this format"; - var responseStatusCode = result.response().getStatusLine().getStatusCode(); - if (request.isStreaming()) { - var restStatus = toRestStatus(responseStatusCode); - return new UnifiedChatCompletionException( - restStatus, - errorMessage(message, request, result, errorResponse, responseStatusCode), - "error", - restStatus.name().toLowerCase(Locale.ROOT) - ); + var statusCode = result.response().getStatusLine().getStatusCode(); + var errorMessage = extractErrorMessage(message, request, errorResponse, statusCode); + var restStatus = toRestStatus(statusCode); + + if (errorResponse.errorStructureFound()) { + return new UnifiedChatCompletionException(restStatus, errorMessage, ERROR_TYPE, restStatus.name().toLowerCase(Locale.ROOT)); } else { - return super.buildError(message, request, result, errorResponse); + return buildDefaultChatCompletionError(errorResponse, errorMessage, restStatus); } } - private static Exception buildMidStreamError(Request request, String message, Exception e) { + /** + * Builds a mid-stream error for the Elastic Inference Service. + * This method is called when an error occurs during the streaming process. + * + * @param inferenceEntityId The ID of the inference entity. + * @param message The error message received from the service. + * @param e The exception that occurred, if any. + * @return An instance of {@link UnifiedChatCompletionException} representing the mid-stream error. + */ + private UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) { var errorResponse = ElasticInferenceServiceErrorResponseEntity.fromString(message); + // Check if the error response contains a specific structure if (errorResponse.errorStructureFound()) { return new UnifiedChatCompletionException( RestStatus.INTERNAL_SERVER_ERROR, format( "%s for request from inference entity id [%s]. Error message: [%s]", SERVER_ERROR_OBJECT, - request.getInferenceEntityId(), + inferenceEntityId, errorResponse.getErrorMessage() ), - "error", - "stream_error" + ERROR_TYPE, + STREAM_ERROR ); } else if (e != null) { return UnifiedChatCompletionException.fromThrowable(e); } else { - return new UnifiedChatCompletionException( - RestStatus.INTERNAL_SERVER_ERROR, - format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()), - "error", - "stream_error" - ); + return buildDefaultMidStreamChatCompletionError(inferenceEntityId, errorResponse); } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java index 9e6fdb6eb8bb5..d7c13c730cf15 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java @@ -10,8 +10,6 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.logging.LogManager; -import org.elasticsearch.logging.Logger; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; @@ -23,19 +21,24 @@ import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.MidStreamUnifiedChatCompletionExceptionConvertible; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionExceptionConvertible; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor; import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiCompletionResponseEntity; import java.nio.charset.StandardCharsets; -import java.util.Locale; import java.util.Objects; import java.util.Optional; import java.util.concurrent.Flow; import static org.elasticsearch.core.Strings.format; +/** + * Handles streaming chat completion responses and error parsing for Google Vertex AI inference endpoints. + * This handler is designed to work with the unified Google Vertex AI chat completion API. + */ public class GoogleVertexAiUnifiedChatCompletionResponseHandler extends GoogleVertexAiResponseHandler { private static final String ERROR_FIELD = "error"; @@ -52,7 +55,14 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher buildMidStreamError(request, m, e)); + var googleVertexAiProcessor = new GoogleVertexAiUnifiedStreamingProcessor( + (message, exception) -> buildMidStreamChatCompletionError( + request.getInferenceEntityId(), + message, + exception, + GoogleVertexAiErrorResponse::fromString + ) + ); flow.subscribe(serverSentEventProcessor); serverSentEventProcessor.subscribe(googleVertexAiProcessor); @@ -60,58 +70,14 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher, Void> ERROR_PARSER = new ConstructingObjectParser<>( "google_vertex_ai_error_wrapper", true, @@ -170,6 +136,28 @@ static ErrorResponse fromString(String response) { this.status = status; } + @Override + public UnifiedChatCompletionException toUnifiedChatCompletionException(String errorMessage, RestStatus restStatus) { + return new UnifiedChatCompletionException(restStatus, errorMessage, this.status(), String.valueOf(this.code()), null); + } + + @Override + public UnifiedChatCompletionException toUnifiedChatCompletionException(String inferenceEntityId) { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format( + "%s for request from inference entity id [%s]. Error message: [%s]", + SERVER_ERROR_OBJECT, + inferenceEntityId, + this.getErrorMessage() + ), + this.status(), + String.valueOf(this.code()), + null + ); + + } + public int code() { return code; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java index 8dffd612db5c8..c36e46c450b70 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java @@ -16,14 +16,12 @@ import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; -import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.MidStreamUnifiedChatCompletionExceptionConvertible; import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; -import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceErrorResponseEntity; import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; -import java.util.Locale; import java.util.Optional; import static org.elasticsearch.core.Strings.format; @@ -40,62 +38,18 @@ public HuggingFaceChatCompletionResponseHandler(String requestType, ResponsePars super(requestType, parseFunction, HuggingFaceErrorResponseEntity::fromResponse); } + /** + * Builds an error for mid-stream responses from Hugging Face. + * This method is called when an error response is received during streaming operations. + * + * @param inferenceEntityId The ID of the inference entity that made the request. + * @param message The error message to include in the exception. + * @param e The exception that occurred. + * @return An instance of {@link UnifiedChatCompletionException} representing the error. + */ @Override - protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { - assert request.isStreaming() : "Only streaming requests support this format"; - var responseStatusCode = result.response().getStatusLine().getStatusCode(); - if (request.isStreaming()) { - var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode); - var restStatus = toRestStatus(responseStatusCode); - return errorResponse instanceof HuggingFaceErrorResponseEntity - ? new UnifiedChatCompletionException( - restStatus, - errorMessage, - HUGGING_FACE_ERROR, - restStatus.name().toLowerCase(Locale.ROOT) - ) - : new UnifiedChatCompletionException( - restStatus, - errorMessage, - createErrorType(errorResponse), - restStatus.name().toLowerCase(Locale.ROOT) - ); - } else { - return super.buildError(message, request, result, errorResponse); - } - } - - @Override - protected Exception buildMidStreamError(Request request, String message, Exception e) { - var errorResponse = StreamingHuggingFaceErrorResponseEntity.fromString(message); - if (errorResponse instanceof StreamingHuggingFaceErrorResponseEntity streamingHuggingFaceErrorResponseEntity) { - return new UnifiedChatCompletionException( - RestStatus.INTERNAL_SERVER_ERROR, - format( - "%s for request from inference entity id [%s]. Error message: [%s]", - SERVER_ERROR_OBJECT, - request.getInferenceEntityId(), - errorResponse.getErrorMessage() - ), - HUGGING_FACE_ERROR, - extractErrorCode(streamingHuggingFaceErrorResponseEntity) - ); - } else if (e != null) { - return UnifiedChatCompletionException.fromThrowable(e); - } else { - return new UnifiedChatCompletionException( - RestStatus.INTERNAL_SERVER_ERROR, - format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()), - createErrorType(errorResponse), - "stream_error" - ); - } - } - - private static String extractErrorCode(StreamingHuggingFaceErrorResponseEntity streamingHuggingFaceErrorResponseEntity) { - return streamingHuggingFaceErrorResponseEntity.httpStatusCode() != null - ? String.valueOf(streamingHuggingFaceErrorResponseEntity.httpStatusCode()) - : null; + public UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) { + return buildMidStreamChatCompletionError(inferenceEntityId, message, e, HuggingFaceStreamingErrorResponseEntity::fromString); } /** @@ -110,17 +64,19 @@ private static String extractErrorCode(StreamingHuggingFaceErrorResponseEntity s * } * */ - private static class StreamingHuggingFaceErrorResponseEntity extends ErrorResponse { + private static class HuggingFaceStreamingErrorResponseEntity extends ErrorResponse + implements + MidStreamUnifiedChatCompletionExceptionConvertible { private static final ConstructingObjectParser, Void> ERROR_PARSER = new ConstructingObjectParser<>( HUGGING_FACE_ERROR, true, - args -> Optional.ofNullable((StreamingHuggingFaceErrorResponseEntity) args[0]) + args -> Optional.ofNullable((HuggingFaceStreamingErrorResponseEntity) args[0]) ); - private static final ConstructingObjectParser ERROR_BODY_PARSER = + private static final ConstructingObjectParser ERROR_BODY_PARSER = new ConstructingObjectParser<>( HUGGING_FACE_ERROR, true, - args -> new StreamingHuggingFaceErrorResponseEntity(args[0] != null ? (String) args[0] : "unknown", (Integer) args[1]) + args -> new HuggingFaceStreamingErrorResponseEntity(args[0] != null ? (String) args[0] : "unknown", (Integer) args[1]) ); static { @@ -157,7 +113,7 @@ private static ErrorResponse fromString(String response) { @Nullable private final Integer httpStatusCode; - StreamingHuggingFaceErrorResponseEntity(String errorMessage, @Nullable Integer httpStatusCode) { + HuggingFaceStreamingErrorResponseEntity(String errorMessage, @Nullable Integer httpStatusCode) { super(errorMessage); this.httpStatusCode = httpStatusCode; } @@ -167,5 +123,20 @@ public Integer httpStatusCode() { return httpStatusCode; } + @Override + public UnifiedChatCompletionException toUnifiedChatCompletionException(String inferenceEntityId) { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format( + "%s for request from inference entity id [%s]. Error message: [%s]", + SERVER_ERROR_OBJECT, + inferenceEntityId, + this.getErrorMessage() + ), + HUGGING_FACE_ERROR, + this.httpStatusCode() != null ? String.valueOf(this.httpStatusCode()) : null + ); + + } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceErrorResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceErrorResponseEntity.java index d30a60c341a58..f79d9b9e69079 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceErrorResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceErrorResponseEntity.java @@ -7,14 +7,20 @@ package org.elasticsearch.xpack.inference.services.huggingface.response; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionExceptionConvertible; -public class HuggingFaceErrorResponseEntity extends ErrorResponse { +import java.util.Locale; + +public class HuggingFaceErrorResponseEntity extends ErrorResponse implements UnifiedChatCompletionExceptionConvertible { + private static final String HUGGING_FACE_ERROR = "hugging_face_error"; public HuggingFaceErrorResponseEntity(String message) { super(message); @@ -52,4 +58,9 @@ public static ErrorResponse fromResponse(HttpResult response) { return ErrorResponse.UNDEFINED_ERROR; } + + @Override + public UnifiedChatCompletionException toUnifiedChatCompletionException(String errorMessage, RestStatus restStatus) { + return new UnifiedChatCompletionException(restStatus, errorMessage, HUGGING_FACE_ERROR, restStatus.name().toLowerCase(Locale.ROOT)); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonUnifiedChatCompletionResponseHandler.java index 41b82bbf2cd02..6eeac36dd4370 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonUnifiedChatCompletionResponseHandler.java @@ -7,45 +7,18 @@ package org.elasticsearch.xpack.inference.services.ibmwatsonx; -import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; -import org.elasticsearch.xpack.inference.external.http.HttpResult; -import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; -import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.ibmwatsonx.response.IbmWatsonxErrorResponseEntity; import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; -import java.util.Locale; - /** * Handles streaming chat completion responses and error parsing for Watsonx inference endpoints. * Adapts the OpenAI handler to support Watsonx's error schema. */ public class IbmWatsonUnifiedChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler { - private static final String WATSONX_ERROR = "watsonx_error"; - public IbmWatsonUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { super(requestType, parseFunction, IbmWatsonxErrorResponseEntity::fromResponse); } - @Override - protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { - assert request.isStreaming() : "Only streaming requests support this format"; - var responseStatusCode = result.response().getStatusLine().getStatusCode(); - if (request.isStreaming()) { - var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode); - var restStatus = toRestStatus(responseStatusCode); - return errorResponse instanceof IbmWatsonxErrorResponseEntity - ? new UnifiedChatCompletionException(restStatus, errorMessage, WATSONX_ERROR, restStatus.name().toLowerCase(Locale.ROOT)) - : new UnifiedChatCompletionException( - restStatus, - errorMessage, - createErrorType(errorResponse), - restStatus.name().toLowerCase(Locale.ROOT) - ); - } else { - return super.buildError(message, request, result, errorResponse); - } - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxErrorResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxErrorResponseEntity.java index 012283d54be89..d4f24ac07651a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxErrorResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxErrorResponseEntity.java @@ -7,17 +7,23 @@ package org.elasticsearch.xpack.inference.services.ibmwatsonx.response; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionExceptionConvertible; +import java.util.Locale; import java.util.Map; import java.util.Objects; -public class IbmWatsonxErrorResponseEntity extends ErrorResponse { +public class IbmWatsonxErrorResponseEntity extends ErrorResponse implements UnifiedChatCompletionExceptionConvertible { + + private static final String WATSONX_ERROR = "watsonx_error"; private IbmWatsonxErrorResponseEntity(String errorMessage) { super(errorMessage); @@ -41,4 +47,9 @@ public static ErrorResponse fromResponse(HttpResult response) { return ErrorResponse.UNDEFINED_ERROR; } + + @Override + public UnifiedChatCompletionException toUnifiedChatCompletionException(String errorMessage, RestStatus restStatus) { + return new UnifiedChatCompletionException(restStatus, errorMessage, WATSONX_ERROR, restStatus.name().toLowerCase(Locale.ROOT)); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java index a9d6df687fe99..9cf705b9ca9da 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java @@ -7,45 +7,18 @@ package org.elasticsearch.xpack.inference.services.mistral; -import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; -import org.elasticsearch.xpack.inference.external.http.HttpResult; -import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; -import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.mistral.response.MistralErrorResponse; import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; -import java.util.Locale; - /** * Handles streaming chat completion responses and error parsing for Mistral inference endpoints. * Adapts the OpenAI handler to support Mistral's error schema. */ public class MistralUnifiedChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler { - private static final String MISTRAL_ERROR = "mistral_error"; - public MistralUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { super(requestType, parseFunction, MistralErrorResponse::fromResponse); } - @Override - protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { - assert request.isStreaming() : "Only streaming requests support this format"; - var responseStatusCode = result.response().getStatusLine().getStatusCode(); - if (request.isStreaming()) { - var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode); - var restStatus = toRestStatus(responseStatusCode); - return errorResponse instanceof MistralErrorResponse - ? new UnifiedChatCompletionException(restStatus, errorMessage, MISTRAL_ERROR, restStatus.name().toLowerCase(Locale.ROOT)) - : new UnifiedChatCompletionException( - restStatus, - errorMessage, - createErrorType(errorResponse), - restStatus.name().toLowerCase(Locale.ROOT) - ); - } else { - return super.buildError(message, request, result, errorResponse); - } - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/response/MistralErrorResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/response/MistralErrorResponse.java index 02dfb746fae53..5a6f1260f6ec0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/response/MistralErrorResponse.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/response/MistralErrorResponse.java @@ -7,10 +7,14 @@ package org.elasticsearch.xpack.inference.services.mistral.response; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionExceptionConvertible; import java.nio.charset.StandardCharsets; +import java.util.Locale; /** * Represents an error response entity for Mistral inference services. @@ -65,12 +69,19 @@ * } * */ -public class MistralErrorResponse extends ErrorResponse { +public class MistralErrorResponse extends ErrorResponse implements UnifiedChatCompletionExceptionConvertible { + + private static final String MISTRAL_ERROR = "mistral_error"; public MistralErrorResponse(String message) { super(message); } + @Override + public UnifiedChatCompletionException toUnifiedChatCompletionException(String errorMessage, RestStatus restStatus) { + return new UnifiedChatCompletionException(restStatus, errorMessage, MISTRAL_ERROR, restStatus.name().toLowerCase(Locale.ROOT)); + } + /** * Creates an ErrorResponse from the given HttpResult. * Attempts to read the body as a UTF-8 string and constructs a MistralErrorResponseEntity. diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java index e1a0117c7bcca..4c3588443501c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java @@ -8,30 +8,26 @@ package org.elasticsearch.xpack.inference.services.openai; import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.response.streaming.OpenAiStreamingChatCompletionErrorResponse; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor; -import org.elasticsearch.xpack.inference.external.response.streaming.StreamingErrorResponse; -import java.util.Locale; import java.util.concurrent.Flow; import java.util.function.Function; -import static org.elasticsearch.core.Strings.format; - /** * Handles streaming chat completion responses and error parsing for OpenAI inference endpoints. * This handler is designed to work with the unified OpenAI chat completion API. */ public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatCompletionResponseHandler { public OpenAiUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { - super(requestType, parseFunction, StreamingErrorResponse::fromResponse); + super(requestType, parseFunction, OpenAiStreamingChatCompletionErrorResponse::fromResponse); } public OpenAiUnifiedChatCompletionResponseHandler( @@ -45,64 +41,30 @@ public OpenAiUnifiedChatCompletionResponseHandler( @Override public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser()); - var openAiProcessor = new OpenAiUnifiedStreamingProcessor((m, e) -> buildMidStreamError(request, m, e)); + var openAiProcessor = new OpenAiUnifiedStreamingProcessor( + (m, e) -> buildMidStreamChatCompletionError(request.getInferenceEntityId(), m, e) + ); flow.subscribe(serverSentEventProcessor); serverSentEventProcessor.subscribe(openAiProcessor); return new StreamingUnifiedChatCompletionResults(openAiProcessor); } @Override - protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { - assert request.isStreaming() : "Only streaming requests support this format"; - var responseStatusCode = result.response().getStatusLine().getStatusCode(); - if (request.isStreaming()) { - var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode); - var restStatus = toRestStatus(responseStatusCode); - return errorResponse instanceof StreamingErrorResponse oer - ? new UnifiedChatCompletionException(restStatus, errorMessage, oer.type(), oer.code(), oer.param()) - : new UnifiedChatCompletionException( - restStatus, - errorMessage, - createErrorType(errorResponse), - restStatus.name().toLowerCase(Locale.ROOT) - ); - } else { - return super.buildError(message, request, result, errorResponse); - } - } - - protected static String createErrorType(ErrorResponse errorResponse) { - return errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown"; - } - - protected Exception buildMidStreamError(Request request, String message, Exception e) { - return buildMidStreamError(request.getInferenceEntityId(), message, e); + protected UnifiedChatCompletionException buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { + return buildChatCompletionError(message, request, result, errorResponse); } - public static UnifiedChatCompletionException buildMidStreamError(String inferenceEntityId, String message, Exception e) { - var errorResponse = StreamingErrorResponse.fromString(message); - if (errorResponse instanceof StreamingErrorResponse oer) { - return new UnifiedChatCompletionException( - RestStatus.INTERNAL_SERVER_ERROR, - format( - "%s for request from inference entity id [%s]. Error message: [%s]", - SERVER_ERROR_OBJECT, - inferenceEntityId, - errorResponse.getErrorMessage() - ), - oer.type(), - oer.code(), - oer.param() - ); - } else if (e != null) { - return UnifiedChatCompletionException.fromThrowable(e); - } else { - return new UnifiedChatCompletionException( - RestStatus.INTERNAL_SERVER_ERROR, - format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, inferenceEntityId), - createErrorType(errorResponse), - "stream_error" - ); - } + /** + * Builds a custom mid-stream {@link UnifiedChatCompletionException} for OpenAI inference endpoints. + * This method is called when an error response is received during streaming. + * + * @param inferenceEntityId the ID of the inference entity + * @param message the error message received during streaming + * @param e the exception that occurred + * @return an instance of {@link UnifiedChatCompletionException} with details from the error response + */ + public UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) { + // Use the custom type StreamingErrorResponse for mid-stream errors + return buildMidStreamChatCompletionError(inferenceEntityId, message, e, OpenAiStreamingChatCompletionErrorResponse::fromString); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticCompletionPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticCompletionPayload.java index a8c4f7c57796b..1d59b6e03c9e0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticCompletionPayload.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticCompletionPayload.java @@ -40,7 +40,14 @@ * Each chunk should be in a valid JSON format, as that is the format the Elastic API uses. */ public class ElasticCompletionPayload implements SageMakerStreamSchemaPayload, ElasticPayload { - private static final XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler( + private static final OpenAiUnifiedChatCompletionResponseHandler ERROR_HANDLER = new OpenAiUnifiedChatCompletionResponseHandler( + "sagemaker openai chat completion", + ((request, result) -> { + assert false : "do not call this"; + throw new UnsupportedOperationException("SageMaker should not call this object's response parser."); + }) + ); + private static final XContentParserConfiguration PARSER_CONFIGURATION = XContentParserConfiguration.EMPTY.withDeprecationHandler( LoggingDeprecationHandler.INSTANCE ); @@ -94,7 +101,7 @@ public SdkBytes chatCompletionRequestBytes(SageMakerModel model, UnifiedCompleti public StreamingUnifiedChatCompletionResults.Results chatCompletionResponseBody(SageMakerModel model, SdkBytes response) { var responseData = response.asUtf8String(); try { - var results = OpenAiUnifiedStreamingProcessor.parse(parserConfig, responseData) + var results = OpenAiUnifiedStreamingProcessor.parse(PARSER_CONFIGURATION, responseData) .collect( () -> new ArrayDeque(), ArrayDeque::offer, @@ -102,7 +109,7 @@ public StreamingUnifiedChatCompletionResults.Results chatCompletionResponseBody( ); return new StreamingUnifiedChatCompletionResults.Results(results); } catch (Exception e) { - throw OpenAiUnifiedChatCompletionResponseHandler.buildMidStreamError(model.getInferenceEntityId(), responseData, e); + throw ERROR_HANDLER.buildMidStreamChatCompletionError(model.getInferenceEntityId(), responseData, e); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiCompletionPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiCompletionPayload.java index 03e1941df6938..978338b350ef0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiCompletionPayload.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiCompletionPayload.java @@ -22,7 +22,6 @@ import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; -import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; @@ -40,6 +39,11 @@ import java.util.Map; import java.util.stream.Stream; +/** + * Handles chat completion requests and responses for OpenAI models in SageMaker. + * This class implements the SageMakerStreamSchemaPayload interface to provide + * the necessary methods for handling OpenAI chat completions. + */ public class OpenAiCompletionPayload implements SageMakerStreamSchemaPayload { private static final XContent jsonXContent = JsonXContent.jsonXContent; @@ -50,7 +54,7 @@ public class OpenAiCompletionPayload implements SageMakerStreamSchemaPayload { private static final String USER_FIELD = "user"; private static final String USER_ROLE = "user"; private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens"; - private static final ResponseHandler ERROR_HANDLER = new OpenAiUnifiedChatCompletionResponseHandler( + private static final OpenAiUnifiedChatCompletionResponseHandler ERROR_HANDLER = new OpenAiUnifiedChatCompletionResponseHandler( "sagemaker openai chat completion", ((request, result) -> { assert false : "do not call this"; @@ -88,12 +92,12 @@ public StreamingUnifiedChatCompletionResults.Results chatCompletionResponseBody( var serverSentEvents = serverSentEvents(response); var results = serverSentEvents.flatMap(event -> { if ("error".equals(event.type())) { - throw OpenAiUnifiedChatCompletionResponseHandler.buildMidStreamError(model.getInferenceEntityId(), event.data(), null); + throw ERROR_HANDLER.buildMidStreamChatCompletionError(model.getInferenceEntityId(), event.data(), null); } else { try { return OpenAiUnifiedStreamingProcessor.parse(parserConfig, event); } catch (Exception e) { - throw OpenAiUnifiedChatCompletionResponseHandler.buildMidStreamError(model.getInferenceEntityId(), event.data(), e); + throw ERROR_HANDLER.buildMidStreamChatCompletionError(model.getInferenceEntityId(), event.data(), e); } } })