Skip to content

Commit a255e32

Browse files
authored
Merge pull request #914 from FlorentinD/listComprehension
Implement list-comprehension
2 parents 5e63178 + ff34fcb commit a255e32

File tree

7 files changed

+149
-11
lines changed

7 files changed

+149
-11
lines changed

morpheus-spark-cypher/src/main/scala/org/opencypher/morpheus/impl/SparkSQLExprMapper.scala

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
*/
2727
package org.opencypher.morpheus.impl
2828

29-
import org.apache.spark.sql.catalyst.expressions.CaseWhen
29+
import org.apache.spark.sql.catalyst.expressions.{ArrayFilter, ArrayTransform, CaseWhen, ExprId, LambdaFunction, NamedLambdaVariable}
3030
import org.apache.spark.sql.functions.{array_contains => _, translate => _, _}
3131
import org.apache.spark.sql.types._
3232
import org.apache.spark.sql.{Column, DataFrame}
@@ -88,6 +88,9 @@ object SparkSQLExprMapper {
8888
*/
8989
def asSparkSQLExpr(implicit header: RecordHeader, df: DataFrame, parameters: CypherMap): Column = {
9090
val outCol = expr match {
91+
case v: LambdaVar =>
92+
val sparkType = v.cypherType.toSparkType.getOrElse(throw IllegalStateException(s"No valid dataType for LambdaVar $v"))
93+
new Column(NamedLambdaVariable(v.name, sparkType, nullable = v.cypherType.isNullable, ExprId(v.hashCode.toLong)))
9194
// Evaluate based on already present data; no recursion
9295
case _: Var | _: HasLabel | _: HasType | _: StartNode | _: EndNode => column_for(expr)
9396
// Evaluate bottom-up
@@ -329,6 +332,25 @@ object SparkSQLExprMapper {
329332
case _: ListSliceFrom => list_slice(child0, Some(child1), None)
330333
case _: ListSliceTo => list_slice(child0, None, Some(child1))
331334

335+
case ListComprehension(variable, innerPredicate, extractExpression, listExpr) =>
336+
val lambdaVar = variable.asSparkSQLExpr.expr match {
337+
case v: NamedLambdaVariable => v
338+
case err => throw IllegalStateException(s"$variable should be converted into a NamedLambdaVariable instead of $err")
339+
}
340+
val filteredExpr = innerPredicate match {
341+
case Some(filterExpr) =>
342+
val filterFunc = LambdaFunction(filterExpr.asSparkSQLExpr.expr, Seq(lambdaVar))
343+
ArrayFilter(listExpr.asSparkSQLExpr.expr, filterFunc)
344+
case None => listExpr.asSparkSQLExpr.expr
345+
}
346+
val result = extractExpression match{
347+
case Some(extractExpr) =>
348+
val extractFunc = LambdaFunction(extractExpr.asSparkSQLExpr.expr, Seq(lambdaVar))
349+
ArrayTransform(filteredExpr, extractFunc)
350+
case None => filteredExpr
351+
}
352+
new Column(result)
353+
332354
case MapExpression(items) => expr.cypherType.material match {
333355
case CTMap(_) =>
334356
val innerColumns = items.map {

morpheus-tck/src/test/resources/failing_blacklist

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,9 @@ Feature "UnwindAcceptance": Scenario "Unwinding a collected expression"
2020
Feature "UnwindAcceptance": Scenario "Unwind does not remove variables from scope"
2121
Feature "TypeConversionFunctions": Scenario "`toInteger()` handling mixed number types"
2222
Feature "TypeConversionFunctions": Scenario "`toInteger()` handling Any type"
23-
Feature "TypeConversionFunctions": Scenario "`toInteger()` on a list of strings"
2423
Feature "TypeConversionFunctions": Scenario "`toFloat()` on mixed number types"
2524
Feature "TypeConversionFunctions": Scenario "`toFloat()` handling Any type"
26-
Feature "TypeConversionFunctions": Scenario "`toFloat()` on a list of strings"
2725
Feature "TypeConversionFunctions": Scenario "`toString()` should work on Any type"
28-
Feature "TypeConversionFunctions": Scenario "`toString()` on a list of integers"
2926
Feature "TypeConversionFunctions": Scenario "`toBoolean()` on invalid types #1"
3027
Feature "TypeConversionFunctions": Scenario "`toBoolean()` on invalid types #2"
3128
Feature "TypeConversionFunctions": Scenario "`toBoolean()` on invalid types #3"
@@ -192,7 +189,6 @@ Feature "Aggregation": Scenario "`min()` over mixed values"
192189
Feature "Aggregation": Scenario "`max()` over mixed values"
193190
Feature "Aggregation": Scenario "`max()` over mixed numeric values"
194191
Feature "ListOperations": Scenario "IN should return true if correct list found despite other lists having nulls"
195-
Feature "ListOperations": Scenario "Size of list comprehension"
196192
Feature "ListOperations": Scenario "IN should return false when matching a number with a string - list version"
197193
Feature "ListOperations": Scenario "IN should return false when matching a number with a string"
198194
Feature "ListOperations": Scenario "IN should return true when LHS and RHS contain a nested list - singleton version"

morpheus-testing/src/test/scala/org/opencypher/morpheus/impl/acceptance/ExpressionTests.scala

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1468,4 +1468,78 @@ class ExpressionTests extends MorpheusTestSuite with ScanGraphInit with Checkers
14681468
}
14691469
}
14701470
}
1471+
1472+
describe("list comprehension") {
1473+
it("supports list comprehension with static mapping") {
1474+
val result = morpheus.cypher(
1475+
"""
1476+
|WITH [1, 2, 3] AS things
1477+
|RETURN [n IN things | 1] AS value
1478+
""".stripMargin)
1479+
1480+
result.records.toMaps shouldEqual Bag(
1481+
CypherMap("value" -> List(1, 1, 1))
1482+
)
1483+
}
1484+
1485+
it("supports list comprehension with simple mapping") {
1486+
val result = morpheus.cypher(
1487+
"""
1488+
|WITH [1, 2, 3] AS things
1489+
|RETURN [n IN things | n*3] AS value
1490+
""".stripMargin)
1491+
1492+
result.records.toMaps shouldEqual Bag(
1493+
CypherMap("value" -> List(3, 6, 9))
1494+
)
1495+
}
1496+
1497+
it("supports list comprehension with more complex mapping") {
1498+
val result = morpheus.cypher(
1499+
"""
1500+
|WITH ['1', '2', '3'] AS things
1501+
|RETURN [n IN things | toInteger(n)*3 + toInteger(n)] AS value
1502+
""".stripMargin)
1503+
1504+
result.records.toMaps shouldEqual Bag(
1505+
CypherMap("value" -> List(4, 8, 12))
1506+
)
1507+
}
1508+
1509+
it("supports list comprehension with inner predicate") {
1510+
val result = morpheus.cypher(
1511+
"""
1512+
|WITH [1, 2, 3] AS things
1513+
|RETURN [n IN things WHERE n > 2] AS value
1514+
""".stripMargin)
1515+
1516+
result.records.toMaps shouldEqual Bag(
1517+
CypherMap("value" -> List(3))
1518+
)
1519+
}
1520+
1521+
it("supports list comprehension with inner predicate and more complex mapping") {
1522+
val result = morpheus.cypher(
1523+
"""
1524+
|WITH ['1', '2', '3'] AS things
1525+
|RETURN [n IN things WHERE toInteger(n) > 2 | toInteger(n)*3 + toInteger(n)] AS value
1526+
""".stripMargin)
1527+
1528+
result.records.toMaps shouldEqual Bag(
1529+
CypherMap("value" -> List(12))
1530+
)
1531+
}
1532+
1533+
it("supports nested list comprehensions") {
1534+
val result = morpheus.cypher(
1535+
"""
1536+
|WITH [[1,2,3], [2,2,3], [3,4]] AS things
1537+
|RETURN [n IN things | [n IN n WHERE n < 2]] AS value
1538+
""".stripMargin)
1539+
1540+
result.records.toMaps shouldEqual Bag(
1541+
CypherMap("value" -> List(List(1), List(), List()))
1542+
)
1543+
}
1544+
}
14711545
}

okapi-ir/src/main/scala/org/opencypher/okapi/ir/api/expr/Expr.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,8 @@ final case class SimpleVar(name: String)(val cypherType: CypherType) extends Ret
196196
override def withOwner(expr: Var): SimpleVar = SimpleVar(expr.name)(expr.cypherType)
197197
}
198198

199+
final case class LambdaVar(name: String)(val cypherType: CypherType) extends Var
200+
199201
final case class StartNode(rel: Expr)(val cypherType: CypherType) extends Expr {
200202
type This = StartNode
201203

@@ -1099,6 +1101,18 @@ final case class ListSliceFrom(list: Expr, from: Expr) extends ListSlice(Some(fr
10991101

11001102
final case class ListSliceTo(list: Expr, to: Expr) extends ListSlice(None, Some(to))
11011103

1104+
final case class ListComprehension(variable: Expr, innerPredicate: Option[Expr], extractExpression: Option[Expr], expr : Expr) extends Expr {
1105+
override def withoutType: String = {
1106+
val p = innerPredicate.map(" WHERE " + _.withoutType).getOrElse("")
1107+
val e = extractExpression.map(" | " + _.withoutType).getOrElse("")
1108+
s"[${variable.withoutType} IN ${expr.withoutType}$p$e]"
1109+
}
1110+
override def cypherType: CypherType = extractExpression match {
1111+
case Some(x) => CTList(x.cypherType)
1112+
case None => expr.cypherType
1113+
}
1114+
}
1115+
11021116
final case class ContainerIndex(container: Expr, index: Expr)(val cypherType: CypherType) extends Expr {
11031117

11041118
override def withoutType: String = s"${container.withoutType}[${index.withoutType}]"

okapi-ir/src/main/scala/org/opencypher/okapi/ir/impl/ExpressionConverter.scala

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import org.opencypher.okapi.ir.api.expr._
3434
import org.opencypher.okapi.ir.impl.parse.{functions => f}
3535
import org.opencypher.okapi.ir.impl.typer.SignatureConverter.Signature
3636
import org.opencypher.okapi.ir.impl.typer.{InvalidArgument, InvalidContainerAccess, MissingParameter, UnTypedExpr, WrongNumberOfArguments}
37-
import org.opencypher.v9_0.expressions.{RegexMatch, functions}
37+
import org.opencypher.v9_0.expressions.{ExtractScope, LogicalVariable, RegexMatch, functions}
3838
import org.opencypher.v9_0.{expressions => ast}
3939
import SignatureTyping._
4040

@@ -62,18 +62,21 @@ final class ExpressionConverter(context: IRBuilderContext) {
6262
}
6363
}
6464

65-
def convert(e: ast.Expression): Expr = {
65+
def convert(e: ast.Expression) (implicit lambdaVars: Map[String, CypherType]): Expr = {
6666

6767
lazy val child0: Expr = convert(e.arguments.head)
6868

6969
lazy val child1: Expr = convert(e.arguments(1))
7070

7171
lazy val child2: Expr = convert(e.arguments(2))
7272

73-
lazy val convertedChildren: List[Expr] = e.arguments.toList.map(convert)
73+
lazy val convertedChildren: List[Expr] = e.arguments.toList.map(convert(_))
7474

7575
e match {
76-
case ast.Variable(name) => Var(name)(context.knownTypes.getOrElse(e, throw UnTypedExpr(e)))
76+
case ast.Variable(name) => lambdaVars.get(name) match {
77+
case Some(varType) => LambdaVar(name)(varType)
78+
case None => Var(name)(context.knownTypes.getOrElse(e, throw UnTypedExpr(e)))
79+
}
7780
case p@ast.Parameter(name, _) => Param(name)(parameterType(p))
7881

7982
// Literals
@@ -276,6 +279,16 @@ final class ExpressionConverter(context: IRBuilderContext) {
276279
case ast.ListSlice(list, None, Some(to)) => ListSliceTo(convert(list), convert(to))
277280
case ast.ListSlice(list, Some(from), None) => ListSliceFrom(convert(list), convert(from))
278281

282+
case ast.ListComprehension(ExtractScope(variable, innerPredicate, extractExpression), expr) =>
283+
val listExpr = convert(expr)(lambdaVars)
284+
val listInnerType = listExpr.cypherType match {
285+
case CTList(inner) => inner
286+
case err => throw IllegalArgumentException("a list to step over", err, "Wrong list comprehension type")
287+
}
288+
val updatedLambdaVars: Map[String, CypherType] = lambdaVars + (variable.name -> listInnerType)
289+
ListComprehension(convert(variable)(updatedLambdaVars), innerPredicate.map(convert(_)(updatedLambdaVars)),
290+
extractExpression.map(convert(_)(updatedLambdaVars)), listExpr)
291+
279292
case ast.ContainerIndex(container, index) =>
280293
val convertedContainer = convert(container)
281294
val elementType = convertedContainer.cypherType.material match {

okapi-ir/src/main/scala/org/opencypher/okapi/ir/impl/IRBuilderContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ final case class IRBuilderContext(
6161
patternConverter.convert(p, knownTypes, qgn.getOrElse(workingGraph.qualifiedGraphName))
6262
}
6363

64-
def convertExpression(e: ast.Expression): Expr = exprConverter.convert(e)
64+
def convertExpression(e: ast.Expression): Expr = exprConverter.convert(e)(lambdaVars = Map())
6565

6666
def schemaFor(qgn: QualifiedGraphName): PropertyGraphSchema = queryLocalCatalog.schema(qgn)
6767

okapi-ir/src/test/scala/org/opencypher/okapi/ir/impl/ExpressionConverterTest.scala

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,10 +439,29 @@ class ExpressionConverterTest extends BaseTestSuite with Neo4jAstTestSupport {
439439
)
440440
}
441441

442+
describe("list comprehension") {
443+
val intVar = LambdaVar("x")(CTInteger)
444+
it("can convert list comprehension with static mapping") {
445+
convert("[x IN [1,2] | 1]") shouldEqual ListComprehension(intVar, None, Some(IntegerLit(1)), ListLit(List(IntegerLit(1), IntegerLit(2))))
446+
}
447+
448+
it("can convert list comprehension with unary mapping") {
449+
convert("[x IN [1,2] | toString(x)]")
450+
}
451+
452+
it("can convert list comprehension with 2 var-calls") {
453+
convert("[x IN [1,2] | x + x * 2]") shouldEqual ListComprehension(intVar, None, Some(Add(intVar, Multiply(intVar, IntegerLit(2)))), ListLit(List(IntegerLit(1), IntegerLit(2))))
454+
}
455+
456+
it("can convert list comprehension with inner predicate") {
457+
convert("[x IN [1,2] WHERE x < 1 | 1]") shouldEqual ListComprehension(intVar, Some(LessThan(intVar, IntegerLit(1))), Some(IntegerLit(1)), ListLit(List(IntegerLit(1), IntegerLit(2))))
458+
}
459+
}
460+
442461
implicit def toVar(s: Symbol): Var = all.find(_.name == s.name).get
443462

444463
private def convert(e: ast.Expression): Expr =
445-
new ExpressionConverter(testContext).convert(e)
464+
new ExpressionConverter(testContext).convert(e)(Map())
446465

447466
implicit class TestExpr(expr: Expr) {
448467
def shouldEqual(other: Expr): Assertion = {

0 commit comments

Comments
 (0)