Skip to content

Commit 9cec40e

Browse files
committed
fix(isthmus): handle subqueries with outer field references
fixes #382 Signed-off-by: Niels Pardon <[email protected]>
1 parent 0b61b18 commit 9cec40e

File tree

13 files changed

+260
-152
lines changed

13 files changed

+260
-152
lines changed

core/src/main/java/io/substrait/expression/FieldReference.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,13 @@ public <R, C extends VisitationContext, E extends Throwable> R accept(
3636
}
3737

3838
public boolean isSimpleRootReference() {
39-
return segments().size() == 1 && !inputExpression().isPresent();
39+
return segments().size() == 1
40+
&& !inputExpression().isPresent()
41+
&& !outerReferenceStepsOut().isPresent();
42+
}
43+
44+
public boolean isOuterReference() {
45+
return outerReferenceStepsOut().orElse(0) > 0;
4046
}
4147

4248
public FieldReference dereferenceStruct(int index) {

isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java

Lines changed: 85 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,14 @@
2323
import io.substrait.relation.Rel;
2424
import io.substrait.relation.Set;
2525
import io.substrait.relation.Sort;
26-
import io.substrait.util.EmptyVisitationContext;
26+
import io.substrait.util.VisitationContext;
2727
import java.util.ArrayList;
2828
import java.util.Collection;
2929
import java.util.Collections;
30+
import java.util.HashSet;
3031
import java.util.List;
3132
import java.util.Optional;
33+
import java.util.Stack;
3234
import java.util.stream.Collectors;
3335
import java.util.stream.IntStream;
3436
import java.util.stream.Stream;
@@ -40,6 +42,7 @@
4042
import org.apache.calcite.rel.RelFieldCollation;
4143
import org.apache.calcite.rel.RelNode;
4244
import org.apache.calcite.rel.core.AggregateCall;
45+
import org.apache.calcite.rel.core.CorrelationId;
4346
import org.apache.calcite.rel.core.JoinRelType;
4447
import org.apache.calcite.rel.logical.LogicalValues;
4548
import org.apache.calcite.rel.type.RelDataType;
@@ -59,7 +62,7 @@
5962
* visitFallback and throw UnsupportedOperationException.
6063
*/
6164
public class SubstraitRelNodeConverter
62-
extends AbstractRelVisitor<RelNode, EmptyVisitationContext, RuntimeException> {
65+
extends AbstractRelVisitor<RelNode, SubstraitRelNodeConverter.Context, RuntimeException> {
6366

6467
protected final RelDataTypeFactory typeFactory;
6568

@@ -137,42 +140,44 @@ public static RelNode convert(
137140
return relRoot.accept(
138141
new SubstraitRelNodeConverter(
139142
EXTENSION_COLLECTION, relOptCluster.getTypeFactory(), relBuilder),
140-
null);
143+
Context.newContext());
141144
}
142145

143146
@Override
144-
public RelNode visit(Filter filter, EmptyVisitationContext context) throws RuntimeException {
147+
public RelNode visit(Filter filter, Context context) throws RuntimeException {
145148
RelNode input = filter.getInput().accept(this, context);
149+
context.pushParentRelNodes(input);
146150
RexNode filterCondition = filter.getCondition().accept(expressionRexConverter, context);
147-
RelNode node = relBuilder.push(input).filter(filterCondition).build();
151+
RelNode node =
152+
relBuilder.push(input).filter(context.popCorrelationIds(), filterCondition).build();
153+
context.popParentRelNodes();
148154
return applyRemap(node, filter.getRemap());
149155
}
150156

151157
@Override
152-
public RelNode visit(NamedScan namedScan, EmptyVisitationContext context)
153-
throws RuntimeException {
158+
public RelNode visit(NamedScan namedScan, Context context) throws RuntimeException {
154159
RelNode node = relBuilder.scan(namedScan.getNames()).build();
155160
return applyRemap(node, namedScan.getRemap());
156161
}
157162

158163
@Override
159-
public RelNode visit(LocalFiles localFiles, EmptyVisitationContext context)
160-
throws RuntimeException {
164+
public RelNode visit(LocalFiles localFiles, Context context) throws RuntimeException {
161165
return visitFallback(localFiles, context);
162166
}
163167

164168
@Override
165-
public RelNode visit(EmptyScan emptyScan, EmptyVisitationContext context)
166-
throws RuntimeException {
169+
public RelNode visit(EmptyScan emptyScan, Context context) throws RuntimeException {
167170
RelDataType rowType =
168171
typeConverter.toCalcite(relBuilder.getTypeFactory(), emptyScan.getInitialSchema().struct());
169172
RelNode node = LogicalValues.create(relBuilder.getCluster(), rowType, ImmutableList.of());
170173
return applyRemap(node, emptyScan.getRemap());
171174
}
172175

173176
@Override
174-
public RelNode visit(Project project, EmptyVisitationContext context) throws RuntimeException {
177+
public RelNode visit(Project project, Context context) throws RuntimeException {
175178
RelNode child = project.getInput().accept(this, context);
179+
context.pushParentRelNodes(child);
180+
176181
Stream<RexNode> directOutputs =
177182
IntStream.range(0, child.getRowType().getFieldCount())
178183
.mapToObj(fieldIndex -> rexBuilder.makeInputRef(child, fieldIndex));
@@ -183,12 +188,17 @@ public RelNode visit(Project project, EmptyVisitationContext context) throws Run
183188
List<RexNode> rexExprs =
184189
Stream.concat(directOutputs, exprs).collect(java.util.stream.Collectors.toList());
185190

186-
RelNode node = relBuilder.push(child).project(rexExprs).build();
191+
RelNode node =
192+
relBuilder
193+
.push(child)
194+
.project(rexExprs, List.of(), false, context.popCorrelationIds())
195+
.build();
196+
context.popParentRelNodes();
187197
return applyRemap(node, project.getRemap());
188198
}
189199

190200
@Override
191-
public RelNode visit(Cross cross, EmptyVisitationContext context) throws RuntimeException {
201+
public RelNode visit(Cross cross, Context context) throws RuntimeException {
192202
RelNode left = cross.getLeft().accept(this, context);
193203
RelNode right = cross.getRight().accept(this, context);
194204
// Calcite represents CROSS JOIN as the equivalent INNER JOIN with true condition
@@ -198,14 +208,15 @@ public RelNode visit(Cross cross, EmptyVisitationContext context) throws Runtime
198208
}
199209

200210
@Override
201-
public RelNode visit(Join join, EmptyVisitationContext context) throws RuntimeException {
211+
public RelNode visit(Join join, Context context) throws RuntimeException {
202212
RelNode left = join.getLeft().accept(this, context);
203213
RelNode right = join.getRight().accept(this, context);
214+
context.pushParentRelNodes(left, right);
204215
RexNode condition =
205216
join.getCondition()
206217
.map(c -> c.accept(expressionRexConverter, context))
207218
.orElse(relBuilder.literal(true));
208-
var joinType =
219+
JoinRelType joinType =
209220
switch (join.getJoinType()) {
210221
case INNER -> JoinRelType.INNER;
211222
case LEFT -> JoinRelType.LEFT;
@@ -220,12 +231,18 @@ public RelNode visit(Join join, EmptyVisitationContext context) throws RuntimeEx
220231
default -> throw new UnsupportedOperationException(
221232
"Unsupported join type: " + join.getJoinType().name());
222233
};
223-
RelNode node = relBuilder.push(left).push(right).join(joinType, condition).build();
234+
RelNode node =
235+
relBuilder
236+
.push(left)
237+
.push(right)
238+
.join(joinType, condition, context.popCorrelationIds())
239+
.build();
240+
context.popParentRelNodes();
224241
return applyRemap(node, join.getRemap());
225242
}
226243

227244
@Override
228-
public RelNode visit(Set set, EmptyVisitationContext context) throws RuntimeException {
245+
public RelNode visit(Set set, Context context) throws RuntimeException {
229246
int numInputs = set.getInputs().size();
230247
set.getInputs()
231248
.forEach(
@@ -253,8 +270,7 @@ public RelNode visit(Set set, EmptyVisitationContext context) throws RuntimeExce
253270
}
254271

255272
@Override
256-
public RelNode visit(Aggregate aggregate, EmptyVisitationContext context)
257-
throws RuntimeException {
273+
public RelNode visit(Aggregate aggregate, Context context) throws RuntimeException {
258274
if (!PreCalciteAggregateValidator.isValidCalciteAggregate(aggregate)) {
259275
aggregate =
260276
PreCalciteAggregateValidator.PreCalciteAggregateTransformer
@@ -282,7 +298,7 @@ public RelNode visit(Aggregate aggregate, EmptyVisitationContext context)
282298
return applyRemap(node, aggregate.getRemap());
283299
}
284300

285-
private AggregateCall fromMeasure(Aggregate.Measure measure, EmptyVisitationContext context) {
301+
private AggregateCall fromMeasure(Aggregate.Measure measure, Context context) {
286302
var eArgs = measure.getFunction().arguments();
287303
var arguments =
288304
IntStream.range(0, measure.getFunction().arguments().size())
@@ -357,7 +373,7 @@ private AggregateCall fromMeasure(Aggregate.Measure measure, EmptyVisitationCont
357373
}
358374

359375
@Override
360-
public RelNode visit(Sort sort, EmptyVisitationContext context) throws RuntimeException {
376+
public RelNode visit(Sort sort, Context context) throws RuntimeException {
361377
RelNode child = sort.getInput().accept(this, context);
362378
List<RexNode> sortExpressions =
363379
sort.getSortFields().stream()
@@ -367,7 +383,7 @@ public RelNode visit(Sort sort, EmptyVisitationContext context) throws RuntimeEx
367383
return applyRemap(node, sort.getRemap());
368384
}
369385

370-
private RexNode directedRexNode(Expression.SortField sortField, EmptyVisitationContext context) {
386+
private RexNode directedRexNode(Expression.SortField sortField, Context context) {
371387
var expression = sortField.expr();
372388
var rexNode = expression.accept(expressionRexConverter, context);
373389
var sortDirection = sortField.direction();
@@ -382,7 +398,7 @@ private RexNode directedRexNode(Expression.SortField sortField, EmptyVisitationC
382398
}
383399

384400
@Override
385-
public RelNode visit(Fetch fetch, EmptyVisitationContext context) throws RuntimeException {
401+
public RelNode visit(Fetch fetch, Context context) throws RuntimeException {
386402
RelNode child = fetch.getInput().accept(this, context);
387403
var optCount = fetch.getCount();
388404
long count = optCount.orElse(-1L);
@@ -397,8 +413,7 @@ public RelNode visit(Fetch fetch, EmptyVisitationContext context) throws Runtime
397413
return applyRemap(node, fetch.getRemap());
398414
}
399415

400-
private RelFieldCollation toRelFieldCollation(
401-
Expression.SortField sortField, EmptyVisitationContext context) {
416+
private RelFieldCollation toRelFieldCollation(Expression.SortField sortField, Context context) {
402417
var expression = sortField.expr();
403418
var rex = expression.accept(expressionRexConverter, context);
404419
var sortDirection = sortField.direction();
@@ -426,7 +441,7 @@ private RelFieldCollation toRelFieldCollation(
426441
}
427442

428443
@Override
429-
public RelNode visitFallback(Rel rel, EmptyVisitationContext context) throws RuntimeException {
444+
public RelNode visitFallback(Rel rel, Context context) throws RuntimeException {
430445
throw new UnsupportedOperationException(
431446
String.format(
432447
"Rel %s of type %s not handled by visitor type %s.",
@@ -454,12 +469,49 @@ private RelNode applyRemap(RelNode relNode, Rel.Remap remap) {
454469
return relBuilder.push(relNode).project(rexList).build();
455470
}
456471

457-
private void checkRexInputRefOnly(RexNode rexNode, String context, String aggName) {
458-
if (!(rexNode instanceof RexInputRef)) {
459-
throw new UnsupportedOperationException(
460-
String.format(
461-
"Compound expression %s in %s of agg function %s is not implemented yet.",
462-
rexNode, context, aggName));
472+
public static class Context implements VisitationContext {
473+
protected final Stack<RelNode[]> parentRelations = new Stack<>();
474+
475+
protected final Stack<java.util.Set<CorrelationId>> correlationIds = new Stack<>();
476+
477+
private int subqueryDepth;
478+
479+
public static Context newContext() {
480+
return new Context();
481+
}
482+
483+
public void pushParentRelNodes(final RelNode... inputs) {
484+
parentRelations.push(inputs);
485+
this.correlationIds.push(new HashSet<>());
486+
}
487+
488+
public void popParentRelNodes() {
489+
parentRelations.pop();
463490
}
491+
492+
public RelNode[] getParentRelation(final Integer stepsOut) {
493+
return this.parentRelations.get(stepsOut - subqueryDepth);
494+
}
495+
496+
public java.util.Set<CorrelationId> popCorrelationIds() {
497+
return correlationIds.pop();
498+
}
499+
500+
public void addCorrelationId(final int stepsOut, final CorrelationId correlationId) {
501+
final int index = stepsOut - subqueryDepth;
502+
this.correlationIds.get(index).add(correlationId);
503+
}
504+
505+
public void incrementSubqueryDepth() {
506+
this.subqueryDepth++;
507+
}
508+
509+
public void decrementSubqueryDepth() {
510+
this.subqueryDepth--;
511+
}
512+
}
513+
514+
public RelBuilder getRelBuilder() {
515+
return relBuilder;
464516
}
465517
}

isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package io.substrait.isthmus;
22

33
import io.substrait.extension.SimpleExtension;
4+
import io.substrait.isthmus.SubstraitRelNodeConverter.Context;
45
import io.substrait.plan.Plan;
56
import io.substrait.relation.NamedScan;
67
import io.substrait.relation.Rel;
@@ -91,7 +92,7 @@ public RelNode convert(Rel rel) {
9192
CalciteSchema rootSchema = toSchema(rel);
9293
RelBuilder relBuilder = createRelBuilder(rootSchema);
9394
SubstraitRelNodeConverter converter = createSubstraitRelNodeConverter(relBuilder);
94-
return rel.accept(converter, null);
95+
return rel.accept(converter, Context.newContext());
9596
}
9697

9798
/**

isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import io.substrait.expression.ExpressionCreator;
66
import io.substrait.isthmus.*;
77
import io.substrait.type.Type;
8-
import io.substrait.util.EmptyVisitationContext;
98
import java.util.ArrayList;
109
import java.util.List;
1110
import java.util.Optional;
@@ -51,7 +50,7 @@ public class CallConverters {
5150
* org.apache.calcite.rex.RexLiteral} and then re-interpreted to have the correct type.
5251
*
5352
* <p>See {@link ExpressionRexConverter#visit(Expression.UserDefinedLiteral,
54-
* EmptyVisitationContext)} for this conversion.
53+
* SubstraitRelNodeConverter.Context)} for this conversion.
5554
*
5655
* <p>When converting from Calcite to Substrait, this call converter extracts the {@link
5756
* Expression.UserDefinedLiteral} that was stored.

0 commit comments

Comments
 (0)