diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 99ab1fed0f4d2..ca6c2233ef870 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -81,7 +81,10 @@ def test_function_parity(self): missing_in_py = jvm_fn_set.difference(py_fn_set) # Functions that we expect to be missing in python until they are added to pyspark - expected_missing_in_py = set() + expected_missing_in_py = set( + # TODO(SPARK-53107): Implement the time_trunc function in Python + ["time_trunc"] + ) self.assertEqual( expected_missing_in_py, missing_in_py, "Missing functions in pyspark not as expected" diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index 4d8f658ca32d9..49fa45ed02cbe 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -6292,6 +6292,27 @@ object functions { def timestamp_add(unit: String, quantity: Column, ts: Column): Column = Column.internalFn("timestampadd", lit(unit), quantity, ts) + /** + * Returns `time` truncated to the `unit`. + * + * @param unit + * A STRING representing the unit to truncate the time to. Supported units are: "HOUR", + * "MINUTE", "SECOND", "MILLISECOND", and "MICROSECOND". The unit is case-insensitive. + * @param time + * A TIME to truncate. + * @return + * A TIME truncated to the specified unit. + * @note + * If any of the inputs is `NULL`, the result is `NULL`. + * @throws IllegalArgumentException + * If the `unit` is not supported. + * @group datetime_funcs + * @since 4.1.0 + */ + def time_trunc(unit: Column, time: Column): Column = { + Column.fn("time_trunc", unit, time) + } + /** * Parses the `timestamp` expression with the `format` expression to a timestamp without time * zone. Returns null with invalid input. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TimeFunctionsSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/TimeFunctionsSuiteBase.scala index 8506ab4527c9b..005bfcb13d2e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TimeFunctionsSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TimeFunctionsSuiteBase.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import java.time.LocalTime import java.time.temporal.ChronoUnit -import org.apache.spark.{SparkConf, SparkDateTimeException} +import org.apache.spark.{SparkConf, SparkDateTimeException, SparkIllegalArgumentException} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -241,6 +241,60 @@ abstract class TimeFunctionsSuiteBase extends QueryTest with SharedSparkSession checkAnswer(result2, expected) } + test("SPARK-53107: time_trunc function") { + // Input data for the function (including null values). + val schema = StructType(Seq( + StructField("unit", StringType), + StructField("time", TimeType()) + )) + val data = Seq( + Row("HOUR", LocalTime.parse("00:00:00")), + Row("second", LocalTime.parse("01:02:03.4")), + Row("MicroSecond", LocalTime.parse("23:59:59.999999")), + Row(null, LocalTime.parse("01:02:03")), + Row("MiNuTe", null), + Row(null, null) + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + + // Test the function using both `selectExpr` and `select`. + val result1 = df.selectExpr( + "time_trunc(unit, time)" + ) + val result2 = df.select( + time_trunc(col("unit"), col("time")) + ) + // Check that both methods produce the same result. + checkAnswer(result1, result2) + + // Expected output of the function. + val expected = Seq( + "00:00:00", + "01:02:03", + "23:59:59.999999", + null, + null, + null + ).toDF("timeString").select(col("timeString").cast("time")) + // Check that the results match the expected output. + checkAnswer(result1, expected) + checkAnswer(result2, expected) + + // Error is thrown for malformed input. + val invalidUnitDF = Seq(("invalid_unit", LocalTime.parse("01:02:03"))).toDF("unit", "time") + checkError( + exception = intercept[SparkIllegalArgumentException] { + invalidUnitDF.select(time_trunc(col("unit"), col("time"))).collect() + }, + condition = "INVALID_PARAMETER_VALUE.TIME_UNIT", + parameters = Map( + "functionName" -> "`time_trunc`", + "parameter" -> "`unit`", + "invalidValue" -> "'invalid_unit'" + ) + ) + } + test("SPARK-52883: to_time function without format") { // Input data for the function. val schema = StructType(Seq(