Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -511,13 +511,43 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object StreamingDeduplicationStrategy extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case Deduplicate(keys, child) if child.isStreaming =>
StreamingDeduplicateExec(keys, planLater(child)) :: Nil
StreamingDeduplicateExec(
keys, maybeNormalizeFloatingPointKeys(keys, child)) :: Nil

case DeduplicateWithinWatermark(keys, child) if child.isStreaming =>
StreamingDeduplicateWithinWatermarkExec(keys, planLater(child)) :: Nil
StreamingDeduplicateWithinWatermarkExec(
keys, maybeNormalizeFloatingPointKeys(keys, child)) :: Nil

case _ => Nil
}

/**
* If any dedup key contains a floating-point type (including nested types), wraps the physical
* child in a ProjectExec that normalizes NaN and -0.0 so that semantically equal values produce
* identical UnsafeRow bytes in the state store.
*
* Note that the streaming dedupe node does not support map-typed keys, althoug the
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: Maps are rejected for batch group keys as well; it's not streaming specific behavior.

Also, typo.

* `NormalizeFloatingNumbers` helper does.
*/
private def maybeNormalizeFloatingPointKeys(
keys: Seq[Attribute], child: LogicalPlan): SparkPlan = {
val physicalChild = planLater(child)
val normalizedProjectList = child.output.map { attr =>
if (keys.exists(_.exprId == attr.exprId)) {
NormalizeFloatingNumbers.normalize(attr) match {
case a: Attribute => a
case other => Alias(other, attr.name)(exprId = attr.exprId)
}
} else {
attr
}
}
if (normalizedProjectList.exists(!_.isInstanceOf[Attribute])) {
execution.ProjectExec(normalizedProjectList, physicalChild)
} else {
physicalChild
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,79 @@ class StreamingDeduplicationSuite extends StateStoreMetricsTest
)
}

// Canonical quiet NaN and a non-canonical NaN with a different bit pattern.
// Both are NaN, but they have distinct raw bit representations.
private val canonicalNaN = java.lang.Float.intBitsToFloat(0x7fc00000)
private val nonCanonicalNaN = java.lang.Float.intBitsToFloat(0x7f800001)
assert(java.lang.Float.isNaN(canonicalNaN) && java.lang.Float.isNaN(nonCanonicalNaN),
"Both values must be NaN")
assert(java.lang.Float.floatToRawIntBits(canonicalNaN) !=
java.lang.Float.floatToRawIntBits(nonCanonicalNaN),
"The two NaN values must have different bit patterns")

test("same NaN bit pattern should be deduplicated") {
val inputData = MemoryStream[(Float, Int)]
val result = inputData.toDS().toDF("value", "id").dropDuplicates("value")

testStream(result, Append)(
AddData(inputData, (canonicalNaN, 1)),
CheckLastBatch((canonicalNaN, 1)),
assertNumStateRows(total = 1, updated = 1),

AddData(inputData, (canonicalNaN, 2)),
CheckLastBatch(),
assertNumStateRows(total = 1, updated = 0),

CheckAnswer((canonicalNaN, 1))
)
}

/**
* Tests that different NaN bit patterns are deduplicated for a given key type.
* `buildKey` transforms the base DataFrame (columns: f, id) to add a "key" column.
*/
private def testNaNDedup(keyDesc: String, buildKey: DataFrame => DataFrame): Unit = {
test(s"NaN normalization in dedup key: $keyDesc") {
val inputData = MemoryStream[(Float, Int)]
val base = inputData.toDS().toDF("f", "id")
val result = buildKey(base).dropDuplicates("key").select("f", "id")

testStream(result, Append)(
AddData(inputData, (canonicalNaN, 1)),
CheckLastBatch((canonicalNaN, 1)),
assertNumStateRows(total = 1, updated = 1),

// Non-canonical NaN should be treated as a duplicate
AddData(inputData, (nonCanonicalNaN, 2)),
CheckLastBatch(),
assertNumStateRows(total = 1, updated = 0),

CheckAnswer((canonicalNaN, 1))
)
}
}

testNaNDedup("scalar float", _.withColumn("key", col("f")))
testNaNDedup("struct containing float", _.withColumn("key", struct(col("f"))))
testNaNDedup("array containing float", _.withColumn("key", array(col("f"))))

test("negative zero and positive zero should be deduplicated") {
val inputData = MemoryStream[(Float, Int)]
val result = inputData.toDS().toDF("value", "id").dropDuplicates("value")

testStream(result, Append)(
AddData(inputData, (0.0f, 1)),
CheckLastBatch((0.0f, 1)),
assertNumStateRows(total = 1, updated = 1),

AddData(inputData, (-0.0f, 2)),
CheckLastBatch(),
assertNumStateRows(total = 1, updated = 0),

CheckAnswer((0.0f, 1))
)
}

test("deduplicate with some columns") {
val inputData = MemoryStream[(String, Int)]
val result = inputData.toDS().dropDuplicates("_1")
Expand Down