Skip to content

Commit 3c85ef6

Browse files
committed
fix: Provide Null coercion with other types
1 parent 5b0a80c commit 3c85ef6

File tree

2 files changed

+41
-7
lines changed

2 files changed

+41
-7
lines changed

datafusion/core/src/physical_plan/windows/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use crate::physical_plan::{
3030
use crate::scalar::ScalarValue;
3131
use arrow::datatypes::Schema;
3232
use datafusion_expr::{
33-
window_function::{signature_for_built_in, BuiltInWindowFunction, WindowFunction},
33+
window_function::{signature_for_built_in, BuiltInWindowFunction},
3434
WindowFrame,
3535
};
3636
use datafusion_physical_expr::window::BuiltInWindowFunctionExpr;
@@ -39,6 +39,7 @@ use std::sync::Arc;
3939

4040
mod window_agg_exec;
4141

42+
pub use datafusion_expr::window_function::WindowFunction;
4243
pub use datafusion_physical_expr::window::{
4344
AggregateWindowExpr, BuiltInWindowExpr, WindowExpr,
4445
};

datafusion/expr/src/binary_rule.rs

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,9 @@ pub fn coerce_types(
8787
| Operator::BitwiseShiftRight => bitwise_coercion(lhs_type, rhs_type),
8888
Operator::And | Operator::Or => match (lhs_type, rhs_type) {
8989
// logical binary boolean operators can only be evaluated in bools
90-
(DataType::Boolean, DataType::Boolean) => Some(DataType::Boolean),
90+
(DataType::Boolean, DataType::Boolean)
91+
| (DataType::Boolean, DataType::Null)
92+
| (DataType::Null, DataType::Boolean) => Some(DataType::Boolean),
9193
_ => None,
9294
},
9395
// logical equality operators have their own rules, and always return a boolean
@@ -164,6 +166,11 @@ pub fn comparison_eq_coercion(
164166
// same type => equality is possible
165167
return Some(lhs_type.clone());
166168
}
169+
match (lhs_type, rhs_type) {
170+
(_, DataType::Null) if !is_dictionary(lhs_type) => return Some(lhs_type.clone()),
171+
(DataType::Null, _) if !is_dictionary(rhs_type) => return Some(rhs_type.clone()),
172+
_ => (),
173+
};
167174
comparison_binary_numeric_coercion(lhs_type, rhs_type)
168175
.or_else(|| string_boolean_equality_coercion(lhs_type, rhs_type))
169176
.or_else(|| dictionary_coercion(lhs_type, rhs_type))
@@ -205,6 +212,11 @@ fn comparison_order_coercion(
205212
// same type => all good
206213
return Some(lhs_type.clone());
207214
}
215+
match (lhs_type, rhs_type) {
216+
(_, DataType::Null) if !is_dictionary(lhs_type) => return Some(lhs_type.clone()),
217+
(DataType::Null, _) if !is_dictionary(rhs_type) => return Some(rhs_type.clone()),
218+
_ => (),
219+
};
208220
comparison_binary_numeric_coercion(lhs_type, rhs_type)
209221
.or_else(|| string_coercion(lhs_type, rhs_type))
210222
.or_else(|| dictionary_coercion(lhs_type, rhs_type))
@@ -314,6 +326,9 @@ fn mathematics_numerical_coercion(
314326
(Decimal(_, _), Decimal(_, _)) => {
315327
coercion_decimal_mathematics_type(mathematics_op, lhs_type, rhs_type)
316328
}
329+
(Decimal(precision, scale), Null) | (Null, Decimal(precision, scale)) => {
330+
Some(Decimal(*precision, *scale))
331+
}
317332
(Decimal(_, _), _) => {
318333
let converted_decimal_type = coerce_numeric_type_to_decimal(rhs_type);
319334
match converted_decimal_type {
@@ -415,6 +430,7 @@ pub fn is_signed_numeric(dt: &DataType) -> bool {
415430
| DataType::Float32
416431
| DataType::Float64
417432
| DataType::Decimal(_, _)
433+
| DataType::Null
418434
)
419435
}
420436

@@ -498,6 +514,8 @@ fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType>
498514
(LargeUtf8, Utf8) => Some(LargeUtf8),
499515
(Utf8, LargeUtf8) => Some(LargeUtf8),
500516
(LargeUtf8, LargeUtf8) => Some(LargeUtf8),
517+
(Utf8, Null) | (Null, Utf8) => Some(Utf8),
518+
(LargeUtf8, Null) | (Null, LargeUtf8) => Some(LargeUtf8),
501519
_ => None,
502520
}
503521
}
@@ -622,6 +640,11 @@ fn eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
622640
// same type => equality is possible
623641
return Some(lhs_type.clone());
624642
}
643+
match (lhs_type, rhs_type) {
644+
(_, DataType::Null) if !is_dictionary(lhs_type) => return Some(lhs_type.clone()),
645+
(DataType::Null, _) if !is_dictionary(rhs_type) => return Some(rhs_type.clone()),
646+
_ => (),
647+
};
625648
numerical_coercion(lhs_type, rhs_type)
626649
.or_else(|| dictionary_coercion(lhs_type, rhs_type))
627650
.or_else(|| temporal_coercion(lhs_type, rhs_type))
@@ -642,14 +665,22 @@ pub fn interval_coercion(
642665
match op {
643666
Operator::Plus | Operator::Minus => match (lhs_type, rhs_type) {
644667
(Timestamp(unit, zone), Interval(_))
645-
| (Interval(_), Timestamp(unit, zone)) => {
668+
| (Interval(_), Timestamp(unit, zone))
669+
| (Timestamp(unit, zone), Null)
670+
| (Null, Timestamp(unit, zone)) => {
646671
Some(Timestamp(unit.clone(), zone.clone()))
647672
}
648-
(Date32, Interval(_)) | (Interval(_), Date32) => {
673+
(Date32, Interval(_))
674+
| (Interval(_), Date32)
675+
| (Date32, Null)
676+
| (Null, Date32) => {
649677
// TODO: this is not correct and should be replaced with correctly typed timestamp
650678
Some(Date32)
651679
}
652-
(Date64, Interval(_)) | (Interval(_), Date64) => {
680+
(Date64, Interval(_))
681+
| (Interval(_), Date64)
682+
| (Date64, Null)
683+
| (Null, Date64) => {
653684
// TODO: this is not correct and should be replaced with correctly typed timestamp
654685
Some(Date64)
655686
}
@@ -679,7 +710,9 @@ pub fn interval_coercion(
679710
| (UInt8, Interval(itype))
680711
| (Interval(itype), UInt8)
681712
| (Decimal(_, _), Interval(itype))
682-
| (Interval(itype), Decimal(_, _)) => Some(Interval(itype.clone())),
713+
| (Interval(itype), Decimal(_, _))
714+
| (Null, Interval(itype))
715+
| (Interval(itype), Null) => Some(Interval(itype.clone())),
683716
_ => None,
684717
},
685718
_ => None,
@@ -699,7 +732,7 @@ pub fn date_coercion(
699732
// that the coercion removes the least amount of information
700733
match op {
701734
Operator::Minus => match (lhs_type, rhs_type) {
702-
(Date32, Date32) => Some(Int32),
735+
(Date32, Date32) | (Date32, Null) | (Null, Date32) => Some(Int32),
703736
(Timestamp(_, _), Timestamp(_, _)) => Some(Interval(MonthDayNano)),
704737
_ => None,
705738
},

0 commit comments

Comments
 (0)