Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.hadoop.fs.FileStatus

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.avro.AvroUtils
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, Write, WriteBuilder}
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.execution.datasources.v2.FileTable
import org.apache.spark.sql.types.{DataType, StructType}
Expand All @@ -43,13 +43,14 @@ case class AvroTable(
AvroUtils.inferSchema(sparkSession, options.asScala.toMap, files)

override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
new WriteBuilder {
override def build(): Write =
AvroWrite(paths, formatName, supportsDataType, mergedWriteInfo(info))
createFileWriteBuilder(info) {
(mergedInfo, partSchema, bSpec, customLocs, dynamicOverwrite, truncate) =>
AvroWrite(paths, formatName, supportsDataType, mergedInfo, partSchema, bSpec,
customLocs, dynamicOverwrite, truncate)
}
}

override def supportsDataType(dataType: DataType): Boolean = AvroUtils.supportsDataType(dataType)

override def formatName: String = "AVRO"
override def formatName: String = "Avro"
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.v2.avro
import org.apache.hadoop.mapreduce.Job

import org.apache.spark.sql.avro.AvroUtils
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.connector.write.LogicalWriteInfo
import org.apache.spark.sql.execution.datasources.OutputWriterFactory
import org.apache.spark.sql.execution.datasources.v2.FileWrite
Expand All @@ -29,7 +30,12 @@ case class AvroWrite(
paths: Seq[String],
formatName: String,
supportsDataType: DataType => Boolean,
info: LogicalWriteInfo) extends FileWrite {
info: LogicalWriteInfo,
partitionSchema: StructType,
override val bucketSpec: Option[BucketSpec] = None,
override val customPartitionLocations: Map[Map[String, String], String] = Map.empty,
override val dynamicPartitionOverwrite: Boolean,
override val isTruncate: Boolean) extends FileWrite {
override def prepareWrite(
sqlConf: SQLConf,
job: Job,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.datasources.{CreateTable => CreateTableV1, LogicalRelation}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Utils, FileDataSourceV2, FileTable}
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
import org.apache.spark.sql.internal.connector.V1Function
import org.apache.spark.sql.types.{DataType, MetadataBuilder, StringType, StructField, StructType}
Expand Down Expand Up @@ -247,7 +247,34 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager)
constructV1TableCmd(None, c.tableSpec, ident, StructType(fields), c.partitioning,
c.ignoreIfExists, storageFormat, provider)
} else {
c
// File sources: validate data types and create via
// V1 command. Non-file V2 providers keep V2 plan.
DataSourceV2Utils.getTableProvider(
provider, conf) match {
case Some(f: FileDataSourceV2) =>
val ft = f.getTable(
c.tableSchema, c.partitioning.toArray,
new org.apache.spark.sql.util
.CaseInsensitiveStringMap(
java.util.Collections.emptyMap()))
ft match {
case ft: FileTable =>
c.tableSchema.foreach { field =>
if (!ft.supportsDataType(
field.dataType)) {
throw QueryCompilationErrors
.dataTypeUnsupportedByDataSourceError(
ft.formatName, field)
}
}
case _ =>
}
constructV1TableCmd(None, c.tableSpec, ident,
StructType(c.columns.map(_.toV1Column)),
c.partitioning,
c.ignoreIfExists, storageFormat, provider)
case _ => c
}
}

case c @ CreateTableAsSelect(
Expand All @@ -267,7 +294,17 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager)
constructV1TableCmd(Some(c.query), c.tableSpec, ident, new StructType, c.partitioning,
c.ignoreIfExists, storageFormat, provider)
} else {
c
// File sources: create via V1 command.
// Non-file V2 providers keep V2 plan.
DataSourceV2Utils.getTableProvider(
provider, conf) match {
case Some(_: FileDataSourceV2) =>
constructV1TableCmd(Some(c.query),
c.tableSpec, ident, new StructType,
c.partitioning, c.ignoreIfExists,
storageFormat, provider)
case _ => c
}
}

case RefreshTable(ResolvedV1TableOrViewIdentifier(ident)) =>
Expand All @@ -281,7 +318,16 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager)
throw QueryCompilationErrors.unsupportedTableOperationError(
ident, "REPLACE TABLE")
} else {
c
// File sources don't support REPLACE TABLE in
// the session catalog (requires StagingTableCatalog).
DataSourceV2Utils.getTableProvider(
provider, conf) match {
case Some(_: FileDataSourceV2) =>
throw QueryCompilationErrors
.unsupportedTableOperationError(
ident, "REPLACE TABLE")
case _ => c
}
}

case c @ ReplaceTableAsSelect(ResolvedV1Identifier(ident), _, _, _, _, _, _) =>
Expand All @@ -290,7 +336,14 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager)
throw QueryCompilationErrors.unsupportedTableOperationError(
ident, "REPLACE TABLE AS SELECT")
} else {
c
DataSourceV2Utils.getTableProvider(
provider, conf) match {
case Some(_: FileDataSourceV2) =>
throw QueryCompilationErrors
.unsupportedTableOperationError(
ident, "REPLACE TABLE AS SELECT")
case _ => c
}
}

// For CREATE TABLE LIKE, use the v1 command if both the target and source are in the session
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ import java.util.Locale

import scala.jdk.CollectionConverters._

import org.apache.hadoop.fs.Path

import org.apache.spark.SparkUnsupportedOperationException
import org.apache.spark.annotation.Stable
import org.apache.spark.sql
import org.apache.spark.sql.SaveMode
Expand Down Expand Up @@ -168,8 +171,12 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram

import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
val catalogManager = df.sparkSession.sessionState.catalogManager
val fileV2CreateMode = (curmode == SaveMode.ErrorIfExists ||
curmode == SaveMode.Ignore) &&
provider.isInstanceOf[FileDataSourceV2]
curmode match {
case SaveMode.Append | SaveMode.Overwrite =>
case _ if curmode == SaveMode.Append || curmode == SaveMode.Overwrite ||
fileV2CreateMode =>
val (table, catalog, ident) = provider match {
case supportsExtract: SupportsCatalogOptions =>
val ident = supportsExtract.extractIdentifier(dsOptions)
Expand All @@ -178,7 +185,12 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram

(catalog.loadTable(ident), Some(catalog), Some(ident))
case _: TableProvider =>
val t = getTable
val t = try {
getTable
} catch {
case _: SparkUnsupportedOperationException if fileV2CreateMode =>
return saveToV1SourceCommand(path)
}
if (t.supports(BATCH_WRITE)) {
(t, None, None)
} else {
Expand All @@ -189,15 +201,40 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram
}
}

if (fileV2CreateMode) {
val outputPath = Option(dsOptions.get("path")).map(new Path(_))
outputPath.foreach { p =>
val hadoopConf = df.sparkSession.sessionState
.newHadoopConfWithOptions(extraOptions.toMap)
val fs = p.getFileSystem(hadoopConf)
val qualifiedPath = fs.makeQualified(p)
if (fs.exists(qualifiedPath)) {
if (curmode == SaveMode.ErrorIfExists) {
throw QueryCompilationErrors.outputPathAlreadyExistsError(qualifiedPath)
} else {
return LocalRelation(
DataSourceV2Relation.create(table, catalog, ident, dsOptions).output)
}
}
}
}

val relation = DataSourceV2Relation.create(table, catalog, ident, dsOptions)
checkPartitioningMatchesV2Table(table)
if (curmode == SaveMode.Append) {
if (curmode == SaveMode.Append || fileV2CreateMode) {
AppendData.byName(relation, df.logicalPlan, finalOptions)
} else {
// Truncate the table. TableCapabilityCheck will throw a nice exception if this
// isn't supported
OverwriteByExpression.byName(
relation, df.logicalPlan, Literal(true), finalOptions)
val dynamicOverwrite =
df.sparkSession.sessionState.conf.partitionOverwriteMode ==
PartitionOverwriteMode.DYNAMIC &&
partitioningColumns.exists(_.nonEmpty)
if (dynamicOverwrite) {
OverwritePartitionsDynamic.byName(
relation, df.logicalPlan, finalOptions)
} else {
OverwriteByExpression.byName(
relation, df.logicalPlan, Literal(true), finalOptions)
}
}

case createMode =>
Expand Down Expand Up @@ -226,14 +263,19 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram
finalOptions,
ignoreIfExists = createMode == SaveMode.Ignore)
case _: TableProvider =>
if (getTable.supports(BATCH_WRITE)) {
throw QueryCompilationErrors.writeWithSaveModeUnsupportedBySourceError(
source, createMode.name())
} else {
// Streaming also uses the data source V2 API. So it may be that the data source
// implements v2, but has no v2 implementation for batch writes. In that case, we
// fallback to saving as though it's a V1 source.
saveToV1SourceCommand(path)
try {
if (getTable.supports(BATCH_WRITE)) {
throw QueryCompilationErrors.writeWithSaveModeUnsupportedBySourceError(
source, createMode.name())
} else {
// Streaming also uses the data source V2 API. So it may be that the data source
// implements v2, but has no v2 implementation for batch writes. In that case, we
// fallback to saving as though it's a V1 source.
saveToV1SourceCommand(path)
}
} catch {
case _: SparkUnsupportedOperationException =>
saveToV1SourceCommand(path)
}
}
}
Expand Down Expand Up @@ -439,8 +481,10 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram

val session = df.sparkSession
val v2ProviderOpt = lookupV2Provider()
val canUseV2 = v2ProviderOpt.isDefined || (hasCustomSessionCatalog &&
!df.sparkSession.sessionState.catalogManager.catalog(CatalogManager.SESSION_CATALOG_NAME)
val canUseV2 = v2ProviderOpt.isDefined ||
(hasCustomSessionCatalog &&
!df.sparkSession.sessionState.catalogManager
.catalog(CatalogManager.SESSION_CATALOG_NAME)
.isInstanceOf[CatalogExtension])

session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match {
Expand Down Expand Up @@ -477,6 +521,45 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram
val v2Relation = DataSourceV2Relation.create(table, Some(catalog), Some(ident))
AppendData.byName(v2Relation, df.logicalPlan, extraOptions.toMap)

// For file tables, Overwrite on existing table uses
// OverwriteByExpression (truncate + append) instead of
// ReplaceTableAsSelect (which requires StagingTableCatalog).
case (SaveMode.Overwrite, Some(table: FileTable)) =>
checkPartitioningMatchesV2Table(table)
val v2Relation = DataSourceV2Relation.create(table, Some(catalog), Some(ident))
val conf = df.sparkSession.sessionState.conf
val dynamicPartitionOverwrite = table.partitioning.length > 0 &&
conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC &&
partitioningColumns.exists(_.nonEmpty)
if (dynamicPartitionOverwrite) {
OverwritePartitionsDynamic.byName(
v2Relation, df.logicalPlan, extraOptions.toMap)
} else {
OverwriteByExpression.byName(
v2Relation, df.logicalPlan, Literal(true), extraOptions.toMap)
}

// File table Overwrite when table doesn't exist: create it.
case (SaveMode.Overwrite, None)
if v2ProviderOpt.exists(_.isInstanceOf[FileDataSourceV2]) =>
val tableSpec = UnresolvedTableSpec(
properties = Map.empty,
provider = Some(source),
optionExpression = OptionList(Seq.empty),
location = extraOptions.get("path"),
comment = extraOptions.get(TableCatalog.PROP_COMMENT),
collation = extraOptions.get(TableCatalog.PROP_COLLATION),
serde = None,
external = false,
constraints = Seq.empty)
CreateTableAsSelect(
UnresolvedIdentifier(nameParts),
partitioningAsV2,
df.queryExecution.analyzed,
tableSpec,
writeOptions = extraOptions.toMap,
ignoreIfExists = false)

case (SaveMode.Overwrite, _) =>
val tableSpec = UnresolvedTableSpec(
properties = Map.empty,
Expand Down Expand Up @@ -595,8 +678,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram

private def lookupV2Provider(): Option[TableProvider] = {
DataSource.lookupDataSourceV2(source, df.sparkSession.sessionState.conf) match {
// TODO(SPARK-28396): File source v2 write path is currently broken.
case Some(_: FileDataSourceV2) => None
case other => other
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggr
import org.apache.spark.sql.execution.RowToColumnConverter
import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils
import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, ByteType, DateType, DoubleType, FloatType, IntegerType, LongType, ShortType, StructField, StructType}
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}

Expand All @@ -43,12 +44,22 @@ object AggregatePushDownUtils {

var finalSchema = new StructType()

val caseSensitive = SQLConf.get.caseSensitiveAnalysis

def getStructFieldForCol(colName: String): StructField = {
schema.apply(colName)
if (caseSensitive) {
schema.apply(colName)
} else {
schema.find(_.name.equalsIgnoreCase(colName)).getOrElse(schema.apply(colName))
}
}

def isPartitionCol(colName: String) = {
partitionNames.contains(colName)
if (caseSensitive) {
partitionNames.contains(colName)
} else {
partitionNames.exists(_.equalsIgnoreCase(colName))
}
}

def processMinOrMax(agg: AggregateFunc): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -487,10 +487,10 @@ case class DataSource(
val caseSensitive = conf.caseSensitiveAnalysis
PartitioningUtils.validatePartitionColumn(data.schema, partitionColumns, caseSensitive)

val fileIndex = catalogTable.map(_.identifier).map { tableIdent =>
sparkSession.table(tableIdent).queryExecution.analyzed.collect {
val fileIndex = catalogTable.map(_.identifier).flatMap { tableIdent =>
sparkSession.table(tableIdent).queryExecution.analyzed.collectFirst {
case LogicalRelationWithTable(t: HadoopFsRelation, _) => t.location
}.head
}
}
// For partitioned relation r, r.schema's column ordering can be different from the column
// ordering of data.logicalPlan (partition columns are all moved after data column). This
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.datasources.v2.{ExtractV2Table, PushedDownOperators}
import org.apache.spark.sql.execution.datasources.v2.{ExtractV2Table, FileTable, PushedDownOperators}
import org.apache.spark.sql.execution.streaming.runtime.StreamingRelation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources
Expand Down Expand Up @@ -360,6 +360,13 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan]
case u: UnresolvedCatalogRelation if u.isStreaming =>
getStreamingRelation(u.tableMeta, u.options, Unassigned)

// TODO(SPARK-56233): Add MICRO_BATCH_READ capability to FileTable
// so streaming reads don't need V1 fallback.
case StreamingRelationV2(
_, _, ft: FileTable, extraOptions, _, _, _, None, name)
if ft.catalogTable.isDefined =>
getStreamingRelation(ft.catalogTable.get, extraOptions, name)

case s @ StreamingRelationV2(
_, _, table, extraOptions, _, _, _,
Some(UnresolvedCatalogRelation(tableMeta, _, true)), name) =>
Expand Down
Loading