@@ -826,8 +826,8 @@ public virtual SqlExpression Case(
826826 }
827827
828828 // Simplify:
829- // a == null ? null : a -> a
830- // a != null ? a : null -> a
829+ // a == b ? b : a -> a
830+ // a != b ? a : b -> a
831831 // And lift:
832832 // a == b ? null : a -> NULLIF(a, b)
833833 // a != b ? a : null -> NULLIF(a, b)
@@ -838,28 +838,39 @@ public virtual SqlExpression Case(
838838 Test : SqlBinaryExpression { OperatorType : ExpressionType . Equal or ExpressionType . NotEqual } binary ,
839839 Result : var result
840840 }
841- ]
842- && binary . OperatorType switch
843- {
844- ExpressionType . Equal when result is SqlConstantExpression { Value : null } && elseResult is not null => elseResult ,
845- ExpressionType . NotEqual when elseResult is null or SqlConstantExpression { Value : null } => result ,
846- _ => null
847- } is SqlExpression conditionalResult )
841+ ] )
848842 {
849843 var ( left , right ) = ( binary . Left , binary . Right ) ;
850844
851- if ( left . Equals ( conditionalResult ) )
845+ // Reverse ifEqual/ifNotEqual for ExpressionType.NotEqual for easier reasonining below
846+ var ( ifEqual , ifNotEqual ) = binary . OperatorType is ExpressionType . Equal
847+ ? ( result , elseResult ?? Constant ( null , result . Type , result . TypeMapping ) )
848+ : ( elseResult ?? Constant ( null , result . Type , result . TypeMapping ) , result ) ;
849+
850+ if ( left . Equals ( ifNotEqual ) )
852851 {
853- return right is SqlConstantExpression { Value : null }
854- ? left
855- : Function ( "NULLIF" , [ left , right ] , nullable : true , [ false , false ] , left . Type , left . TypeMapping ) ;
852+ switch ( ifEqual )
853+ {
854+ // a == b ? b : a -> a
855+ case SqlConstantExpression { Value : null } :
856+ return Function ( "NULLIF" , [ left , right ] , nullable : true , [ false , false ] , left . Type , left . TypeMapping ) ;
857+ // a == b ? null : a -> NULLIF(a, b)
858+ case var _ when ifEqual . Equals ( right ) :
859+ return left ;
860+ }
856861 }
857862
858- if ( right . Equals ( conditionalResult ) )
863+ if ( right . Equals ( ifNotEqual ) )
859864 {
860- return left is SqlConstantExpression { Value : null }
861- ? right
862- : Function ( "NULLIF" , [ right , left ] , nullable : true , [ false , false ] , right . Type , right . TypeMapping ) ;
865+ switch ( ifEqual )
866+ {
867+ // b == a ? b : a -> a
868+ case SqlConstantExpression { Value : null } :
869+ return Function ( "NULLIF" , [ right , left ] , nullable : true , [ false , false ] , right . Type , right . TypeMapping ) ;
870+ // b == a ? null : a -> NULLIF(a, b)
871+ case var _ when ifEqual . Equals ( left ) :
872+ return right ;
873+ }
863874 }
864875 }
865876
0 commit comments