diff --git a/cpp/src/gandiva/function_registry_arithmetic.cc b/cpp/src/gandiva/function_registry_arithmetic.cc index a9047449b8f..10dc60c6480 100644 --- a/cpp/src/gandiva/function_registry_arithmetic.cc +++ b/cpp/src/gandiva/function_registry_arithmetic.cc @@ -90,15 +90,14 @@ std::vector GetArithmeticFunctionRegistry() { // add/sub/multiply/divide/mod BINARY_SYMMETRIC_FN(add, {}), BINARY_SYMMETRIC_FN(subtract, {}), BINARY_SYMMETRIC_FN(multiply, {}), - NUMERIC_TYPES_WITHOUT_DECIMAL(BINARY_SYMMETRIC_SAFE_INTERNAL_NULL, divide, {}), + NUMERIC_TYPES(BINARY_SYMMETRIC_SAFE_INTERNAL_NULL, divide, {}), BINARY_SYMMETRIC_SAFE_INTERNAL_NULL(mod, {"modulo"}, int8), BINARY_SYMMETRIC_SAFE_INTERNAL_NULL(mod, {"modulo"}, int16), BINARY_SYMMETRIC_SAFE_INTERNAL_NULL(mod, {"modulo"}, int32), BINARY_SYMMETRIC_SAFE_INTERNAL_NULL(mod, {"modulo"}, int64), BINARY_SYMMETRIC_SAFE_INTERNAL_NULL(mod, {"modulo"}, float32), BINARY_SYMMETRIC_SAFE_INTERNAL_NULL(mod, {"modulo"}, float64), - BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(mod, {"modulo"}, decimal128), - BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(divide, {}, decimal128), + BINARY_SYMMETRIC_SAFE_INTERNAL_NULL(mod, {"modulo"}, decimal128), BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(div, {}, int32), BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(div, {}, int64), BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(div, {}, float32), diff --git a/cpp/src/gandiva/precompiled/decimal_ops.cc b/cpp/src/gandiva/precompiled/decimal_ops.cc index 5e208967c60..e97e0aa7d5d 100644 --- a/cpp/src/gandiva/precompiled/decimal_ops.cc +++ b/cpp/src/gandiva/precompiled/decimal_ops.cc @@ -347,7 +347,7 @@ BasicDecimal128 Multiply(const BasicDecimalScalar128& x, const BasicDecimalScala return result; } -BasicDecimal128 Divide(int64_t context, const BasicDecimalScalar128& x, +BasicDecimal128 Divide(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y, int32_t out_precision, int32_t out_scale, bool* overflow) { if (y.value() == 0) { @@ -392,7 +392,7 @@ BasicDecimal128 Divide(int64_t context, const BasicDecimalScalar128& x, return result; } -BasicDecimal128 Mod(int64_t context, const BasicDecimalScalar128& x, +BasicDecimal128 Mod(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y, int32_t out_precision, int32_t out_scale, bool* overflow) { if (y.value() == 0) { diff --git a/cpp/src/gandiva/precompiled/decimal_ops.h b/cpp/src/gandiva/precompiled/decimal_ops.h index 292dce2208c..951ab723908 100644 --- a/cpp/src/gandiva/precompiled/decimal_ops.h +++ b/cpp/src/gandiva/precompiled/decimal_ops.h @@ -41,12 +41,12 @@ arrow::BasicDecimal128 Multiply(const BasicDecimalScalar128& x, int32_t out_scale, bool* overflow); /// Divide 'x' by 'y', and return the result. -arrow::BasicDecimal128 Divide(int64_t context, const BasicDecimalScalar128& x, +arrow::BasicDecimal128 Divide(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y, int32_t out_precision, int32_t out_scale, bool* overflow); /// Divide 'x' by 'y', and return the remainder. -arrow::BasicDecimal128 Mod(int64_t context, const BasicDecimalScalar128& x, +arrow::BasicDecimal128 Mod(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y, int32_t out_precision, int32_t out_scale, bool* overflow); diff --git a/cpp/src/gandiva/precompiled/decimal_ops_test.cc b/cpp/src/gandiva/precompiled/decimal_ops_test.cc index 445201fa006..6f29f5f02fc 100644 --- a/cpp/src/gandiva/precompiled/decimal_ops_test.cc +++ b/cpp/src/gandiva/precompiled/decimal_ops_test.cc @@ -144,13 +144,13 @@ void TestDecimalSql::Verify(DecimalTypeUtil::Op op, const DecimalScalar128& x, case DecimalTypeUtil::kOpDivide: op_name = "divide"; - out_value = decimalops::Divide(context, x, y, out_type->precision(), + out_value = decimalops::Divide(x, y, out_type->precision(), out_type->scale(), &overflow); break; case DecimalTypeUtil::kOpMod: op_name = "mod"; - out_value = decimalops::Mod(context, x, y, out_type->precision(), out_type->scale(), + out_value = decimalops::Mod(x, y, out_type->precision(), out_type->scale(), &overflow); break; @@ -451,16 +451,14 @@ TEST_F(TestDecimalSql, DivideByZero) { context.Reset(); result_precision = 38; result_scale = 19; - decimalops::Divide(reinterpret_cast(&context), - DecimalScalar128{"201", 20, 3}, DecimalScalar128{"0", 20, 2}, + decimalops::Divide(DecimalScalar128{"201", 20, 3}, DecimalScalar128{"0", 20, 2}, result_precision, result_scale, &overflow); // EXPECT_TRUE(context.has_error()); // EXPECT_EQ(context.get_error(), "divide by zero error"); // divide-by-nonzero should not cause an error. context.Reset(); - decimalops::Divide(reinterpret_cast(&context), - DecimalScalar128{"201", 20, 3}, DecimalScalar128{"1", 20, 2}, + decimalops::Divide(DecimalScalar128{"201", 20, 3}, DecimalScalar128{"1", 20, 2}, result_precision, result_scale, &overflow); EXPECT_FALSE(context.has_error()); @@ -468,7 +466,7 @@ TEST_F(TestDecimalSql, DivideByZero) { context.Reset(); result_precision = 20; result_scale = 3; - decimalops::Mod(reinterpret_cast(&context), DecimalScalar128{"201", 20, 3}, + decimalops::Mod(DecimalScalar128{"201", 20, 3}, DecimalScalar128{"0", 20, 2}, result_precision, result_scale, &overflow); // EXPECT_TRUE(context.has_error()); @@ -476,7 +474,7 @@ TEST_F(TestDecimalSql, DivideByZero) { // mod-by-nonzero should not cause an error. context.Reset(); - decimalops::Mod(reinterpret_cast(&context), DecimalScalar128{"201", 20, 3}, + decimalops::Mod(DecimalScalar128{"201", 20, 3}, DecimalScalar128{"1", 20, 2}, result_precision, result_scale, &overflow); EXPECT_FALSE(context.has_error()); diff --git a/cpp/src/gandiva/precompiled/decimal_wrapper.cc b/cpp/src/gandiva/precompiled/decimal_wrapper.cc index 082d5832d14..7ab887c38bd 100644 --- a/cpp/src/gandiva/precompiled/decimal_wrapper.cc +++ b/cpp/src/gandiva/precompiled/decimal_wrapper.cc @@ -52,35 +52,67 @@ void multiply_decimal128_decimal128(int64_t x_high, uint64_t x_low, int32_t x_pr } FORCE_INLINE -void divide_decimal128_decimal128(int64_t context, int64_t x_high, uint64_t x_low, - int32_t x_precision, int32_t x_scale, int64_t y_high, - uint64_t y_low, int32_t y_precision, int32_t y_scale, +void divide_decimal128_decimal128(int64_t x_high, uint64_t x_low, + int32_t x_precision, int32_t x_scale, bool x_valid, + int64_t y_high, uint64_t y_low, int32_t y_precision, + int32_t y_scale, bool y_valid, bool* out_valid, int32_t out_precision, int32_t out_scale, int64_t* out_high, uint64_t* out_low) { + if (!x_valid || !y_valid) { + *out_valid = false; + arrow::BasicDecimal128 out = 0; + *out_high = out.high_bits(); + *out_low = out.low_bits(); + return; + } gandiva::BasicDecimalScalar128 x(x_high, x_low, x_precision, x_scale); gandiva::BasicDecimalScalar128 y(y_high, y_low, y_precision, y_scale); bool overflow; + if (y.value() == 0) { + *out_valid = false; + arrow::BasicDecimal128 out = 0; + *out_high = out.high_bits(); + *out_low = out.low_bits(); + return; + } // TODO ravindra: generate error on overflows (ARROW-4570). arrow::BasicDecimal128 out = - gandiva::decimalops::Divide(context, x, y, out_precision, out_scale, &overflow); + gandiva::decimalops::Divide(x, y, out_precision, out_scale, &overflow); + *out_valid = true; *out_high = out.high_bits(); *out_low = out.low_bits(); } FORCE_INLINE -void mod_decimal128_decimal128(int64_t context, int64_t x_high, uint64_t x_low, - int32_t x_precision, int32_t x_scale, int64_t y_high, - uint64_t y_low, int32_t y_precision, int32_t y_scale, +void mod_decimal128_decimal128(int64_t x_high, uint64_t x_low, + int32_t x_precision, int32_t x_scale, bool x_valid, + int64_t y_high, uint64_t y_low, int32_t y_precision, + int32_t y_scale, bool y_valid, bool* out_valid, int32_t out_precision, int32_t out_scale, int64_t* out_high, uint64_t* out_low) { + if (!x_valid || !y_valid) { + *out_valid = false; + arrow::BasicDecimal128 out = 0; + *out_high = out.high_bits(); + *out_low = out.low_bits(); + return; + } gandiva::BasicDecimalScalar128 x(x_high, x_low, x_precision, x_scale); gandiva::BasicDecimalScalar128 y(y_high, y_low, y_precision, y_scale); bool overflow; + if (y.value() == 0) { + *out_valid = false; + arrow::BasicDecimal128 out = 0; + *out_high = out.high_bits(); + *out_low = out.low_bits(); + return; + } // TODO ravindra: generate error on overflows (ARROW-4570). arrow::BasicDecimal128 out = - gandiva::decimalops::Mod(context, x, y, out_precision, out_scale, &overflow); + gandiva::decimalops::Mod(x, y, out_precision, out_scale, &overflow); + *out_valid = true; *out_high = out.high_bits(); *out_low = out.low_bits(); }