diff --git a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java index 5f30a1931..838fc13eb 100644 --- a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java @@ -149,6 +149,11 @@ public OUTPUT visit(Expression.StructLiteral expr) throws EXCEPTION { return visitFallback(expr); } + @Override + public OUTPUT visit(Expression.StructNested expr) throws EXCEPTION { + return visitFallback(expr); + } + @Override public OUTPUT visit(Expression.Switch expr) throws EXCEPTION { return visitFallback(expr); diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index aa9e69148..7ad8c01d9 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -574,6 +574,27 @@ public R accept(ExpressionVisitor visitor) throws } } + @Value.Immutable + abstract static class StructNested implements Expression { + public abstract List fields(); + + public Type getType() { + return Type.withNullability(false) + .struct( + fields().stream() + .map(Expression::getType) + .collect(java.util.stream.Collectors.toList())); + } + + public static ImmutableExpression.StructNested.Builder builder() { + return ImmutableExpression.StructNested.builder(); + } + + public R accept(ExpressionVisitor visitor) throws E { + return visitor.visit(this); + } + } + @Value.Immutable abstract static class UserDefinedLiteral implements Literal { public abstract ByteString value(); diff --git a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java index b27a241a2..0723a6dec 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java @@ -61,6 +61,8 @@ public interface ExpressionVisitor { R visit(Expression.StructLiteral expr) throws E; + R visit(Expression.StructNested expr) throws E; + R visit(Expression.UserDefinedLiteral expr) throws E; R visit(Expression.Switch expr) throws E; diff --git a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java index da86d704e..d0b7be3b7 100644 --- a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java @@ -72,6 +72,12 @@ private Expression lit(Consumer consumer) { return Expression.newBuilder().setLiteral(builder).build(); } + private Expression nested(Consumer consumer) { + var builder = Expression.Nested.newBuilder(); + consumer.accept(builder); + return Expression.newBuilder().setNested(builder).build(); + } + @Override public Expression visit(io.substrait.expression.Expression.BoolLiteral expr) { return lit(bldr -> bldr.setNullable(expr.nullable()).setBoolean(expr.value())); @@ -323,6 +329,18 @@ public Expression visit(io.substrait.expression.Expression.StructLiteral expr) { }); } + @Override + public Expression visit(io.substrait.expression.Expression.StructNested expr) { + return nested( + bldr -> { + var values = + expr.fields().stream() + .map(this::toProto) + .collect(java.util.stream.Collectors.toList()); + bldr.setStruct(Expression.Nested.Struct.newBuilder().addAllFields(values)); + }); + } + @Override public Expression visit(io.substrait.expression.Expression.UserDefinedLiteral expr) { var typeReference = diff --git a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java index 29bbe1a8c..484f16a95 100644 --- a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java @@ -173,6 +173,14 @@ public Optional visit(Expression.StructLiteral expr) throws EXCEPTIO return visitLiteral(expr); } + @Override + public Optional visit(Expression.StructNested expr) throws EXCEPTION { + var expressions = visitExprList(expr.fields()); + return expressions.map( + expressionList -> + Expression.StructNested.builder().from(expr).fields(expressionList).build()); + } + @Override public Optional visit(Expression.UserDefinedLiteral expr) throws EXCEPTION { return visitLiteral(expr); diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index 9a5416f69..105eeab8f 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -133,7 +133,7 @@ public Rel from(io.substrait.proto.Rel rel) { protected Rel newRead(ReadRel rel) { if (rel.hasVirtualTable()) { var virtualTable = rel.getVirtualTable(); - if (virtualTable.getValuesCount() == 0) { + if (virtualTable.getValuesCount() == 0 && virtualTable.getExpressionsCount() == 0) { return newEmptyScan(rel); } else { return newVirtualTable(rel); @@ -417,17 +417,33 @@ protected FileOrFiles newFileOrFiles(ReadRel.LocalFiles.FileOrFiles file) { protected VirtualTableScan newVirtualTable(ReadRel rel) { var virtualTable = rel.getVirtualTable(); + // If both values and expressions are set, raise an error + if (virtualTable.getValuesCount() > 0 && virtualTable.getExpressionsCount() > 0) { + throw new IllegalArgumentException( + "Virtual table cannot have both values and expressions set"); + } + var virtualTableSchema = newNamedStruct(rel); + var converter = new ProtoExpressionConverter(lookup, extensions, virtualTableSchema.struct(), this); - List structLiterals = new ArrayList<>(virtualTable.getValuesCount()); + + List expressions = + new ArrayList<>(virtualTable.getValuesCount() + virtualTable.getExpressionsCount()); + for (var struct : virtualTable.getValuesList()) { - structLiterals.add( + expressions.add( ImmutableExpression.StructLiteral.builder() .fields( - struct.getFieldsList().stream() - .map(converter::from) - .collect(java.util.stream.Collectors.toList())) + struct.getFieldsList().stream().map(converter::from).collect(Collectors.toList())) + .build()); + } + + for (var expr : virtualTable.getExpressionsList()) { + expressions.add( + ImmutableExpression.StructNested.builder() + .fields( + expr.getFieldsList().stream().map(converter::from).collect(Collectors.toList())) .build()); } @@ -438,7 +454,7 @@ protected VirtualTableScan newVirtualTable(ReadRel rel) { rel.hasBestEffortFilter() ? converter.from(rel.getBestEffortFilter()) : null)) .filter(Optional.ofNullable(rel.hasFilter() ? converter.from(rel.getFilter()) : null)) .initialSchema(NamedStruct.fromProto(rel.getBaseSchema(), protoTypeConverter)) - .rows(structLiterals); + .rows(expressions); builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) diff --git a/core/src/main/java/io/substrait/relation/VirtualTableScan.java b/core/src/main/java/io/substrait/relation/VirtualTableScan.java index 88e78d7ff..f7df7a6c1 100644 --- a/core/src/main/java/io/substrait/relation/VirtualTableScan.java +++ b/core/src/main/java/io/substrait/relation/VirtualTableScan.java @@ -4,12 +4,13 @@ import io.substrait.type.Type; import io.substrait.type.TypeVisitor; import java.util.List; +import java.util.Objects; import org.immutables.value.Value; @Value.Immutable public abstract class VirtualTableScan extends AbstractReadRel { - public abstract List getRows(); + public abstract List getRows(); /** * @@ -29,9 +30,9 @@ protected void check() { == NamedFieldCountingTypeVisitor.countNames(this.getInitialSchema().struct()); var rows = getRows(); - assert rows.size() > 0 - && names.stream().noneMatch(s -> s == null) - && rows.stream().noneMatch(r -> r == null) + assert !rows.isEmpty() + && names.stream().noneMatch(Objects::isNull) + && rows.stream().noneMatch(Objects::isNull) && rows.stream() .allMatch(r -> NamedFieldCountingTypeVisitor.countNames(r.getType()) == names.size()); } diff --git a/core/src/test/java/io/substrait/relation/VirtualTableScanTest.java b/core/src/test/java/io/substrait/relation/VirtualTableScanTest.java index ca6fceaa7..d5bd3c129 100644 --- a/core/src/test/java/io/substrait/relation/VirtualTableScanTest.java +++ b/core/src/test/java/io/substrait/relation/VirtualTableScanTest.java @@ -1,5 +1,12 @@ package io.substrait.relation; +import static io.substrait.expression.ExpressionCreator.bool; +import static io.substrait.expression.ExpressionCreator.fp32; +import static io.substrait.expression.ExpressionCreator.fp64; +import static io.substrait.expression.ExpressionCreator.i8; +import static io.substrait.expression.ExpressionCreator.i16; +import static io.substrait.expression.ExpressionCreator.i32; +import static io.substrait.expression.ExpressionCreator.i64; import static io.substrait.expression.ExpressionCreator.list; import static io.substrait.expression.ExpressionCreator.map; import static io.substrait.expression.ExpressionCreator.string; @@ -25,6 +32,13 @@ void check() { NamedStruct.of( Arrays.stream( new String[] { + "bool_field", + "i8_field", + "i16_field", + "i32_field", + "i64_field", + "fp32_field", + "fp64_field", "string", "struct", "struct_field1", @@ -37,6 +51,13 @@ void check() { }) .collect(Collectors.toList()), R.struct( + R.BOOLEAN, + R.I8, + R.I16, + R.I32, + R.I64, + R.FP32, + R.FP64, R.STRING, R.struct(R.STRING, R.STRING), R.list(R.struct(R.STRING)), @@ -44,6 +65,13 @@ void check() { .addRows( struct( false, + bool(false, true), + i8(false, 42), + i16(false, 1234), + i32(false, 123456), + i64(false, 9876543210L), + fp32(false, 3.14f), + fp64(false, 2.718281828), string(false, "string_val"), struct( false, diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala index afb8c16ba..e65584f47 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -345,13 +345,21 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) } override def visit(virtualTableScan: relation.VirtualTableScan): LogicalPlan = { - val rows = virtualTableScan.getRows.asScala.map( - row => + val rows = virtualTableScan.getRows.asScala.map { + case structLit: SExpression.StructLiteral => InternalRow.fromSeq( - row - .fields() - .asScala - .map(field => field.accept(expressionConverter).asInstanceOf[Literal].value))) + structLit.fields.asScala + .map(field => field.accept(expressionConverter).asInstanceOf[Literal].value) + ) + case structNested: SExpression.StructNested => + InternalRow.fromSeq( + structNested.fields.asScala + .map(expr => expr.accept(expressionConverter)) + ) + case other => + throw new UnsupportedOperationException( + s"Unsupported row type in VirtualTableScan: ${other.getClass}") + } virtualTableScan.getInitialSchema match { case ns: NamedStruct if ns.names().isEmpty && rows.length == 1 => OneRowRelation()