diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 6f806760b3736..3507627a58571 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -97,6 +97,21 @@ case class RaiseError(errorClass: Expression, errorParms: Expression, dataType: override def prettyName: String = "raise_error" + override def sql: String = { + // When constructed from the single-argument form (just an error message string), + // output only the original message to produce valid, roundtrippable SQL. + (errorClass, errorParms) match { + case (Literal(cls, _), CreateMap(Seq(Literal(key, _), msg), _)) + if cls != null && + (cls.toString == "USER_RAISED_EXCEPTION" || + cls.toString == "_LEGACY_ERROR_USER_RAISED_EXCEPTION") && + key != null && key.toString == "errorMessage" => + s"$prettyName(${msg.sql})" + case _ => + super.sql + } + } + override def eval(input: InternalRow): Any = { val error = errorClass.eval(input).asInstanceOf[UTF8String] val parms: MapData = errorParms.eval(input).asInstanceOf[MapData] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index 40e6fe1a90a63..ff0183eb265b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -55,6 +55,16 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ) } + test("SPARK-55109: RaiseError.sql uses single-argument form only for known error classes") { + assert(RaiseError(Literal("error!")).sql === "raise_error('error!')") + + // A custom errorClass should NOT produce the single-argument form + val customError = RaiseError( + Literal("CUSTOM_ERROR"), + CreateMap(Seq(Literal("errorMessage"), Literal("error!")))) + assert(customError.sql === "raise_error('CUSTOM_ERROR', map('errorMessage', 'error!'))") + } + test("uuid") { checkEvaluation(Length(Uuid(Some(0))), 36) val r = new Random()