Skip to content

Xinyual/add tokenizer and sparse encoding #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 69 commits into
base: addTokenizerAndSparseEncoding
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
ba553c4
add tokenizer and sparse encoding
xinyual Aug 25, 2023
578274f
add tokenizer and sparse encoding
xinyual Aug 25, 2023
c2ffd7a
add tokenizer and sparse encoding
xinyual Aug 25, 2023
ed8832d
add tokenizer and sparse encoding
xinyual Aug 25, 2023
258f2cb
add tokenizer and sparse encoding
xinyual Aug 25, 2023
e9da84e
remove special token
xinyual Aug 28, 2023
ce18800
add filter
xinyual Aug 28, 2023
761b2d3
try empty model
xinyual Aug 28, 2023
6e6cd83
remove warm up
xinyual Aug 28, 2023
6618845
try empty model
xinyual Aug 28, 2023
1b919d2
add block
xinyual Aug 28, 2023
4f70e4a
add log
xinyual Aug 28, 2023
c3e4186
add log
xinyual Aug 28, 2023
5ccc350
add log
xinyual Aug 28, 2023
73f0329
remove log
xinyual Aug 28, 2023
3b0b7d3
remove pt file detect
xinyual Aug 28, 2023
7b98bf3
add log
xinyual Aug 28, 2023
13e6d77
add functionName pipeline
xinyual Aug 28, 2023
9daeb96
remove verify log
xinyual Aug 28, 2023
e000f22
skip special token in sparse encoding
xinyual Aug 28, 2023
c435a89
skip omit tokenize config
xinyual Aug 29, 2023
8a7d19e
skip omit tokenize config-change warm up logic
xinyual Aug 29, 2023
39449fc
reArch
xinyual Aug 29, 2023
a115e19
deduplicate
xinyual Aug 29, 2023
846f295
omit ml config in sparse encoding
xinyual Aug 29, 2023
564f1e6
add null config in warm up
xinyual Aug 29, 2023
a1b9351
fix original test
xinyual Aug 29, 2023
75de44c
add tokenize ut half
xinyual Aug 29, 2023
a381d64
fix sparse encoding bug
xinyual Aug 30, 2023
ab745e2
add UT for sparse encoding and tokenize
xinyual Aug 30, 2023
97716f2
remove useless framwork type
xinyual Aug 30, 2023
7437fc5
common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java
xinyual Aug 31, 2023
fc89d11
change key for tokenize
xinyual Aug 31, 2023
0eb4221
reArch DLModel
xinyual Aug 31, 2023
a17cef0
reArch DLModel again
xinyual Aug 31, 2023
04d9b13
response format
xinyual Aug 31, 2023
a118a63
tokenize only one output
xinyual Aug 31, 2023
9ff9be9
clean sparse output
xinyual Aug 31, 2023
b95bbb9
clean sparse output
xinyual Aug 31, 2023
2181a64
change UT number
xinyual Aug 31, 2023
976bb79
remove useless predict code
xinyual Sep 4, 2023
1614dab
remove useless part
xinyual Sep 4, 2023
23231bc
change tokenize way
xinyual Sep 5, 2023
0b1e206
reArch add textEmbedding model
xinyual Sep 5, 2023
4ded8d6
add tokenize logic
xinyual Sep 5, 2023
c057977
add abstract
xinyual Sep 5, 2023
3abb2df
clear code
xinyual Sep 5, 2023
540df26
fix it class
xinyual Sep 6, 2023
5dc8197
fix it class
xinyual Sep 6, 2023
5dbc4b7
add IT file
xinyual Sep 6, 2023
9ae26f2
reformulate
xinyual Sep 7, 2023
fe20404
reformulate remote inference
xinyual Sep 7, 2023
e8ee101
reformulate remote inference
xinyual Sep 7, 2023
8a67625
reformulate remote inference json and array
xinyual Sep 7, 2023
5ae0bcd
verify
xinyual Sep 8, 2023
2ba663a
undo string utils
xinyual Sep 8, 2023
6c1259a
skip dummy model
xinyual Sep 11, 2023
afdf235
skip dummy model
xinyual Sep 11, 2023
5a928fe
skip dummy model
xinyual Sep 11, 2023
6018da8
skip dummy model
xinyual Sep 11, 2023
eb5f944
skip dummy model
xinyual Sep 11, 2023
9bf24b4
skip dummy model
xinyual Sep 11, 2023
cd3e513
add inner load Model
xinyual Sep 11, 2023
597ab68
rename variable
xinyual Sep 11, 2023
e1c8b5e
add default for idf
xinyual Sep 12, 2023
0d14e8b
add ut for sparse encoding and tokenizer
xinyual Sep 12, 2023
4c82bad
add close model
xinyual Sep 12, 2023
ced1323
change mock class
xinyual Sep 12, 2023
c651a4e
remove buffer for sparse encoding output
xinyual Sep 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ dependencies {
testImplementation group: 'org.mockito', name: 'mockito-core', version: '4.4.0'

compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
compileOnly group: 'org.json', name: 'json', version: '20230227'
implementation group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
implementation group: 'org.json', name: 'json', version: '20230227'
}

lombok {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ public enum FunctionName {
RCF_SUMMARIZE,
LOGISTIC_REGRESSION,
TEXT_EMBEDDING,
SPARSE_ENCODING,
TOKENIZE,
METRICS_CORRELATION,
REMOTE;

Expand All @@ -33,7 +35,7 @@ public static FunctionName from(String value) {
* @return true for deep learning model.
*/
public static boolean isDLModel(FunctionName functionName) {
if (functionName == TEXT_EMBEDDING) {
if (functionName == TEXT_EMBEDDING || functionName == SPARSE_ENCODING || functionName == TOKENIZE) {
return true;
}
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ public <T> void parseResponse(T response, List<ModelTensor> modelTensors, boolea
Map<String, Object> data = StringUtils.fromJson((String) response, "response");
modelTensors.add(ModelTensor.builder().name("response").dataAsMap(data).build());
} else {
Map<String, Object> map = new HashMap<>();
map.put("response", response);
modelTensors.add(ModelTensor.builder().name("response").dataAsMap(map).build());
Map<String, Object> map = new HashMap<>();
map.put("response", response);
modelTensors.add(ModelTensor.builder().name("response").dataAsMap(map).build());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ public static MLInput parse(XContentParser parser, String inputAlgoName) throws
}
}
MLInputDataset inputDataSet = null;
if (algorithm == FunctionName.TEXT_EMBEDDING) {
if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.SPARSE_ENCODING || algorithm == FunctionName.TOKENIZE) {
ModelResultFilter filter = new ModelResultFilter(returnBytes, returnNumber, targetResponse, targetResponsePositions);
inputDataSet = new TextDocsInputDataSet(textDocs, filter);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
* ML input class which supports a list fo text docs.
* This class can be used for TEXT_EMBEDDING model.
*/
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.TEXT_EMBEDDING})
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.TEXT_EMBEDDING, FunctionName.SPARSE_ENCODING, FunctionName.TOKENIZE})
public class TextDocsMLInput extends MLInput {
public static final String TEXT_DOCS_FIELD = "text_docs";
public static final String RESULT_FILTER_FIELD = "result_filter";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public MLRegisterModelInput(FunctionName functionName,
if (modelFormat == null) {
throw new IllegalArgumentException("model format is null");
}
if (url != null && modelConfig == null) {
if (url != null && modelConfig == null && functionName != FunctionName.TOKENIZE && functionName != FunctionName.SPARSE_ENCODING) {
throw new IllegalArgumentException("model config is null");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m
if (modelContentHashValue == null) {
throw new IllegalArgumentException("model content hash value is null");
}
if (modelConfig == null) {
if (modelConfig == null && functionName != FunctionName.TOKENIZE && functionName != FunctionName.SPARSE_ENCODING) {
throw new IllegalArgumentException("model config is null");
}
if (totalChunks == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,21 +149,27 @@ public void testClassLoader_ExecuteOutputMCorr() throws IOException {
assertArrayEquals(new long[]{1, 2}, metrics);
}

@Test
public void testClassLoader_MLInput() throws IOException {
assertTrue(MLCommonsClassLoader.canInitMLInput(FunctionName.TEXT_EMBEDDING));
private void testClassLoader_MLInput_DlModel(FunctionName functionName) throws IOException {
assertTrue(MLCommonsClassLoader.canInitMLInput(functionName));

String jsonStr = "{\"text_docs\":[\"doc1\",\"doc2\"],\"result_filter\":{\"return_bytes\":true,\"return_number\":true,\"target_response\":[\"field1\"], \"target_response_positions\": [2]}}";
XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
parser.nextToken();

TextDocsMLInput mlInput = MLCommonsClassLoader.initMLInput(FunctionName.TEXT_EMBEDDING, new Object[]{parser, FunctionName.TEXT_EMBEDDING}, XContentParser.class, FunctionName.class);
TextDocsMLInput mlInput = MLCommonsClassLoader.initMLInput(functionName, new Object[]{parser, functionName}, XContentParser.class, FunctionName.class);
assertNotNull(mlInput);
assertEquals(FunctionName.TEXT_EMBEDDING, mlInput.getFunctionName());
assertEquals(functionName, mlInput.getFunctionName());
assertEquals(2, ((TextDocsInputDataSet)mlInput.getInputDataset()).getDocs().size());
}

@Test
public void testClassLoader_MLInput() throws IOException {
testClassLoader_MLInput_DlModel(FunctionName.TEXT_EMBEDDING);
testClassLoader_MLInput_DlModel(FunctionName.TOKENIZE);
testClassLoader_MLInput_DlModel(FunctionName.SPARSE_ENCODING);
}

public enum TestEnum {
TEST
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,19 @@ public void parse_LinearRegression() throws IOException {
});
}

@Test
public void parse_TextEmbedding() throws IOException {
private void parse_NLPModel(FunctionName functionName) throws IOException {
String sentence = "test sentence";
String column = "column1";
Integer position = 1;
ModelResultFilter resultFilter = ModelResultFilter.builder()
.targetResponse(Arrays.asList(column))
.targetResponsePositions(Arrays.asList(position))
.build();
TextDocsInputDataSet inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList(sentence))
.resultFilter(resultFilter).build();
String expectedInputStr = "{\"algorithm\":\"TEXT_EMBEDDING\",\"text_docs\":[\"test sentence\"],\"return_bytes\":false,\"return_number\":false,\"target_response\":[\"column1\"],\"target_response_positions\":[1]}";
testParse(FunctionName.TEXT_EMBEDDING, inputDataset, expectedInputStr, parsedInput -> {

TextDocsInputDataSet inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList(sentence)).resultFilter(resultFilter).build();
String expectedInputStr = "{\"algorithm\":\"functionName\",\"text_docs\":[\"test sentence\"],\"return_bytes\":false,\"return_number\":false,\"target_response\":[\"column1\"],\"target_response_positions\":[1]}";
expectedInputStr = expectedInputStr.replace("functionName", functionName.toString());
testParse(functionName, inputDataset, expectedInputStr, parsedInput -> {
assertNotNull(parsedInput.getInputDataset());
TextDocsInputDataSet parsedInputDataSet = (TextDocsInputDataSet) parsedInput.getInputDataset();
assertEquals(1, parsedInputDataSet.getDocs().size());
Expand All @@ -134,19 +134,33 @@ public void parse_TextEmbedding() throws IOException {
}

@Test
public void parse_TextEmbedding_NullResultFilter() throws IOException {
public void parse_NLP_Related() throws IOException {
parse_NLPModel(FunctionName.TEXT_EMBEDDING);
parse_NLPModel(FunctionName.TOKENIZE);
parse_NLPModel(FunctionName.SPARSE_ENCODING);
}

private void parse_NLPModel_NullResultFilter(FunctionName functionName) throws IOException {
String sentence = "test sentence";
TextDocsInputDataSet inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList(sentence)).build();
String expectedInputStr = "{\"algorithm\":\"TEXT_EMBEDDING\",\"text_docs\":[\"test sentence\"]}";
testParse(FunctionName.TEXT_EMBEDDING, inputDataset, expectedInputStr, parsedInput -> {
String expectedInputStr = "{\"algorithm\":\"functionName\",\"text_docs\":[\"test sentence\"]}";
expectedInputStr = expectedInputStr.replace("functionName", functionName.toString());
testParse(functionName, inputDataset, expectedInputStr, parsedInput -> {
assertNotNull(parsedInput.getInputDataset());
assertEquals(1, ((TextDocsInputDataSet) parsedInput.getInputDataset()).getDocs().size());
assertEquals(sentence, ((TextDocsInputDataSet) parsedInput.getInputDataset()).getDocs().get(0));
});
}

private void testParse(FunctionName algorithm, MLInputDataset inputDataset, String expectedInputStr,
Consumer<MLInput> verify) throws IOException {

@Test
public void parse_NLPRelated_NullResultFilter() throws IOException {
parse_NLPModel_NullResultFilter(FunctionName.TEXT_EMBEDDING);
parse_NLPModel_NullResultFilter(FunctionName.TOKENIZE);
parse_NLPModel_NullResultFilter(FunctionName.SPARSE_ENCODING);
}

private void testParse(FunctionName algorithm, MLInputDataset inputDataset, String expectedInputStr, Consumer<MLInput> verify) throws IOException {
MLInput input = MLInput.builder().inputDataset(inputDataset).algorithm(algorithm).build();
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
input.toXContent(builder, ToXContent.EMPTY_PARAMS);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ GET /_plugins/_ml/profile/models/zwla5YUB1qmVrJFlwzXJ
"models": {
"zwla5YUB1qmVrJFlwzXJ": { # model id
"model_state": "LOADED",
"predictor": "org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingModel@1a0b0793",
"predictor": "org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingDenseModel@1a0b0793",
"target_worker_nodes": [ # plan to deploy model to these nodes
"0TLL4hHxRv6_G3n6y1l0BQ"
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ public List downloadPrebuiltModelMetaList(String taskId, MLRegisterModelInput re
* @param modelContentHash model content hash value
* @param listener action listener
*/
public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String modelName, String version, String url, String modelContentHash, ActionListener<Map<String, Object>> listener) {
public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String modelName, String version, String url, String modelContentHash, FunctionName functionName, ActionListener<Map<String, Object>> listener) {
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
Path registerModelPath = mlEngine.getRegisterModelPath(taskId, modelName, version);
Expand All @@ -202,7 +202,7 @@ public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String mo
File modelZipFile = new File(modelPath);
log.debug("download model to file {}", modelZipFile.getAbsolutePath());
DownloadUtils.download(url, modelPath, new ProgressBar());
verifyModelZipFile(modelFormat, modelPath, modelName);
verifyModelZipFile(modelFormat, modelPath, modelName, functionName);
String hash = calculateFileHash(modelZipFile);
if (hash.equals(modelContentHash)) {
List<String> chunkFiles = splitFileIntoChunks(modelZipFile, modelPartsPath, CHUNK_SIZE);
Expand All @@ -224,7 +224,7 @@ public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String mo
}
}

public void verifyModelZipFile(MLModelFormat modelFormat, String modelZipFilePath, String modelName) throws IOException {
public void verifyModelZipFile(MLModelFormat modelFormat, String modelZipFilePath, String modelName, FunctionName functionName) throws IOException {
boolean hasPtFile = false;
boolean hasOnnxFile = false;
boolean hasTokenizerFile = false;
Expand All @@ -239,7 +239,7 @@ public void verifyModelZipFile(MLModelFormat modelFormat, String modelZipFilePat
}
}
}
if (!hasPtFile && !hasOnnxFile) {
if (!hasPtFile && !hasOnnxFile && functionName != FunctionName.TOKENIZE) {
throw new IllegalArgumentException("Can't find model file");
}
if (!hasTokenizerFile) {
Expand Down
Loading