diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala index 99b3b79d1a4e6..b25fce6d5192f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala @@ -86,20 +86,21 @@ case class PivotFirst( override val dataType: DataType = ArrayType(valueDataType) - val pivotIndex: Map[Any, Int] = if (pivotColumn.dataType.isInstanceOf[AtomicType]) { - HashMap(pivotColumnValues.zipWithIndex: _*) - } else { + private val usesTreeMap: Boolean = !TypeUtils.typeWithProperEquals(pivotColumn.dataType) + + val pivotIndex: Map[Any, Int] = if (usesTreeMap) { TreeMap(pivotColumnValues.zipWithIndex: _*)( TypeUtils.getInterpretedOrdering(pivotColumn.dataType)) + } else { + HashMap(pivotColumnValues.zipWithIndex: _*) } - // Null-safe lookup into pivotIndex. For atomic types, pivotIndex is a HashMap which - // handles null keys safely via hash-based lookup. For non-atomic types, pivotIndex is a TreeMap - // whose comparison-based lookup throws NPE on null keys. Returning -1 for null is safe on the - // TreeMap path because null can never be a TreeMap key (insertion would also NPE), so it can - // never match any pivot value. + // Null-safe lookup into pivotIndex. When pivotIndex is a TreeMap, its comparison-based lookup + // throws NPE on null keys. Returning -1 for null is safe on the TreeMap path because null can + // never be a TreeMap key (insertion would also NPE), so it can never match any pivot value. + // Otherwise, pivotIndex is a HashMap that handles null keys safely via hash-based lookup. private def findPivotIndex(key: Any): Int = key match { - case null if !pivotColumn.dataType.isInstanceOf[AtomicType] => -1 + case null if usesTreeMap => -1 case _ => pivotIndex.getOrElse(key, -1) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 5202ae5d4e5d6..e4f9035b83bdd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -397,4 +397,83 @@ class DataFramePivotSuite extends QueryTest with SharedSparkSession { Row(500, 200)) } } + + test("pivot with explicit values under UTF8_LCASE collation") { + withTable("lcase_pivot") { + sql( + """CREATE TABLE lcase_pivot ( + | quarter STRING COLLATE UTF8_LCASE, + | course STRING, + | earnings INT + |) USING PARQUET""".stripMargin) + sql( + """INSERT INTO lcase_pivot VALUES + | ('q1', 'dotNET', 10000), + | ('q2', 'dotNET', 15000), + | ('q1', 'Java', 20000), + | ('q2', 'Java', 30000)""".stripMargin) + + checkAnswer( + sql( + """SELECT * FROM lcase_pivot + |PIVOT ( + | SUM(earnings) + | FOR quarter IN ('Q1' AS Q1, 'Q2' AS Q2) + |)""".stripMargin), + Row("dotNET", 10000, 15000) :: + Row("Java", 20000, 30000) :: Nil) + } + } + + test("pivot with explicit values under UNICODE_CI collation") { + // scalastyle:off nonascii + val precomposed = "\u00FCber" // über (precomposed) + val decomposed = "u\u0308ber" // über (decomposed) + // scalastyle:on nonascii + withTable("uci_pivot") { + sql( + """CREATE TABLE uci_pivot ( + | key STRING COLLATE UNICODE_CI, + | amount INT + |) USING PARQUET""".stripMargin) + sql(s"INSERT INTO uci_pivot VALUES ('$precomposed', 100)") + sql(s"INSERT INTO uci_pivot VALUES ('$decomposed', 200)") + sql("INSERT INTO uci_pivot VALUES ('other', 50)") + + checkAnswer( + sql( + s"""SELECT * FROM uci_pivot + |PIVOT ( + | SUM(amount) FOR key IN ( + | '$precomposed' AS uber, + | 'other' AS other + | ) + |)""".stripMargin), + Row(300, 50)) + } + } + + test("pivot with null collated string column should not NPE") { + withTable("lcase_null_pivot") { + sql( + """CREATE TABLE lcase_null_pivot ( + | key STRING COLLATE UTF8_LCASE, + | amount INT + |) USING PARQUET""".stripMargin) + sql( + """INSERT INTO lcase_null_pivot VALUES + | ('a', 10), + | (NULL, 20), + | ('b', 30)""".stripMargin) + + checkAnswer( + sql( + """SELECT * FROM lcase_null_pivot + |PIVOT ( + | SUM(amount) + | FOR key IN ('a' AS a, 'b' AS b) + |)""".stripMargin), + Row(10, 30)) + } + } }