Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -98,64 +98,57 @@ object ResolveSchemaEvolution extends Rule[LogicalPlan] {
}

/**
* Computes the set of table changes needed to evolve `originalTarget` schema
* to accommodate `originalSource` schema. When `isByName` is true, fields are matched
* Computes the set of table changes needed to evolve `target` schema
* to accommodate `source` schema. When `isByName` is true, fields are matched
* by name. When false, fields are matched by position.
*/
def computeSchemaChanges(
originalTarget: StructType,
originalSource: StructType,
target: StructType,
source: StructType,
isByName: Boolean): Array[TableChange] =
computeSchemaChanges(
originalTarget,
originalSource,
originalTarget,
originalSource,
target,
source,
fieldPath = Nil,
isByName)
isByName,
error = throw QueryExecutionErrors.failedToMergeIncompatibleSchemasError(
target, source, null))

private def computeSchemaChanges(
currentType: DataType,
newType: DataType,
originalTarget: StructType,
originalSource: StructType,
fieldPath: List[String],
isByName: Boolean): Array[TableChange] = {
isByName: Boolean,
error: => Nothing): Array[TableChange] = {
(currentType, newType) match {
case (StructType(currentFields), StructType(newFields)) =>
if (isByName) {
computeSchemaChangesByName(
currentFields, newFields, originalTarget, originalSource, fieldPath)
computeSchemaChangesByName(currentFields, newFields, fieldPath, error)
} else {
computeSchemaChangesByPosition(
currentFields, newFields, originalTarget, originalSource, fieldPath)
computeSchemaChangesByPosition(currentFields, newFields, fieldPath, error)
}

case (ArrayType(currentElementType, _), ArrayType(newElementType, _)) =>
computeSchemaChanges(
currentElementType,
newElementType,
originalTarget,
originalSource,
fieldPath :+ "element",
isByName)
isByName,
error)

case (MapType(currentKeyType, currentValueType, _),
MapType(newKeyType, newValueType, _)) =>
case (MapType(currentKeyType, currentValueType, _), MapType(newKeyType, newValueType, _)) =>
val keyChanges = computeSchemaChanges(
currentKeyType,
newKeyType,
originalTarget,
originalSource,
fieldPath :+ "key",
isByName)
isByName,
error)
val valueChanges = computeSchemaChanges(
currentValueType,
newValueType,
originalTarget,
originalSource,
fieldPath :+ "value",
isByName)
isByName,
error)
keyChanges ++ valueChanges

case (currentType: AtomicType, newType: AtomicType) if currentType != newType =>
Expand All @@ -167,8 +160,7 @@ object ResolveSchemaEvolution extends Rule[LogicalPlan] {

case _ =>
// Do not support change between atomic and complex types for now
throw QueryExecutionErrors.failedToMergeIncompatibleSchemasError(
originalTarget, originalSource, null)
error
}
}

Expand All @@ -179,9 +171,8 @@ object ResolveSchemaEvolution extends Rule[LogicalPlan] {
private def computeSchemaChangesByName(
currentFields: Array[StructField],
newFields: Array[StructField],
originalTarget: StructType,
originalSource: StructType,
fieldPath: List[String]): Array[TableChange] = {
fieldPath: List[String],
error: => Nothing): Array[TableChange] = {
val currentFieldMap = toFieldMap(currentFields)
val newFieldMap = toFieldMap(newFields)

Expand All @@ -192,10 +183,9 @@ object ResolveSchemaEvolution extends Rule[LogicalPlan] {
computeSchemaChanges(
f.dataType,
newFieldMap(f.name).dataType,
originalTarget,
originalSource,
fieldPath :+ f.name,
isByName = true)
isByName = true,
error)
}

// Collect newly added fields
Expand All @@ -213,18 +203,16 @@ object ResolveSchemaEvolution extends Rule[LogicalPlan] {
private def computeSchemaChangesByPosition(
currentFields: Array[StructField],
newFields: Array[StructField],
originalTarget: StructType,
originalSource: StructType,
fieldPath: List[String]): Array[TableChange] = {
fieldPath: List[String],
error: => Nothing): Array[TableChange] = {
// Update existing field types by pairing fields at the same position.
val updates = currentFields.zip(newFields).flatMap { case (currentField, newField) =>
computeSchemaChanges(
currentField.dataType,
newField.dataType,
originalTarget,
originalSource,
fieldPath :+ currentField.name,
isByName = false)
isByName = false,
error)
}

// Extra source fields beyond the target's field count are new additions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1041,9 +1041,10 @@ case class MergeIntoTable(

override lazy val pendingSchemaChanges: Seq[TableChange] = {
if (schemaEvolutionEnabled && schemaEvolutionReady) {
val referencedSourceSchema = MergeIntoTable.sourceSchemaForSchemaEvolution(this)
ResolveSchemaEvolution.computeSchemaChanges(
targetTable.schema, referencedSourceSchema, isByName = true).toSeq
val allChanges = ResolveSchemaEvolution.computeSchemaChanges(
targetTable.schema, sourceTable.schema, isByName = true)
MergeIntoTable.filterValidSchemaEvolution(
allChanges, matchedActions ++ notMatchedActions, sourceTable)
} else {
Seq.empty
}
Expand Down Expand Up @@ -1097,52 +1098,36 @@ object MergeIntoTable {
.toSet
}

// A pruned version of source schema that only contains columns/nested fields
// explicitly and directly assigned to a target counterpart in MERGE INTO actions,
// which are relevant for schema evolution.
// Examples:
// * UPDATE SET target.a = source.a
// * UPDATE SET nested.a = source.nested.a
// * INSERT (a, nested.b) VALUES (source.a, source.nested.b)
// New columns/nested fields in this schema that are not existing in target schema
// will be added for schema evolution.
def sourceSchemaForSchemaEvolution(merge: MergeIntoTable): StructType = {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the most complicated method that did many recursive calls, so I agree we should avoid it.

The idea that I had was compatible with changes that Johan did and could look like this:

  private def computeSchemaChanges(merge: MergeIntoTable): Seq[TableChange] = {
    val actions = merge.matchedActions ++ merge.notMatchedActions

    val assignments = actions.flatMap {
      case a: UpdateAction => a.assignments
      case a: InsertAction => a.assignments
      case _ => Seq.empty
    }

    val changes = new mutable.HashSet[TableChange]()

    assignments.foreach {
      case a if isFieldAdditionCandidate(a, merge) =>
        val fieldPath = extractFieldPath(a.key)
        changes += TableChange.addColumn(fieldPath.toArray, a.value.dataType)
      case a if a.resolved && a.key.dataType != a.value.dataType =>
        changes ++= ResolveSchemaEvolution.computeSchemaChanges(
          a.key.dataType,
          a.value.dataType,
          merge.targetTable.schema,
          merge.sourceTable.schema,
          fieldPath = extractFieldPath(a.key),
          isByName = true)
      case _ =>
        // OK
    }

    changes.toSeq
  }

  private def extractFieldPath(expr: Expression): Seq[String] = {
    expr match {
      case UnresolvedAttribute(nameParts) => nameParts
      case a: AttributeReference => Seq(a.name)
      case Alias(child, _) => extractFieldPath(child)
      case GetStructField(child, ordinal, nameOpt) =>
        extractFieldPath(child) :+ nameOpt.getOrElse(s"col$ordinal")
      case _ => Seq.empty
    }
  }

  private def areSchemaEvolutionReady(
      assignments: Seq[Assignment],
      merge: MergeIntoTable): Boolean = {
    assignments.forall(assign => assign.resolved || isFieldAdditionCandidate(assign, merge))
  }

  // TODO: clean up and add doc
  private def isFieldAdditionCandidate(
      assignment: Assignment,
      merge: MergeIntoTable): Boolean = {
    val key = assignment.key
    val keyPath = extractFieldPath(key)
    val value = assignment.value
    val valuePath = extractFieldPath(value)
    !key.resolved &&
      value.resolved &&
      keyPath == valuePath &&
      assignment.value.references.subsetOf(merge.sourceTable.outputSet) &&
      merge.targetTable.resolve(keyPath, SQLConf.get.resolver).isEmpty
  }

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With that said, yours may be better. Let me explore it in more detail.

val actions = merge.matchedActions ++ merge.notMatchedActions
/**
* Filters schema changes to only those relevant to identity assignments
* (e.g. `target.x = source.x`) in the MERGE actions. Only identity assignments can
* introduce new columns or type changes via schema evolution.
*
* A schema change is kept if its field path is equal to or nested under the key path
* of an identity assignment.
*/
private def filterValidSchemaEvolution(
changes: Array[TableChange],
actions: Seq[MergeAction],
source: LogicalPlan): Seq[TableChange] = {
val assignments = actions.collect {
case a: UpdateAction => a.assignments
case a: InsertAction => a.assignments
}.flatten

val containsStarAction = actions.exists {
case _: UpdateStarAction => true
case _: InsertStarAction => true
case _ => false
}

def filterSchema(sourceSchema: StructType, basePath: Seq[String]): StructType =
StructType(sourceSchema.flatMap { field =>
val fieldPath = basePath :+ field.name

field.dataType match {
// Specifically assigned to in one clause:
// always keep, including all nested attributes
case _ if assignments.exists(isEqual(_, fieldPath)) => Some(field)
// If this is a struct and one of the children is being assigned to in a merge clause,
// keep it and continue filtering children.
case struct: StructType if assignments.exists(assign =>
isPrefix(fieldPath, extractFieldPath(assign.key, allowUnresolved = true))) =>
Some(field.copy(dataType = filterSchema(struct, fieldPath)))
// The field isn't assigned to directly or indirectly (i.e. its children) in any non-*
// clause. Check if it should be kept with any * action.
case struct: StructType if containsStarAction =>
Some(field.copy(dataType = filterSchema(struct, fieldPath)))
case _ if containsStarAction => Some(field)
// The field and its children are not assigned to in any * or non-* action, drop it.
case _ => None
}
})

filterSchema(merge.sourceTable.schema, Seq.empty)
val evolutionPaths = assignments
.filter(isSameColumnAssignment(_, source))
.map(a => extractFieldPath(a.key, allowUnresolved = true))
.filter(_.nonEmpty)

val resolver = SQLConf.get.resolver
changes.filter { case change: TableChange.ColumnChange =>
val changePath = change.fieldNames().toSeq
evolutionPaths.exists { ep =>
ep.length <= changePath.length &&
ep.zip(changePath).forall { case (a, b) => resolver(a, b) }
}
}.toSeq
}

// Helper method to extract field path from an Expression.
Expand All @@ -1156,24 +1141,6 @@ object MergeIntoTable {
}
}

// Helper method to check if a given field path is a prefix of another path.
private def isPrefix(prefix: Seq[String], path: Seq[String]): Boolean =
prefix.length <= path.length && prefix.zip(path).forall {
case (prefixNamePart, pathNamePart) =>
SQLConf.get.resolver(prefixNamePart, pathNamePart)
}

// Helper method to check if an assignment key is equal to a source column
// and if the assignment value is that same source column.
// Example: UPDATE SET target.a = source.a
private def isEqual(assignment: Assignment, sourceFieldPath: Seq[String]): Boolean = {
// key must be a non-qualified field path that may be added to target schema via evolution
val assignmentKeyExpr = extractFieldPath(assignment.key, allowUnresolved = true)
// value should always be resolved (from source)
val assignmentValueExpr = extractFieldPath(assignment.value, allowUnresolved = false)
assignmentKeyExpr == assignmentValueExpr && assignmentKeyExpr == sourceFieldPath
}

private def areSchemaEvolutionReady(
assignments: Seq[Assignment],
source: LogicalPlan): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,94 @@ trait MergeIntoSchemaEvolutionTypeWideningAndExtraFieldTests
(3, 75, "newdep")).toDF("pk", "salary", "dep")
)

// When assigning s.bonus to existing t.salary and source.salary has a wider type (long) than
// target.salary (int), no evolution should occur because the assignment uses s.bonus, not
// s.salary. The type mismatch on the same-named column should be irrelevant.
testEvolution("source has extra column with type mismatch on existing column -" +
"should not evolve when assigning from differently named source column")(
targetData = {
val schema = StructType(Seq(
StructField("pk", IntegerType, nullable = false),
StructField("salary", IntegerType),
StructField("dep", StringType)
))
spark.createDataFrame(spark.sparkContext.parallelize(Seq(
Row(1, 100, "hr"),
Row(2, 200, "software")
)), schema)
},
sourceData = {
val schema = StructType(Seq(
StructField("pk", IntegerType, nullable = false),
StructField("salary", LongType),
StructField("dep", StringType),
StructField("bonus", LongType)
))
spark.createDataFrame(spark.sparkContext.parallelize(Seq(
Row(2, 150L, "dummy", 50L),
Row(3, 250L, "dummy", 75L)
)), schema)
},
clauses = Seq(
update(set = "salary = s.bonus"),
insert(values = "(pk, salary, dep) VALUES (s.pk, s.bonus, 'newdep')")
),
expected = Seq(
(1, 100, "hr"),
(2, 50, "software"),
(3, 75, "newdep")).toDF("pk", "salary", "dep"),
expectedWithoutEvolution = Seq(
(1, 100, "hr"),
(2, 50, "software"),
(3, 75, "newdep")).toDF("pk", "salary", "dep"),
expectedSchema = StructType(Seq(
StructField("pk", IntegerType, nullable = false),
StructField("salary", IntegerType),
StructField("dep", StringType)
)),
expectedSchemaWithoutEvolution = StructType(Seq(
StructField("pk", IntegerType, nullable = false),
StructField("salary", IntegerType),
StructField("dep", StringType)
))
)

// When assigning s.bonus (StringType) to target salary (IntegerType), the types are
// incompatible. This should fail both with and without schema evolution because the explicit
// assignment has mismatched types regardless of evolution.
testEvolution("source has extra column with type mismatch on existing column -" +
"should fail when assigning from incompatible source column")(
targetData = {
val schema = StructType(Seq(
StructField("pk", IntegerType, nullable = false),
StructField("salary", IntegerType),
StructField("dep", StringType)
))
spark.createDataFrame(spark.sparkContext.parallelize(Seq(
Row(1, 100, "hr"),
Row(2, 200, "software")
)), schema)
},
sourceData = {
val schema = StructType(Seq(
StructField("pk", IntegerType, nullable = false),
StructField("salary", LongType),
StructField("dep", StringType),
StructField("bonus", StringType)
))
spark.createDataFrame(spark.sparkContext.parallelize(Seq(
Row(2, 150L, "dummy", "fifty"),
Row(3, 250L, "dummy", "seventy-five")
)), schema)
},
clauses = Seq(
update(set = "salary = s.bonus"),
insert(values = "(pk, salary, dep) VALUES (s.pk, s.bonus, 'newdep')")
),
expectErrorContains = "Cannot safely cast",
expectErrorWithoutEvolutionContains = "Cannot safely cast"
)

// No evolution when using named_struct to construct value without referencing new field
testNestedStructsEvolution("source has extra struct field -" +
"no evolution when not directly referencing new field - INSERT")(
Expand Down