Skip to content

Commit bfeb504

Browse files
committed
fix test failures
add feature flag Signed-off-by: Jing Zhang <[email protected]>
1 parent a002860 commit bfeb504

File tree

13 files changed

+195
-32
lines changed

13 files changed

+195
-32
lines changed

common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,9 @@ public <T> T createPayload(String action, Map<String, String> parameters) {
362362
}
363363

364364
private boolean neededStreamParameterInPayload(Map<String, String> parameters) {
365+
if (parameters == null) {
366+
return false;
367+
}
365368
boolean isStream = parameters.containsKey("stream");
366369
if (!isStream) {
367370
return false;

common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,4 +350,8 @@ private MLCommonsSettings() {}
350350
// Feature flag for enabling telemetry static metric collection job -- MLStatsJobProcessor
351351
public static final Setting<Boolean> ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED = Setting
352352
.boolSetting("plugins.ml_commons.metrics_static_collection_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
353+
354+
/** This setting is to enable/disable streaming feature. */
355+
public static final Setting<Boolean> ML_COMMONS_STREAM_ENABLED = Setting
356+
.boolSetting("plugins.ml_commons.stream_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
353357
}

common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED;
2020
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED;
2121
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED;
22+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_STREAM_ENABLED;
2223

2324
import java.util.ArrayList;
2425
import java.util.List;
@@ -43,6 +44,7 @@ public class MLFeatureEnabledSetting {
4344

4445
// This is to identify if this node is in multi-tenancy or not.
4546
private volatile Boolean isMultiTenancyEnabled;
47+
private volatile Boolean isStreamEnabled;
4648

4749
private volatile Boolean isMcpServerEnabled;
4850

@@ -66,6 +68,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
6668
isRagSearchPipelineEnabled = ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED.get(settings);
6769
isMetricCollectionEnabled = ML_COMMONS_METRIC_COLLECTION_ENABLED.get(settings);
6870
isStaticMetricCollectionEnabled = ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED.get(settings);
71+
isStreamEnabled = ML_COMMONS_STREAM_ENABLED.get(settings);
6972

7073
clusterService
7174
.getClusterSettings()
@@ -94,6 +97,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
9497
clusterService
9598
.getClusterSettings()
9699
.addSettingsUpdateConsumer(ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED, it -> isStaticMetricCollectionEnabled = it);
100+
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_STREAM_ENABLED, it -> isStreamEnabled = it);
97101
}
98102

99103
/**
@@ -164,6 +168,13 @@ public boolean isMcpServerEnabled() {
164168
return isMcpServerEnabled;
165169
}
166170

171+
/** Whether the streaming feature is enabled. If disabled, APIs in ml-commons will block stream.
172+
* @return whether the streaming is enabled.
173+
*/
174+
public boolean isStreamEnabled() {
175+
return isStreamEnabled;
176+
}
177+
167178
public void addListener(SettingsChangeListener listener) {
168179
listeners.add(listener);
169180
}

common/src/test/java/org/opensearch/ml/common/settings/MLFeatureEnabledSettingTests.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ public void setUp() {
4444
MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED,
4545
MLCommonsSettings.ML_COMMONS_METRIC_COLLECTION_ENABLED,
4646
MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED
47+
MLCommonsSettings.ML_COMMONS_STREAM_ENABLED
4748
)
4849
);
4950
when(mockClusterService.getClusterSettings()).thenReturn(mockClusterSettings);
@@ -65,6 +66,7 @@ public void testDefaults_allFeaturesEnabled() {
6566
.put("plugins.ml_commons.rag_pipeline_feature_enabled", true)
6667
.put("plugins.ml_commons.metrics_collection_enabled", true)
6768
.put("plugins.ml_commons.metrics_static_collection_enabled", true)
69+
.put("plugins.ml_commons.stream_enabled", true)
6870
.build();
6971

7072
MLFeatureEnabledSetting setting = new MLFeatureEnabledSetting(mockClusterService, settings);
@@ -81,6 +83,7 @@ public void testDefaults_allFeaturesEnabled() {
8183
assertTrue(setting.isRagSearchPipelineEnabled());
8284
assertTrue(setting.isMetricCollectionEnabled());
8385
assertTrue(setting.isStaticMetricCollectionEnabled());
86+
assertTrue(setting.isStreamEnabled());
8487
}
8588

8689
@Test
@@ -99,6 +102,7 @@ public void testDefaults_someFeaturesDisabled() {
99102
.put("plugins.ml_commons.rag_pipeline_feature_enabled", false)
100103
.put("plugins.ml_commons.metrics_collection_enabled", false)
101104
.put("plugins.ml_commons.metrics_static_collection_enabled", false)
105+
.put("plugins.ml_commons.stream_enabled", false)
102106
.build();
103107

104108
MLFeatureEnabledSetting setting = new MLFeatureEnabledSetting(mockClusterService, settings);
@@ -115,6 +119,7 @@ public void testDefaults_someFeaturesDisabled() {
115119
assertFalse(setting.isRagSearchPipelineEnabled());
116120
assertFalse(setting.isMetricCollectionEnabled());
117121
assertFalse(setting.isStaticMetricCollectionEnabled());
122+
assertFalse(setting.isStreamEnabled());
118123
}
119124

120125
@Test

ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,11 @@ public Map<String, String> getConnectorCredential(Connector connector) {
154154

155155
public Predictable deploy(MLModel mlModel, Map<String, Object> params) {
156156
Predictable predictable = MLEngineClassLoader.initInstance(mlModel.getAlgorithm(), null, MLAlgoParams.class);
157-
predictable.initModel(mlModel, params, encryptor, streamManager, threadPool);
157+
if (mlModel.getAlgorithm() == FunctionName.REMOTE) {
158+
predictable.initModel(mlModel, params, encryptor, streamManager, threadPool);
159+
} else {
160+
predictable.initModel(mlModel, params, encryptor);
161+
}
158162
return predictable;
159163
}
160164

ml-algorithms/src/main/java/org/opensearch/ml/engine/arrow/RemoteModelStreamProducer.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
16
package org.opensearch.ml.engine.arrow;
27

38
import java.nio.charset.StandardCharsets;

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutorTest.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ public class AbstractConnectorExecutorTest {
2121
@Before
2222
public void setUp() {
2323
MockitoAnnotations.initMocks(this);
24+
when(mockConnector.getAccessKey()).thenReturn("access_key");
25+
when(mockConnector.getSecretKey()).thenReturn("secret_key");
26+
when(mockConnector.getSessionToken()).thenReturn("session_token");
27+
when(mockConnector.getRegion()).thenReturn("us-east-1-test");
2428
executor = new AwsConnectorExecutor(mockConnector);
2529
connectorClientConfig = new ConnectorClientConfig();
2630
}

0 commit comments

Comments
 (0)