Skip to content

Commit 6a33df4

Browse files
committed
174: Adding recipe for custom compute functions
This recipe shows the major portions of a custom, or new, compute function: - defining a compute kernel - creating a function instance - associating the kernel with the function - registering the function in a registry - calling the function
1 parent 3fae0d1 commit 6a33df4

File tree

1 file changed

+270
-0
lines changed

1 file changed

+270
-0
lines changed

cpp/code/compute_fn.cc

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
// ------------------------------
2+
// Dependencies
3+
4+
// standard dependencies
5+
#include <stdint.h>
6+
#include <string>
7+
#include <iostream>
8+
9+
// arrow dependencies
10+
#include <arrow/api.h>
11+
#include <arrow/compute/api.h>
12+
#include <arrow/compute/exec/key_hash.h>
13+
14+
#include "common.h"
15+
16+
17+
// >> aliases for types in standard library
18+
using std::shared_ptr;
19+
using std::vector;
20+
21+
// arrow util types
22+
using arrow::Result;
23+
using arrow::Status;
24+
using arrow::Datum;
25+
26+
// arrow data types and helpers
27+
using arrow::UInt32Builder;
28+
using arrow::Int32Builder;
29+
30+
using arrow::Array;
31+
using arrow::ArraySpan;
32+
33+
34+
// aliases for types used in `NamedScalarFn`
35+
// |> kernel parameters
36+
using arrow::compute::KernelContext;
37+
using arrow::compute::ExecSpan;
38+
using arrow::compute::ExecResult;
39+
40+
// |> other context types
41+
using arrow::compute::ExecContext;
42+
using arrow::compute::LightContext;
43+
44+
// |> common types for compute functions
45+
using arrow::compute::FunctionRegistry;
46+
using arrow::compute::FunctionDoc;
47+
using arrow::compute::InputType;
48+
using arrow::compute::OutputType;
49+
using arrow::compute::Arity;
50+
51+
// |> the "kind" of function we want
52+
using arrow::compute::ScalarFunction;
53+
54+
// |> structs and classes for hashing
55+
using arrow::util::MiniBatch;
56+
using arrow::util::TempVectorStack;
57+
58+
using arrow::compute::KeyColumnArray;
59+
using arrow::compute::Hashing32;
60+
61+
// |> functions used for hashing
62+
using arrow::compute::ColumnArrayFromArrayData;
63+
64+
65+
// ------------------------------
66+
// Structs and Classes
67+
68+
// >> Documentation for a compute function
69+
/**
70+
* Create a const instance of `FunctionDoc` that contains 3 attributes:
71+
* 1. Short description
72+
* 2. Long description (limited to 78 characters)
73+
* 3. Name of input arguments
74+
*/
75+
const FunctionDoc named_scalar_fn_doc {
76+
"Unary function that calculates a hash for each row of the input"
77+
,"This function uses an xxHash-like algorithm which produces 32-bit hashes."
78+
,{ "input_array" }
79+
};
80+
81+
82+
// >> Kernel implementations for a compute function
83+
/**
84+
* Create implementations that will be associated with our compute function. When a
85+
* compute function is invoked, the compute API framework will delegate execution to an
86+
* associated kernel that matches: (1) input argument types/shapes and (2) output argument
87+
* types/shapes.
88+
*
89+
* Kernel implementations may be functions or may be methods (functions within a class or
90+
* struct).
91+
*/
92+
struct NamedScalarFn {
93+
94+
/**
95+
* A kernel implementation that expects a single array as input, and outputs an array of
96+
* uint32 values. We write this implementation knowing what function we want to
97+
* associate it with ("NamedScalarFn"), but that association is made later (see
98+
* `RegisterScalarFnKernels()` below).
99+
*/
100+
static Status
101+
Exec(KernelContext *ctx, const ExecSpan &input_arg, ExecResult *out) {
102+
StartRecipe("DefineAComputeKernel");
103+
104+
if (input_arg.num_values() != 1 or not input_arg[0].is_array()) {
105+
return Status::Invalid("Unsupported argument types or shape");
106+
}
107+
108+
// >> Initialize stack-based memory allocator with an allocator and memory size
109+
TempVectorStack stack_memallocator;
110+
auto input_dtype_width = input_arg[0].type()->bit_width();
111+
if (input_dtype_width > 0) {
112+
ARROW_RETURN_NOT_OK(
113+
stack_memallocator.Init(
114+
ctx->exec_context()->memory_pool()
115+
,input_dtype_width * max_batchsize
116+
)
117+
);
118+
}
119+
120+
// >> Prepare input data structure for propagation to hash function
121+
// NOTE: "start row index" and "row count" can potentially be options in the future
122+
ArraySpan hash_input = input_arg[0].array;
123+
int64_t hash_startrow = 0;
124+
int64_t hash_rowcount = hash_input.length;
125+
ARROW_ASSIGN_OR_RAISE(
126+
KeyColumnArray input_keycol
127+
,ColumnArrayFromArrayData(hash_input.ToArrayData(), hash_startrow, hash_rowcount)
128+
);
129+
130+
// >> Call hashing function
131+
vector<uint32_t> hash_results;
132+
hash_results.resize(hash_input.length);
133+
134+
LightContext hash_ctx;
135+
hash_ctx.hardware_flags = ctx->exec_context()->cpu_info()->hardware_flags();
136+
hash_ctx.stack = &stack_memallocator;
137+
138+
Hashing32::HashMultiColumn({ input_keycol }, &hash_ctx, hash_results.data());
139+
140+
// >> Prepare results of hash function for kernel output argument
141+
UInt32Builder builder;
142+
builder.Reserve(hash_results.size());
143+
builder.AppendValues(hash_results);
144+
145+
ARROW_ASSIGN_OR_RAISE(auto result_array, builder.Finish());
146+
out->value = result_array->data();
147+
148+
EndRecipe("DefineAComputeKernel");
149+
return Status::OK();
150+
}
151+
152+
153+
static constexpr uint32_t max_batchsize = MiniBatch::kMiniBatchLength;
154+
};
155+
156+
157+
// ------------------------------
158+
// Functions
159+
160+
161+
// >> Function registration and kernel association
162+
/**
163+
* A convenience function that shows how we construct an instance of `ScalarFunction` that
164+
* will be registered in a function registry. The instance is constructed with: (1) a
165+
* unique name ("named_scalar_fn"), (2) an "arity" (`Arity::Unary()`), and (3) an instance
166+
* of `FunctionDoc`.
167+
*
168+
* The function name is used to invoke it from a function registry after it has been
169+
* registered. The "arity" is the cardinality of the function's parameters--1 parameter is
170+
* a unary function, 2 parameters is a binary function, etc. Finally, it is helpful to
171+
* associate the function with documentation, which uses the `FunctionDoc` struct.
172+
*/
173+
shared_ptr<ScalarFunction>
174+
RegisterScalarFnKernels() {
175+
StartRecipe("AddKernelsToFunction");
176+
// Instantiate a function to be registered
177+
auto fn_named_scalar = std::make_shared<ScalarFunction>(
178+
"named_scalar_fn"
179+
,Arity::Unary()
180+
,std::move(named_scalar_fn_doc)
181+
);
182+
183+
// Associate a kernel implementation with the function using
184+
// `ScalarFunction::AddKernel()`
185+
DCHECK_OK(
186+
fn_named_scalar->AddKernel(
187+
{ InputType(arrow::int32()) }
188+
,OutputType(arrow::uint32())
189+
,NamedScalarFn::Exec
190+
)
191+
);
192+
193+
EndRecipe("AddKernelsToFunction");
194+
return fn_named_scalar;
195+
}
196+
197+
198+
/**
199+
* A convenience function that shows how we register a custom function with a
200+
* `FunctionRegistry`. To keep this simple and general, this function takes a pointer to a
201+
* FunctionRegistry as an input argument, then invokes `FunctionRegistry::AddFunction()`.
202+
*/
203+
void
204+
RegisterNamedScalarFn(FunctionRegistry *registry) {
205+
auto scalar_fn = RegisterScalarFnKernels();
206+
DCHECK_OK(registry->AddFunction(std::move(scalar_fn)));
207+
}
208+
209+
210+
// >> Convenience functions
211+
/**
212+
* An optional convenience function to easily invoke our compute function. This executes
213+
* our compute function by invoking `CallFunction` with the name that we used to register
214+
* the function ("named_scalar_fn" in this case).
215+
*/
216+
ARROW_EXPORT
217+
Result<Datum>
218+
NamedScalarFn(const Datum &input_arg, ExecContext *ctx) {
219+
auto func_name = "named_scalar_fn";
220+
return CallFunction(func_name, { input_arg }, ctx);
221+
}
222+
223+
224+
Result<shared_ptr<Array>>
225+
BuildIntArray() {
226+
vector<int32_t> col_vals { 0, 1, 1, 2, 3, 5, 8, 13, 21, 34 };
227+
228+
Int32Builder builder;
229+
ARROW_RETURN_NOT_OK(builder.Reserve(col_vals.size()));
230+
ARROW_RETURN_NOT_OK(builder.AppendValues(col_vals));
231+
return builder.Finish();
232+
}
233+
234+
235+
class ComputeFunctionTest : public ::testing::Test {};
236+
237+
TEST(ComputeFunctionTest, TestRegisterAndCallFunction) {
238+
// >> Construct some test data
239+
auto build_result = BuildIntArray();
240+
if (not build_result.ok()) {
241+
std::cerr << build_result.status().message() << std::endl;
242+
return 1;
243+
}
244+
245+
// >> Peek at the data
246+
auto col_vals = *build_result;
247+
std::cout << col_vals->ToString() << std::endl;
248+
249+
// >> Invoke compute function
250+
StartRecipe("RegisterAndCallComputeFunction");
251+
// |> First, register
252+
auto fn_registry = arrow::compute::GetFunctionRegistry();
253+
RegisterNamedScalarFn(fn_registry);
254+
255+
256+
// |> Then, invoke
257+
Datum col_as_datum { col_vals };
258+
auto fn_result = NamedScalarFn(col_as_datum);
259+
if (not fn_result.ok()) {
260+
std::cerr << fn_result.status().message() << std::endl;
261+
return 2;
262+
}
263+
264+
auto result_data = fn_result->make_array();
265+
std::cout << "Success:" << std::endl;
266+
std::cout << "\t" << result_data->ToString() << std::endl;
267+
268+
EndRecipe("RegisterAndCallComputeFunction");
269+
return 0;
270+
}

0 commit comments

Comments
 (0)