diff --git a/datafusion/spark/src/function/datetime/make_interval.rs b/datafusion/spark/src/function/datetime/make_interval.rs index d510eacb9aa4..7e358d552246 100644 --- a/datafusion/spark/src/function/datetime/make_interval.rs +++ b/datafusion/spark/src/function/datetime/make_interval.rs @@ -22,11 +22,11 @@ use arrow::array::{Array, ArrayRef, IntervalMonthDayNanoBuilder, PrimitiveArray} use arrow::datatypes::DataType::Interval; use arrow::datatypes::IntervalUnit::MonthDayNano; use arrow::datatypes::{DataType, IntervalMonthDayNano}; -use datafusion_common::{ - exec_err, plan_datafusion_err, DataFusionError, Result, ScalarValue, -}; +use datafusion_common::types::{logical_float64, logical_int32, NativeType}; +use datafusion_common::{plan_datafusion_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + TypeSignatureClass, Volatility, }; use datafusion_functions::utils::make_scalar_function; @@ -43,8 +43,64 @@ impl Default for SparkMakeInterval { impl SparkMakeInterval { pub fn new() -> Self { + let int32 = Coercion::new_implicit( + TypeSignatureClass::Native(logical_int32()), + vec![TypeSignatureClass::Integer], + NativeType::Int32, + ); + + let float64 = Coercion::new_implicit( + TypeSignatureClass::Native(logical_float64()), + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ); + + let variants = vec![ + TypeSignature::Nullary, + // year + TypeSignature::Coercible(vec![int32.clone()]), + // year, month + TypeSignature::Coercible(vec![int32.clone(), int32.clone()]), + // year, month, week + TypeSignature::Coercible(vec![int32.clone(), int32.clone(), int32.clone()]), + // year, month, week, day + TypeSignature::Coercible(vec![ + int32.clone(), + int32.clone(), + int32.clone(), + int32.clone(), + ]), + // year, month, week, day, hour + TypeSignature::Coercible(vec![ + int32.clone(), + int32.clone(), + int32.clone(), + int32.clone(), + int32.clone(), + ]), + // year, month, week, day, hour, minute + TypeSignature::Coercible(vec![ + int32.clone(), + int32.clone(), + int32.clone(), + int32.clone(), + int32.clone(), + int32.clone(), + ]), + // year, month, week, day, hour, minute, second + TypeSignature::Coercible(vec![ + int32.clone(), + int32.clone(), + int32.clone(), + int32.clone(), + int32.clone(), + int32.clone(), + float64.clone(), + ]), + ]; + Self { - signature: Signature::user_defined(Volatility::Immutable), + signature: Signature::one_of(variants, Volatility::Immutable), } } } @@ -74,27 +130,6 @@ impl ScalarUDFImpl for SparkMakeInterval { } make_scalar_function(make_interval_kernel, vec![])(&args.args) } - - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - let length = arg_types.len(); - match length { - x if x > 7 => { - exec_err!( - "make_interval expects between 0 and 7 arguments, got {}", - arg_types.len() - ) - } - _ => Ok((0..arg_types.len()) - .map(|i| { - if i == 6 { - DataType::Float64 - } else { - DataType::Int32 - } - }) - .collect()), - } - } } fn make_interval_kernel(args: &[ArrayRef]) -> Result {