diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index 8a5e19bf8..420aae48d 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -1,24 +1,12 @@ package io.substrait.isthmus; -import com.google.common.annotations.VisibleForTesting; -import io.substrait.isthmus.sql.SubstraitSqlValidator; +import io.substrait.isthmus.sql.SubstraitSqlToCalcite; import io.substrait.plan.ImmutablePlan.Builder; import io.substrait.plan.Plan.Version; import io.substrait.plan.PlanProtoConverter; import io.substrait.proto.Plan; -import java.util.List; -import org.apache.calcite.plan.hep.HepPlanner; -import org.apache.calcite.plan.hep.HepProgram; import org.apache.calcite.prepare.Prepare; -import org.apache.calcite.rel.RelRoot; -import org.apache.calcite.rel.rules.CoreRules; -import org.apache.calcite.sql.SqlNode; -import org.apache.calcite.sql.SqlNodeList; import org.apache.calcite.sql.parser.SqlParseException; -import org.apache.calcite.sql.parser.SqlParser; -import org.apache.calcite.sql.validate.SqlValidator; -import org.apache.calcite.sql2rel.SqlToRelConverter; -import org.apache.calcite.sql2rel.StandardConvertletTable; /** Take a SQL statement and a set of table definitions and return a substrait plan. */ public class SqlToSubstrait extends SqlConverterBase { @@ -32,69 +20,35 @@ public SqlToSubstrait(FeatureBoard features) { } public Plan execute(String sql, Prepare.CatalogReader catalogReader) throws SqlParseException { - SqlValidator validator = new SubstraitSqlValidator(catalogReader); - return executeInner(sql, validator, catalogReader); + return executeInner(sql, catalogReader); } - List sqlToRelNode(String sql, Prepare.CatalogReader catalogReader) - throws SqlParseException { - SqlValidator validator = new SubstraitSqlValidator(catalogReader); - return sqlToRelNode(sql, validator, catalogReader); - } - - private Plan executeInner(String sql, SqlValidator validator, Prepare.CatalogReader catalogReader) + private Plan executeInner(String sql, Prepare.CatalogReader catalogReader) throws SqlParseException { Builder builder = io.substrait.plan.Plan.builder(); builder.version(Version.builder().from(Version.DEFAULT_VERSION).producer("isthmus").build()); // TODO: consider case in which one sql passes conversion while others don't - sqlToRelNode(sql, validator, catalogReader).stream() + SubstraitSqlToCalcite.convertSelects(sql, catalogReader).stream() .map(root -> SubstraitRelVisitor.convert(root, EXTENSION_COLLECTION, featureBoard)) .forEach(root -> builder.addRoots(root)); PlanProtoConverter planToProto = new PlanProtoConverter(); - return planToProto.toProto(builder.build()); } - private List sqlToRelNode( - String sql, SqlValidator validator, Prepare.CatalogReader catalogReader) - throws SqlParseException { - SqlParser parser = SqlParser.create(sql, parserConfig); - SqlNodeList parsedList = parser.parseStmtList(); - SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader); - List roots = - parsedList.stream() - .map(parsed -> getBestExpRelRoot(converter, parsed)) - .collect(java.util.stream.Collectors.toList()); - return roots; - } - - @VisibleForTesting - SqlToRelConverter createSqlToRelConverter( - SqlValidator validator, Prepare.CatalogReader catalogReader) { - SqlToRelConverter converter = - new SqlToRelConverter( - null, - validator, - catalogReader, - relOptCluster, - StandardConvertletTable.INSTANCE, - converterConfig); - return converter; - } - - @VisibleForTesting - static RelRoot getBestExpRelRoot(SqlToRelConverter converter, SqlNode parsed) { - RelRoot root = converter.convertQuery(parsed, true, true); - { - // RelBuilder seems to implicitly use the rule below, - // need to add to avoid discrepancies in assertFullRoundTrip - HepProgram program = HepProgram.builder().addRuleInstance(CoreRules.PROJECT_REMOVE).build(); - HepPlanner hepPlanner = new HepPlanner(program); - hepPlanner.setRoot(root.rel); - root = root.withRel(hepPlanner.findBestExp()); - } - return root; - } + // @VisibleForTesting + // static RelRoot getBestExpRelRoot(SqlToRelConverter converter, SqlNode parsed) { + // RelRoot root = converter.convertQuery(parsed, true, true); + // { + // // RelBuilder seems to implicitly use the rule below, + // // need to add to avoid discrepancies in assertFullRoundTrip + // HepProgram program = + // HepProgram.builder().addRuleInstance(CoreRules.PROJECT_REMOVE).build(); + // HepPlanner hepPlanner = new HepPlanner(program); + // hepPlanner.setRoot(root.rel); + // root = root.withRel(hepPlanner.findBestExp()); + // } + // return root; + // } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSelectStatementParser.java b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSelectStatementParser.java new file mode 100644 index 000000000..812c4170c --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSelectStatementParser.java @@ -0,0 +1,26 @@ +package io.substrait.isthmus.sql; + +import java.util.List; +import org.apache.calcite.avatica.util.Casing; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.parser.SqlParseException; +import org.apache.calcite.sql.parser.SqlParser; +import org.apache.calcite.sql.validate.SqlConformanceEnum; + +/** Utility class for parsing SELECT statements to {@link org.apache.calcite.rel.RelRoot}s */ +public class SubstraitSelectStatementParser { + + private static final SqlParser.Config PARSER_CONFIG = + SqlParser.config() + // TODO: switch to Casing.UNCHANGED + .withUnquotedCasing(Casing.TO_UPPER) + // use LENIENT conformance to allow for parsing a wide variety of dialects + .withConformance(SqlConformanceEnum.LENIENT); + + /** Parse one or more SELECT statements */ + public static List parseSelectStatements(String selectStatements) + throws SqlParseException { + SqlParser parser = SqlParser.create(selectStatements, PARSER_CONFIG); + return parser.parseStmtList(); + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlToCalcite.java b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlToCalcite.java new file mode 100644 index 000000000..f647b0f4e --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlToCalcite.java @@ -0,0 +1,93 @@ +package io.substrait.isthmus.sql; + +import io.substrait.isthmus.SubstraitTypeSystem; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.calcite.jdbc.JavaTypeFactoryImpl; +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelOptPlanner; +import org.apache.calcite.plan.RelOptTable; +import org.apache.calcite.plan.hep.HepPlanner; +import org.apache.calcite.plan.hep.HepProgram; +import org.apache.calcite.prepare.Prepare; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelRoot; +import org.apache.calcite.rel.rules.CoreRules; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.parser.SqlParseException; +import org.apache.calcite.sql2rel.SqlToRelConverter; +import org.apache.calcite.sql2rel.StandardConvertletTable; + +public class SubstraitSqlToCalcite { + + public static RelRoot convertSelect(String selectStatement, Prepare.CatalogReader catalogReader) + throws SqlParseException { + return convertSelect(selectStatement, catalogReader, createRelOptCluster()); + } + + public static RelRoot convertSelect( + String selectStatement, Prepare.CatalogReader catalogReader, RelOptCluster cluster) + throws SqlParseException { + List sqlNodes = SubstraitSelectStatementParser.parseSelectStatements(selectStatement); + if (sqlNodes.size() != 1) { + throw new IllegalArgumentException( + String.format("Expected one SELECT statement, found: %d", sqlNodes.size())); + } + List relRoots = convert(sqlNodes, catalogReader, cluster); + // as there was only 1 select statement, there should only be 1 root + return relRoots.get(0); + } + + public static List convertSelects( + String selectStatements, Prepare.CatalogReader catalogReader) throws SqlParseException { + return convertSelects(selectStatements, catalogReader, createRelOptCluster()); + } + + public static List convertSelects( + String selectStatements, Prepare.CatalogReader catalogReader, RelOptCluster cluster) + throws SqlParseException { + List sqlNodes = SubstraitSelectStatementParser.parseSelectStatements(selectStatements); + return convert(sqlNodes, catalogReader, cluster); + } + + static List convert( + List selectStatements, Prepare.CatalogReader catalogReader, RelOptCluster cluster) { + RelOptTable.ViewExpander viewExpander = null; + SqlToRelConverter converter = + new SqlToRelConverter( + viewExpander, + new SubstraitSqlValidator(catalogReader), + catalogReader, + cluster, + StandardConvertletTable.INSTANCE, + SqlToRelConverter.CONFIG); + // apply validation + boolean needsValidation = true; + // query is the root of the tree + boolean top = true; + return selectStatements.stream() + .map( + sqlNode -> removeUnnecessaryProjects(converter.convertQuery(sqlNode, needsValidation, top))) + .collect(Collectors.toList()); + } + + static RelOptCluster createRelOptCluster() { + RexBuilder rexBuilder = + new RexBuilder(new JavaTypeFactoryImpl(SubstraitTypeSystem.TYPE_SYSTEM)); + HepProgram program = HepProgram.builder().build(); + RelOptPlanner emptyPlanner = new HepPlanner(program); + return RelOptCluster.create(emptyPlanner, rexBuilder); + } + + static RelRoot removeUnnecessaryProjects(RelRoot root) { + return root.withRel(removeUnnecessaryProjects(root.rel)); + } + + static RelNode removeUnnecessaryProjects(RelNode root) { + HepProgram program = HepProgram.builder().addRuleInstance(CoreRules.PROJECT_REMOVE).build(); + HepPlanner planner = new HepPlanner(program); + planner.setRoot(root); + return planner.findBestExp(); + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/ApplyJoinPlanTest.java b/isthmus/src/test/java/io/substrait/isthmus/ApplyJoinPlanTest.java index 8fb704977..b74ca2411 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ApplyJoinPlanTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ApplyJoinPlanTest.java @@ -1,5 +1,6 @@ package io.substrait.isthmus; +import io.substrait.isthmus.sql.SubstraitSqlToCalcite; import java.util.Map; import java.util.Map.Entry; import java.util.Optional; @@ -11,10 +12,7 @@ public class ApplyJoinPlanTest extends PlanTestBase { - private static RelRoot getCalcitePlan(String sql) throws SqlParseException { - SqlToSubstrait s = new SqlToSubstrait(); - return s.sqlToRelNode(sql, TPCDS_CATALOG).get(0); - } + static SqlToSubstrait s = new SqlToSubstrait(); private static void validateOuterRef( Map fieldAccessDepthMap, String refName, String colName, int depth) { @@ -53,16 +51,15 @@ public void lateralJoinQuery() throws SqlParseException { */ // validate outer reference map - RelRoot root = getCalcitePlan(sql); + RelRoot root = SubstraitSqlToCalcite.convertSelect(sql, TPCDS_CATALOG); Map fieldAccessDepthMap = buildOuterFieldRefMap(root); Assertions.assertEquals(1, fieldAccessDepthMap.size()); validateOuterRef(fieldAccessDepthMap, "$cor0", "SS_ITEM_SK", 1); // TODO validate end to end conversion - SqlToSubstrait sE2E = new SqlToSubstrait(); Assertions.assertThrows( UnsupportedOperationException.class, - () -> sE2E.execute(sql, TPCDS_CATALOG), + () -> s.execute(sql, TPCDS_CATALOG), "Lateral join is not supported"); } @@ -74,7 +71,7 @@ public void outerApplyQuery() throws SqlParseException { + "FROM store_sales OUTER APPLY\n" + " (select i_item_sk from item where item.i_item_sk = store_sales.ss_item_sk)"; - RelRoot root = getCalcitePlan(sql); + RelRoot root = SubstraitSqlToCalcite.convertSelect(sql, TPCDS_CATALOG); Map fieldAccessDepthMap = buildOuterFieldRefMap(root); Assertions.assertEquals(1, fieldAccessDepthMap.size()); @@ -83,7 +80,7 @@ public void outerApplyQuery() throws SqlParseException { // TODO validate end to end conversion Assertions.assertThrows( UnsupportedOperationException.class, - () -> new SqlToSubstrait().execute(sql, TPCDS_CATALOG), + () -> s.execute(sql, TPCDS_CATALOG), "APPLY is not supported"); } @@ -112,7 +109,7 @@ public void nestedApplyJoinQuery() throws SqlParseException { LogicalFilter(condition=[AND(=($4, $cor0.I_ITEM_SK), =($4, $cor2.SS_ITEM_SK))]) LogicalTableScan(table=[[tpcds, PROMOTION]]) */ - RelRoot root = getCalcitePlan(sql); + RelRoot root = SubstraitSqlToCalcite.convertSelect(sql, TPCDS_CATALOG); Map fieldAccessDepthMap = buildOuterFieldRefMap(root); Assertions.assertEquals(3, fieldAccessDepthMap.size()); @@ -123,7 +120,7 @@ public void nestedApplyJoinQuery() throws SqlParseException { // TODO validate end to end conversion Assertions.assertThrows( UnsupportedOperationException.class, - () -> new SqlToSubstrait().execute(sql, TPCDS_CATALOG), + () -> s.execute(sql, TPCDS_CATALOG), "APPLY is not supported"); } @@ -138,7 +135,7 @@ public void crossApplyQuery() throws SqlParseException { // TODO validate end to end conversion Assertions.assertThrows( UnsupportedOperationException.class, - () -> new SqlToSubstrait().execute(sql, TPCDS_CATALOG), + () -> s.execute(sql, TPCDS_CATALOG), "APPLY is not supported"); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java index ee89f608c..4f08191b7 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java @@ -4,6 +4,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import io.substrait.isthmus.sql.SubstraitCreateStatementParser; +import io.substrait.isthmus.sql.SubstraitSqlToCalcite; import io.substrait.plan.Plan; import io.substrait.relation.NamedScan; import java.util.List; @@ -25,10 +26,8 @@ void preserveNamesFromSql() throws Exception { String query = "SELECT \"a\", \"B\" FROM foo GROUP BY a, b"; List expectedNames = List.of("a", "B"); - List calciteRelRoots = s.sqlToRelNode(query, catalogReader); - assertEquals(1, calciteRelRoots.size()); - - org.apache.calcite.rel.RelRoot calciteRelRoot1 = calciteRelRoots.get(0); + org.apache.calcite.rel.RelRoot calciteRelRoot1 = + SubstraitSqlToCalcite.convertSelect(query, catalogReader); assertEquals(expectedNames, calciteRelRoot1.validatedRowType.getFieldNames()); io.substrait.plan.Plan.Root substraitRelRoot = diff --git a/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java index 86b8874b2..739448535 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java @@ -2,10 +2,9 @@ import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -import static org.junit.jupiter.api.Assertions.assertEquals; +import io.substrait.isthmus.sql.SubstraitSqlToCalcite; import java.io.IOException; -import java.util.List; import org.apache.calcite.plan.hep.HepPlanner; import org.apache.calcite.plan.hep.HepProgram; import org.apache.calcite.plan.hep.HepProgramBuilder; @@ -24,11 +23,8 @@ void conversionHandlesBuiltInSum0CallAddedByRule() throws SqlParseException, IOE // verify that the query works generally assertFullRoundTrip(query); - SqlToSubstrait sqlConverter = new SqlToSubstrait(); - List relRoots = sqlConverter.sqlToRelNode(query, TPCH_CATALOG); - assertEquals(1, relRoots.size()); - RelRoot planRoot = relRoots.get(0); - RelNode originalPlan = planRoot.rel; + RelRoot relRoot = SubstraitSqlToCalcite.convertSelect(query, TPCH_CATALOG); + RelNode originalPlan = relRoot.rel; // Create a program to apply the AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN rule. // This will introduce a SqlSumEmptyIsZeroAggFunction to the plan. @@ -46,6 +42,6 @@ void conversionHandlesBuiltInSum0CallAddedByRule() throws SqlParseException, IOE assertDoesNotThrow( () -> // Conversion of the new plan should succeed - SubstraitRelVisitor.convert(RelRoot.of(newPlan, planRoot.kind), EXTENSION_COLLECTION)); + SubstraitRelVisitor.convert(RelRoot.of(newPlan, relRoot.kind), EXTENSION_COLLECTION)); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java index 27ae7eea9..96d2eb50f 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java @@ -12,6 +12,7 @@ import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.sql.SubstraitCreateStatementParser; import io.substrait.isthmus.sql.SubstraitSqlDialect; +import io.substrait.isthmus.sql.SubstraitSqlToCalcite; import io.substrait.plan.Plan; import io.substrait.plan.Plan.Root; import io.substrait.plan.PlanProtoConverter; @@ -91,7 +92,7 @@ protected Plan assertProtoPlanRoundrip( Plan plan = new ProtoPlanConverter(extensions).from(protoPlan1); io.substrait.proto.Plan protoPlan2 = new PlanProtoConverter().toProto(plan); assertEquals(protoPlan1, protoPlan2); - List rootRels = s.sqlToRelNode(query, catalogReader); + List rootRels = SubstraitSqlToCalcite.convertSelects(query, catalogReader); assertEquals(rootRels.size(), plan.getRoots().size()); for (int i = 0; i < rootRels.size(); i++) { Plan.Root rootRel = SubstraitRelVisitor.convert(rootRels.get(i), extensions); @@ -130,9 +131,7 @@ protected RelRoot assertSqlSubstraitRelRoundTrip( SqlToSubstrait s = new SqlToSubstrait(); // 1. SQL -> Calcite RelRoot - List relRoots = s.sqlToRelNode(query, catalogReader); - assertEquals(1, relRoots.size()); - RelRoot relRoot1 = relRoots.get(0); + RelRoot relRoot1 = SubstraitSqlToCalcite.convertSelect(query, catalogReader); // 2. Calcite RelRoot -> Substrait Rel Plan.Root pojo1 = SubstraitRelVisitor.convert(relRoot1, extensions); @@ -175,13 +174,10 @@ protected void assertFullRoundTrip(String query, String createStatements) */ protected void assertFullRoundTrip(String sqlQuery, Prepare.CatalogReader catalogReader) throws SqlParseException { - SqlToSubstrait sqlConverter = new SqlToSubstrait(); ExtensionCollector extensionCollector = new ExtensionCollector(); // SQL -> Calcite 1 - List relRoots = sqlConverter.sqlToRelNode(sqlQuery, catalogReader); - assertEquals(1, relRoots.size()); - RelRoot calcite1 = relRoots.get(0); + RelRoot calcite1 = SubstraitSqlToCalcite.convertSelect(sqlQuery, catalogReader); // Calcite 1 -> Substrait POJO 1 Plan.Root pojo1 = SubstraitRelVisitor.convert(calcite1, extensions);