Skip to content

Commit d44ef25

Browse files
committed
fix test failures
add feature flag Signed-off-by: Jing Zhang <[email protected]>
1 parent 1514269 commit d44ef25

File tree

12 files changed

+244
-31
lines changed

12 files changed

+244
-31
lines changed

common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,9 @@ public <T> T createPayload(String action, Map<String, String> parameters) {
362362
}
363363

364364
private boolean neededStreamParameterInPayload(Map<String, String> parameters) {
365+
if (parameters == null) {
366+
return false;
367+
}
365368
boolean isStream = parameters.containsKey("stream");
366369
if (!isStream) {
367370
return false;

common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,4 +342,8 @@ private MLCommonsSettings() {}
342342
/** This setting sets the remote metadata service name */
343343
public static final Setting<String> REMOTE_METADATA_SERVICE_NAME = Setting
344344
.simpleString("plugins.ml_commons." + REMOTE_METADATA_SERVICE_NAME_KEY, Setting.Property.NodeScope, Setting.Property.Final);
345+
346+
/** This setting is to enable/disable streaming feature. */
347+
public static final Setting<Boolean> ML_COMMONS_STREAM_ENABLED = Setting
348+
.boolSetting("plugins.ml_commons.stream_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
345349
}

common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED;
1616
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED;
1717
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED;
18+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_STREAM_ENABLED;
1819

1920
import java.util.ArrayList;
2021
import java.util.List;
@@ -39,6 +40,7 @@ public class MLFeatureEnabledSetting {
3940

4041
// This is to identify if this node is in multi-tenancy or not.
4142
private volatile Boolean isMultiTenancyEnabled;
43+
private volatile Boolean isStreamEnabled;
4244

4345
private final List<SettingsChangeListener> listeners = new ArrayList<>();
4446

@@ -51,6 +53,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
5153
isBatchIngestionEnabled = ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED.get(settings);
5254
isBatchInferenceEnabled = ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED.get(settings);
5355
isMultiTenancyEnabled = ML_COMMONS_MULTI_TENANCY_ENABLED.get(settings);
56+
isStreamEnabled = ML_COMMONS_STREAM_ENABLED.get(settings);
5457

5558
clusterService
5659
.getClusterSettings()
@@ -69,6 +72,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
6972
clusterService
7073
.getClusterSettings()
7174
.addSettingsUpdateConsumer(ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED, it -> isBatchInferenceEnabled = it);
75+
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_STREAM_ENABLED, it -> isStreamEnabled = it);
7276
}
7377

7478
/**
@@ -131,6 +135,14 @@ public boolean isMultiTenancyEnabled() {
131135
return isMultiTenancyEnabled;
132136
}
133137

138+
/**
139+
* Whether the streaming feature is enabled. If disabled, APIs in ml-commons will block stream.
140+
* @return whether the streaming is enabled.
141+
*/
142+
public boolean isStreamEnabled() {
143+
return isStreamEnabled;
144+
}
145+
134146
public void addListener(SettingsChangeListener listener) {
135147
listeners.add(listener);
136148
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,11 @@ public Map<String, String> getConnectorCredential(Connector connector) {
154154

155155
public Predictable deploy(MLModel mlModel, Map<String, Object> params) {
156156
Predictable predictable = MLEngineClassLoader.initInstance(mlModel.getAlgorithm(), null, MLAlgoParams.class);
157-
predictable.initModel(mlModel, params, encryptor, streamManager, threadPool);
157+
if (mlModel.getAlgorithm() == FunctionName.REMOTE) {
158+
predictable.initModel(mlModel, params, encryptor, streamManager, threadPool);
159+
} else {
160+
predictable.initModel(mlModel, params, encryptor);
161+
}
158162
return predictable;
159163
}
160164

ml-algorithms/src/main/java/org/opensearch/ml/engine/arrow/RemoteModelStreamProducer.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
16
package org.opensearch.ml.engine.arrow;
27

38
import java.nio.charset.StandardCharsets;

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutorTest.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ public class AbstractConnectorExecutorTest {
2121
@Before
2222
public void setUp() {
2323
MockitoAnnotations.initMocks(this);
24+
when(mockConnector.getAccessKey()).thenReturn("access_key");
25+
when(mockConnector.getSecretKey()).thenReturn("secret_key");
26+
when(mockConnector.getSessionToken()).thenReturn("session_token");
27+
when(mockConnector.getRegion()).thenReturn("us-east-1-test");
2428
executor = new AwsConnectorExecutor(mockConnector);
2529
connectorClientConfig = new ConnectorClientConfig();
2630
}

0 commit comments

Comments
 (0)