diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f11dcbd1e7c1e..c5e08af300364 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -511,13 +511,43 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object StreamingDeduplicationStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case Deduplicate(keys, child) if child.isStreaming => - StreamingDeduplicateExec(keys, planLater(child)) :: Nil + StreamingDeduplicateExec( + keys, maybeNormalizeFloatingPointKeys(keys, child)) :: Nil case DeduplicateWithinWatermark(keys, child) if child.isStreaming => - StreamingDeduplicateWithinWatermarkExec(keys, planLater(child)) :: Nil + StreamingDeduplicateWithinWatermarkExec( + keys, maybeNormalizeFloatingPointKeys(keys, child)) :: Nil case _ => Nil } + + /** + * If any dedup key contains a floating-point type (including nested types), wraps the physical + * child in a ProjectExec that normalizes NaN and -0.0 so that semantically equal values produce + * identical UnsafeRow bytes in the state store. + * + * Note that the streaming dedupe node does not support map-typed keys, althoug the + * `NormalizeFloatingNumbers` helper does. + */ + private def maybeNormalizeFloatingPointKeys( + keys: Seq[Attribute], child: LogicalPlan): SparkPlan = { + val physicalChild = planLater(child) + val normalizedProjectList = child.output.map { attr => + if (keys.exists(_.exprId == attr.exprId)) { + NormalizeFloatingNumbers.normalize(attr) match { + case a: Attribute => a + case other => Alias(other, attr.name)(exprId = attr.exprId) + } + } else { + attr + } + } + if (normalizedProjectList.exists(!_.isInstanceOf[Attribute])) { + execution.ProjectExec(normalizedProjectList, physicalChild) + } else { + physicalChild + } + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala index 7aec3353cd4dc..003f71d16437a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala @@ -60,6 +60,79 @@ class StreamingDeduplicationSuite extends StateStoreMetricsTest ) } + // Canonical quiet NaN and a non-canonical NaN with a different bit pattern. + // Both are NaN, but they have distinct raw bit representations. + private val canonicalNaN = java.lang.Float.intBitsToFloat(0x7fc00000) + private val nonCanonicalNaN = java.lang.Float.intBitsToFloat(0x7f800001) + assert(java.lang.Float.isNaN(canonicalNaN) && java.lang.Float.isNaN(nonCanonicalNaN), + "Both values must be NaN") + assert(java.lang.Float.floatToRawIntBits(canonicalNaN) != + java.lang.Float.floatToRawIntBits(nonCanonicalNaN), + "The two NaN values must have different bit patterns") + + test("same NaN bit pattern should be deduplicated") { + val inputData = MemoryStream[(Float, Int)] + val result = inputData.toDS().toDF("value", "id").dropDuplicates("value") + + testStream(result, Append)( + AddData(inputData, (canonicalNaN, 1)), + CheckLastBatch((canonicalNaN, 1)), + assertNumStateRows(total = 1, updated = 1), + + AddData(inputData, (canonicalNaN, 2)), + CheckLastBatch(), + assertNumStateRows(total = 1, updated = 0), + + CheckAnswer((canonicalNaN, 1)) + ) + } + + /** + * Tests that different NaN bit patterns are deduplicated for a given key type. + * `buildKey` transforms the base DataFrame (columns: f, id) to add a "key" column. + */ + private def testNaNDedup(keyDesc: String, buildKey: DataFrame => DataFrame): Unit = { + test(s"NaN normalization in dedup key: $keyDesc") { + val inputData = MemoryStream[(Float, Int)] + val base = inputData.toDS().toDF("f", "id") + val result = buildKey(base).dropDuplicates("key").select("f", "id") + + testStream(result, Append)( + AddData(inputData, (canonicalNaN, 1)), + CheckLastBatch((canonicalNaN, 1)), + assertNumStateRows(total = 1, updated = 1), + + // Non-canonical NaN should be treated as a duplicate + AddData(inputData, (nonCanonicalNaN, 2)), + CheckLastBatch(), + assertNumStateRows(total = 1, updated = 0), + + CheckAnswer((canonicalNaN, 1)) + ) + } + } + + testNaNDedup("scalar float", _.withColumn("key", col("f"))) + testNaNDedup("struct containing float", _.withColumn("key", struct(col("f")))) + testNaNDedup("array containing float", _.withColumn("key", array(col("f")))) + + test("negative zero and positive zero should be deduplicated") { + val inputData = MemoryStream[(Float, Int)] + val result = inputData.toDS().toDF("value", "id").dropDuplicates("value") + + testStream(result, Append)( + AddData(inputData, (0.0f, 1)), + CheckLastBatch((0.0f, 1)), + assertNumStateRows(total = 1, updated = 1), + + AddData(inputData, (-0.0f, 2)), + CheckLastBatch(), + assertNumStateRows(total = 1, updated = 0), + + CheckAnswer((0.0f, 1)) + ) + } + test("deduplicate with some columns") { val inputData = MemoryStream[(String, Int)] val result = inputData.toDS().dropDuplicates("_1")