Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -555,20 +555,20 @@ class DataFrameSuite extends QueryTest
Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2)))

checkAnswer(
arrayData.toDF().orderBy($"data".getItem(0).asc),
arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq)
arrayData.orderBy($"data".getItem(0).asc),
arrayData.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq)

checkAnswer(
arrayData.toDF().orderBy($"data".getItem(0).desc),
arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq)
arrayData.orderBy($"data".getItem(0).desc),
arrayData.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq)

checkAnswer(
arrayData.toDF().orderBy($"data".getItem(1).asc),
arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq)
arrayData.orderBy($"data".getItem(1).asc),
arrayData.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq)

checkAnswer(
arrayData.toDF().orderBy($"data".getItem(1).desc),
arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq)
arrayData.orderBy($"data".getItem(1).desc),
arrayData.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq)
}

test("limit") {
Expand All @@ -577,12 +577,12 @@ class DataFrameSuite extends QueryTest
testData.take(10).toSeq)

checkAnswer(
arrayData.toDF().limit(1),
arrayData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq)))
arrayData.limit(1),
arrayData.take(1))

checkAnswer(
mapData.toDF().limit(1),
mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq)))
mapData.limit(1),
mapData.take(1))

// SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake
checkAnswer(
Expand All @@ -597,12 +597,12 @@ class DataFrameSuite extends QueryTest
testData.collect().drop(90).toSeq)

checkAnswer(
arrayData.toDF().offset(99),
arrayData.collect().drop(99).map(r => Row.fromSeq(r.productIterator.toSeq)))
arrayData.offset(99),
arrayData.collect().drop(99))

checkAnswer(
mapData.toDF().offset(99),
mapData.collect().drop(99).map(r => Row.fromSeq(r.productIterator.toSeq)))
mapData.offset(99),
mapData.collect().drop(99))
}

test("limit with offset") {
Expand Down
14 changes: 7 additions & 7 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -551,19 +551,19 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark

checkAnswer(
sql("SELECT * FROM arrayData ORDER BY data[0] ASC"),
arrayData.collect().sortBy(_.data(0)).map(Row.fromTuple).toSeq)
arrayData.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq)

checkAnswer(
sql("SELECT * FROM arrayData ORDER BY data[0] DESC"),
arrayData.collect().sortBy(_.data(0)).reverse.map(Row.fromTuple).toSeq)
arrayData.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq)

checkAnswer(
sql("SELECT * FROM mapData ORDER BY data[1] ASC"),
mapData.collect().sortBy(_.data(1)).map(Row.fromTuple).toSeq)
mapData.collect().sortBy(_.getAs[Map[Int, String]](0)(1)).toSeq)

checkAnswer(
sql("SELECT * FROM mapData ORDER BY data[1] DESC"),
mapData.collect().sortBy(_.data(1)).reverse.map(Row.fromTuple).toSeq)
mapData.collect().sortBy(_.getAs[Map[Int, String]](0)(1)).reverse.toSeq)
}

test("external sorting") {
Expand Down Expand Up @@ -1007,7 +1007,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
StructField("f3", BooleanType, false) ::
StructField("f4", IntegerType, true) :: Nil)

val rowRDD1 = unparsedStrings.map { r =>
val rowRDD1 = unparsedStrings.as[String].rdd.map { r =>
val values = r.split(",").map(_.trim)
val v4 = try values(3).toInt catch {
case _: NumberFormatException => null
Expand Down Expand Up @@ -1037,7 +1037,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
StructField("f12", BooleanType, false) :: Nil), false) ::
StructField("f2", MapType(StringType, IntegerType, true), false) :: Nil)

val rowRDD2 = unparsedStrings.map { r =>
val rowRDD2 = unparsedStrings.as[String].rdd.map { r =>
val values = r.split(",").map(_.trim)
val v4 = try values(3).toInt catch {
case _: NumberFormatException => null
Expand All @@ -1064,7 +1064,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
Row(4, 2147483644) :: Nil)

// The value of a MapType column can be a mutable map.
val rowRDD3 = unparsedStrings.map { r =>
val rowRDD3 = unparsedStrings.as[String].rdd.map { r =>
val values = r.split(",").map(_.trim)
val v4 = try values(3).toInt catch {
case _: NumberFormatException => null
Expand Down
4 changes: 2 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -359,11 +359,11 @@ class UDFSuite extends QueryTest with SharedSparkSession {
sql("""
| SELECT tmp.t.* FROM
| (SELECT arrayDataFunc(data, nestedData) AS t FROM arrayData) tmp
""".stripMargin).toDF(), arrayData.toDF())
""".stripMargin).toDF(), arrayData)
checkAnswer(
sql("""
| SELECT mapDataFunc(data) AS t FROM mapData
""".stripMargin).toDF(), mapData.toDF())
""".stripMargin).toDF(), mapData)
checkAnswer(
sql("""
| SELECT tmp.t.* FROM
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1304,7 +1304,7 @@ class ArrowConvertersSuite extends SharedSparkSession {
}

test("interval is supported for arrow") {
val collected = calendarIntervalData.toDF().toArrowBatchRdd.collect()
val collected = calendarIntervalData.toArrowBatchRdd.collect()
assert(collected.length == 1)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,25 +185,25 @@ class InMemoryColumnarQuerySuite extends QueryTest
test("SPARK-1678 regression: compression must not lose repeated values") {
checkAnswer(
sql("SELECT * FROM repeatedData"),
repeatedData.collect().toSeq.map(Row.fromTuple))
repeatedData.collect().toSeq)

spark.catalog.cacheTable("repeatedData")

checkAnswer(
sql("SELECT * FROM repeatedData"),
repeatedData.collect().toSeq.map(Row.fromTuple))
repeatedData.collect().toSeq)
}

test("with null values") {
checkAnswer(
sql("SELECT * FROM nullableRepeatedData"),
nullableRepeatedData.collect().toSeq.map(Row.fromTuple))
nullableRepeatedData.collect().toSeq)

spark.catalog.cacheTable("nullableRepeatedData")

checkAnswer(
sql("SELECT * FROM nullableRepeatedData"),
nullableRepeatedData.collect().toSeq.map(Row.fromTuple))
nullableRepeatedData.collect().toSeq)
}

test("SPARK-2729 regression: timestamp data type") {
Expand All @@ -226,13 +226,13 @@ class InMemoryColumnarQuerySuite extends QueryTest
test("SPARK-3320 regression: batched column buffer building should work with empty partitions") {
checkAnswer(
sql("SELECT * FROM withEmptyParts"),
withEmptyParts.collect().toSeq.map(Row.fromTuple))
withEmptyParts.collect().toSeq)

spark.catalog.cacheTable("withEmptyParts")

checkAnswer(
sql("SELECT * FROM withEmptyParts"),
withEmptyParts.collect().toSeq.map(Row.fromTuple))
withEmptyParts.collect().toSeq)
}

test("SPARK-4182 Caching complex types") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1252,7 +1252,7 @@ abstract class JsonSuite
StructField("f4", ArrayType(StringType), nullable = true) ::
StructField("f5", IntegerType, true) :: Nil)

val rowRDD1 = unparsedStrings.map { r =>
val rowRDD1 = unparsedStrings.as[String].rdd.map { r =>
val values = r.split(",").map(_.trim)
val v5 = try values(3).toInt catch {
case _: NumberFormatException => null
Expand All @@ -1275,7 +1275,7 @@ abstract class JsonSuite
StructField("f12", BooleanType, false) :: Nil), false) ::
StructField("f2", MapType(StringType, IntegerType, true), false) :: Nil)

val rowRDD2 = unparsedStrings.map { r =>
val rowRDD2 = unparsedStrings.as[String].rdd.map { r =>
val values = r.split(",").map(_.trim)
val v4 = try values(3).toInt catch {
case _: NumberFormatException => null
Expand Down
63 changes: 31 additions & 32 deletions sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql.test
import java.nio.charset.StandardCharsets
import java.time.{Duration, Period}

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSessionProvider
import org.apache.spark.sql.classic
import org.apache.spark.sql.classic.{DataFrame, SQLImplicits}
Expand Down Expand Up @@ -156,44 +155,44 @@ private[sql] trait SQLTestData extends SparkSessionProvider { self =>
df
}

protected lazy val arrayData: RDD[ArrayData] = {
val rdd = spark.sparkContext.parallelize(
protected lazy val arrayData: DataFrame = {
val df = spark.sparkContext.parallelize(
ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) ::
ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil)
rdd.toDF().createOrReplaceTempView("arrayData")
rdd
ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil).toDF()
df.createOrReplaceTempView("arrayData")
df
}

protected lazy val mapData: RDD[MapData] = {
val rdd = spark.sparkContext.parallelize(
protected lazy val mapData: DataFrame = {
val df = spark.sparkContext.parallelize(
MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) ::
MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) ::
MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) ::
MapData(Map(1 -> "a4", 2 -> "b4")) ::
MapData(Map(1 -> "a5")) :: Nil)
rdd.toDF().createOrReplaceTempView("mapData")
rdd
MapData(Map(1 -> "a5")) :: Nil).toDF()
df.createOrReplaceTempView("mapData")
df
}

protected lazy val calendarIntervalData: RDD[IntervalData] = {
val rdd = spark.sparkContext.parallelize(
IntervalData(new CalendarInterval(1, 1, 1)) :: Nil)
rdd.toDF().createOrReplaceTempView("calendarIntervalData")
rdd
protected lazy val calendarIntervalData: DataFrame = {
val df = spark.sparkContext.parallelize(
IntervalData(new CalendarInterval(1, 1, 1)) :: Nil).toDF()
df.createOrReplaceTempView("calendarIntervalData")
df
}

protected lazy val repeatedData: RDD[StringData] = {
val rdd = spark.sparkContext.parallelize(List.fill(2)(StringData("test")))
rdd.toDF().createOrReplaceTempView("repeatedData")
rdd
protected lazy val repeatedData: DataFrame = {
val df = spark.sparkContext.parallelize(List.fill(2)(StringData("test"))).toDF()
df.createOrReplaceTempView("repeatedData")
df
}

protected lazy val nullableRepeatedData: RDD[StringData] = {
val rdd = spark.sparkContext.parallelize(
protected lazy val nullableRepeatedData: DataFrame = {
val df = spark.sparkContext.parallelize(
List.fill(2)(StringData(null)) ++
List.fill(2)(StringData("test")))
rdd.toDF().createOrReplaceTempView("nullableRepeatedData")
rdd
List.fill(2)(StringData("test"))).toDF()
df.createOrReplaceTempView("nullableRepeatedData")
df
}

protected lazy val nullInts: DataFrame = {
Expand Down Expand Up @@ -231,19 +230,19 @@ private[sql] trait SQLTestData extends SparkSessionProvider { self =>
df
}

protected lazy val unparsedStrings: RDD[String] = {
protected lazy val unparsedStrings: DataFrame = {
spark.sparkContext.parallelize(
"1, A1, true, null" ::
"2, B2, false, null" ::
"3, C3, true, null" ::
"4, D4, true, 2147483644" :: Nil)
"4, D4, true, 2147483644" :: Nil).toDF("value")
}

// An RDD with 4 elements and 8 partitions
protected lazy val withEmptyParts: RDD[IntField] = {
val rdd = spark.sparkContext.parallelize((1 to 4).map(IntField), 8)
rdd.toDF().createOrReplaceTempView("withEmptyParts")
rdd
// A DataFrame with 4 elements and 8 partitions
protected lazy val withEmptyParts: DataFrame = {
val df = spark.sparkContext.parallelize((1 to 4).map(IntField), 8).toDF()
df.createOrReplaceTempView("withEmptyParts")
df
}

protected lazy val person: DataFrame = {
Expand Down