diff --git a/cpp/src/gandiva/CMakeLists.txt b/cpp/src/gandiva/CMakeLists.txt index 0b7a45fd30c..a53eb3032ea 100644 --- a/cpp/src/gandiva/CMakeLists.txt +++ b/cpp/src/gandiva/CMakeLists.txt @@ -69,9 +69,11 @@ set(SRC_FILES expression_registry.cc exported_funcs_registry.cc filter.cc + array_ops.cc function_ir_builder.cc function_registry.cc function_registry_arithmetic.cc + function_registry_array.cc function_registry_datetime.cc function_registry_hash.cc function_registry_math_ops.cc @@ -237,6 +239,7 @@ add_gandiva_test(internals-test random_generator_holder_test.cc hash_utils_test.cc gdv_function_stubs_test.cc + array_ops_test.cc EXTRA_DEPENDENCIES LLVM::LLVM_INTERFACE ${GANDIVA_OPENSSL_LIBS} diff --git a/cpp/src/gandiva/annotator.cc b/cpp/src/gandiva/annotator.cc index f6acaff1804..d0117eee4f0 100644 --- a/cpp/src/gandiva/annotator.cc +++ b/cpp/src/gandiva/annotator.cc @@ -46,15 +46,23 @@ FieldDescriptorPtr Annotator::MakeDesc(FieldPtr field, bool is_output) { int data_idx = buffer_count_++; int validity_idx = buffer_count_++; int offsets_idx = FieldDescriptor::kInvalidIdx; + int child_offsets_idx = FieldDescriptor::kInvalidIdx; if (arrow::is_binary_like(field->type()->id())) { offsets_idx = buffer_count_++; } + + if (field->type()->id() == arrow::Type::LIST) { + offsets_idx = buffer_count_++; + if (arrow::is_binary_like(field->type()->field(0)->type()->id())) { + child_offsets_idx = buffer_count_++; + } + } int data_buffer_ptr_idx = FieldDescriptor::kInvalidIdx; if (is_output) { data_buffer_ptr_idx = buffer_count_++; } return std::make_shared(field, data_idx, validity_idx, offsets_idx, - data_buffer_ptr_idx); + data_buffer_ptr_idx, child_offsets_idx); } void Annotator::PrepareBuffersForField(const FieldDescriptor& desc, @@ -74,16 +82,52 @@ void Annotator::PrepareBuffersForField(const FieldDescriptor& desc, if (desc.HasOffsetsIdx()) { uint8_t* offsets_buf = const_cast(array_data.buffers[buffer_idx]->data()); eval_batch->SetBuffer(desc.offsets_idx(), offsets_buf, array_data.offset); - ++buffer_idx; + + if (desc.HasChildOffsetsIdx()) { + if (is_output) { + // if list field is output field, we should put buffer pointer into eval batch + // for resizing + uint8_t* child_offsets_buf = reinterpret_cast( + array_data.child_data.at(0)->buffers[buffer_idx].get()); + eval_batch->SetBuffer(desc.child_data_offsets_idx(), child_offsets_buf, + array_data.child_data.at(0)->offset); + } else { + // if list field is input field, just put buffer data into eval batch + uint8_t* child_offsets_buf = const_cast( + array_data.child_data.at(0)->buffers[buffer_idx]->data()); + eval_batch->SetBuffer(desc.child_data_offsets_idx(), child_offsets_buf, + array_data.child_data.at(0)->offset); + } + } + if (array_data.type->id() != arrow::Type::LIST || + arrow::is_binary_like(array_data.type->field(0)->type()->id())) + // primitive type list data buffer index is 1 + // binary like type list data buffer index is 2 + ++buffer_idx; + } + + if (array_data.type->id() != arrow::Type::LIST) { + uint8_t* data_buf = const_cast(array_data.buffers[buffer_idx]->data()); + eval_batch->SetBuffer(desc.data_idx(), data_buf, array_data.offset); + } else { + uint8_t* data_buf = + const_cast(array_data.child_data.at(0)->buffers[buffer_idx]->data()); + eval_batch->SetBuffer(desc.data_idx(), data_buf, array_data.child_data.at(0)->offset); } - uint8_t* data_buf = const_cast(array_data.buffers[buffer_idx]->data()); - eval_batch->SetBuffer(desc.data_idx(), data_buf, array_data.offset); if (is_output) { // pass in the Buffer object for output data buffers. Can be used for resizing. - uint8_t* data_buf_ptr = - reinterpret_cast(array_data.buffers[buffer_idx].get()); - eval_batch->SetBuffer(desc.data_buffer_ptr_idx(), data_buf_ptr, array_data.offset); + if (array_data.type->id() != arrow::Type::LIST) { + uint8_t* data_buf_ptr = + reinterpret_cast(array_data.buffers[buffer_idx].get()); + eval_batch->SetBuffer(desc.data_buffer_ptr_idx(), data_buf_ptr, array_data.offset); + } else { + // list data buffer is in child data buffer + uint8_t* data_buf_ptr = reinterpret_cast( + array_data.child_data.at(0)->buffers[buffer_idx].get()); + eval_batch->SetBuffer(desc.data_buffer_ptr_idx(), data_buf_ptr, + array_data.child_data.at(0)->offset); + } } } diff --git a/cpp/src/gandiva/array_ops.cc b/cpp/src/gandiva/array_ops.cc new file mode 100644 index 00000000000..6763839fd94 --- /dev/null +++ b/cpp/src/gandiva/array_ops.cc @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/array_ops.h" + +#include "arrow/util/value_parsing.h" +#include "gandiva/engine.h" +#include "gandiva/exported_funcs.h" + +/// Stub functions that can be accessed from LLVM or the pre-compiled library. + +extern "C" { + +bool array_utf8_contains_utf8(int64_t context_ptr, const char* entry_buf, + int32_t* entry_child_offsets, int32_t entry_offsets_len, + const char* contains_data, int32_t contains_data_length) { + for (int i = 0; i < entry_offsets_len; i++) { + int32_t entry_len = *(entry_child_offsets + i + 1) - *(entry_child_offsets + i); + if (entry_len != contains_data_length) { + entry_buf = entry_buf + entry_len; + continue; + } + if (strncmp(entry_buf, contains_data, contains_data_length) == 0) { + return true; + } + entry_buf = entry_buf + entry_len; + } + return false; +} + +int64_t array_utf8_length(int64_t context_ptr, const char* entry_buf, + int32_t* entry_child_offsets, int32_t entry_offsets_len) { + int64_t res = entry_offsets_len; + return res; +} +} + +namespace gandiva { +void ExportedArrayFunctions::AddMappings(Engine* engine) const { + std::vector args; + auto types = engine->types(); + + args = {types->i64_type(), // int64_t execution_context + types->i8_ptr_type(), // int8_t* data ptr + types->i32_ptr_type(), // int32_t* child offsets ptr + types->i32_type()}; // int32_t child offsets length + + engine->AddGlobalMappingForFunc("array_utf8_length", types->i64_type() /*return_type*/, + args, reinterpret_cast(array_utf8_length)); + + args = {types->i64_type(), // int64_t execution_context + types->i8_ptr_type(), // int8_t* data ptr + types->i32_ptr_type(), // int32_t* child offsets ptr + types->i32_type(), // int32_t child offsets length + types->i8_ptr_type(), // const char* contains data buf + types->i32_type()}; // int32_t contains data length + + engine->AddGlobalMappingForFunc("array_utf8_contains_utf8", + types->i1_type() /*return_type*/, args, + reinterpret_cast(array_utf8_contains_utf8)); +} +} // namespace gandiva diff --git a/cpp/src/gandiva/array_ops.h b/cpp/src/gandiva/array_ops.h new file mode 100644 index 00000000000..19bf35e9f4f --- /dev/null +++ b/cpp/src/gandiva/array_ops.h @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "gandiva/visibility.h" + +/// Array functions that can be accessed from LLVM. +extern "C" { +GANDIVA_EXPORT +bool array_utf8_contains_utf8(int64_t context_ptr, const char* entry_buf, + int32_t* entry_child_offsets, int32_t entry_offsets_len, + const char* contains_data, int32_t contains_data_length); +GANDIVA_EXPORT +int64_t array_utf8_length(int64_t context_ptr, const char* entry_buf, + int32_t* entry_child_offsets, int32_t entry_offsets_len); +} diff --git a/cpp/src/gandiva/array_ops_test.cc b/cpp/src/gandiva/array_ops_test.cc new file mode 100644 index 00000000000..f2209c8e634 --- /dev/null +++ b/cpp/src/gandiva/array_ops_test.cc @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "gandiva/execution_context.h" +#include "gandiva/precompiled/types.h" + +namespace gandiva { + +TEST(TestArrayOps, TestUtf8ContainsUtf8) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast(&ctx); + const char* entry_buf = "trianglecirclerectangle"; + int32_t entry_child_offsets[] = {0, 8, 14, 24}; + int32_t entry_offsets_len = 3; + const char* contains_data = "triangle"; + int32_t contains_data_length = 8; + + EXPECT_EQ( + array_utf8_contains_utf8(ctx_ptr, entry_buf, entry_child_offsets, entry_offsets_len, + contains_data, contains_data_length), + true); +} + +TEST(TestArrayOps, TestUtf8Length) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast(&ctx); + const char* entry_buf = "trianglecirclerectangle"; + int32_t entry_child_offsets[] = {0, 8, 14, 24}; + int32_t entry_offsets_len = 3; + + EXPECT_EQ(array_utf8_length(ctx_ptr, entry_buf, entry_child_offsets, entry_offsets_len), + 3); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/dex.h b/cpp/src/gandiva/dex.h index 3920f82f1d7..e5afad90fc3 100644 --- a/cpp/src/gandiva/dex.h +++ b/cpp/src/gandiva/dex.h @@ -80,6 +80,19 @@ class GANDIVA_EXPORT VectorReadFixedLenValueDex : public VectorReadBaseDex { void Accept(DexVisitor& visitor) override { visitor.Visit(*this); } }; +/// value component of a fixed-len list ValueVector +class GANDIVA_EXPORT VectorReadFixedLenValueListDex : public VectorReadBaseDex { + public: + explicit VectorReadFixedLenValueListDex(FieldDescriptorPtr field_desc) + : VectorReadBaseDex(field_desc) {} + + int DataIdx() const { return field_desc_->data_idx(); } + + int OffsetsIdx() const { return field_desc_->offsets_idx(); } + + void Accept(DexVisitor& visitor) override { visitor.Visit(*this); } +}; + /// value component of a variable-len ValueVector class GANDIVA_EXPORT VectorReadVarLenValueDex : public VectorReadBaseDex { public: @@ -93,6 +106,21 @@ class GANDIVA_EXPORT VectorReadVarLenValueDex : public VectorReadBaseDex { void Accept(DexVisitor& visitor) override { visitor.Visit(*this); } }; +/// value component of a variable-len list ValueVector +class GANDIVA_EXPORT VectorReadVarLenValueListDex : public VectorReadBaseDex { + public: + explicit VectorReadVarLenValueListDex(FieldDescriptorPtr field_desc) + : VectorReadBaseDex(field_desc) {} + + int DataIdx() const { return field_desc_->data_idx(); } + + int OffsetsIdx() const { return field_desc_->offsets_idx(); } + + int ChildOffsetsIdx() const { return field_desc_->child_data_offsets_idx(); } + + void Accept(DexVisitor& visitor) override { visitor.Visit(*this); } +}; + /// validity based on a local bitmap. class GANDIVA_EXPORT LocalBitMapValidityDex : public Dex { public: diff --git a/cpp/src/gandiva/dex_visitor.h b/cpp/src/gandiva/dex_visitor.h index ba5de970dda..a5de17c41a7 100644 --- a/cpp/src/gandiva/dex_visitor.h +++ b/cpp/src/gandiva/dex_visitor.h @@ -27,7 +27,9 @@ namespace gandiva { class VectorReadValidityDex; class VectorReadFixedLenValueDex; +class VectorReadFixedLenValueListDex; class VectorReadVarLenValueDex; +class VectorReadVarLenValueListDex; class LocalBitMapValidityDex; class LiteralDex; class TrueDex; @@ -48,7 +50,9 @@ class GANDIVA_EXPORT DexVisitor { virtual void Visit(const VectorReadValidityDex& dex) = 0; virtual void Visit(const VectorReadFixedLenValueDex& dex) = 0; + virtual void Visit(const VectorReadFixedLenValueListDex& dex) = 0; virtual void Visit(const VectorReadVarLenValueDex& dex) = 0; + virtual void Visit(const VectorReadVarLenValueListDex& dex) = 0; virtual void Visit(const LocalBitMapValidityDex& dex) = 0; virtual void Visit(const TrueDex& dex) = 0; virtual void Visit(const FalseDex& dex) = 0; @@ -72,7 +76,9 @@ class GANDIVA_EXPORT DexVisitor { class GANDIVA_EXPORT DexDefaultVisitor : public DexVisitor { VISIT_DCHECK(VectorReadValidityDex) VISIT_DCHECK(VectorReadFixedLenValueDex) + VISIT_DCHECK(VectorReadFixedLenValueListDex) VISIT_DCHECK(VectorReadVarLenValueDex) + VISIT_DCHECK(VectorReadVarLenValueListDex) VISIT_DCHECK(LocalBitMapValidityDex) VISIT_DCHECK(TrueDex) VISIT_DCHECK(FalseDex) diff --git a/cpp/src/gandiva/exported_funcs.h b/cpp/src/gandiva/exported_funcs.h index 58205266094..9528273507f 100644 --- a/cpp/src/gandiva/exported_funcs.h +++ b/cpp/src/gandiva/exported_funcs.h @@ -32,6 +32,12 @@ class ExportedFuncsBase { virtual void AddMappings(Engine* engine) const = 0; }; +// Class for exporting Array functions +class ExportedArrayFunctions : public ExportedFuncsBase { + void AddMappings(Engine* engine) const override; +}; +REGISTER_EXPORTED_FUNCS(ExportedArrayFunctions); + // Class for exporting Stub functions class ExportedStubFunctions : public ExportedFuncsBase { void AddMappings(Engine* engine) const override; diff --git a/cpp/src/gandiva/expr_decomposer.cc b/cpp/src/gandiva/expr_decomposer.cc index 07252b42fd2..71a81a41787 100644 --- a/cpp/src/gandiva/expr_decomposer.cc +++ b/cpp/src/gandiva/expr_decomposer.cc @@ -39,8 +39,16 @@ Status ExprDecomposer::Visit(const FieldNode& node) { DexPtr validity_dex = std::make_shared(desc); DexPtr value_dex; - if (desc->HasOffsetsIdx()) { - value_dex = std::make_shared(desc); + if (desc->HasChildOffsetsIdx()) { + // handle list type + value_dex = std::make_shared(desc); + } else if (desc->HasOffsetsIdx()) { + if (desc->field()->type()->id() == arrow::Type::LIST) { + // handle list type + value_dex = std::make_shared(desc); + } else { + value_dex = std::make_shared(desc); + } } else { value_dex = std::make_shared(desc); } diff --git a/cpp/src/gandiva/expr_validator.cc b/cpp/src/gandiva/expr_validator.cc index fd46c2894b9..961678fc659 100644 --- a/cpp/src/gandiva/expr_validator.cc +++ b/cpp/src/gandiva/expr_validator.cc @@ -42,7 +42,7 @@ Status ExprValidator::Validate(const ExpressionPtr& expr) { } Status ExprValidator::Visit(const FieldNode& node) { - auto llvm_type = types_->IRType(node.return_type()->id()); + auto llvm_type = types_->DataVecType(node.return_type()); ARROW_RETURN_IF(llvm_type == nullptr, Status::ExpressionValidationError("Field ", node.field()->name(), " has unsupported data type ", @@ -111,7 +111,7 @@ Status ExprValidator::Visit(const IfNode& node) { } Status ExprValidator::Visit(const LiteralNode& node) { - auto llvm_type = types_->IRType(node.return_type()->id()); + auto llvm_type = types_->DataVecType(node.return_type()); ARROW_RETURN_IF(llvm_type == nullptr, Status::ExpressionValidationError("Value ", ToString(node.holder()), " has unsupported data type ", diff --git a/cpp/src/gandiva/expression_registry.cc b/cpp/src/gandiva/expression_registry.cc index c3a08fd3a4f..b3232700d38 100644 --- a/cpp/src/gandiva/expression_registry.cc +++ b/cpp/src/gandiva/expression_registry.cc @@ -14,7 +14,7 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. - +#include #include "gandiva/expression_registry.h" #include "gandiva/function_registry.h" @@ -166,6 +166,10 @@ static void AddArrowTypesToVector(arrow::Type::type type, DataTypeVector& vector case arrow::Type::type::INTERVAL_DAY_TIME: vector.push_back(arrow::day_time_interval()); break; + case arrow::Type::type::LIST: + vector.push_back(arrow::list(arrow::uint32())); + //vector.push_back(arrow::list(arrow::utf8())); + break; default: // Unsupported types. test ensures that // when one of these are added build breaks. diff --git a/cpp/src/gandiva/field_descriptor.h b/cpp/src/gandiva/field_descriptor.h index 0fe6fe37f4d..7b2d0c3b4fa 100644 --- a/cpp/src/gandiva/field_descriptor.h +++ b/cpp/src/gandiva/field_descriptor.h @@ -30,12 +30,14 @@ class FieldDescriptor { static const int kInvalidIdx = -1; FieldDescriptor(FieldPtr field, int data_idx, int validity_idx = kInvalidIdx, - int offsets_idx = kInvalidIdx, int data_buffer_ptr_idx = kInvalidIdx) + int offsets_idx = kInvalidIdx, int data_buffer_ptr_idx = kInvalidIdx, + int child_offsets_idx = kInvalidIdx) : field_(field), data_idx_(data_idx), validity_idx_(validity_idx), offsets_idx_(offsets_idx), - data_buffer_ptr_idx_(data_buffer_ptr_idx) {} + data_buffer_ptr_idx_(data_buffer_ptr_idx), + child_offsets_idx_(child_offsets_idx) {} /// Index of validity array in the array-of-buffers int validity_idx() const { return validity_idx_; } @@ -49,6 +51,9 @@ class FieldDescriptor { /// Index of data buffer pointer in the array-of-buffers int data_buffer_ptr_idx() const { return data_buffer_ptr_idx_; } + /// Index of list type child data offsets + int child_data_offsets_idx() const { return child_offsets_idx_; } + FieldPtr field() const { return field_; } const std::string& Name() const { return field_->name(); } @@ -58,12 +63,15 @@ class FieldDescriptor { bool HasDataBufferPtrIdx() const { return data_buffer_ptr_idx_ != kInvalidIdx; } + bool HasChildOffsetsIdx() const { return child_offsets_idx_ != kInvalidIdx; } + private: FieldPtr field_; int data_idx_; int validity_idx_; int offsets_idx_; int data_buffer_ptr_idx_; + int child_offsets_idx_; }; } // namespace gandiva diff --git a/cpp/src/gandiva/function_registry.cc b/cpp/src/gandiva/function_registry.cc index d5d015c10b4..9180e8c33ca 100644 --- a/cpp/src/gandiva/function_registry.cc +++ b/cpp/src/gandiva/function_registry.cc @@ -16,17 +16,19 @@ // under the License. #include "gandiva/function_registry.h" + +#include +#include +#include + #include "gandiva/function_registry_arithmetic.h" +#include "gandiva/function_registry_array.h" #include "gandiva/function_registry_datetime.h" #include "gandiva/function_registry_hash.h" #include "gandiva/function_registry_math_ops.h" #include "gandiva/function_registry_string.h" #include "gandiva/function_registry_timestamp_arithmetic.h" -#include -#include -#include - namespace gandiva { FunctionRegistry::iterator FunctionRegistry::begin() const { @@ -65,6 +67,9 @@ SignatureMap FunctionRegistry::InitPCMap() { auto v6 = GetDateTimeArithmeticFunctionRegistry(); pc_registry_.insert(std::end(pc_registry_), v6.begin(), v6.end()); + auto v7 = GetArrayFunctionRegistry(); + pc_registry_.insert(std::end(pc_registry_), v7.begin(), v7.end()); + for (auto& elem : pc_registry_) { for (auto& func_signature : elem.signatures()) { map.insert(std::make_pair(&(func_signature), &elem)); diff --git a/cpp/src/gandiva/function_registry_array.cc b/cpp/src/gandiva/function_registry_array.cc new file mode 100644 index 00000000000..057115caf04 --- /dev/null +++ b/cpp/src/gandiva/function_registry_array.cc @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/function_registry_array.h" + +#include "gandiva/function_registry_common.h" + +namespace gandiva { +std::vector GetArrayFunctionRegistry() { + static std::vector array_fn_registry_ = { + NativeFunction("array_contains", {}, DataTypeVector{list(utf8()), utf8()}, + boolean(), kResultNullIfNull, "array_utf8_contains_utf8", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + NativeFunction("array_length", {}, DataTypeVector{list(utf8())}, int64(), + kResultNullIfNull, "array_utf8_length", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + }; + return array_fn_registry_; +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/function_registry_array.h b/cpp/src/gandiva/function_registry_array.h new file mode 100644 index 00000000000..9b8e4553702 --- /dev/null +++ b/cpp/src/gandiva/function_registry_array.h @@ -0,0 +1,28 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "gandiva/native_function.h" + +namespace gandiva { + +std::vector GetArrayFunctionRegistry(); + +} // namespace gandiva diff --git a/cpp/src/gandiva/function_registry_test.cc b/cpp/src/gandiva/function_registry_test.cc index e3c1e85f79c..f33388e0aef 100644 --- a/cpp/src/gandiva/function_registry_test.cc +++ b/cpp/src/gandiva/function_registry_test.cc @@ -93,4 +93,14 @@ TEST_F(TestFunctionRegistry, TestNoDuplicates) { "different precompiled functions:\n" << stream.str(); } + +TEST_F(TestFunctionRegistry, TestFound2) { + FunctionSignature array_length("array_length", {list(utf8())}, arrow::int64()); + + const NativeFunction* function = registry_.LookupSignature(array_length); + EXPECT_NE(function, nullptr); + EXPECT_THAT(function->signatures(), testing::Contains(array_length)); + EXPECT_EQ(function->pc_name(), "array_utf8_length"); +} + } // namespace gandiva diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc index db1802ae6b8..2977bfd0acb 100644 --- a/cpp/src/gandiva/gdv_function_stubs.cc +++ b/cpp/src/gandiva/gdv_function_stubs.cc @@ -34,8 +34,6 @@ #include "gandiva/replace_holder.h" #include "gandiva/to_date_holder.h" -/// Stub functions that can be accessed from LLVM or the pre-compiled library. - extern "C" { const uint8_t* gdv_fn_get_json_object_utf8_utf8(int64_t ptr, const char* data, int data_len, @@ -280,6 +278,91 @@ const char* gdv_fn_sha1_decimal128(int64_t context, int64_t x_high, uint64_t x_l out_length); } +/// Stub functions that can be accessed from LLVM or the pre-compiled library. +#define POPULATE_NUMERIC_LIST_TYPE_VECTOR(TYPE, SCALE) \ + int32_t gdv_fn_populate_list_##TYPE##_vector(int64_t context_ptr, int8_t* data_ptr, \ + int32_t* offsets, int64_t slot, \ + TYPE* entry_buf, int32_t entry_len) { \ + auto buffer = reinterpret_cast(data_ptr); \ + int32_t offset = static_cast(buffer->size()); \ + auto status = buffer->Resize(offset + entry_len * SCALE, false /*shrink*/); \ + if (!status.ok()) { \ + gandiva::ExecutionContext* context = \ + reinterpret_cast(context_ptr); \ + context->set_error_msg(status.message().c_str()); \ + return -1; \ + } \ + memcpy(buffer->mutable_data() + offset, (char*)entry_buf, entry_len * SCALE); \ + offsets[slot] = offset / SCALE; \ + offsets[slot + 1] = offset / SCALE + entry_len; \ + return 0; \ + } + +POPULATE_NUMERIC_LIST_TYPE_VECTOR(int32_t, 4) +POPULATE_NUMERIC_LIST_TYPE_VECTOR(int64_t, 8) +POPULATE_NUMERIC_LIST_TYPE_VECTOR(float, 4) +POPULATE_NUMERIC_LIST_TYPE_VECTOR(double, 8) + +int32_t gdv_fn_populate_list_varlen_vector(int64_t context_ptr, int8_t* data_ptr, + int32_t* offsets, int32_t* child_offsets, + int64_t slot, const char* entry_buf, + int32_t* entry_child_offsets, + int32_t entry_offsets_len) { + // we should calculate varlen list type varlen offset + // copy from entry child offsets + // it should be noted that, + // buffer size unit is byte(8 bit), + // offset element unit is int32(32 bit) + auto child_offsets_buffer = reinterpret_cast(child_offsets); + int32_t child_offsets_buffer_offset = + static_cast(child_offsets_buffer->size()); + + // data buffer elelment is char(8 bit) + auto data_buffer = reinterpret_cast(data_ptr); + int32_t data_buffer_offset = static_cast(data_buffer->size()); + + // sets the size in the child offsets buffer + // offsets element is int32, we should resize buffer by extra offsets_len * 4 + auto status = child_offsets_buffer->Resize( + child_offsets_buffer_offset + entry_offsets_len * 4, false /*shrink*/); + if (!status.ok()) { + gandiva::ExecutionContext* context = + reinterpret_cast(context_ptr); + + context->set_error_msg(status.message().c_str()); + return -1; + } + + // append the new child offsets entry to child offsets buffer + // offsets buffer last offset number indicating data length + // we should take this extra offset into consider + // so the initialize child_offsets_buffer length is 1(int32) + memcpy(child_offsets_buffer->mutable_data() + child_offsets_buffer_offset - 4, + (char*)entry_child_offsets, (entry_offsets_len + 1) * 4); + + // compute data length + int32_t data_length = + *(entry_child_offsets + entry_offsets_len) - *(entry_child_offsets); + + // sets the size in the child offsets buffer. + status = data_buffer->Resize(data_buffer_offset + data_length, false /*shrink*/); + if (!status.ok()) { + gandiva::ExecutionContext* context = + reinterpret_cast(context_ptr); + + context->set_error_msg(status.message().c_str()); + return -1; + } + + // append the new child offsets entry to child offsets buffer + memcpy(data_buffer->mutable_data() + data_buffer_offset, entry_buf, data_length); + + // update offsets buffer. + offsets[slot] = child_offsets_buffer_offset / 4 - 1; + offsets[slot + 1] = child_offsets_buffer_offset / 4 - 1 + entry_offsets_len; + return 0; +} + int32_t gdv_fn_dec_from_string(int64_t context, const char* in, int32_t in_length, int32_t* precision_from_str, int32_t* scale_from_str, int64_t* dec_high_from_str, uint64_t* dec_low_from_str) { @@ -454,6 +537,7 @@ GDV_FN_CAST_VARCHAR_REAL(float64, DoubleType) #undef GDV_FN_CAST_VARCHAR_INTEGER #undef GDV_FN_CAST_VARCHAR_REAL +#undef POPULATE_NUMERIC_LIST_TYPE_VECTOR } namespace gandiva { @@ -624,6 +708,18 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const { types->i1_type() /*return_type*/, args, reinterpret_cast(gdv_fn_in_expr_lookup_utf8)); +#define ADD_MAPPING_FOR_NUMERIC_LIST_TYPE_POPULATE_FUNCTION(LLVM_TYPE, DATA_TYPE) \ + args = {types->i64_type(), types->i8_ptr_type(), types->i32_ptr_type(), \ + types->i64_type(), types->LLVM_TYPE##_ptr_type(), types->i32_type()}; \ + engine->AddGlobalMappingForFunc( \ + "gdv_fn_populate_list_" #DATA_TYPE "_vector", types->i32_type() /*return_type*/, \ + args, reinterpret_cast(gdv_fn_populate_list_##DATA_TYPE##_vector)); + + ADD_MAPPING_FOR_NUMERIC_LIST_TYPE_POPULATE_FUNCTION(i32, int32_t) + ADD_MAPPING_FOR_NUMERIC_LIST_TYPE_POPULATE_FUNCTION(i64, int64_t) + ADD_MAPPING_FOR_NUMERIC_LIST_TYPE_POPULATE_FUNCTION(float, float) + ADD_MAPPING_FOR_NUMERIC_LIST_TYPE_POPULATE_FUNCTION(double, double) + // gdv_fn_populate_varlen_vector args = {types->i64_type(), // int64_t execution_context types->i8_ptr_type(), // int8_t* data ptr @@ -636,6 +732,20 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const { types->i32_type() /*return_type*/, args, reinterpret_cast(gdv_fn_populate_varlen_vector)); + // gdv_fn_populate_list_varlen_vector + args = {types->i64_type(), // int64_t execution_context + types->i8_ptr_type(), // int8_t* data ptr + types->i32_ptr_type(), // int32_t* offsets ptr + types->i32_ptr_type(), // int32_t* child offsets ptr + types->i64_type(), // int64_t slot + types->i8_ptr_type(), // const char* entry_buf + types->i32_ptr_type(), // int32_t* entry child offsets ptr + types->i32_type()}; // int32_t entry child offsets length + + engine->AddGlobalMappingForFunc( + "gdv_fn_populate_list_varlen_vector", types->i32_type() /*return_type*/, args, + reinterpret_cast(gdv_fn_populate_list_varlen_vector)); + // gdv_fn_random args = {types->i64_type()}; engine->AddGlobalMappingForFunc("gdv_fn_random", types->double_type(), args, @@ -1156,4 +1266,5 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const { types->i8_ptr_type() /*return_type*/, args, reinterpret_cast(gdv_fn_sha256_decimal128)); } +#undef ADD_MAPPING_FOR_NUMERIC_LIST_TYPE_POPULATE_FUNCTION } // namespace gandiva diff --git a/cpp/src/gandiva/jni/expression_registry_helper.cc b/cpp/src/gandiva/jni/expression_registry_helper.cc index 0d1f74ba6f0..8047259a7f6 100644 --- a/cpp/src/gandiva/jni/expression_registry_helper.cc +++ b/cpp/src/gandiva/jni/expression_registry_helper.cc @@ -136,6 +136,10 @@ void ArrowToProtobuf(DataTypePtr type, types::ExtGandivaType* gandiva_data_type) gandiva_data_type->set_type(types::GandivaType::INTERVAL); gandiva_data_type->set_intervaltype(types::IntervalType::DAY_TIME); break; + case arrow::Type::LIST: + gandiva_data_type->set_type(types::GandivaType::LIST); + gandiva_data_type->set_valuetype(types::GandivaType::INT32); + break; default: // un-supported types. test ensures that // when one of these are added build breaks. @@ -149,6 +153,7 @@ Java_org_apache_arrow_gandiva_evaluator_ExpressionRegistryJniHelper_getGandivaSu types::GandivaDataTypes gandiva_data_types; auto supported_types = ExpressionRegistry::supported_types(); for (auto const& type : supported_types) { + std::cout << type->ToString() << std::endl; types::ExtGandivaType* gandiva_data_type = gandiva_data_types.add_datatype(); ArrowToProtobuf(type, gandiva_data_type); } diff --git a/cpp/src/gandiva/jni/jni_common.cc b/cpp/src/gandiva/jni/jni_common.cc index e17eac7eab4..0d6ac74d72c 100644 --- a/cpp/src/gandiva/jni/jni_common.cc +++ b/cpp/src/gandiva/jni/jni_common.cc @@ -189,6 +189,18 @@ DataTypePtr ProtoTypeToInterval(const types::ExtGandivaType& ext_type) { } } +DataTypePtr ProtoTypeToList(const types::ExtGandivaType& ext_type) { + switch (ext_type.valuetype()) { + case types::UINT32: + return arrow::list(arrow::uint32()); + case types::UTF8: + return arrow::list(arrow::utf8()); + default: + std::cerr << "Unknown value type: " << ext_type.valuetype() << "\n"; + return nullptr; + } +} + DataTypePtr ProtoTypeToDataType(const types::ExtGandivaType& ext_type) { switch (ext_type.type()) { case types::NONE: @@ -236,8 +248,9 @@ DataTypePtr ProtoTypeToDataType(const types::ExtGandivaType& ext_type) { return ProtoTypeToTimestamp(ext_type); case types::INTERVAL: return ProtoTypeToInterval(ext_type); - case types::FIXED_SIZE_BINARY: case types::LIST: + return ProtoTypeToList(ext_type); + case types::FIXED_SIZE_BINARY: case types::STRUCT: case types::UNION: case types::DICTIONARY: @@ -533,55 +546,193 @@ bool ParseProtobuf(uint8_t* buf, int bufLen, google::protobuf::Message* msg) { return msg->ParseFromCodedStream(&cis); } -Status make_record_batch_with_buf_addrs(SchemaPtr schema, int num_rows, - jlong* in_buf_addrs, jlong* in_buf_sizes, - int in_bufs_len, - std::shared_ptr* batch) { - std::vector> columns; - auto num_fields = schema->num_fields(); - int buf_idx = 0; - int sz_idx = 0; - for (int i = 0; i < num_fields; i++) { - auto field = schema->field(i); - std::vector> buffers; +std::shared_ptr GetOffsetDataType( + std::shared_ptr parent_type) { + switch (parent_type->id()) { + case arrow::BinaryType::type_id: + return std::make_shared::OffsetType>(); + case arrow::LargeBinaryType::type_id: + return std::make_shared::OffsetType>(); + case arrow::ListType::type_id: + return std::make_shared::OffsetType>(); + case arrow::LargeListType::type_id: + return std::make_shared::OffsetType>(); + default: + return nullptr; + } +} + +template +bool is_fixed_width_type(T _) { + return std::is_base_of::value; +} + +arrow::Status AppendNodes(std::shared_ptr column, + std::vector>* nodes) { + auto type = column->type(); + (*nodes).push_back(std::make_pair(column->length(), column->null_count())); + switch (type->id()) { + case arrow::Type::LIST: + case arrow::Type::LARGE_LIST: { + auto list_array = std::dynamic_pointer_cast(column); + RETURN_NOT_OK(AppendNodes(list_array->values(), nodes)); + } break; + default: { + } break; + } + return arrow::Status::OK(); +} + +arrow::Status AppendBuffers(std::shared_ptr column, + std::vector>* buffers) { + auto type = column->type(); + switch (type->id()) { + case arrow::Type::LIST: + case arrow::Type::LARGE_LIST: { + auto list_array = std::dynamic_pointer_cast(column); + (*buffers).push_back(list_array->null_bitmap()); + (*buffers).push_back(list_array->value_offsets()); + RETURN_NOT_OK(AppendBuffers(list_array->values(), buffers)); + } break; + default: { + for (auto& buffer : column->data()->buffers) { + (*buffers).push_back(buffer); + } + } break; + } + return arrow::Status::OK(); +} + +arrow::Status FIXOffsetBuffer(std::shared_ptr* in_buf, int fix_row) { + if ((*in_buf) == nullptr || (*in_buf)->size() == 0) return arrow::Status::OK(); + if ((*in_buf)->size() * 8 <= fix_row) { + ARROW_ASSIGN_OR_RAISE(auto valid_copy, arrow::AllocateBuffer((*in_buf)->size() + 1)); + std::memcpy(valid_copy->mutable_data(), (*in_buf)->data(), + static_cast((*in_buf)->size())); + (*in_buf) = std::move(valid_copy); + } + arrow::BitUtil::SetBitsTo(const_cast((*in_buf)->data()), fix_row, 1, true); + return arrow::Status::OK(); +} + +arrow::Status MakeArrayData(std::shared_ptr type, int num_rows, + std::vector> in_bufs, + int in_bufs_len, std::shared_ptr* arr_data, + int* buf_idx_ptr) { + if (arrow::is_nested(type->id())) { + // Maybe ListType, MapType, StructType or UnionType + switch (type->id()) { + case arrow::Type::LIST: + case arrow::Type::LARGE_LIST: { + auto offset_data_type = GetOffsetDataType(type); + auto list_type = std::dynamic_pointer_cast(type); + auto child_type = list_type->value_type(); + std::shared_ptr child_array_data, offset_array_data; + // create offset array + // Chendi: For some reason, for ListArray::FromArrays will remove last row from + // offset array, refer to array_nested.cc CleanListOffsets function + FIXOffsetBuffer(&in_bufs[*buf_idx_ptr], num_rows); + RETURN_NOT_OK(MakeArrayData(offset_data_type, num_rows + 1, in_bufs, in_bufs_len, + &offset_array_data, buf_idx_ptr)); + auto offset_array = arrow::MakeArray(offset_array_data); + // create child data array + RETURN_NOT_OK(MakeArrayData(child_type, -1, in_bufs, in_bufs_len, + &child_array_data, buf_idx_ptr)); + auto child_array = arrow::MakeArray(child_array_data); + auto list_array = + arrow::ListArray::FromArrays(*offset_array, *child_array).ValueOrDie(); + *arr_data = list_array->data(); + + } break; + default: + return arrow::Status::NotImplemented("MakeArrayData for type ", type->ToString(), + " is not supported yet."); + } - if (buf_idx >= in_bufs_len) { - return Status::Invalid("insufficient number of in_buf_addrs"); + } else { + int64_t null_count = arrow::kUnknownNullCount; + std::vector> buffers; + if (*buf_idx_ptr >= in_bufs_len) { + return arrow::Status::Invalid("insufficient number of in_buf_addrs"); } - jlong validity_addr = in_buf_addrs[buf_idx++]; - jlong validity_size = in_buf_sizes[sz_idx++]; - auto validity = std::shared_ptr( - new arrow::Buffer(reinterpret_cast(validity_addr), validity_size)); - buffers.push_back(validity); - - if (buf_idx >= in_bufs_len) { - return Status::Invalid("insufficient number of in_buf_addrs"); + if (in_bufs[*buf_idx_ptr]->size() == 0) { + null_count = 0; } - jlong value_addr = in_buf_addrs[buf_idx++]; - jlong value_size = in_buf_sizes[sz_idx++]; - auto data = std::shared_ptr( - new arrow::Buffer(reinterpret_cast(value_addr), value_size)); - buffers.push_back(data); - - if (arrow::is_binary_like(field->type()->id())) { - if (buf_idx >= in_bufs_len) { - return Status::Invalid("insufficient number of in_buf_addrs"); + buffers.push_back(in_bufs[*buf_idx_ptr]); + *buf_idx_ptr += 1; + + if (arrow::is_binary_like(type->id())) { + if (*buf_idx_ptr >= in_bufs_len) { + return arrow::Status::Invalid("insufficient number of in_buf_addrs"); } - // add offsets buffer for variable-len fields. - jlong offsets_addr = in_buf_addrs[buf_idx++]; - jlong offsets_size = in_buf_sizes[sz_idx++]; - auto offsets = std::shared_ptr( - new arrow::Buffer(reinterpret_cast(offsets_addr), offsets_size)); - buffers.push_back(offsets); + buffers.push_back(in_bufs[*buf_idx_ptr]); + auto offsets_size = in_bufs[*buf_idx_ptr]->size(); + *buf_idx_ptr += 1; + if (num_rows == -1) num_rows = offsets_size / 4 - 1; + } + + if (*buf_idx_ptr >= in_bufs_len) { + return arrow::Status::Invalid("insufficient number of in_buf_addrs"); + } + auto value_size = in_bufs[*buf_idx_ptr]->size(); + buffers.push_back(in_bufs[*buf_idx_ptr]); + *buf_idx_ptr += 1; + if (num_rows == -1) { + num_rows = value_size * 8 / arrow::bit_width(type->id()); } - auto array_data = arrow::ArrayData::Make(field->type(), num_rows, std::move(buffers)); - columns.push_back(array_data); + *arr_data = arrow::ArrayData::Make(type, num_rows, std::move(buffers), null_count); } - *batch = arrow::RecordBatch::Make(schema, num_rows, columns); - return Status::OK(); + return arrow::Status::OK(); +} + +arrow::Status MakeRecordBatch(const std::shared_ptr& schema, int num_rows, + std::vector> in_bufs, + int in_bufs_len, + std::shared_ptr* batch) { + std::vector> arrays; + auto num_fields = schema->num_fields(); + int buf_idx = 0; + + for (int i = 0; i < num_fields; i++) { + auto field = schema->field(i); + std::shared_ptr array_data; + RETURN_NOT_OK(MakeArrayData(field->type(), num_rows, in_bufs, in_bufs_len, + &array_data, &buf_idx)); + arrays.push_back(array_data); + } + + *batch = arrow::RecordBatch::Make(schema, num_rows, arrays); + return arrow::Status::OK(); +} + +arrow::Status MakeRecordBatch(const std::shared_ptr& schema, int num_rows, + int64_t* in_buf_addrs, int64_t* in_buf_sizes, + int in_bufs_len, + std::shared_ptr* batch) { + std::vector> arrays; + std::vector> buffers; + for (int i = 0; i < in_bufs_len; i++) { + if (in_buf_addrs[i] != 0) { + auto data = std::shared_ptr(new arrow::Buffer( + reinterpret_cast(in_buf_addrs[i]), in_buf_sizes[i])); + buffers.push_back(data); + } else { + buffers.push_back(std::make_shared(nullptr, 0)); + } + } + return MakeRecordBatch(schema, num_rows, buffers, in_bufs_len, batch); +} + + +Status make_record_batch_with_buf_addrs(SchemaPtr schema, int num_rows, + jlong* in_buf_addrs, jlong* in_buf_sizes, + int in_bufs_len, + std::shared_ptr* batch) { + + return MakeRecordBatch(schema, num_rows,(int64_t*)in_buf_addrs,(int64_t*)in_buf_sizes, in_bufs_len, batch); } // projector related functions. diff --git a/cpp/src/gandiva/llvm_generator.cc b/cpp/src/gandiva/llvm_generator.cc index 1a80f1e7586..2d8f8be820a 100644 --- a/cpp/src/gandiva/llvm_generator.cc +++ b/cpp/src/gandiva/llvm_generator.cc @@ -32,7 +32,6 @@ #include "gandiva/lvalue.h" namespace gandiva { - #define ADD_TRACE(...) \ if (enable_ir_traces_) { \ AddTrace(__VA_ARGS__); \ @@ -187,6 +186,14 @@ llvm::Value* LLVMGenerator::GetOffsetsReference(llvm::Value* arg_addrs, int idx, return ir_builder()->CreateIntToPtr(load, types()->i32_ptr_type(), name + "_oarray"); } +/// Get reference to child offsets array at specified index in the args list. +llvm::Value* LLVMGenerator::GetChildOffsetsReference(llvm::Value* arg_addrs, int idx, + FieldPtr field) { + const std::string& name = field->name(); + llvm::Value* load = LoadVectorAtIndex(arg_addrs, idx, name); + return ir_builder()->CreateIntToPtr(load, types()->i32_ptr_type(), name + "_coarray"); +} + /// Get reference to local bitmap array at specified index in the args list. llvm::Value* LLVMGenerator::GetLocalBitMapReference(llvm::Value* arg_bitmaps, int idx) { llvm::Value* load = LoadVectorAtIndex(arg_bitmaps, idx, ""); @@ -363,6 +370,38 @@ Status LLVMGenerator::CodeGenExprValue(DexPtr value_expr, int buffer_count, AddFunctionCall("gdv_fn_populate_varlen_vector", types()->i32_type(), {arg_context_ptr, output_buffer_ptr_ref, output_offset_ref, loop_var, output_value->data(), output_value->length()}); + } else if (output_type_id == arrow::Type::LIST) { + auto output_list_internal_type = output->Type()->field(0)->type()->id(); + if (arrow::is_binary_like(output_list_internal_type)) { + auto output_list_value = std::dynamic_pointer_cast(output_value); + llvm::Value* child_output_offset_ref = GetChildOffsetsReference( + arg_addrs, output->child_data_offsets_idx(), output->field()); + AddFunctionCall( + "gdv_fn_populate_list_varlen_vector", types()->i32_type(), + {arg_context_ptr, output_buffer_ptr_ref, output_offset_ref, + child_output_offset_ref, loop_var, output_list_value->data(), + output_list_value->child_offsets(), output_list_value->offsets_length()}); + } else if (output_list_internal_type == arrow::Type::INT32) { + AddFunctionCall("gdv_fn_populate_list_int32_t_vector", types()->i32_type(), + {arg_context_ptr, output_buffer_ptr_ref, output_offset_ref, + loop_var, output_value->data(), output_value->length()}); + } else if (output_list_internal_type == arrow::Type::INT64) { + AddFunctionCall("gdv_fn_populate_list_int64_t_vector", types()->i32_type(), + {arg_context_ptr, output_buffer_ptr_ref, output_offset_ref, + loop_var, output_value->data(), output_value->length()}); + } else if (output_list_internal_type == arrow::Type::FLOAT) { + AddFunctionCall("gdv_fn_populate_list_float_vector", types()->i32_type(), + {arg_context_ptr, output_buffer_ptr_ref, output_offset_ref, + loop_var, output_value->data(), output_value->length()}); + } else if (output_list_internal_type == arrow::Type::DOUBLE) { + AddFunctionCall("gdv_fn_populate_list_double_vector", types()->i32_type(), + {arg_context_ptr, output_buffer_ptr_ref, output_offset_ref, + loop_var, output_value->data(), output_value->length()}); + } else { + return Status::NotImplemented("list internal type ", + output->Type()->field(0)->type()->ToString(), + " not supported"); + } } else { return Status::NotImplemented("output type ", output->Type()->ToString(), " not supported"); @@ -568,6 +607,46 @@ void LLVMGenerator::Visitor::Visit(const VectorReadFixedLenValueDex& dex) { result_ = lvalue; } +void LLVMGenerator::Visitor::Visit(const VectorReadFixedLenValueListDex& dex) { + llvm::IRBuilder<>* builder = ir_builder(); + llvm::Value* slot; + + // compute list len from the offsets array. + llvm::Value* offsets_slot_ref = + GetBufferReference(dex.OffsetsIdx(), kBufferTypeOffsets, dex.Field()); + llvm::Value* offsets_slot_index = + builder->CreateAdd(loop_var_, GetSliceOffset(dex.OffsetsIdx())); + + // => offset_start = offsets[loop_var] + slot = builder->CreateGEP(offsets_slot_ref, offsets_slot_index); + llvm::Value* offset_start = builder->CreateLoad(slot, "offset_start"); + + // => offset_end = offsets[loop_var + 1] + llvm::Value* offsets_slot_index_next = builder->CreateAdd( + offsets_slot_index, generator_->types()->i64_constant(1), "loop_var+1"); + slot = builder->CreateGEP(offsets_slot_ref, offsets_slot_index_next); + llvm::Value* offset_end = builder->CreateLoad(slot, "offset_end"); + + // => offsets_len_value = offset_end - offset_start + llvm::Value* list_len = builder->CreateSub(offset_end, offset_start, "offsets_len"); + + // get data array + llvm::Value* slot_ref = GetBufferReference(dex.DataIdx(), kBufferTypeData, dex.Field()); + // do not forget slice offset + llvm::Value* offset_start_int64 = + builder->CreateIntCast(offset_start, generator_->types()->i64_type(), true); + llvm::Value* slot_index = + builder->CreateAdd(offset_start_int64, GetSliceOffset(dex.DataIdx())); + llvm::Value* data_list = builder->CreateGEP(slot_ref, slot_index); + + // TODO: handle bool type bitmap + // TODO: handle decimal precision and scale + + ADD_VISITOR_TRACE("visit fixed-len data list vector " + dex.FieldName() + " length %T", + list_len); + result_.reset(new LValue(data_list, list_len)); +} + void LLVMGenerator::Visitor::Visit(const VectorReadVarLenValueDex& dex) { llvm::IRBuilder<>* builder = ir_builder(); llvm::Value* slot; @@ -601,6 +680,65 @@ void LLVMGenerator::Visitor::Visit(const VectorReadVarLenValueDex& dex) { result_.reset(new LValue(data_value, len_value)); } +/* + * create list type field context for each loop + */ +void LLVMGenerator::Visitor::Visit(const VectorReadVarLenValueListDex& dex) { + /* Example + * list_data: [["var_len_val11"], ["var_len_val211", "var_len_val22"], + * ["var_len_val3331"]] loop_var: 0, 1, 2 data_buffer: + * var_len_val11var_len_val211var_len_val22var_len_val3331 offsets_buffer: 0, 1, 3, 4 + * list_element_len = offsets[loop_var+1]-offsets[loop_var] => 1, 2, 1 + * child_offsets_buffer: 0, 13, 27, 40, 55 + * for i in list_element_len: + * data_buffer[child_offsets_buffer[offsets[i+1]] - child_offsets_buffer[offsets[i]]] + * => list_data[loop_var][i] + */ + llvm::IRBuilder<>* builder = ir_builder(); + llvm::Value* slot; + + // compute list length from the offsets array + llvm::Value* offsets_slot_ref = + GetBufferReference(dex.OffsetsIdx(), kBufferTypeOffsets, dex.Field()); + llvm::Value* offsets_slot_index = + builder->CreateAdd(loop_var_, GetSliceOffset(dex.OffsetsIdx())); + + // => offset_start = offsets[loop_var] + slot = builder->CreateGEP(offsets_slot_ref, offsets_slot_index); + llvm::Value* offset_start = builder->CreateLoad(slot, "offset_start"); + + // => offset_end = offsets[loop_var + 1] + llvm::Value* offsets_slot_index_next = builder->CreateAdd( + offsets_slot_index, generator_->types()->i64_constant(1), "loop_var+1"); + slot = builder->CreateGEP(offsets_slot_ref, offsets_slot_index_next); + llvm::Value* offset_end = builder->CreateLoad(slot, "offset_end"); + + // => list_data_length = offset_end - offset_start + llvm::Value* list_data_length = + builder->CreateSub(offset_end, offset_start, "offsets_len"); + + // get the child offsets array from the child offsets array, + // start from offset 'offset_start' + llvm::Value* child_offset_slot_ref = + GetBufferReference(dex.ChildOffsetsIdx(), kBufferTypeChildOffsets, dex.Field()); + // do not forget slice offset + llvm::Value* offset_start_int64 = + builder->CreateIntCast(offset_start, generator_->types()->i64_type(), true); + llvm::Value* child_offset_slot_index = + builder->CreateAdd(offset_start_int64, GetSliceOffset(dex.ChildOffsetsIdx())); + llvm::Value* child_offsets = + builder->CreateGEP(child_offset_slot_ref, child_offset_slot_index); + llvm::Value* child_offset_start = + builder->CreateLoad(child_offsets, "child_offset_start"); + + // get the data array + llvm::Value* data_slot_ref = + GetBufferReference(dex.DataIdx(), kBufferTypeData, dex.Field()); + llvm::Value* data_value = builder->CreateGEP(data_slot_ref, child_offset_start); + + result_.reset(new ListLValue(data_value, child_offsets, list_data_length)); +} + void LLVMGenerator::Visitor::Visit(const VectorReadValidityDex& dex) { llvm::IRBuilder<>* builder = ir_builder(); llvm::Value* slot_ref = @@ -738,7 +876,7 @@ void LLVMGenerator::Visitor::Visit(const NonNullableFuncDex& dex) { llvm::IRBuilder<>* builder = ir_builder(); LLVMTypes* types = generator_->types(); auto arrow_type_id = arrow_return_type->id(); - auto result_type = types->IRType(arrow_type_id); + auto result_type = types->DataVecType(arrow_return_type); // Build combined validity of the args. llvm::Value* is_valid = types->true_constant(); @@ -1125,7 +1263,7 @@ LValuePtr LLVMGenerator::Visitor::BuildIfElse(llvm::Value* condition, // Emit the merge block. builder->SetInsertPoint(merge_bb); - auto llvm_type = types->IRType(result_type->id()); + auto llvm_type = types->DataVecType(result_type); llvm::PHINode* result_value = builder->CreatePHI(llvm_type, 2, "res_value"); result_value->addIncoming(then_lvalue->data(), then_bb); result_value->addIncoming(else_lvalue->data(), else_bb); @@ -1170,7 +1308,7 @@ LValuePtr LLVMGenerator::Visitor::BuildFunctionCall(const NativeFunction* func, std::vector* params) { auto types = generator_->types(); auto arrow_return_type_id = arrow_return_type->id(); - auto llvm_return_type = types->IRType(arrow_return_type_id); + auto llvm_return_type = types->DataVecType(arrow_return_type); DecimalIR decimalIR(generator_->engine_.get()); if (arrow_return_type_id == arrow::Type::DECIMAL) { @@ -1290,6 +1428,10 @@ llvm::Value* LLVMGenerator::Visitor::GetBufferReference(int idx, BufferType buff case kBufferTypeOffsets: slot_ref = generator_->GetOffsetsReference(arg_addrs_, idx, field); break; + + case kBufferTypeChildOffsets: + slot_ref = generator_->GetChildOffsetsReference(arg_addrs_, idx, field); + break; } // Revert to the saved block. @@ -1388,5 +1530,4 @@ void LLVMGenerator::AddTrace(const std::string& msg, llvm::Value* value) { } AddFunctionCall(print_fn_name, types()->i32_type(), args); } - } // namespace gandiva diff --git a/cpp/src/gandiva/llvm_generator.h b/cpp/src/gandiva/llvm_generator.h index 8ff9711c0f9..f4ad598364d 100644 --- a/cpp/src/gandiva/llvm_generator.h +++ b/cpp/src/gandiva/llvm_generator.h @@ -95,7 +95,9 @@ class GANDIVA_EXPORT LLVMGenerator { void Visit(const VectorReadValidityDex& dex) override; void Visit(const VectorReadFixedLenValueDex& dex) override; + void Visit(const VectorReadFixedLenValueListDex& dex) override; void Visit(const VectorReadVarLenValueDex& dex) override; + void Visit(const VectorReadVarLenValueListDex& dex) override; void Visit(const LocalBitMapValidityDex& dex) override; void Visit(const TrueDex& dex) override; void Visit(const FalseDex& dex) override; @@ -118,7 +120,12 @@ class GANDIVA_EXPORT LLVMGenerator { bool has_arena_allocs() { return has_arena_allocs_; } private: - enum BufferType { kBufferTypeValidity = 0, kBufferTypeData, kBufferTypeOffsets }; + enum BufferType { + kBufferTypeValidity = 0, + kBufferTypeData, + kBufferTypeOffsets, + kBufferTypeChildOffsets + }; llvm::IRBuilder<>* ir_builder() { return generator_->ir_builder(); } llvm::Module* module() { return generator_->module(); } @@ -185,6 +192,10 @@ class GANDIVA_EXPORT LLVMGenerator { /// Generate code to load the vector at specified index and cast it as offsets array. llvm::Value* GetOffsetsReference(llvm::Value* arg_addrs, int idx, FieldPtr field); + /// Generate code to load the vector at specified index and cast it as child offsets + /// array. + llvm::Value* GetChildOffsetsReference(llvm::Value* arg_addrs, int idx, FieldPtr field); + /// Generate code to load the vector at specified index and cast it as buffer pointer. llvm::Value* GetDataBufferPtrReference(llvm::Value* arg_addrs, int idx, FieldPtr field); diff --git a/cpp/src/gandiva/llvm_types.cc b/cpp/src/gandiva/llvm_types.cc index de322a8c0fc..50d218a9252 100644 --- a/cpp/src/gandiva/llvm_types.cc +++ b/cpp/src/gandiva/llvm_types.cc @@ -41,6 +41,7 @@ LLVMTypes::LLVMTypes(llvm::LLVMContext& context) : context_(context) { {arrow::Type::type::STRING, i8_ptr_type()}, {arrow::Type::type::BINARY, i8_ptr_type()}, {arrow::Type::type::DECIMAL, i128_type()}, + {arrow::Type::type::LIST, i32_type()}, {arrow::Type::type::INTERVAL_MONTHS, i32_type()}, {arrow::Type::type::INTERVAL_DAY_TIME, i64_type()}}; } diff --git a/cpp/src/gandiva/llvm_types.h b/cpp/src/gandiva/llvm_types.h index d6f0952713e..04d03e38461 100644 --- a/cpp/src/gandiva/llvm_types.h +++ b/cpp/src/gandiva/llvm_types.h @@ -65,6 +65,10 @@ class GANDIVA_EXPORT LLVMTypes { llvm::PointerType* i128_ptr_type() { return ptr_type(i128_type()); } + llvm::PointerType* float_ptr_type() { return ptr_type(float_type()); } + + llvm::PointerType* double_ptr_type() { return ptr_type(double_type()); } + template llvm::Constant* int_constant(ctype val) { return llvm::ConstantInt::get(context_, llvm::APInt(N, val)); @@ -104,6 +108,13 @@ class GANDIVA_EXPORT LLVMTypes { /// For a given data type, find the ir type used for the data vector slot. llvm::Type* DataVecType(const DataTypePtr& data_type) { + // support list type + // list type data is formed by base type buffer, wrapped with offsets buffer + // offsets buffer is to separate data into list + // not support nested list + if (data_type->id() == arrow::Type::LIST) { + return IRType(data_type->field(0)->type()->id()); + } return IRType(data_type->id()); } diff --git a/cpp/src/gandiva/llvm_types_test.cc b/cpp/src/gandiva/llvm_types_test.cc index 66696830618..665a82d133f 100644 --- a/cpp/src/gandiva/llvm_types_test.cc +++ b/cpp/src/gandiva/llvm_types_test.cc @@ -50,12 +50,22 @@ TEST_F(TestLLVMTypes, TestFound) { types_->i64_type()); EXPECT_EQ(types_->DataVecType(arrow::timestamp(arrow::TimeUnit::MILLI)), types_->i64_type()); + + EXPECT_EQ(types_->IRType(arrow::Type::STRING), types_->i8_ptr_type()); + EXPECT_EQ(types_->DataVecType(arrow::list(arrow::boolean())), types_->i1_type()); + EXPECT_EQ(types_->DataVecType(arrow::list(arrow::int32())), types_->i32_type()); + EXPECT_EQ(types_->DataVecType(arrow::list(arrow::int64())), types_->i64_type()); + EXPECT_EQ(types_->DataVecType(arrow::list(arrow::float32())), types_->float_type()); + EXPECT_EQ(types_->DataVecType(arrow::list(arrow::float64())), types_->double_type()); + EXPECT_EQ(types_->DataVecType(arrow::list(arrow::utf8())), types_->i8_ptr_type()); } TEST_F(TestLLVMTypes, TestNotFound) { EXPECT_EQ(types_->IRType(arrow::Type::SPARSE_UNION), nullptr); EXPECT_EQ(types_->IRType(arrow::Type::DENSE_UNION), nullptr); EXPECT_EQ(types_->DataVecType(arrow::null()), nullptr); + // not support nested list type + EXPECT_EQ(types_->DataVecType(arrow::list(arrow::list(arrow::utf8()))), nullptr); } } // namespace gandiva diff --git a/cpp/src/gandiva/lvalue.h b/cpp/src/gandiva/lvalue.h index df292855b69..4d1dca8f7cf 100644 --- a/cpp/src/gandiva/lvalue.h +++ b/cpp/src/gandiva/lvalue.h @@ -74,4 +74,27 @@ class GANDIVA_EXPORT DecimalLValue : public LValue { llvm::Value* scale_; }; +class GANDIVA_EXPORT ListLValue : public LValue { + public: + ListLValue(llvm::Value* data, llvm::Value* child_offsets, llvm::Value* offsets_length, + llvm::Value* validity = NULLPTR) + : LValue(data, NULLPTR, validity), + child_offsets_(child_offsets), + offsets_length_(offsets_length) {} + + llvm::Value* child_offsets() { return child_offsets_; } + + llvm::Value* offsets_length() { return offsets_length_; } + + void AppendFunctionParams(std::vector* params) override { + LValue::AppendFunctionParams(params); + params->push_back(child_offsets_); + params->push_back(offsets_length_); + } + + private: + llvm::Value* child_offsets_; + llvm::Value* offsets_length_; +}; + } // namespace gandiva diff --git a/cpp/src/gandiva/precompiled/types.h b/cpp/src/gandiva/precompiled/types.h index 53f5a30fca0..f7d83633b71 100644 --- a/cpp/src/gandiva/precompiled/types.h +++ b/cpp/src/gandiva/precompiled/types.h @@ -18,6 +18,8 @@ #pragma once #include + +#include "gandiva/array_ops.h" #include "gandiva/gdv_function_stubs.h" // Use the same names as in arrow data types. Makes it easy to write pre-processor macros. diff --git a/cpp/src/gandiva/projector.cc b/cpp/src/gandiva/projector.cc index 734720c64c9..8a918ba8071 100644 --- a/cpp/src/gandiva/projector.cc +++ b/cpp/src/gandiva/projector.cc @@ -30,7 +30,6 @@ #include "gandiva/llvm_generator.h" namespace gandiva { - class ProjectorCacheKey { public: ProjectorCacheKey(SchemaPtr schema, std::shared_ptr configuration, @@ -256,6 +255,35 @@ Status Projector::Evaluate(const arrow::RecordBatch& batch, // Create and return array arrays. output->clear(); for (auto& array_data : output_data_vecs) { + if (array_data->type->id() == arrow::Type::LIST) { + auto child_data = array_data->child_data[0]; + int64_t child_data_size = 1; + if (arrow::is_binary_like(child_data->type->id())) { + /* when allocate array data, child data length is an initialized value, + * after calculating, child data offsets buffer has been resized for results, + * but array data length is unchanged. + * We should recalculate child data length and make ArrayData with new length + * + * Otherwise, child data offsets buffer length is data length + 1 + * and offset data is int32_t, need use buffer->size()/4 - 1 + */ + child_data_size = child_data->buffers[1]->size() / 4 - 1; + } else if (child_data->type->id() == arrow::Type::INT32) { + child_data_size = child_data->buffers[1]->size() / 4; + } else if (child_data->type->id() == arrow::Type::INT64) { + child_data_size = child_data->buffers[1]->size() / 8; + } else if (child_data->type->id() == arrow::Type::FLOAT) { + child_data_size = child_data->buffers[1]->size() / 4; + } else if (child_data->type->id() == arrow::Type::DOUBLE) { + child_data_size = child_data->buffers[1]->size() / 8; + } + auto new_child_data = arrow::ArrayData::Make( + child_data->type, child_data_size, child_data->buffers, child_data->offset); + array_data = arrow::ArrayData::Make(array_data->type, array_data->length, + array_data->buffers, {new_child_data}, + array_data->null_count, array_data->offset); + } + output->push_back(arrow::MakeArray(array_data)); } return Status::OK(); @@ -281,6 +309,23 @@ Status Projector::AllocArrayData(const DataTypePtr& type, int64_t num_records, buffers.push_back(std::move(offsets_buffer)); } + if (type_id == arrow::Type::LIST) { + auto offsets_len = arrow::BitUtil::BytesForBits((num_records + 1) * 32); + + ARROW_ASSIGN_OR_RAISE(auto offsets_buffer, arrow::AllocateBuffer(offsets_len, pool)); + buffers.push_back(std::move(offsets_buffer)); + + if (arrow::is_binary_like(type->field(0)->type()->id())) { + // child offsets length is internal data length + 1 + // offsets element is int32 + // so here i just allocate extra 32 bit for extra 1 length + ARROW_ASSIGN_OR_RAISE( + auto child_offsets_buffer, + arrow::AllocateResizableBuffer(arrow::BitUtil::BytesForBits(32), pool)); + buffers.push_back(std::move(child_offsets_buffer)); + } + } + // The output vector always has a data array. int64_t data_len; if (arrow::is_primitive(type_id) || type_id == arrow::Type::DECIMAL) { @@ -289,6 +334,8 @@ Status Projector::AllocArrayData(const DataTypePtr& type, int64_t num_records, } else if (arrow::is_binary_like(type_id)) { // we don't know the expected size for varlen output vectors. data_len = 0; + } else if (type_id == arrow::Type::LIST) { + data_len = 0; } else { return Status::Invalid("Unsupported output data type " + type->ToString()); } @@ -301,7 +348,24 @@ Status Projector::AllocArrayData(const DataTypePtr& type, int64_t num_records, } buffers.push_back(std::move(data_buffer)); - *array_data = arrow::ArrayData::Make(type, num_records, std::move(buffers)); + if (type->id() == arrow::Type::LIST) { + auto internal_type = type->field(0)->type(); + ArrayDataPtr child_data; + if (arrow::is_primitive(internal_type->id())) { + child_data = arrow::ArrayData::Make(internal_type, 0 /*initialize length*/, + {nullptr, std::move(buffers[2])}, 0); + } + if (arrow::is_binary_like(internal_type->id())) { + child_data = arrow::ArrayData::Make( + internal_type, 0 /*initialize length*/, + {nullptr, std::move(buffers[2]), std::move(buffers[3])}, 0); + } + *array_data = arrow::ArrayData::Make( + type, num_records, {std::move(buffers[0]), std::move(buffers[1])}, {child_data}); + + } else { + *array_data = arrow::ArrayData::Make(type, num_records, std::move(buffers)); + } return Status::OK(); } @@ -328,7 +392,7 @@ Status Projector::ValidateArrayDataCapacity(const arrow::ArrayData& array_data, min_bitmap_len, " actual size ", bitmap_len)); auto type_id = field.type()->id(); - if (arrow::is_binary_like(type_id)) { + if (arrow::is_binary_like(type_id) || type_id == arrow::Type::LIST) { // validate size of offsets buffer. int64_t min_offsets_len = arrow::BitUtil::BytesForBits((num_records + 1) * 32); int64_t offsets_len = array_data.buffers[1]->capacity(); @@ -358,5 +422,4 @@ Status Projector::ValidateArrayDataCapacity(const arrow::ArrayData& array_data, } std::string Projector::DumpIR() { return llvm_generator_->DumpIR(); } - } // namespace gandiva diff --git a/cpp/src/gandiva/proto/Types.proto b/cpp/src/gandiva/proto/Types.proto index 7c0c49f2d85..4ef0db80e34 100644 --- a/cpp/src/gandiva/proto/Types.proto +++ b/cpp/src/gandiva/proto/Types.proto @@ -85,6 +85,7 @@ message ExtGandivaType { optional TimeUnit timeUnit = 6; // used by TIME32/TIME64 optional string timeZone = 7; // used by TIMESTAMP optional IntervalType intervalType = 8; // used by INTERVAL + optional GandivaType valueType = 9; // used by LIST } message Field { diff --git a/cpp/src/gandiva/tests/CMakeLists.txt b/cpp/src/gandiva/tests/CMakeLists.txt index 5fa2da16c63..078ac70cfc0 100644 --- a/cpp/src/gandiva/tests/CMakeLists.txt +++ b/cpp/src/gandiva/tests/CMakeLists.txt @@ -25,6 +25,7 @@ add_gandiva_test(binary_test) add_gandiva_test(date_time_test) add_gandiva_test(to_string_test) add_gandiva_test(utf8_test) +add_gandiva_test(list_test) add_gandiva_test(hash_test) add_gandiva_test(in_expr_test) add_gandiva_test(null_validity_test) diff --git a/cpp/src/gandiva/tests/list_test.cc b/cpp/src/gandiva/tests/list_test.cc new file mode 100644 index 00000000000..d8f44a5c235 --- /dev/null +++ b/cpp/src/gandiva/tests/list_test.cc @@ -0,0 +1,300 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include + +#include "arrow/memory_pool.h" +#include "arrow/status.h" +#include "gandiva/projector.h" +#include "gandiva/tests/test_util.h" +#include "gandiva/tree_expr_builder.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::float32; +using arrow::float64; +using arrow::int32; +using arrow::int64; +using arrow::utf8; +using std::string; +using std::vector; + +class TestList : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +template +void _build_list_array(const vector& values, const vector& length, + const vector& validity, arrow::MemoryPool* pool, + ArrayPtr* array) { + size_t sum = 0; + for (auto& len : length) { + sum += len; + } + EXPECT_TRUE(values.size() == sum); + EXPECT_TRUE(length.size() == validity.size()); + + auto value_builder = std::make_shared(pool); + auto builder = std::make_shared(pool, value_builder); + int i = 0; + for (size_t l = 0; l < length.size(); l++) { + if (validity[l]) { + auto status = builder->Append(); + for (int j = 0; j < length[l]; j++) { + ASSERT_OK(value_builder->Append(values[i])); + i++; + } + } else { + ASSERT_OK(builder->AppendNull()); + for (int j = 0; j < length[l]; j++) { + i++; + } + } + } + ASSERT_OK(builder->Finish(array)); +} + +/* + * expression: + * input: a + * output: res + * typeof(a) can be list / list / list + */ +void _test_list_type_field_alias(DataTypePtr type, ArrayPtr array, + arrow::MemoryPool* pool) { + auto field_a = field("a", type); + auto schema = arrow::schema({field_a}); + auto result = field("res", type); + + auto num_records = 5; + assert(array->length() == num_records); + + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array}); + + // Make expression + auto field_a_node = TreeExprBuilder::MakeField(field_a); + auto expr = TreeExprBuilder::MakeExpression(field_a_node, result); + + // Build a projector for the expressions. + std::shared_ptr projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + EXPECT_ARROW_ARRAY_EQUALS(array, outputs[0]); + // EXPECT_ARROW_ARRAY_EQUALS will not check the length of child data, but + // ArrayData::Slice method will check length. ArrayData::ToString method will call + // ArrayData::Slice method + EXPECT_TRUE(array->ToString() == outputs[0]->ToString()); + EXPECT_TRUE(array->null_count() == outputs[0]->null_count()); +} + +TEST_F(TestList, TestListUtf8) { + ArrayPtr array; + _build_list_array( + {"a", "b", "bb", "c", "cc", "ccc", "d", "dd", "ddd", "dddd", "e", "ee", "eee", + "eeee", "eeeee"}, + {1, 4, 3, 2, 5}, {true, true, false, true, true}, pool_, &array); + _test_list_type_field_alias(list(utf8()), array, pool_); +} + +TEST_F(TestList, TestListUtf8WithInvalidData) { + ArrayPtr array; + _build_list_array( + {"a", "b", "bb", "c", "cc", "ccc", "d", "dd", "ddd", "dddd", "e", "ee", "eee", + "eeee", "eeeee"}, + {1, 2, 3, 4, 5}, {true, false, true, true, false}, pool_, &array); + _test_list_type_field_alias(list(utf8()), array, pool_); +} + +TEST_F(TestList, TestListInt64) { + ArrayPtr array; + _build_list_array( + {1, 10, 20, 100, 200, 300, 1000, 2000, 3000, 4000, 10000, 20000, 30000, 40000, + 50000}, + {1, 2, 5, 4, 3}, {true, true, true, true, false}, pool_, &array); + _test_list_type_field_alias(list(int64()), array, pool_); +} + +TEST_F(TestList, TestListInt32) { + ArrayPtr array; + _build_list_array( + {1, 10, 20, 100, 200, 300, 1000, 2000, 3000, 4000, 10000, 20000, 30000, 40000, + 50000}, + {5, 2, 3, 4, 1}, {true, false, true, true, true}, pool_, &array); + _test_list_type_field_alias(list(int32()), array, pool_); +} + +TEST_F(TestList, TestListFloat32) { + ArrayPtr array; + _build_list_array( + {1.1f, 11.1f, 22.2f, 111.1f, 222.2f, 333.3f, 1111.1f, 2222.2f, 3333.3f, 4444.4f, + 11111.1f, 22222.2f, 33333.3f, 44444.4f, 55555.5f}, + {1, 2, 3, 4, 5}, {true, true, true, true, true}, pool_, &array); + _test_list_type_field_alias(list(float32()), array, pool_); +} + +TEST_F(TestList, TestListFloat64) { + ArrayPtr array; + _build_list_array( + {1.1, 1.11, 2.22, 1.111, 2.222, 3.333, 1.1111, 2.2222, 3.3333, 4.4444, 1.11111, + 2.22222, 3.33333, 4.44444, 5.55555}, + {1, 2, 4, 3, 5}, {true, false, true, true, true}, pool_, &array); + _test_list_type_field_alias(list(float64()), array, pool_); +} + +/* + * array_length(a) + */ +TEST_F(TestList, TestListUtf8Length) { + // schema for input fields + auto field_a = field("a", list(utf8())); + auto schema = arrow::schema({field_a}); + + // output fields + auto res = field("res", int64()); + + // Create a row-batch with some sample data + int num_records = 5; + ArrayPtr array_a; + _build_list_array( + {"a", "b", "bb", "c", "cc", "ccc", "d", "dd", "ddd", "dddd", "e", "ee", "eee", + "eeee", "eeeee"}, + {1, 2, 3, 4, 5}, {true, true, true, true, true}, pool_, &array_a); + + // expected output + auto exp = MakeArrowArrayInt64({1, 2, 3, 4, 5}, {true, true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + // build expressions. + // array_length(a) + auto expr = TreeExprBuilder::MakeExpression("array_length", {field_a}, res); + + // Build a projector for the expressions. + std::shared_ptr projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestList, TestListUtf8LengthWithInvalidData) { + // schema for input fields + auto field_a = field("a", list(utf8())); + auto schema = arrow::schema({field_a}); + + // output fields + auto res = field("res", int64()); + + // Create a row-batch with some sample data + int num_records = 5; + ArrayPtr array_a; + _build_list_array( + {"a", "b", "bb", "cc", "cc", "ccc", "d", "dd", "ddd"}, {1, 2, 2, 3, 1}, + {true, false, true, false, true}, pool_, &array_a); + + // expected output + auto exp = MakeArrowArrayInt64({1, 2, 2, 3, 1}, {true, false, true, false, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + // build expressions. + // array_length(a) + auto expr = TreeExprBuilder::MakeExpression("array_length", {field_a}, res); + + // Build a projector for the expressions. + std::shared_ptr projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +/* + * array_contains(a, "element") + */ +TEST_F(TestList, TestListUtf8Contains) { + // schema for input fields + auto field_a = field("a", list(utf8())); + auto field_b = field("b", utf8()); + auto schema = arrow::schema({field_a, field_b}); + + // output fields + auto res = field("res", boolean()); + + // Create a row-batch with some sample data + int num_records = 5; + ArrayPtr array_a; + _build_list_array( + {"rectangle", "circle", "rectangle", "circle", "triangle", "triangle", "circle", + "rectangle"}, + {2, 3, 1, 1, 1}, {true, true, true, true, true}, pool_, &array_a); + auto array_b = + MakeArrowArrayUtf8({"rectangle", "circle", "circle", "circle", "rectangll"}); + + // expected output + auto exp = MakeArrowArrayBool({true, true, false, true, false}, + {true, true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a, array_b}); + + // build expressions. + // array_contains(a, b) + auto expr = TreeExprBuilder::MakeExpression("array_contains", {field_a, field_b}, res); + + // Build a projector for the expressions. + std::shared_ptr projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/tests/projector_build_validation_test.cc b/cpp/src/gandiva/tests/projector_build_validation_test.cc index 5b86844f940..82b59ef19ad 100644 --- a/cpp/src/gandiva/tests/projector_build_validation_test.cc +++ b/cpp/src/gandiva/tests/projector_build_validation_test.cc @@ -26,6 +26,7 @@ namespace gandiva { using arrow::boolean; using arrow::float32; using arrow::int32; +using arrow::utf8; class TestProjector : public ::testing::Test { public: @@ -80,7 +81,7 @@ TEST_F(TestProjector, TestNotMatchingDataType) { TEST_F(TestProjector, TestNotSupportedDataType) { // schema for input fields - auto field0 = field("f0", list(int32())); + auto field0 = field("f0", map(utf8(), int32())); auto schema = arrow::schema({field0}); // output fields @@ -94,7 +95,7 @@ TEST_F(TestProjector, TestNotSupportedDataType) { std::shared_ptr projector; auto status = Projector::Make(schema, {lt_expr}, TestConfiguration(), &projector); EXPECT_TRUE(status.IsExpressionValidationError()); - std::string expected_error = "Field f0 has unsupported data type list"; + std::string expected_error = "Field f0 has unsupported data type map"; EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); } diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistry.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistry.java index 0155af08234..8bfada07aba 100644 --- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistry.java +++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistry.java @@ -178,10 +178,11 @@ private static ArrowType getArrowType(ExtGandivaType type) { return new ArrowType.Decimal(0, 0, 128); case GandivaType.INTERVAL_VALUE: return new ArrowType.Interval(mapArrowIntervalUnit(type.getIntervalType())); + case GandivaType.LIST_VALUE: + return new ArrowType.List(); case GandivaType.FIXED_SIZE_BINARY_VALUE: case GandivaType.MAP_VALUE: case GandivaType.DICTIONARY_VALUE: - case GandivaType.LIST_VALUE: case GandivaType.STRUCT_VALUE: case GandivaType.UNION_VALUE: default: diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/ArrowTypeHelper.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/ArrowTypeHelper.java index 3503bbbf835..1d58b7abc98 100644 --- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/ArrowTypeHelper.java +++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/ArrowTypeHelper.java @@ -227,6 +227,13 @@ private static void initArrowTypeInterval(ArrowType.Interval interval, } } + private static void initArrowTypeList(ArrowType.List list, + GandivaTypes.ExtGandivaType.Builder builder) { + + builder.setType(GandivaTypes.GandivaType.LIST); + builder.setValueType(GandivaTypes.GandivaType.UINT32); + } + /** * Converts an arrow type into a protobuf. * @@ -288,6 +295,7 @@ public static GandivaTypes.ExtGandivaType arrowTypeToProtobuf(ArrowType arrowTyp break; } case Type.List: { // 12 + ArrowTypeHelper.initArrowTypeList((ArrowType.List) arrowType, builder); break; } case Type.Struct_: { // 13