Skip to content

Remote model inference streaming #3898

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ subprojects {

configurations.all {
// Force spotless depending on newer version of guava due to CVE-2023-2976. Remove after spotless upgrades.
resolutionStrategy.force "com.google.guava:guava:32.1.3-jre"
resolutionStrategy.force "com.google.guava:guava:${versions.guava}"
resolutionStrategy.force 'org.apache.commons:commons-compress:1.26.0'
}
}
Expand Down
2 changes: 1 addition & 1 deletion common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dependencies {
compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.11.0'
compileOnly group: 'org.json', name: 'json', version: '20231013'
testImplementation group: 'org.json', name: 'json', version: '20231013'
implementation('com.google.guava:guava:32.1.3-jre') {
implementation ("com.google.guava:guava:${versions.guava}") {
exclude group: 'com.google.guava', module: 'failureaccess'
exclude group: 'com.google.code.findbugs', module: 'jsr305'
exclude group: 'org.checkerframework', module: 'checker-qual'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.apache.commons.text.StringEscapeUtils;
import org.apache.commons.text.StringSubstitutor;
import org.opensearch.Version;
import org.opensearch.common.io.stream.BytesStreamOutput;
Expand All @@ -36,6 +38,9 @@
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;

import com.google.gson.JsonObject;
import com.google.gson.JsonParser;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
Expand Down Expand Up @@ -346,12 +351,40 @@ public <T> T createPayload(String action, Map<String, String> parameters) {

if (!isJson(payload)) {
throw new IllegalArgumentException("Invalid payload: " + payload);
} else if (neededStreamParameterInPayload(parameters)) {
JsonObject jsonObject = JsonParser.parseString(payload).getAsJsonObject();
jsonObject.addProperty("stream", true);
payload = jsonObject.toString();
}
return (T) payload;
}
return (T) parameters.get("http_body");
}

private boolean neededStreamParameterInPayload(Map<String, String> parameters) {
if (parameters == null) {
return false;
}
boolean isStream = parameters.containsKey("stream");
if (!isStream) {
return false;
}

String llmInterface = parameters.get("_llm_interface");
if (llmInterface.isBlank()) {
return false;
}

llmInterface = llmInterface.trim().toLowerCase(Locale.ROOT);
llmInterface = StringEscapeUtils.unescapeJava(llmInterface);
switch (llmInterface) {
case "openai/v1/chat/completions":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see you mentioned Bedrock converse support too in the description. Are we only going with Open AI for now?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The neededStreamParameterInPayload is to check if we need to add "stream": true in http request body as openai does. But Bedrock converseStream doesn't use that parameter, but use a specific url like POST /model/modelId/converse-stream, so no converse here.

return true;
default:
return false;
}
}

protected String fillNullParameters(Map<String, String> parameters, String payload) {
List<String> bodyParams = findStringParametersWithNullDefaultValue(payload);
String newPayload = payload;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,4 +350,8 @@ private MLCommonsSettings() {}
// Feature flag for enabling telemetry static metric collection job -- MLStatsJobProcessor
public static final Setting<Boolean> ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED = Setting
.boolSetting("plugins.ml_commons.metrics_static_collection_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);

/** This setting is to enable/disable streaming feature. */
public static final Setting<Boolean> ML_COMMONS_STREAM_ENABLED = Setting
.boolSetting("plugins.ml_commons.stream_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED;
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED;
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED;
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_STREAM_ENABLED;

import java.util.ArrayList;
import java.util.List;
Expand All @@ -43,6 +44,7 @@ public class MLFeatureEnabledSetting {

// This is to identify if this node is in multi-tenancy or not.
private volatile Boolean isMultiTenancyEnabled;
private volatile Boolean isStreamEnabled;

private volatile Boolean isMcpServerEnabled;

Expand All @@ -66,6 +68,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
isRagSearchPipelineEnabled = ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED.get(settings);
isMetricCollectionEnabled = ML_COMMONS_METRIC_COLLECTION_ENABLED.get(settings);
isStaticMetricCollectionEnabled = ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED.get(settings);
isStreamEnabled = ML_COMMONS_STREAM_ENABLED.get(settings);

clusterService
.getClusterSettings()
Expand Down Expand Up @@ -94,6 +97,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED, it -> isStaticMetricCollectionEnabled = it);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_STREAM_ENABLED, it -> isStreamEnabled = it);
}

/**
Expand Down Expand Up @@ -164,6 +168,13 @@ public boolean isMcpServerEnabled() {
return isMcpServerEnabled;
}

/** Whether the streaming feature is enabled. If disabled, APIs in ml-commons will block stream.
* @return whether the streaming is enabled.
*/
public boolean isStreamEnabled() {
return isStreamEnabled;
}

public void addListener(SettingsChangeListener listener) {
listeners.add(listener);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ public void setUp() {
MLCommonsSettings.ML_COMMONS_MCP_SERVER_ENABLED,
MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED,
MLCommonsSettings.ML_COMMONS_METRIC_COLLECTION_ENABLED,
MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED
MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED,
MLCommonsSettings.ML_COMMONS_STREAM_ENABLED
)
);
when(mockClusterService.getClusterSettings()).thenReturn(mockClusterSettings);
Expand All @@ -65,6 +66,7 @@ public void testDefaults_allFeaturesEnabled() {
.put("plugins.ml_commons.rag_pipeline_feature_enabled", true)
.put("plugins.ml_commons.metrics_collection_enabled", true)
.put("plugins.ml_commons.metrics_static_collection_enabled", true)
.put("plugins.ml_commons.stream_enabled", true)
.build();

MLFeatureEnabledSetting setting = new MLFeatureEnabledSetting(mockClusterService, settings);
Expand All @@ -81,6 +83,7 @@ public void testDefaults_allFeaturesEnabled() {
assertTrue(setting.isRagSearchPipelineEnabled());
assertTrue(setting.isMetricCollectionEnabled());
assertTrue(setting.isStaticMetricCollectionEnabled());
assertTrue(setting.isStreamEnabled());
}

@Test
Expand All @@ -99,6 +102,7 @@ public void testDefaults_someFeaturesDisabled() {
.put("plugins.ml_commons.rag_pipeline_feature_enabled", false)
.put("plugins.ml_commons.metrics_collection_enabled", false)
.put("plugins.ml_commons.metrics_static_collection_enabled", false)
.put("plugins.ml_commons.stream_enabled", false)
.build();

MLFeatureEnabledSetting setting = new MLFeatureEnabledSetting(mockClusterService, settings);
Expand All @@ -115,6 +119,7 @@ public void testDefaults_someFeaturesDisabled() {
assertFalse(setting.isRagSearchPipelineEnabled());
assertFalse(setting.isMetricCollectionEnabled());
assertFalse(setting.isStaticMetricCollectionEnabled());
assertFalse(setting.isStreamEnabled());
}

@Test
Expand Down
9 changes: 7 additions & 2 deletions memory/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,13 @@ dependencies {
implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
implementation group: 'org.apache.httpcomponents.core5', name: 'httpcore5', version: "${versions.httpcore5}"
implementation "org.opensearch:common-utils:${common_utils_version}"
implementation (group: 'com.google.guava', name: 'guava', version: '32.1.3-jre') {
exclude group: 'com.google.errorprone', module: 'error_prone_annotations'
compileOnly (group: 'com.google.guava', name: 'guava', version: "${versions.guava}") {
exclude group: 'com.google.guava', module: 'failureaccess'
exclude group: 'com.google.code.findbugs', module: 'jsr305'
exclude group: 'org.checkerframework', module: 'checker-qual'
exclude group: 'com.google.errorprone', module: 'error_prone_annotations'
exclude group: 'com.google.j2objc', module: 'j2objc-annotations'
exclude group: 'com.google.guava', module: 'listenablefuture'
}
testImplementation (group: 'junit', name: 'junit', version: '4.13.2') {
exclude module : 'hamcrest'
Expand Down
33 changes: 29 additions & 4 deletions ml-algorithms/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,24 @@ plugins {
}

repositories {
mavenLocal()
mavenCentral()
}

dependencies {
implementation project(path: ":${rootProject.name}-spi", configuration: 'shadow')
implementation project(path: ":${rootProject.name}-common", configuration: 'shadow')
implementation project(':opensearch-ml-memory')

compileOnly "org.apache.arrow:arrow-memory-core:${versions.arrow}"
compileOnly "org.apache.arrow:arrow-vector:${versions.arrow}"
compileOnly 'org.checkerframework:checker-qual:3.44.0'
compileOnly "org.slf4j:slf4j-api:${versions.slf4j}"
compileOnly "commons-codec:commons-codec:${versions.commonscodec}"
compileOnly "com.google.errorprone:error_prone_annotations:2.27.0"
compileOnly "com.google.guava:failureaccess:1.0.1"
compileOnly "com.google.guava:guava:${versions.guava}"

compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
implementation "org.opensearch.client:opensearch-rest-client:${opensearch_version}"
testImplementation "org.opensearch.test:framework:${opensearch_version}"
Expand All @@ -41,10 +52,17 @@ dependencies {
implementation group: 'io.protostuff', name: 'protostuff-collectionschema', version: '1.8.0'
testImplementation group: 'junit', name: 'junit', version: '4.13.2'
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.15.2'
implementation (group: 'com.google.guava', name: 'guava', version: '32.1.3-jre') {
exclude group: 'com.google.errorprone', module: 'error_prone_annotations'
compileOnly (group: 'com.google.guava', name: 'guava', version: "${versions.guava}") {
exclude group: 'com.google.errorprone', module: 'error_prone_annotations'
exclude group: 'com.google.guava', module: 'failureaccess'
exclude group: 'com.google.code.findbugs', module: 'jsr305'
exclude group: 'org.checkerframework', module: 'checker-qual'
exclude group: 'com.google.j2objc', module: 'j2objc-annotations'
exclude group: 'com.google.guava', module: 'listenablefuture'
}
implementation (group: 'com.google.code.gson', name: 'gson', version: '2.11.0') {
exclude group: "com.google.errorprone", module: "error_prone_annotations"
}
implementation group: 'com.google.code.gson', name: 'gson', version: '2.11.0'
implementation platform("ai.djl:bom:0.31.1")
implementation group: 'ai.djl.pytorch', name: 'pytorch-model-zoo'
implementation group: 'ai.djl', name: 'api'
Expand All @@ -70,6 +88,9 @@ dependencies {
implementation platform('software.amazon.awssdk:bom:2.30.18')
api 'software.amazon.awssdk:auth:2.30.18'
implementation 'software.amazon.awssdk:apache-client'
implementation ('software.amazon.awssdk:bedrockruntime') {
exclude group: 'io.netty'
}
implementation ('com.amazonaws:aws-encryption-sdk-java:2.4.1') {
exclude group: 'org.bouncycastle', module: 'bcprov-ext-jdk18on'
}
Expand All @@ -84,11 +105,15 @@ dependencies {
}
implementation('net.minidev:json-smart:2.5.2')
implementation group: 'org.json', name: 'json', version: '20231013'
implementation group: 'software.amazon.awssdk', name: 'netty-nio-client', version: "2.30.18"
api('io.modelcontextprotocol.sdk:mcp:0.9.0')
implementation (group: 'software.amazon.awssdk', name: 'netty-nio-client', version: "2.30.18") {
exclude group: 'io.netty'
}
testImplementation("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}")
testImplementation("com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}")
testImplementation group: 'com.networknt' , name: 'json-schema-validator', version: '1.4.0'
api group: 'com.squareup.okhttp3', name: 'okhttp', version: '4.12.0'
implementation group: 'com.squareup.okhttp3', name: 'okhttp-sse', version: '4.12.0'
}

lombok {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import java.util.Locale;
import java.util.Map;

import org.opensearch.arrow.spi.StreamManager;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
Expand All @@ -26,8 +27,10 @@
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.Output;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.threadpool.ThreadPool;

import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

/**
Expand All @@ -50,6 +53,11 @@ public class MLEngine {

private Encryptor encryptor;

@Setter
private StreamManager streamManager;
@Setter
private ThreadPool threadPool;

public MLEngine(Path opensearchDataFolder, Encryptor encryptor) {
this.mlCachePath = opensearchDataFolder.resolve("ml_cache");
this.mlModelsCachePath = mlCachePath.resolve("models_cache");
Expand Down Expand Up @@ -146,7 +154,11 @@ public Map<String, String> getConnectorCredential(Connector connector) {

public Predictable deploy(MLModel mlModel, Map<String, Object> params) {
Predictable predictable = MLEngineClassLoader.initInstance(mlModel.getAlgorithm(), null, MLAlgoParams.class);
predictable.initModel(mlModel, params, encryptor);
if (mlModel.getAlgorithm() == FunctionName.REMOTE) {
predictable.initModel(mlModel, params, encryptor, streamManager, threadPool);
} else {
predictable.initModel(mlModel, params, encryptor);
}
return predictable;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@

import java.util.Map;

import org.opensearch.arrow.spi.StreamManager;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.threadpool.ThreadPool;

/**
* This is machine learning algorithms predict interface.
Expand Down Expand Up @@ -47,7 +49,19 @@ default void asyncPredict(MLInput mlInput, ActionListener<MLTaskResponse> action
* @param params other parameters
* @param encryptor encryptor
*/
void initModel(MLModel model, Map<String, Object> params, Encryptor encryptor);
default void initModel(MLModel model, Map<String, Object> params, Encryptor encryptor) {

};

default void initModel(
MLModel model,
Map<String, Object> params,
Encryptor encryptor,
StreamManager streamManager,
ThreadPool threadPool
) {

};

/**
* Close resources like deployed model.
Expand Down
Loading
Loading