Skip to content

[FEATURE] Improve EncryptorImpl with Asynchronous Handling for Scalability #3919

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Future;
import java.util.function.BiFunction;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
Expand Down Expand Up @@ -79,9 +80,9 @@ public interface Connector extends ToXContentObject, Writeable {

<T> T createPayload(String action, Map<String, String> parameters);

void decrypt(String action, BiFunction<String, String, String> function, String tenantId);
void decrypt(String action, BiFunction<String, String, Future<String>> function, String tenantId);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider an ActionFuture here. It has better handling of OpenSearch-specific thread pools, exceptions, and task APIs.


void encrypt(BiFunction<String, String, String> function, String tenantId);
void encrypt(BiFunction<String, String, Future<String>> function, String tenantId);

Connector cloneConnector();

Expand All @@ -91,7 +92,7 @@ public interface Connector extends ToXContentObject, Writeable {

void writeTo(StreamOutput out) throws IOException;

void update(MLCreateConnectorInput updateContent, BiFunction<String, String, String> function);
void update(MLCreateConnectorInput updateContent, BiFunction<String, String, Future<String>> function);

<T> void parseResponse(T orElse, List<ModelTensor> modelTensors, boolean b) throws IOException;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.function.BiFunction;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
Expand Down Expand Up @@ -300,7 +302,7 @@ public void writeTo(StreamOutput out) throws IOException {
}

@Override
public void update(MLCreateConnectorInput updateContent, BiFunction<String, String, String> function) {
public void update(MLCreateConnectorInput updateContent, BiFunction<String, String, Future<String>> function) {
if (updateContent.getName() != null) {
this.name = updateContent.getName();
}
Expand Down Expand Up @@ -377,17 +379,32 @@ private List<String> findStringParametersWithNullDefaultValue(String input) {
}

@Override
public void decrypt(String action, BiFunction<String, String, String> function, String tenantId) {
Map<String, String> decrypted = new HashMap<>();
public void decrypt(String action, BiFunction<String, String, Future<String>> function, String tenantId) {
Map<String, Future<String>> decryptingTempCredential = new HashMap<>();
decryptedCredential = new HashMap<>();
for (String key : credential.keySet()) {
decrypted.put(key, function.apply(credential.get(key), tenantId));
decryptingTempCredential.put(key, function.apply(credential.get(key), tenantId));
}
this.decryptedCredential = decrypted;
fillCredential(decryptingTempCredential, decryptedCredential);
Comment on lines +384 to +388
Copy link
Member

@dbwiddis dbwiddis Jun 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section of code is very confusing. It took me a good 10 minutes to figure out that we're updating a superclass field. Suggestions to improve readability:

  • keep this. or maybe even super. prefix to make it clear this isn't a local variable
  • it makes llittle sense to create an empty map (in a superclass field) and do nothing with it but pass it as an argument. If you just called fillCredential(tempMap, this.(en|de)cryptedCredential) and created the empty map as the first step in that method it'd be more readable.

Similar comments apply to both en/de and Http/Mcp connectors.

Optional<ConnectorAction> connectorAction = findAction(action);
Map<String, String> headers = connectorAction.map(ConnectorAction::getHeaders).orElse(null);
this.decryptedHeaders = createDecryptedHeaders(headers);
}

private void fillCredential(Map<String, Future<String>> decrypted, Map<String, String> decryptedCredential) {
for (String key : decrypted.keySet()) {
try {
if (decrypted.get(key) != null) {
decryptedCredential.put(key, decrypted.get(key).get());
} else {
decryptedCredential.put(key, null);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when the key is null, and you put value of null, what is the intention here? why don't you skip when the key is null?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 why we are assigning null here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 why not use putIfAbsent()?

}
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add meaningful exception message to indicate what goes wrong here. something like, failed to process fill Credentials.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you switch from Future to ActionFuture some of this exception handling is already included.

}
}
}

@Override
public Connector cloneConnector() {
try (BytesStreamOutput bytesStreamOutput = new BytesStreamOutput()) {
Expand All @@ -400,11 +417,12 @@ public Connector cloneConnector() {
}

@Override
public void encrypt(BiFunction<String, String, String> function, String tenantId) {
public void encrypt(BiFunction<String, String, Future<String>> function, String tenantId) {
Map<String, Future<String>> encryptingCredential = new HashMap<>();
for (String key : credential.keySet()) {
String encrypted = function.apply(credential.get(key), tenantId);
credential.put(key, encrypted);
encryptingCredential.put(key, function.apply(credential.get(key), tenantId));
}
fillCredential(encryptingCredential, credential);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.function.BiFunction;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
Expand Down Expand Up @@ -205,21 +207,37 @@ protected Map<String, String> createDecryptedHeaders(Map<String, String> headers
}

@Override
public void decrypt(String action, BiFunction<String, String, String> function, String tenantId) {
Map<String, String> decrypted = new HashMap<>();
public void decrypt(String action, BiFunction<String, String, Future<String>> function, String tenantId) {
Map<String, Future<String>> decryptingTempCredential = new HashMap<>();
decryptedCredential = new HashMap<>();
for (String key : credential.keySet()) {
decrypted.put(key, function.apply(credential.get(key), tenantId));
decryptingTempCredential.put(key, function.apply(credential.get(key), tenantId));
}
this.decryptedCredential = decrypted;
fillCredential(decryptingTempCredential, decryptedCredential);
this.decryptedHeaders = createDecryptedHeaders(headers);
}

private void fillCredential(Map<String, Future<String>> decrypted, Map<String, String> decryptedCredential) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Input validation?
if (decrypted == null || decryptedCredential == null) {
throw new IllegalArgumentException("Input maps cannot be null");
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better yet since we're creating a new map for decryptedCredentials we don't care what it is if we immediately overwrite it here.

for (String key : decrypted.keySet()) {
try {
if (decrypted.get(key) != null) {
decryptedCredential.put(key, decrypted.get(key).get());
} else {
decryptedCredential.put(key, null);
}
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
}
}

@Override
public void encrypt(BiFunction<String, String, String> function, String tenantId) {
public void encrypt(BiFunction<String, String, Future<String>> function, String tenantId) {
Map<String, Future<String>> encryptingCredential = new HashMap<>();
for (String key : credential.keySet()) {
String encrypted = function.apply(credential.get(key), tenantId);
credential.put(key, encrypted);
encryptingCredential.put(key, function.apply(credential.get(key), tenantId));
}
fillCredential(encryptingCredential, credential);
}

@Override
Expand Down Expand Up @@ -332,7 +350,7 @@ public void writeTo(StreamOutput out) throws IOException {
}

@Override
public void update(MLCreateConnectorInput updateContent, BiFunction<String, String, String> function) {
public void update(MLCreateConnectorInput updateContent, BiFunction<String, String, Future<String>> function) {
if (updateContent.getName() != null) {
this.name = updateContent.getName();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Future;
import java.util.function.BiFunction;

import org.junit.Assert;
Expand All @@ -39,13 +41,13 @@ public class AwsConnectorTest {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();

BiFunction<String, String, String> encryptFunction;
BiFunction<String, String, String> decryptFunction;
BiFunction<String, String, Future<String>> encryptFunction;
BiFunction<String, String, Future<String>> decryptFunction;;

@Before
public void setUp() {
encryptFunction = (s, v) -> "encrypted: " + s.toLowerCase(Locale.ROOT);
decryptFunction = (s, v) -> "decrypted: " + s.toUpperCase(Locale.ROOT);
encryptFunction = (s, v) -> CompletableFuture.supplyAsync(() -> "encrypted: " + s.toLowerCase(Locale.ROOT));
decryptFunction = (s, v) -> CompletableFuture.supplyAsync(() -> "decrypted: " + s.toUpperCase(Locale.ROOT));
Comment on lines +49 to +50
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're not including an Executor argument here which means the async part will execute on the ForkJoinPool.commonPool(). This can't be shut down and in theory could trigger spurious thread leak detection.

You should be using a thread pool here. Example.

Similar comment in HttpConnectorTest and AbstractConnectorTest.

}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Future;
import java.util.function.BiFunction;

import org.junit.Assert;
Expand All @@ -38,8 +40,8 @@ public class HttpConnectorTest {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();

BiFunction<String, String, String> encryptFunction;
BiFunction<String, String, String> decryptFunction;
BiFunction<String, String, Future<String>> encryptFunction;
BiFunction<String, String, Future<String>> decryptFunction;

String TEST_CONNECTOR_JSON_STRING = "{\"name\":\"test_connector_name\",\"version\":\"1\","
+ "\"description\":\"this is a test connector\",\"protocol\":\"http\","
Expand All @@ -55,8 +57,8 @@ public class HttpConnectorTest {

@Before
public void setUp() {
encryptFunction = (s, v) -> "encrypted: " + s.toLowerCase(Locale.ROOT);
decryptFunction = (s, v) -> "decrypted: " + s.toUpperCase(Locale.ROOT);
encryptFunction = (s, v) -> CompletableFuture.supplyAsync(() -> "encrypted: " + s.toLowerCase(Locale.ROOT));
decryptFunction = (s, v) -> CompletableFuture.supplyAsync(() -> "decrypted: " + s.toUpperCase(Locale.ROOT));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Future;
import java.util.function.BiFunction;

import org.junit.Assert;
Expand All @@ -40,16 +42,16 @@ public class McpConnectorTest {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();

BiFunction<String, String, String> encryptFunction;
BiFunction<String, String, String> decryptFunction;
BiFunction<String, String, Future<String>> encryptFunction;
BiFunction<String, String, Future<String>> decryptFunction;

String TEST_CONNECTOR_JSON_STRING =
"{\"name\":\"test_mcp_connector_name\",\"version\":\"1\",\"description\":\"this is a test mcp connector\",\"protocol\":\"mcp_sse\",\"credential\":{\"key\":\"test_key_value\"},\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\",\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000,\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"},\"url\":\"https://test.com\",\"headers\":{\"api_key\":\"${credential.key}\"},\"parameters\":{\"sse_endpoint\":\"/custom/sse\"}}";

@Before
public void setUp() {
encryptFunction = (s, v) -> "encrypted: " + s.toLowerCase(Locale.ROOT);
decryptFunction = (s, v) -> "decrypted: " + s.toUpperCase(Locale.ROOT);
encryptFunction = (s, v) -> CompletableFuture.supplyAsync(() -> "encrypted: " + s.toLowerCase(Locale.ROOT));
decryptFunction = (s, v) -> CompletableFuture.supplyAsync(() -> "decrypted: " + s.toUpperCase(Locale.ROOT));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.nio.file.Path;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.Future;

import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
Expand Down Expand Up @@ -221,7 +222,7 @@ private void validateInput(Input input) {
}
}

public String encrypt(String credential, String tenantId) {
public Future<String> encrypt(String credential, String tenantId) {
return encryptor.encrypt(credential, tenantId);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.ml.engine.encryptor;

import java.util.concurrent.Future;

public interface Encryptor {

/**
Expand All @@ -14,7 +16,7 @@ public interface Encryptor {
* @param tenantId id of the tenant
* @return String encryptedText.
*/
String encrypt(String plainText, String tenantId);
Future<String> encrypt(String plainText, String tenantId);

/**
* Takes encryptedText and returns plain text.
Expand All @@ -23,7 +25,7 @@ public interface Encryptor {
* @param tenantId id of the tenant
* @return String plainText.
*/
String decrypt(String encryptedText, String tenantId);
Future<String> decrypt(String encryptedText, String tenantId);

/**
* Set up the masterKey for dynamic updating
Expand Down
Loading