Skip to content

Commit fe2a0fb

Browse files
authored
Support filter template for Query/Search (#371)
Signed-off-by: yhmo <[email protected]>
1 parent 8418945 commit fe2a0fb

File tree

10 files changed

+350
-2
lines changed

10 files changed

+350
-2
lines changed

DEVELOPMENT.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ Once the `make test` is done, you will see some executable examples under the pa
147147
- `./cmake_build/examples/sdk_default_value`: example to show the usage of default value.
148148
- `./cmake_build/examples/sdk_dml`: example to show the usage of dml interfaces.
149149
- `./cmake_build/examples/sdk_dynamic_field`: example to show the usage of dynamic fields.
150+
- `./cmake_build/examples/sdk_filter_template`: example to show the usage of filter template.
150151
- `./cmake_build/examples/sdk_full_text_match`: example to show the usage of BM25 function.
151152
- `./cmake_build/examples/sdk_general`: a general example to show the basic usage.
152153
- `./cmake_build/examples/sdk_group_by`: a general example to show the usage of grouping search.

examples/example.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Once the `make test` is done, you will see some executable examples under the pa
1010
- `./cmake_build/examples/sdk_default_value`: example to show the usage of default value.
1111
- `./cmake_build/examples/sdk_dml`: example to show the usage of dml interfaces.
1212
- `./cmake_build/examples/sdk_dynamic_field`: example to show the usage of dynamic fields.
13+
- `./cmake_build/examples/sdk_filter_template`: example to show the usage of filter template.
1314
- `./cmake_build/examples/sdk_full_text_match`: example to show the usage of BM25 function.
1415
- `./cmake_build/examples/sdk_general`: a general example to show the basic usage.
1516
- `./cmake_build/examples/sdk_group_by`: a general example to show the usage of grouping search.

examples/src/filter_template.cpp

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
// Licensed to the LF AI & Data foundation under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing, software
12+
// distributed under the License is distributed on an "AS IS" BASIS,
13+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
// See the License for the specific language governing permissions and
15+
// limitations under the License.
16+
17+
#include <iostream>
18+
#include <string>
19+
#include <thread>
20+
21+
#include "ExampleUtils.h"
22+
#include "milvus/MilvusClient.h"
23+
24+
int
25+
main(int argc, char* argv[]) {
26+
printf("Example start...\n");
27+
28+
auto client = milvus::MilvusClient::Create();
29+
30+
milvus::ConnectParam connect_param{"localhost", 19530, "root", "Milvus"};
31+
auto status = client->Connect(connect_param);
32+
util::CheckStatus("connect milvus server", status);
33+
34+
const std::string collection_name = "TEST_CPP_FILTER_TEMPLATE";
35+
const std::string field_id = "pk";
36+
const std::string field_vector = "vector";
37+
const std::string field_text = "text";
38+
const uint32_t dimension = 4;
39+
40+
// collection schema, drop and create collection
41+
milvus::CollectionSchema collection_schema(collection_name);
42+
collection_schema.AddField(milvus::FieldSchema(field_id, milvus::DataType::INT64, "", true, true));
43+
collection_schema.AddField(
44+
milvus::FieldSchema(field_vector, milvus::DataType::FLOAT_VECTOR).WithDimension(dimension));
45+
collection_schema.AddField(milvus::FieldSchema(field_text, milvus::DataType::VARCHAR).WithMaxLength(1024));
46+
47+
status = client->DropCollection(collection_name);
48+
status = client->CreateCollection(collection_schema);
49+
util::CheckStatus("create collection: " + collection_name, status);
50+
51+
// create index
52+
milvus::IndexDesc index_vector(field_vector, "", milvus::IndexType::FLAT, milvus::MetricType::L2);
53+
status = client->CreateIndex(collection_name, index_vector);
54+
util::CheckStatus("create index on vector field", status);
55+
56+
// tell server prepare to load collection
57+
status = client->LoadCollection(collection_name);
58+
util::CheckStatus("load collection: " + collection_name, status);
59+
60+
// insert some rows
61+
milvus::EntityRows rows;
62+
for (auto i = 0; i < 10000; ++i) {
63+
milvus::EntityRow row; // id is auto-generated
64+
row[field_text] = "text_" + std::to_string(i);
65+
row[field_vector] = util::GenerateFloatVector(dimension);
66+
rows.emplace_back(std::move(row));
67+
}
68+
69+
milvus::DmlResults dml_results;
70+
status = client->Insert(collection_name, "", rows, dml_results);
71+
util::CheckStatus("insert", status);
72+
std::cout << dml_results.InsertCount() << " rows inserted by row-based." << std::endl;
73+
auto ids = dml_results.IdArray().IntIDArray();
74+
75+
{
76+
// query with filter template
77+
std::string filter = field_id + " in {my_ids}"; // "my_ids" is an alias will be used in filter template
78+
std::cout << "Query with filter expression: " << filter << std::endl;
79+
80+
auto begin = ids.begin() + 500;
81+
auto end = begin + 100;
82+
std::vector<int64_t> filter_ids(begin, end);
83+
nlohmann::json filter_template = filter_ids;
84+
85+
milvus::QueryArguments q_arguments{};
86+
q_arguments.SetCollectionName(collection_name);
87+
q_arguments.AddOutputField(field_text);
88+
q_arguments.SetFilter(filter);
89+
q_arguments.AddFilterTemplate("my_ids", filter_template); // filter template
90+
// set to strong level so that the query is executed after the inserted data is consumed by server
91+
q_arguments.SetConsistencyLevel(milvus::ConsistencyLevel::STRONG);
92+
93+
milvus::QueryResults query_results{};
94+
status = client->Query(q_arguments, query_results);
95+
util::CheckStatus("query", status);
96+
97+
milvus::EntityRows output_rows;
98+
status = query_results.OutputRows(output_rows);
99+
util::CheckStatus("get output rows", status);
100+
std::cout << "Query with filter template:" << std::endl;
101+
for (const auto& row : output_rows) {
102+
std::cout << "\t" << row << std::endl;
103+
}
104+
}
105+
106+
{
107+
// search with filter template
108+
std::string filter = field_text + " in {my_texts}"; // "my_texts" is an alias will be used in filter template
109+
std::vector<std::string> texts;
110+
for (auto i = 300; i < 500; i++) {
111+
texts.push_back("text_" + std::to_string(i));
112+
}
113+
nlohmann::json filter_template = texts;
114+
115+
milvus::SearchArguments s_arguments{};
116+
s_arguments.SetCollectionName(collection_name);
117+
s_arguments.SetLimit(static_cast<int64_t>(texts.size()));
118+
s_arguments.SetFilter(filter);
119+
s_arguments.AddFilterTemplate("my_texts", filter_template);
120+
s_arguments.AddOutputField(field_text);
121+
s_arguments.AddFloatVector(field_vector, util::GenerateFloatVector(dimension));
122+
s_arguments.AddFloatVector(field_vector, util::GenerateFloatVector(dimension));
123+
s_arguments.SetConsistencyLevel(milvus::ConsistencyLevel::BOUNDED);
124+
125+
milvus::SearchResults search_results{};
126+
status = client->Search(s_arguments, search_results);
127+
util::CheckStatus("search", status);
128+
129+
std::cout << "Search with filter template:" << std::endl;
130+
for (auto& result : search_results.Results()) {
131+
std::cout << "Result of one target vector:" << std::endl;
132+
milvus::EntityRows output_rows;
133+
status = result.OutputRows(output_rows);
134+
util::CheckStatus("get output rows", status);
135+
for (const auto& row : output_rows) {
136+
std::cout << "\t" << row << std::endl;
137+
}
138+
}
139+
}
140+
141+
client->Disconnect();
142+
return 0;
143+
}

src/impl/types/QueryArguments.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "milvus/types/QueryArguments.h"
1818

1919
#include "../utils/Constants.h"
20+
#include "../utils/TypeUtils.h"
2021

2122
namespace milvus {
2223

@@ -90,6 +91,29 @@ QueryArguments::SetFilter(std::string filter) {
9091
return Status::OK();
9192
}
9293

94+
Status
95+
QueryArguments::AddFilterTemplate(std::string key, const nlohmann::json& filter_template) {
96+
if (filter_template.is_array()) {
97+
for (const auto& ele : filter_template) {
98+
if (!IsValidTemplate(ele)) {
99+
return {milvus::StatusCode::INVALID_AGUMENT, "Filter template element must be boolean/number/string"};
100+
}
101+
}
102+
} else {
103+
if (!IsValidTemplate(filter_template)) {
104+
return {milvus::StatusCode::INVALID_AGUMENT, "Filter template must be boolean/number/string/array"};
105+
}
106+
}
107+
108+
filter_templates_.insert(std::make_pair(key, filter_template));
109+
return Status::OK();
110+
}
111+
112+
const std::unordered_map<std::string, nlohmann::json>&
113+
QueryArguments::FilterTemplates() const {
114+
return filter_templates_;
115+
}
116+
93117
int64_t
94118
QueryArguments::Limit() const {
95119
// for history reason, query() requires "limit", search() requires "topk"

src/impl/types/SubSearchRequest.cpp

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
#include "milvus/types/SubSearchRequest.h"
1818

19-
#include <nlohmann/json.hpp>
2019
#include <utility>
2120

2221
#include "../utils/Constants.h"
@@ -25,7 +24,6 @@
2524
#include "milvus/utils/FP16.h"
2625

2726
namespace milvus {
28-
2927
const std::string&
3028
SearchRequestBase::Filter() const {
3129
return filter_expression_;
@@ -37,6 +35,29 @@ SearchRequestBase::SetFilter(std::string filter) {
3735
return Status::OK();
3836
}
3937

38+
Status
39+
SearchRequestBase::AddFilterTemplate(std::string key, const nlohmann::json& filter_template) {
40+
if (filter_template.is_array()) {
41+
for (const auto& ele : filter_template) {
42+
if (!IsValidTemplate(ele)) {
43+
return {milvus::StatusCode::INVALID_AGUMENT, "Filter template element must be boolean/number/string"};
44+
}
45+
}
46+
} else {
47+
if (!IsValidTemplate(filter_template)) {
48+
return {milvus::StatusCode::INVALID_AGUMENT, "Filter template must be boolean/number/string/array"};
49+
}
50+
}
51+
52+
filter_templates_.insert(std::make_pair(key, filter_template));
53+
return Status::OK();
54+
}
55+
56+
const std::unordered_map<std::string, nlohmann::json>&
57+
SearchRequestBase::FilterTemplates() const {
58+
return filter_templates_;
59+
}
60+
4061
FieldDataPtr
4162
SearchRequestBase::TargetVectors() const {
4263
return target_vectors_;

src/impl/utils/DqlUtils.cpp

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -783,6 +783,97 @@ DeduceGuaranteeTimestamp(const ConsistencyLevel& level, const std::string& db_na
783783
}
784784
}
785785

786+
Status
787+
DeduceTemplateArray(const nlohmann::json& array, proto::schema::TemplateArrayValue& rpc_array) {
788+
if (array.empty()) {
789+
return Status::OK();
790+
}
791+
const auto& first_ele = array.at(0);
792+
if (first_ele.is_boolean()) {
793+
for (const auto& ele : array) {
794+
if (!ele.is_boolean()) {
795+
return {StatusCode::INVALID_AGUMENT,
796+
"Filter expression template is a list, the first value is Boolean, but some elements are not "
797+
"Boolean"};
798+
}
799+
rpc_array.mutable_bool_data()->add_data(ele.get<bool>());
800+
}
801+
} else if (first_ele.is_number_integer()) {
802+
for (const auto& ele : array) {
803+
if (!ele.is_number_integer()) {
804+
return {StatusCode::INVALID_AGUMENT,
805+
"Filter expression template is a list, the first value is Integer, but some elements are not "
806+
"Integer"};
807+
}
808+
rpc_array.mutable_long_data()->add_data(ele.get<int64_t>());
809+
}
810+
} else if (first_ele.is_number_float()) {
811+
for (const auto& ele : array) {
812+
if (!ele.is_number_float()) {
813+
return {StatusCode::INVALID_AGUMENT,
814+
"Filter expression template is a list, the first value is Double, but some elements are not "
815+
"Double"};
816+
}
817+
rpc_array.mutable_double_data()->add_data(ele.get<double>());
818+
}
819+
} else if (first_ele.is_string()) {
820+
for (const auto& ele : array) {
821+
if (!ele.is_string()) {
822+
return {StatusCode::INVALID_AGUMENT,
823+
"Filter expression template is a list, the first value is String, but some elements are not "
824+
"String"};
825+
}
826+
rpc_array.mutable_string_data()->add_data(ele.get<std::string>());
827+
}
828+
} else if (first_ele.is_array()) {
829+
auto rpc_array_array = rpc_array.mutable_array_data()->add_data();
830+
for (const auto& ele : array) {
831+
if (!ele.is_array()) {
832+
return {
833+
StatusCode::INVALID_AGUMENT,
834+
"Filter expression template is a list, the first value is List, but some elements are not List"};
835+
}
836+
837+
auto sub_array = rpc_array_array->mutable_array_data()->add_data();
838+
auto status = DeduceTemplateArray(ele, *sub_array);
839+
if (!status.IsOk()) {
840+
return status;
841+
}
842+
}
843+
}
844+
845+
return Status::OK();
846+
}
847+
848+
Status
849+
ConvertFilterTemplates(const std::unordered_map<std::string, nlohmann::json>& templates,
850+
::google::protobuf::Map<std::string, proto::schema::TemplateValue>* rpc_templates) {
851+
for (const auto& pair : templates) {
852+
proto::schema::TemplateValue value;
853+
const auto& temp = pair.second;
854+
if (temp.is_array()) {
855+
auto array = value.mutable_array_val();
856+
auto status = DeduceTemplateArray(temp, *array);
857+
if (!status.IsOk()) {
858+
return status;
859+
}
860+
} else if (temp.is_boolean()) {
861+
value.set_bool_val(temp.get<bool>());
862+
} else if (temp.is_number_integer()) {
863+
value.set_int64_val(temp.get<int64_t>());
864+
} else if (temp.is_number_float()) {
865+
value.set_float_val(temp.get<double>());
866+
} else if (temp.is_string()) {
867+
value.set_string_val(temp.get<std::string>());
868+
} else {
869+
return {StatusCode::INVALID_AGUMENT, "Unsupported template type"};
870+
}
871+
rpc_templates->insert(std::make_pair(pair.first, value));
872+
}
873+
874+
return Status::OK();
875+
}
876+
786877
// current_db is the actual target db that the request is performed, for setting the GuaranteeTimestamp
787878
// to compatible with old versions.
788879
// for examples:
@@ -803,6 +894,15 @@ ConvertQueryRequest(const QueryArguments& arguments, const std::string& current_
803894
}
804895

805896
rpc_request.set_expr(arguments.Filter());
897+
if (!arguments.Filter().empty()) {
898+
auto rpc_templates = rpc_request.mutable_expr_template_values();
899+
const auto& templates = arguments.FilterTemplates();
900+
auto status = ConvertFilterTemplates(templates, rpc_templates);
901+
if (!status.IsOk()) {
902+
return status;
903+
}
904+
}
905+
806906
for (const auto& field : arguments.OutputFields()) {
807907
rpc_request.add_output_fields(field);
808908
}
@@ -868,6 +968,13 @@ ConvertSearchRequest(const SearchArguments& arguments, const std::string& curren
868968
rpc_request.set_dsl_type(proto::common::DslType::BoolExprV1);
869969
if (!arguments.Filter().empty()) {
870970
rpc_request.set_dsl(arguments.Filter());
971+
972+
auto rpc_templates = rpc_request.mutable_expr_template_values();
973+
const auto& templates = arguments.FilterTemplates();
974+
auto status = ConvertFilterTemplates(templates, rpc_templates);
975+
if (!status.IsOk()) {
976+
return status;
977+
}
871978
}
872979
for (const auto& partition_name : arguments.PartitionNames()) {
873980
rpc_request.add_partition_names(partition_name);

src/impl/utils/TypeUtils.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,11 @@ doubleToString(double val) {
639639
return stream.str();
640640
}
641641

642+
bool
643+
IsValidTemplate(const nlohmann::json& filter_template) {
644+
return filter_template.is_boolean() || filter_template.is_number() || filter_template.is_string();
645+
}
646+
642647
} // namespace milvus
643648

644649
namespace std {

src/impl/utils/TypeUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,4 +112,7 @@ ConvertResourceGroupConfig(const proto::rg::ResourceGroupConfig& rpc_config, Res
112112
std::string
113113
doubleToString(double val);
114114

115+
bool
116+
IsValidTemplate(const nlohmann::json& filter_template);
117+
115118
} // namespace milvus

0 commit comments

Comments
 (0)