Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,21 @@ case class PivotFirst(

override val dataType: DataType = ArrayType(valueDataType)

val pivotIndex: Map[Any, Int] = if (pivotColumn.dataType.isInstanceOf[AtomicType]) {
HashMap(pivotColumnValues.zipWithIndex: _*)
} else {
private val usesTreeMap: Boolean = !TypeUtils.typeWithProperEquals(pivotColumn.dataType)

val pivotIndex: Map[Any, Int] = if (usesTreeMap) {
TreeMap(pivotColumnValues.zipWithIndex: _*)(
TypeUtils.getInterpretedOrdering(pivotColumn.dataType))
} else {
HashMap(pivotColumnValues.zipWithIndex: _*)
}

// Null-safe lookup into pivotIndex. For atomic types, pivotIndex is a HashMap which
// handles null keys safely via hash-based lookup. For non-atomic types, pivotIndex is a TreeMap
// whose comparison-based lookup throws NPE on null keys. Returning -1 for null is safe on the
// TreeMap path because null can never be a TreeMap key (insertion would also NPE), so it can
// never match any pivot value.
// Null-safe lookup into pivotIndex. When pivotIndex is a TreeMap, its comparison-based lookup
// throws NPE on null keys. Returning -1 for null is safe on the TreeMap path because null can
// never be a TreeMap key (insertion would also NPE), so it can never match any pivot value.
// Otherwise, pivotIndex is a HashMap that handles null keys safely via hash-based lookup.
private def findPivotIndex(key: Any): Int = key match {
case null if !pivotColumn.dataType.isInstanceOf[AtomicType] => -1
case null if usesTreeMap => -1
case _ => pivotIndex.getOrElse(key, -1)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,4 +397,83 @@ class DataFramePivotSuite extends QueryTest with SharedSparkSession {
Row(500, 200))
}
}

test("pivot with explicit values under UTF8_LCASE collation") {
withTable("lcase_pivot") {
sql(
"""CREATE TABLE lcase_pivot (
| quarter STRING COLLATE UTF8_LCASE,
| course STRING,
| earnings INT
|) USING PARQUET""".stripMargin)
sql(
"""INSERT INTO lcase_pivot VALUES
| ('q1', 'dotNET', 10000),
| ('q2', 'dotNET', 15000),
| ('q1', 'Java', 20000),
| ('q2', 'Java', 30000)""".stripMargin)

checkAnswer(
sql(
"""SELECT * FROM lcase_pivot
|PIVOT (
| SUM(earnings)
| FOR quarter IN ('Q1' AS Q1, 'Q2' AS Q2)
|)""".stripMargin),
Row("dotNET", 10000, 15000) ::
Row("Java", 20000, 30000) :: Nil)
}
}

test("pivot with explicit values under UNICODE_CI collation") {
// scalastyle:off nonascii
val precomposed = "\u00FCber" // über (precomposed)
val decomposed = "u\u0308ber" // über (decomposed)
// scalastyle:on nonascii
withTable("uci_pivot") {
sql(
"""CREATE TABLE uci_pivot (
| key STRING COLLATE UNICODE_CI,
| amount INT
|) USING PARQUET""".stripMargin)
sql(s"INSERT INTO uci_pivot VALUES ('$precomposed', 100)")
sql(s"INSERT INTO uci_pivot VALUES ('$decomposed', 200)")
sql("INSERT INTO uci_pivot VALUES ('other', 50)")

checkAnswer(
sql(
s"""SELECT * FROM uci_pivot
|PIVOT (
| SUM(amount) FOR key IN (
| '$precomposed' AS uber,
| 'other' AS other
| )
|)""".stripMargin),
Row(300, 50))
}
}

test("pivot with null collated string column should not NPE") {
withTable("lcase_null_pivot") {
sql(
"""CREATE TABLE lcase_null_pivot (
| key STRING COLLATE UTF8_LCASE,
| amount INT
|) USING PARQUET""".stripMargin)
sql(
"""INSERT INTO lcase_null_pivot VALUES
| ('a', 10),
| (NULL, 20),
| ('b', 30)""".stripMargin)

checkAnswer(
sql(
"""SELECT * FROM lcase_null_pivot
|PIVOT (
| SUM(amount)
| FOR key IN ('a' AS a, 'b' AS b)
|)""".stripMargin),
Row(10, 30))
}
}
}