1
1
package io .substrait .dsl ;
2
2
3
- import com .github .bsideup .jabel .Desugar ;
4
3
import io .substrait .expression .AggregateFunctionInvocation ;
5
4
import io .substrait .expression .Expression ;
6
5
import io .substrait .expression .Expression .Cast ;
@@ -87,8 +86,8 @@ private Aggregate aggregate(
87
86
Function <Rel , List <Aggregate .Measure >> measuresFn ,
88
87
Optional <Rel .Remap > remap ,
89
88
Rel input ) {
90
- var groupings = groupingsFn .apply (input );
91
- var measures = measuresFn .apply (input );
89
+ List < Aggregate . Grouping > groupings = groupingsFn .apply (input );
90
+ List < Aggregate . Measure > measures = measuresFn .apply (input );
92
91
return Aggregate .builder ()
93
92
.groupings (groupings )
94
93
.measures (measures )
@@ -147,12 +146,27 @@ public Filter filter(Function<Rel, Expression> conditionFn, Rel.Remap remap, Rel
147
146
148
147
private Filter filter (
149
148
Function <Rel , Expression > conditionFn , Optional <Rel .Remap > remap , Rel input ) {
150
- var condition = conditionFn .apply (input );
149
+ Expression condition = conditionFn .apply (input );
151
150
return Filter .builder ().input (input ).condition (condition ).remap (remap ).build ();
152
151
}
153
152
154
- @ Desugar
155
- public record JoinInput (Rel left , Rel right ) {}
153
+ public static final class JoinInput {
154
+ private final Rel left ;
155
+ private final Rel right ;
156
+
157
+ JoinInput (Rel left , Rel right ) {
158
+ this .left = left ;
159
+ this .right = right ;
160
+ }
161
+
162
+ public Rel left () {
163
+ return left ;
164
+ }
165
+
166
+ public Rel right () {
167
+ return right ;
168
+ }
169
+ }
156
170
157
171
public Join innerJoin (Function <JoinInput , Expression > conditionFn , Rel left , Rel right ) {
158
172
return join (conditionFn , Join .JoinType .INNER , left , right );
@@ -183,7 +197,7 @@ private Join join(
183
197
Optional <Rel .Remap > remap ,
184
198
Rel left ,
185
199
Rel right ) {
186
- var condition = conditionFn .apply (new JoinInput (left , right ));
200
+ Expression condition = conditionFn .apply (new JoinInput (left , right ));
187
201
return Join .builder ()
188
202
.left (left )
189
203
.right (right )
@@ -263,7 +277,7 @@ private NestedLoopJoin nestedLoopJoin(
263
277
Optional <Rel .Remap > remap ,
264
278
Rel left ,
265
279
Rel right ) {
266
- var condition = conditionFn .apply (new JoinInput (left , right ));
280
+ Expression condition = conditionFn .apply (new JoinInput (left , right ));
267
281
return NestedLoopJoin .builder ()
268
282
.left (left )
269
283
.right (right )
@@ -291,8 +305,8 @@ private NamedScan namedScan(
291
305
Iterable <String > columnNames ,
292
306
Iterable <Type > types ,
293
307
Optional <Rel .Remap > remap ) {
294
- var struct = Type .Struct .builder ().addAllFields (types ).nullable (false ).build ();
295
- var namedStruct = NamedStruct .of (columnNames , struct );
308
+ Type . Struct struct = Type .Struct .builder ().addAllFields (types ).nullable (false ).build ();
309
+ NamedStruct namedStruct = NamedStruct .of (columnNames , struct );
296
310
return NamedScan .builder ().names (tableName ).initialSchema (namedStruct ).remap (remap ).build ();
297
311
}
298
312
@@ -315,7 +329,7 @@ private Project project(
315
329
Function <Rel , Iterable <? extends Expression >> expressionsFn ,
316
330
Optional <Rel .Remap > remap ,
317
331
Rel input ) {
318
- var expressions = expressionsFn .apply (input );
332
+ Iterable <? extends Expression > expressions = expressionsFn .apply (input );
319
333
return Project .builder ().input (input ).expressions (expressions ).remap (remap ).build ();
320
334
}
321
335
@@ -332,7 +346,7 @@ private Expand expand(
332
346
Function <Rel , Iterable <? extends Expand .ExpandField >> fieldsFn ,
333
347
Optional <Rel .Remap > remap ,
334
348
Rel input ) {
335
- var fields = fieldsFn .apply (input );
349
+ Iterable <? extends Expand . ExpandField > fields = fieldsFn .apply (input );
336
350
return Expand .builder ().input (input ).fields (fields ).remap (remap ).build ();
337
351
}
338
352
@@ -363,7 +377,7 @@ private Sort sort(
363
377
Function <Rel , Iterable <? extends Expression .SortField >> sortFieldFn ,
364
378
Optional <Rel .Remap > remap ,
365
379
Rel input ) {
366
- var condition = sortFieldFn .apply (input );
380
+ Iterable <? extends Expression . SortField > condition = sortFieldFn .apply (input );
367
381
return Sort .builder ().input (input ).sortFields (condition ).remap (remap ).build ();
368
382
}
369
383
@@ -465,7 +479,7 @@ public Switch switchExpression(
465
479
466
480
public AggregateFunctionInvocation aggregateFn (
467
481
String namespace , String key , Type outputType , Expression ... args ) {
468
- var declaration =
482
+ SimpleExtension . AggregateFunctionVariant declaration =
469
483
extensions .getAggregateFunction (SimpleExtension .FunctionAnchor .of (namespace , key ));
470
484
return AggregateFunctionInvocation .builder ()
471
485
.arguments (Arrays .stream (args ).collect (java .util .stream .Collectors .toList ()))
@@ -477,7 +491,7 @@ public AggregateFunctionInvocation aggregateFn(
477
491
}
478
492
479
493
public Aggregate .Grouping grouping (Rel input , int ... indexes ) {
480
- var columns = fieldReferences (input , indexes );
494
+ List < FieldReference > columns = fieldReferences (input , indexes );
481
495
return Aggregate .Grouping .builder ().addAllExpressions (columns ).build ();
482
496
}
483
497
@@ -486,7 +500,7 @@ public Aggregate.Grouping grouping(Expression... expressions) {
486
500
}
487
501
488
502
public Aggregate .Measure count (Rel input , int field ) {
489
- var declaration =
503
+ SimpleExtension . AggregateFunctionVariant declaration =
490
504
extensions .getAggregateFunction (
491
505
SimpleExtension .FunctionAnchor .of (
492
506
DefaultExtensionCatalog .FUNCTIONS_AGGREGATE_GENERIC , "count:any" ));
@@ -563,7 +577,7 @@ public Aggregate.Measure sum0(Expression expr) {
563
577
private Aggregate .Measure singleArgumentArithmeticAggregate (
564
578
Expression expr , String functionName , Type outputType ) {
565
579
String typeString = ToTypeString .apply (expr .getType ());
566
- var declaration =
580
+ SimpleExtension . AggregateFunctionVariant declaration =
567
581
extensions .getAggregateFunction (
568
582
SimpleExtension .FunctionAnchor .of (
569
583
DefaultExtensionCatalog .FUNCTIONS_ARITHMETIC ,
@@ -585,7 +599,7 @@ private Aggregate.Measure singleArgumentArithmeticAggregate(
585
599
586
600
public Expression .ScalarFunctionInvocation negate (Expression expr ) {
587
601
// output type of negate is the same as the input type
588
- var outputType = expr .getType ();
602
+ Type outputType = expr .getType ();
589
603
return scalarFn (
590
604
DefaultExtensionCatalog .FUNCTIONS_ARITHMETIC ,
591
605
String .format ("negate:%s" , ToTypeString .apply (outputType )),
@@ -611,12 +625,12 @@ public Expression.ScalarFunctionInvocation divide(Expression left, Expression ri
611
625
612
626
private Expression .ScalarFunctionInvocation arithmeticFunction (
613
627
String fname , Expression left , Expression right ) {
614
- var leftTypeStr = ToTypeString .apply (left .getType ());
615
- var rightTypeStr = ToTypeString .apply (right .getType ());
616
- var key = String .format ("%s:%s_%s" , fname , leftTypeStr , rightTypeStr );
628
+ String leftTypeStr = ToTypeString .apply (left .getType ());
629
+ String rightTypeStr = ToTypeString .apply (right .getType ());
630
+ String key = String .format ("%s:%s_%s" , fname , leftTypeStr , rightTypeStr );
617
631
618
- var isOutputNullable = left .getType ().nullable () || right .getType ().nullable ();
619
- var outputType = left .getType ();
632
+ boolean isOutputNullable = left .getType ().nullable () || right .getType ().nullable ();
633
+ Type outputType = left .getType ();
620
634
outputType =
621
635
isOutputNullable
622
636
? TypeCreator .asNullable (outputType )
@@ -633,14 +647,14 @@ public Expression.ScalarFunctionInvocation equal(Expression left, Expression rig
633
647
public Expression .ScalarFunctionInvocation or (Expression ... args ) {
634
648
// If any arg is nullable, the output of or is potentially nullable
635
649
// For example: false or null = null
636
- var isOutputNullable = Arrays .stream (args ).anyMatch (a -> a .getType ().nullable ());
637
- var outputType = isOutputNullable ? N .BOOLEAN : R .BOOLEAN ;
650
+ boolean isOutputNullable = Arrays .stream (args ).anyMatch (a -> a .getType ().nullable ());
651
+ Type outputType = isOutputNullable ? N .BOOLEAN : R .BOOLEAN ;
638
652
return scalarFn (DefaultExtensionCatalog .FUNCTIONS_BOOLEAN , "or:bool" , outputType , args );
639
653
}
640
654
641
655
public Expression .ScalarFunctionInvocation scalarFn (
642
656
String namespace , String key , Type outputType , FunctionArg ... args ) {
643
- var declaration =
657
+ SimpleExtension . ScalarFunctionVariant declaration =
644
658
extensions .getScalarFunction (SimpleExtension .FunctionAnchor .of (namespace , key ));
645
659
return Expression .ScalarFunctionInvocation .builder ()
646
660
.declaration (declaration )
@@ -659,7 +673,7 @@ public Expression.WindowFunctionInvocation windowFn(
659
673
WindowBound lowerBound ,
660
674
WindowBound upperBound ,
661
675
Expression ... args ) {
662
- var declaration =
676
+ SimpleExtension . WindowFunctionVariant declaration =
663
677
extensions .getWindowFunction (SimpleExtension .FunctionAnchor .of (namespace , key ));
664
678
return Expression .WindowFunctionInvocation .builder ()
665
679
.declaration (declaration )
0 commit comments