Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DEVELOPMENT.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ Once the `make test` is done, you will see some executable examples under the pa
- `./cmake_build/examples/sdk_default_value`: example to show the usage of default value.
- `./cmake_build/examples/sdk_dml`: example to show the usage of dml interfaces.
- `./cmake_build/examples/sdk_dynamic_field`: example to show the usage of dynamic fields.
- `./cmake_build/examples/sdk_filter_template`: example to show the usage of filter template.
- `./cmake_build/examples/sdk_full_text_match`: example to show the usage of BM25 function.
- `./cmake_build/examples/sdk_general`: a general example to show the basic usage.
- `./cmake_build/examples/sdk_group_by`: a general example to show the usage of grouping search.
Expand Down
1 change: 1 addition & 0 deletions examples/example.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Once the `make test` is done, you will see some executable examples under the pa
- `./cmake_build/examples/sdk_default_value`: example to show the usage of default value.
- `./cmake_build/examples/sdk_dml`: example to show the usage of dml interfaces.
- `./cmake_build/examples/sdk_dynamic_field`: example to show the usage of dynamic fields.
- `./cmake_build/examples/sdk_filter_template`: example to show the usage of filter template.
- `./cmake_build/examples/sdk_full_text_match`: example to show the usage of BM25 function.
- `./cmake_build/examples/sdk_general`: a general example to show the basic usage.
- `./cmake_build/examples/sdk_group_by`: a general example to show the usage of grouping search.
Expand Down
143 changes: 143 additions & 0 deletions examples/src/filter_template.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <iostream>
#include <string>
#include <thread>

#include "ExampleUtils.h"
#include "milvus/MilvusClient.h"

int
main(int argc, char* argv[]) {
printf("Example start...\n");

auto client = milvus::MilvusClient::Create();

milvus::ConnectParam connect_param{"localhost", 19530, "root", "Milvus"};
auto status = client->Connect(connect_param);
util::CheckStatus("connect milvus server", status);

const std::string collection_name = "TEST_CPP_FILTER_TEMPLATE";
const std::string field_id = "pk";
const std::string field_vector = "vector";
const std::string field_text = "text";
const uint32_t dimension = 4;

// collection schema, drop and create collection
milvus::CollectionSchema collection_schema(collection_name);
collection_schema.AddField(milvus::FieldSchema(field_id, milvus::DataType::INT64, "", true, true));
collection_schema.AddField(
milvus::FieldSchema(field_vector, milvus::DataType::FLOAT_VECTOR).WithDimension(dimension));
collection_schema.AddField(milvus::FieldSchema(field_text, milvus::DataType::VARCHAR).WithMaxLength(1024));

status = client->DropCollection(collection_name);
status = client->CreateCollection(collection_schema);
util::CheckStatus("create collection: " + collection_name, status);

// create index
milvus::IndexDesc index_vector(field_vector, "", milvus::IndexType::FLAT, milvus::MetricType::L2);
status = client->CreateIndex(collection_name, index_vector);
util::CheckStatus("create index on vector field", status);

// tell server prepare to load collection
status = client->LoadCollection(collection_name);
util::CheckStatus("load collection: " + collection_name, status);

// insert some rows
milvus::EntityRows rows;
for (auto i = 0; i < 10000; ++i) {
milvus::EntityRow row; // id is auto-generated
row[field_text] = "text_" + std::to_string(i);
row[field_vector] = util::GenerateFloatVector(dimension);
rows.emplace_back(std::move(row));
}

milvus::DmlResults dml_results;
status = client->Insert(collection_name, "", rows, dml_results);
util::CheckStatus("insert", status);
std::cout << dml_results.InsertCount() << " rows inserted by row-based." << std::endl;
auto ids = dml_results.IdArray().IntIDArray();

{
// query with filter template
std::string filter = field_id + " in {my_ids}"; // "my_ids" is an alias will be used in filter template
std::cout << "Query with filter expression: " << filter << std::endl;

auto begin = ids.begin() + 500;
auto end = begin + 100;
std::vector<int64_t> filter_ids(begin, end);
nlohmann::json filter_template = filter_ids;

milvus::QueryArguments q_arguments{};
q_arguments.SetCollectionName(collection_name);
q_arguments.AddOutputField(field_text);
q_arguments.SetFilter(filter);
q_arguments.AddFilterTemplate("my_ids", filter_template); // filter template
// set to strong level so that the query is executed after the inserted data is consumed by server
q_arguments.SetConsistencyLevel(milvus::ConsistencyLevel::STRONG);

milvus::QueryResults query_results{};
status = client->Query(q_arguments, query_results);
util::CheckStatus("query", status);

milvus::EntityRows output_rows;
status = query_results.OutputRows(output_rows);
util::CheckStatus("get output rows", status);
std::cout << "Query with filter template:" << std::endl;
for (const auto& row : output_rows) {
std::cout << "\t" << row << std::endl;
}
}

{
// search with filter template
std::string filter = field_text + " in {my_texts}"; // "my_texts" is an alias will be used in filter template
std::vector<std::string> texts;
for (auto i = 300; i < 500; i++) {
texts.push_back("text_" + std::to_string(i));
}
nlohmann::json filter_template = texts;

milvus::SearchArguments s_arguments{};
s_arguments.SetCollectionName(collection_name);
s_arguments.SetLimit(static_cast<int64_t>(texts.size()));
s_arguments.SetFilter(filter);
s_arguments.AddFilterTemplate("my_texts", filter_template);
s_arguments.AddOutputField(field_text);
s_arguments.AddFloatVector(field_vector, util::GenerateFloatVector(dimension));
s_arguments.AddFloatVector(field_vector, util::GenerateFloatVector(dimension));
s_arguments.SetConsistencyLevel(milvus::ConsistencyLevel::BOUNDED);

milvus::SearchResults search_results{};
status = client->Search(s_arguments, search_results);
util::CheckStatus("search", status);

std::cout << "Search with filter template:" << std::endl;
for (auto& result : search_results.Results()) {
std::cout << "Result of one target vector:" << std::endl;
milvus::EntityRows output_rows;
status = result.OutputRows(output_rows);
util::CheckStatus("get output rows", status);
for (const auto& row : output_rows) {
std::cout << "\t" << row << std::endl;
}
}
}

client->Disconnect();
return 0;
}
24 changes: 24 additions & 0 deletions src/impl/types/QueryArguments.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "milvus/types/QueryArguments.h"

#include "../utils/Constants.h"
#include "../utils/TypeUtils.h"

namespace milvus {

Expand Down Expand Up @@ -90,6 +91,29 @@ QueryArguments::SetFilter(std::string filter) {
return Status::OK();
}

Status
QueryArguments::AddFilterTemplate(std::string key, const nlohmann::json& filter_template) {
if (filter_template.is_array()) {
for (const auto& ele : filter_template) {
if (!IsValidTemplate(ele)) {
return {milvus::StatusCode::INVALID_AGUMENT, "Filter template element must be boolean/number/string"};
}
}
} else {
if (!IsValidTemplate(filter_template)) {
return {milvus::StatusCode::INVALID_AGUMENT, "Filter template must be boolean/number/string/array"};
}
}

filter_templates_.insert(std::make_pair(key, filter_template));
return Status::OK();
}

const std::unordered_map<std::string, nlohmann::json>&
QueryArguments::FilterTemplates() const {
return filter_templates_;
}

int64_t
QueryArguments::Limit() const {
// for history reason, query() requires "limit", search() requires "topk"
Expand Down
25 changes: 23 additions & 2 deletions src/impl/types/SubSearchRequest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#include "milvus/types/SubSearchRequest.h"

#include <nlohmann/json.hpp>
#include <utility>

#include "../utils/Constants.h"
Expand All @@ -25,7 +24,6 @@
#include "milvus/utils/FP16.h"

namespace milvus {

const std::string&
SearchRequestBase::Filter() const {
return filter_expression_;
Expand All @@ -37,6 +35,29 @@ SearchRequestBase::SetFilter(std::string filter) {
return Status::OK();
}

Status
SearchRequestBase::AddFilterTemplate(std::string key, const nlohmann::json& filter_template) {
if (filter_template.is_array()) {
for (const auto& ele : filter_template) {
if (!IsValidTemplate(ele)) {
return {milvus::StatusCode::INVALID_AGUMENT, "Filter template element must be boolean/number/string"};
}
}
} else {
if (!IsValidTemplate(filter_template)) {
return {milvus::StatusCode::INVALID_AGUMENT, "Filter template must be boolean/number/string/array"};
}
}

filter_templates_.insert(std::make_pair(key, filter_template));
return Status::OK();
}

const std::unordered_map<std::string, nlohmann::json>&
SearchRequestBase::FilterTemplates() const {
return filter_templates_;
}

FieldDataPtr
SearchRequestBase::TargetVectors() const {
return target_vectors_;
Expand Down
107 changes: 107 additions & 0 deletions src/impl/utils/DqlUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,97 @@ DeduceGuaranteeTimestamp(const ConsistencyLevel& level, const std::string& db_na
}
}

Status
DeduceTemplateArray(const nlohmann::json& array, proto::schema::TemplateArrayValue& rpc_array) {
if (array.empty()) {
return Status::OK();
}
const auto& first_ele = array.at(0);
if (first_ele.is_boolean()) {
for (const auto& ele : array) {
if (!ele.is_boolean()) {
return {StatusCode::INVALID_AGUMENT,
"Filter expression template is a list, the first value is Boolean, but some elements are not "
"Boolean"};
}
rpc_array.mutable_bool_data()->add_data(ele.get<bool>());
}
} else if (first_ele.is_number_integer()) {
for (const auto& ele : array) {
if (!ele.is_number_integer()) {
return {StatusCode::INVALID_AGUMENT,
"Filter expression template is a list, the first value is Integer, but some elements are not "
"Integer"};
}
rpc_array.mutable_long_data()->add_data(ele.get<int64_t>());
}
} else if (first_ele.is_number_float()) {
for (const auto& ele : array) {
if (!ele.is_number_float()) {
return {StatusCode::INVALID_AGUMENT,
"Filter expression template is a list, the first value is Double, but some elements are not "
"Double"};
}
rpc_array.mutable_double_data()->add_data(ele.get<double>());
}
} else if (first_ele.is_string()) {
for (const auto& ele : array) {
if (!ele.is_string()) {
return {StatusCode::INVALID_AGUMENT,
"Filter expression template is a list, the first value is String, but some elements are not "
"String"};
}
rpc_array.mutable_string_data()->add_data(ele.get<std::string>());
}
} else if (first_ele.is_array()) {
auto rpc_array_array = rpc_array.mutable_array_data()->add_data();
for (const auto& ele : array) {
if (!ele.is_array()) {
return {
StatusCode::INVALID_AGUMENT,
"Filter expression template is a list, the first value is List, but some elements are not List"};
}

auto sub_array = rpc_array_array->mutable_array_data()->add_data();
auto status = DeduceTemplateArray(ele, *sub_array);
if (!status.IsOk()) {
return status;
}
}
}

return Status::OK();
}

Status
ConvertFilterTemplates(const std::unordered_map<std::string, nlohmann::json>& templates,
::google::protobuf::Map<std::string, proto::schema::TemplateValue>* rpc_templates) {
for (const auto& pair : templates) {
proto::schema::TemplateValue value;
const auto& temp = pair.second;
if (temp.is_array()) {
auto array = value.mutable_array_val();
auto status = DeduceTemplateArray(temp, *array);
if (!status.IsOk()) {
return status;
}
} else if (temp.is_boolean()) {
value.set_bool_val(temp.get<bool>());
} else if (temp.is_number_integer()) {
value.set_int64_val(temp.get<int64_t>());
} else if (temp.is_number_float()) {
value.set_float_val(temp.get<double>());
} else if (temp.is_string()) {
value.set_string_val(temp.get<std::string>());
} else {
return {StatusCode::INVALID_AGUMENT, "Unsupported template type"};
}
rpc_templates->insert(std::make_pair(pair.first, value));
}

return Status::OK();
}

// current_db is the actual target db that the request is performed, for setting the GuaranteeTimestamp
// to compatible with old versions.
// for examples:
Expand All @@ -803,6 +894,15 @@ ConvertQueryRequest(const QueryArguments& arguments, const std::string& current_
}

rpc_request.set_expr(arguments.Filter());
if (!arguments.Filter().empty()) {
auto rpc_templates = rpc_request.mutable_expr_template_values();
const auto& templates = arguments.FilterTemplates();
auto status = ConvertFilterTemplates(templates, rpc_templates);
if (!status.IsOk()) {
return status;
}
}

for (const auto& field : arguments.OutputFields()) {
rpc_request.add_output_fields(field);
}
Expand Down Expand Up @@ -868,6 +968,13 @@ ConvertSearchRequest(const SearchArguments& arguments, const std::string& curren
rpc_request.set_dsl_type(proto::common::DslType::BoolExprV1);
if (!arguments.Filter().empty()) {
rpc_request.set_dsl(arguments.Filter());

auto rpc_templates = rpc_request.mutable_expr_template_values();
const auto& templates = arguments.FilterTemplates();
auto status = ConvertFilterTemplates(templates, rpc_templates);
if (!status.IsOk()) {
return status;
}
}
for (const auto& partition_name : arguments.PartitionNames()) {
rpc_request.add_partition_names(partition_name);
Expand Down
5 changes: 5 additions & 0 deletions src/impl/utils/TypeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,11 @@ doubleToString(double val) {
return stream.str();
}

bool
IsValidTemplate(const nlohmann::json& filter_template) {
return filter_template.is_boolean() || filter_template.is_number() || filter_template.is_string();
}

} // namespace milvus

namespace std {
Expand Down
3 changes: 3 additions & 0 deletions src/impl/utils/TypeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,7 @@ ConvertResourceGroupConfig(const proto::rg::ResourceGroupConfig& rpc_config, Res
std::string
doubleToString(double val);

bool
IsValidTemplate(const nlohmann::json& filter_template);

} // namespace milvus
Loading
Loading