23
23
import io .substrait .relation .Rel ;
24
24
import io .substrait .relation .Set ;
25
25
import io .substrait .relation .Sort ;
26
- import io .substrait .util .EmptyVisitationContext ;
26
+ import io .substrait .util .VisitationContext ;
27
27
import java .util .ArrayList ;
28
28
import java .util .Collection ;
29
29
import java .util .Collections ;
30
+ import java .util .HashSet ;
30
31
import java .util .List ;
31
32
import java .util .Optional ;
33
+ import java .util .Stack ;
32
34
import java .util .stream .Collectors ;
33
35
import java .util .stream .IntStream ;
34
36
import java .util .stream .Stream ;
40
42
import org .apache .calcite .rel .RelFieldCollation ;
41
43
import org .apache .calcite .rel .RelNode ;
42
44
import org .apache .calcite .rel .core .AggregateCall ;
45
+ import org .apache .calcite .rel .core .CorrelationId ;
43
46
import org .apache .calcite .rel .core .JoinRelType ;
44
47
import org .apache .calcite .rel .logical .LogicalValues ;
45
48
import org .apache .calcite .rel .type .RelDataType ;
59
62
* visitFallback and throw UnsupportedOperationException.
60
63
*/
61
64
public class SubstraitRelNodeConverter
62
- extends AbstractRelVisitor <RelNode , EmptyVisitationContext , RuntimeException > {
65
+ extends AbstractRelVisitor <RelNode , SubstraitRelNodeConverter . Context , RuntimeException > {
63
66
64
67
protected final RelDataTypeFactory typeFactory ;
65
68
@@ -137,42 +140,44 @@ public static RelNode convert(
137
140
return relRoot .accept (
138
141
new SubstraitRelNodeConverter (
139
142
EXTENSION_COLLECTION , relOptCluster .getTypeFactory (), relBuilder ),
140
- null );
143
+ Context . newContext () );
141
144
}
142
145
143
146
@ Override
144
- public RelNode visit (Filter filter , EmptyVisitationContext context ) throws RuntimeException {
147
+ public RelNode visit (Filter filter , Context context ) throws RuntimeException {
145
148
RelNode input = filter .getInput ().accept (this , context );
149
+ context .pushParentRelNodes (input );
146
150
RexNode filterCondition = filter .getCondition ().accept (expressionRexConverter , context );
147
- RelNode node = relBuilder .push (input ).filter (filterCondition ).build ();
151
+ RelNode node =
152
+ relBuilder .push (input ).filter (context .popCorrelationIds (), filterCondition ).build ();
153
+ context .popParentRelNodes ();
148
154
return applyRemap (node , filter .getRemap ());
149
155
}
150
156
151
157
@ Override
152
- public RelNode visit (NamedScan namedScan , EmptyVisitationContext context )
153
- throws RuntimeException {
158
+ public RelNode visit (NamedScan namedScan , Context context ) throws RuntimeException {
154
159
RelNode node = relBuilder .scan (namedScan .getNames ()).build ();
155
160
return applyRemap (node , namedScan .getRemap ());
156
161
}
157
162
158
163
@ Override
159
- public RelNode visit (LocalFiles localFiles , EmptyVisitationContext context )
160
- throws RuntimeException {
164
+ public RelNode visit (LocalFiles localFiles , Context context ) throws RuntimeException {
161
165
return visitFallback (localFiles , context );
162
166
}
163
167
164
168
@ Override
165
- public RelNode visit (EmptyScan emptyScan , EmptyVisitationContext context )
166
- throws RuntimeException {
169
+ public RelNode visit (EmptyScan emptyScan , Context context ) throws RuntimeException {
167
170
RelDataType rowType =
168
171
typeConverter .toCalcite (relBuilder .getTypeFactory (), emptyScan .getInitialSchema ().struct ());
169
172
RelNode node = LogicalValues .create (relBuilder .getCluster (), rowType , ImmutableList .of ());
170
173
return applyRemap (node , emptyScan .getRemap ());
171
174
}
172
175
173
176
@ Override
174
- public RelNode visit (Project project , EmptyVisitationContext context ) throws RuntimeException {
177
+ public RelNode visit (Project project , Context context ) throws RuntimeException {
175
178
RelNode child = project .getInput ().accept (this , context );
179
+ context .pushParentRelNodes (child );
180
+
176
181
Stream <RexNode > directOutputs =
177
182
IntStream .range (0 , child .getRowType ().getFieldCount ())
178
183
.mapToObj (fieldIndex -> rexBuilder .makeInputRef (child , fieldIndex ));
@@ -183,12 +188,17 @@ public RelNode visit(Project project, EmptyVisitationContext context) throws Run
183
188
List <RexNode > rexExprs =
184
189
Stream .concat (directOutputs , exprs ).collect (java .util .stream .Collectors .toList ());
185
190
186
- RelNode node = relBuilder .push (child ).project (rexExprs ).build ();
191
+ RelNode node =
192
+ relBuilder
193
+ .push (child )
194
+ .project (rexExprs , List .of (), false , context .popCorrelationIds ())
195
+ .build ();
196
+ context .popParentRelNodes ();
187
197
return applyRemap (node , project .getRemap ());
188
198
}
189
199
190
200
@ Override
191
- public RelNode visit (Cross cross , EmptyVisitationContext context ) throws RuntimeException {
201
+ public RelNode visit (Cross cross , Context context ) throws RuntimeException {
192
202
RelNode left = cross .getLeft ().accept (this , context );
193
203
RelNode right = cross .getRight ().accept (this , context );
194
204
// Calcite represents CROSS JOIN as the equivalent INNER JOIN with true condition
@@ -198,14 +208,15 @@ public RelNode visit(Cross cross, EmptyVisitationContext context) throws Runtime
198
208
}
199
209
200
210
@ Override
201
- public RelNode visit (Join join , EmptyVisitationContext context ) throws RuntimeException {
211
+ public RelNode visit (Join join , Context context ) throws RuntimeException {
202
212
RelNode left = join .getLeft ().accept (this , context );
203
213
RelNode right = join .getRight ().accept (this , context );
214
+ context .pushParentRelNodes (left , right );
204
215
RexNode condition =
205
216
join .getCondition ()
206
217
.map (c -> c .accept (expressionRexConverter , context ))
207
218
.orElse (relBuilder .literal (true ));
208
- var joinType =
219
+ JoinRelType joinType =
209
220
switch (join .getJoinType ()) {
210
221
case INNER -> JoinRelType .INNER ;
211
222
case LEFT -> JoinRelType .LEFT ;
@@ -220,12 +231,18 @@ public RelNode visit(Join join, EmptyVisitationContext context) throws RuntimeEx
220
231
default -> throw new UnsupportedOperationException (
221
232
"Unsupported join type: " + join .getJoinType ().name ());
222
233
};
223
- RelNode node = relBuilder .push (left ).push (right ).join (joinType , condition ).build ();
234
+ RelNode node =
235
+ relBuilder
236
+ .push (left )
237
+ .push (right )
238
+ .join (joinType , condition , context .popCorrelationIds ())
239
+ .build ();
240
+ context .popParentRelNodes ();
224
241
return applyRemap (node , join .getRemap ());
225
242
}
226
243
227
244
@ Override
228
- public RelNode visit (Set set , EmptyVisitationContext context ) throws RuntimeException {
245
+ public RelNode visit (Set set , Context context ) throws RuntimeException {
229
246
int numInputs = set .getInputs ().size ();
230
247
set .getInputs ()
231
248
.forEach (
@@ -253,8 +270,7 @@ public RelNode visit(Set set, EmptyVisitationContext context) throws RuntimeExce
253
270
}
254
271
255
272
@ Override
256
- public RelNode visit (Aggregate aggregate , EmptyVisitationContext context )
257
- throws RuntimeException {
273
+ public RelNode visit (Aggregate aggregate , Context context ) throws RuntimeException {
258
274
if (!PreCalciteAggregateValidator .isValidCalciteAggregate (aggregate )) {
259
275
aggregate =
260
276
PreCalciteAggregateValidator .PreCalciteAggregateTransformer
@@ -282,7 +298,7 @@ public RelNode visit(Aggregate aggregate, EmptyVisitationContext context)
282
298
return applyRemap (node , aggregate .getRemap ());
283
299
}
284
300
285
- private AggregateCall fromMeasure (Aggregate .Measure measure , EmptyVisitationContext context ) {
301
+ private AggregateCall fromMeasure (Aggregate .Measure measure , Context context ) {
286
302
var eArgs = measure .getFunction ().arguments ();
287
303
var arguments =
288
304
IntStream .range (0 , measure .getFunction ().arguments ().size ())
@@ -357,7 +373,7 @@ private AggregateCall fromMeasure(Aggregate.Measure measure, EmptyVisitationCont
357
373
}
358
374
359
375
@ Override
360
- public RelNode visit (Sort sort , EmptyVisitationContext context ) throws RuntimeException {
376
+ public RelNode visit (Sort sort , Context context ) throws RuntimeException {
361
377
RelNode child = sort .getInput ().accept (this , context );
362
378
List <RexNode > sortExpressions =
363
379
sort .getSortFields ().stream ()
@@ -367,7 +383,7 @@ public RelNode visit(Sort sort, EmptyVisitationContext context) throws RuntimeEx
367
383
return applyRemap (node , sort .getRemap ());
368
384
}
369
385
370
- private RexNode directedRexNode (Expression .SortField sortField , EmptyVisitationContext context ) {
386
+ private RexNode directedRexNode (Expression .SortField sortField , Context context ) {
371
387
var expression = sortField .expr ();
372
388
var rexNode = expression .accept (expressionRexConverter , context );
373
389
var sortDirection = sortField .direction ();
@@ -382,7 +398,7 @@ private RexNode directedRexNode(Expression.SortField sortField, EmptyVisitationC
382
398
}
383
399
384
400
@ Override
385
- public RelNode visit (Fetch fetch , EmptyVisitationContext context ) throws RuntimeException {
401
+ public RelNode visit (Fetch fetch , Context context ) throws RuntimeException {
386
402
RelNode child = fetch .getInput ().accept (this , context );
387
403
var optCount = fetch .getCount ();
388
404
long count = optCount .orElse (-1L );
@@ -397,8 +413,7 @@ public RelNode visit(Fetch fetch, EmptyVisitationContext context) throws Runtime
397
413
return applyRemap (node , fetch .getRemap ());
398
414
}
399
415
400
- private RelFieldCollation toRelFieldCollation (
401
- Expression .SortField sortField , EmptyVisitationContext context ) {
416
+ private RelFieldCollation toRelFieldCollation (Expression .SortField sortField , Context context ) {
402
417
var expression = sortField .expr ();
403
418
var rex = expression .accept (expressionRexConverter , context );
404
419
var sortDirection = sortField .direction ();
@@ -426,7 +441,7 @@ private RelFieldCollation toRelFieldCollation(
426
441
}
427
442
428
443
@ Override
429
- public RelNode visitFallback (Rel rel , EmptyVisitationContext context ) throws RuntimeException {
444
+ public RelNode visitFallback (Rel rel , Context context ) throws RuntimeException {
430
445
throw new UnsupportedOperationException (
431
446
String .format (
432
447
"Rel %s of type %s not handled by visitor type %s." ,
@@ -454,12 +469,49 @@ private RelNode applyRemap(RelNode relNode, Rel.Remap remap) {
454
469
return relBuilder .push (relNode ).project (rexList ).build ();
455
470
}
456
471
457
- private void checkRexInputRefOnly (RexNode rexNode , String context , String aggName ) {
458
- if (!(rexNode instanceof RexInputRef )) {
459
- throw new UnsupportedOperationException (
460
- String .format (
461
- "Compound expression %s in %s of agg function %s is not implemented yet." ,
462
- rexNode , context , aggName ));
472
+ public static class Context implements VisitationContext {
473
+ protected final Stack <RelNode []> parentRelations = new Stack <>();
474
+
475
+ protected final Stack <java .util .Set <CorrelationId >> correlationIds = new Stack <>();
476
+
477
+ private int subqueryDepth ;
478
+
479
+ public static Context newContext () {
480
+ return new Context ();
481
+ }
482
+
483
+ public void pushParentRelNodes (final RelNode ... inputs ) {
484
+ parentRelations .push (inputs );
485
+ this .correlationIds .push (new HashSet <>());
486
+ }
487
+
488
+ public void popParentRelNodes () {
489
+ parentRelations .pop ();
463
490
}
491
+
492
+ public RelNode [] getParentRelation (final Integer stepsOut ) {
493
+ return this .parentRelations .get (stepsOut - subqueryDepth );
494
+ }
495
+
496
+ public java .util .Set <CorrelationId > popCorrelationIds () {
497
+ return correlationIds .pop ();
498
+ }
499
+
500
+ public void addCorrelationId (final int stepsOut , final CorrelationId correlationId ) {
501
+ final int index = stepsOut - subqueryDepth ;
502
+ this .correlationIds .get (index ).add (correlationId );
503
+ }
504
+
505
+ public void incrementSubqueryDepth () {
506
+ this .subqueryDepth ++;
507
+ }
508
+
509
+ public void decrementSubqueryDepth () {
510
+ this .subqueryDepth --;
511
+ }
512
+ }
513
+
514
+ public RelBuilder getRelBuilder () {
515
+ return relBuilder ;
464
516
}
465
517
}
0 commit comments