From 9d7b316187cf7b25c91c9d4935bfe89cae79d48a Mon Sep 17 00:00:00 2001 From: Vasily Bondarenko Date: Wed, 14 Aug 2024 12:31:39 +0100 Subject: [PATCH] SPARK-414 support for delete - added support for delete - minor refactoring for filter transformations - minor fix for catalog to load table with resolved schema instead of empty one --- .../spark/sql/connector/RoundTripTest.java | 32 +++ .../mongodb/MongoSparkConnectorHelper.java | 3 + .../sql/connector/ExpressionConverter.java | 214 ++++++++++++++++++ .../spark/sql/connector/MongoCatalog.java | 3 +- .../spark/sql/connector/MongoTable.java | 31 ++- .../sql/connector/read/MongoScanBuilder.java | 202 +---------------- 6 files changed, 290 insertions(+), 195 deletions(-) create mode 100644 src/main/java/com/mongodb/spark/sql/connector/ExpressionConverter.java diff --git a/src/integrationTest/java/com/mongodb/spark/sql/connector/RoundTripTest.java b/src/integrationTest/java/com/mongodb/spark/sql/connector/RoundTripTest.java index 284ad28a..c3b26fc7 100644 --- a/src/integrationTest/java/com/mongodb/spark/sql/connector/RoundTripTest.java +++ b/src/integrationTest/java/com/mongodb/spark/sql/connector/RoundTripTest.java @@ -17,8 +17,10 @@ package com.mongodb.spark.sql.connector; +import static com.mongodb.spark.sql.connector.mongodb.MongoSparkConnectorHelper.CATALOG; import static java.util.Arrays.asList; import static java.util.Collections.singletonList; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertIterableEquals; import com.mongodb.spark.sql.connector.beans.BoxedBean; @@ -39,6 +41,7 @@ import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; import org.junit.jupiter.api.Test; @@ -162,4 +165,33 @@ void testComplexBean() { .collectAsList(); assertIterableEquals(dataSetOriginal, dataSetMongo); } + + @Test + void testCatalogAccessAndDelete() { + List dataSetOriginal = + asList( + new BoxedBean((byte) 1, (short) 2, 0, 4L, 5.0f, 6.0, true), + new BoxedBean((byte) 1, (short) 2, 1, 4L, 5.0f, 6.0, true), + new BoxedBean((byte) 1, (short) 2, 2, 4L, 5.0f, 6.0, true), + new BoxedBean((byte) 1, (short) 2, 3, 4L, 5.0f, 6.0, false), + new BoxedBean((byte) 1, (short) 2, 4, 4L, 5.0f, 6.0, false), + new BoxedBean((byte) 1, (short) 2, 5, 4L, 5.0f, 6.0, false)); + + SparkSession spark = getOrCreateSparkSession(); + Encoder encoder = Encoders.bean(BoxedBean.class); + spark + .createDataset(dataSetOriginal, encoder) + .write() + .format("mongodb") + .mode("Overwrite") + .save(); + + String tableName = CATALOG + "." + HELPER.getDatabaseName() + "." + HELPER.getCollectionName(); + List rows = spark.sql("select * from " + tableName).collectAsList(); + assertEquals(6, rows.size()); + + spark.sql("delete from " + tableName + " where not booleanField and intField > 3"); + rows = spark.sql("select * from " + tableName).collectAsList(); + assertEquals(4, rows.size()); + } } diff --git a/src/integrationTest/java/com/mongodb/spark/sql/connector/mongodb/MongoSparkConnectorHelper.java b/src/integrationTest/java/com/mongodb/spark/sql/connector/mongodb/MongoSparkConnectorHelper.java index d3ceddb6..fec63931 100644 --- a/src/integrationTest/java/com/mongodb/spark/sql/connector/mongodb/MongoSparkConnectorHelper.java +++ b/src/integrationTest/java/com/mongodb/spark/sql/connector/mongodb/MongoSparkConnectorHelper.java @@ -28,6 +28,7 @@ import com.mongodb.client.model.UpdateOptions; import com.mongodb.client.model.Updates; import com.mongodb.connection.ClusterType; +import com.mongodb.spark.sql.connector.MongoCatalog; import com.mongodb.spark.sql.connector.config.MongoConfig; import java.io.File; import java.io.IOException; @@ -62,6 +63,7 @@ public class MongoSparkConnectorHelper "{_id: '%s', pk: '%s', dups: '%s', i: %d, s: '%s'}"; private static final String COMPLEX_SAMPLE_DATA_TEMPLATE = "{_id: '%s', nested: {pk: '%s', dups: '%s', i: %d}, s: '%s'}"; + public static final String CATALOG = "mongo_catalog"; private static final Logger LOGGER = LoggerFactory.getLogger(MongoSparkConnectorHelper.class); @@ -146,6 +148,7 @@ public SparkConf getSparkConf() { .set("spark.sql.streaming.checkpointLocation", getTempDirectory()) .set("spark.sql.streaming.forceDeleteTempCheckpointLocation", "true") .set("spark.app.id", "MongoSparkConnector") + .set("spark.sql.catalog." + CATALOG, MongoCatalog.class.getCanonicalName()) .set( MongoConfig.PREFIX + MongoConfig.CONNECTION_STRING_CONFIG, getConnectionString().getConnectionString()) diff --git a/src/main/java/com/mongodb/spark/sql/connector/ExpressionConverter.java b/src/main/java/com/mongodb/spark/sql/connector/ExpressionConverter.java new file mode 100644 index 00000000..658d410e --- /dev/null +++ b/src/main/java/com/mongodb/spark/sql/connector/ExpressionConverter.java @@ -0,0 +1,214 @@ +package com.mongodb.spark.sql.connector; + +import static com.mongodb.spark.sql.connector.schema.RowToBsonDocumentConverter.createObjectToBsonValue; +import static java.lang.String.format; + +import com.mongodb.client.model.Filters; +import com.mongodb.spark.sql.connector.assertions.Assertions; +import com.mongodb.spark.sql.connector.config.WriteConfig; +import com.mongodb.spark.sql.connector.schema.RowToBsonDocumentConverter; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.sources.And; +import org.apache.spark.sql.sources.EqualNullSafe; +import org.apache.spark.sql.sources.EqualTo; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.GreaterThanOrEqual; +import org.apache.spark.sql.sources.In; +import org.apache.spark.sql.sources.IsNotNull; +import org.apache.spark.sql.sources.IsNull; +import org.apache.spark.sql.sources.LessThan; +import org.apache.spark.sql.sources.LessThanOrEqual; +import org.apache.spark.sql.sources.Not; +import org.apache.spark.sql.sources.Or; +import org.apache.spark.sql.sources.StringContains; +import org.apache.spark.sql.sources.StringEndsWith; +import org.apache.spark.sql.sources.StringStartsWith; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.bson.BsonValue; +import org.bson.conversions.Bson; +import org.jetbrains.annotations.Nullable; +import org.jetbrains.annotations.VisibleForTesting; + +public class ExpressionConverter { + private final StructType schema; + + public ExpressionConverter(final StructType schema) { + this.schema = schema; + } + + public FilterAndPipelineStage processFilter(final Filter filter) { + Assertions.ensureArgument(() -> filter != null, () -> "Invalid argument filter cannot be null"); + if (filter instanceof And) { + And andFilter = (And) filter; + FilterAndPipelineStage eitherLeft = processFilter(andFilter.left()); + FilterAndPipelineStage eitherRight = processFilter(andFilter.right()); + if (eitherLeft.hasPipelineStage() && eitherRight.hasPipelineStage()) { + return new FilterAndPipelineStage( + filter, Filters.and(eitherLeft.getPipelineStage(), eitherRight.getPipelineStage())); + } + } else if (filter instanceof EqualNullSafe) { + EqualNullSafe equalNullSafe = (EqualNullSafe) filter; + String fieldName = unquoteFieldName(equalNullSafe.attribute()); + return new FilterAndPipelineStage( + filter, + getBsonValue(fieldName, equalNullSafe.value()) + .map(bsonValue -> Filters.eq(fieldName, bsonValue)) + .orElse(null)); + } else if (filter instanceof EqualTo) { + EqualTo equalTo = (EqualTo) filter; + String fieldName = unquoteFieldName(equalTo.attribute()); + return new FilterAndPipelineStage( + filter, + getBsonValue(fieldName, equalTo.value()) + .map(bsonValue -> Filters.eq(fieldName, bsonValue)) + .orElse(null)); + } else if (filter instanceof GreaterThan) { + GreaterThan greaterThan = (GreaterThan) filter; + String fieldName = unquoteFieldName(greaterThan.attribute()); + return new FilterAndPipelineStage( + filter, + getBsonValue(fieldName, greaterThan.value()) + .map(bsonValue -> Filters.gt(fieldName, bsonValue)) + .orElse(null)); + } else if (filter instanceof GreaterThanOrEqual) { + GreaterThanOrEqual greaterThanOrEqual = (GreaterThanOrEqual) filter; + String fieldName = unquoteFieldName(greaterThanOrEqual.attribute()); + return new FilterAndPipelineStage( + filter, + getBsonValue(fieldName, greaterThanOrEqual.value()) + .map(bsonValue -> Filters.gte(fieldName, bsonValue)) + .orElse(null)); + } else if (filter instanceof In) { + In inFilter = (In) filter; + String fieldName = unquoteFieldName(inFilter.attribute()); + List values = + Arrays.stream(inFilter.values()) + .map(v -> getBsonValue(fieldName, v)) + .filter(Optional::isPresent) + .map(Optional::get) + .collect(Collectors.toList()); + + // Ensure all values were matched otherwise leave to Spark to filter. + Bson pipelineStage = null; + if (values.size() == inFilter.values().length) { + pipelineStage = Filters.in(fieldName, values); + } + return new FilterAndPipelineStage(filter, pipelineStage); + } else if (filter instanceof IsNull) { + IsNull isNullFilter = (IsNull) filter; + String fieldName = unquoteFieldName(isNullFilter.attribute()); + return new FilterAndPipelineStage(filter, Filters.eq(fieldName, null)); + } else if (filter instanceof IsNotNull) { + IsNotNull isNotNullFilter = (IsNotNull) filter; + String fieldName = unquoteFieldName(isNotNullFilter.attribute()); + return new FilterAndPipelineStage(filter, Filters.ne(fieldName, null)); + } else if (filter instanceof LessThan) { + LessThan lessThan = (LessThan) filter; + String fieldName = unquoteFieldName(lessThan.attribute()); + return new FilterAndPipelineStage( + filter, + getBsonValue(fieldName, lessThan.value()) + .map(bsonValue -> Filters.lt(fieldName, bsonValue)) + .orElse(null)); + } else if (filter instanceof LessThanOrEqual) { + LessThanOrEqual lessThanOrEqual = (LessThanOrEqual) filter; + String fieldName = unquoteFieldName(lessThanOrEqual.attribute()); + return new FilterAndPipelineStage( + filter, + getBsonValue(fieldName, lessThanOrEqual.value()) + .map(bsonValue -> Filters.lte(fieldName, bsonValue)) + .orElse(null)); + } else if (filter instanceof Not) { + Not notFilter = (Not) filter; + FilterAndPipelineStage notChild = processFilter(notFilter.child()); + if (notChild.hasPipelineStage()) { + return new FilterAndPipelineStage(filter, Filters.not(notChild.pipelineStage)); + } + } else if (filter instanceof Or) { + Or or = (Or) filter; + FilterAndPipelineStage eitherLeft = processFilter(or.left()); + FilterAndPipelineStage eitherRight = processFilter(or.right()); + if (eitherLeft.hasPipelineStage() && eitherRight.hasPipelineStage()) { + return new FilterAndPipelineStage( + filter, Filters.or(eitherLeft.getPipelineStage(), eitherRight.getPipelineStage())); + } + } else if (filter instanceof StringContains) { + StringContains stringContains = (StringContains) filter; + String fieldName = unquoteFieldName(stringContains.attribute()); + return new FilterAndPipelineStage( + filter, Filters.regex(fieldName, format(".*%s.*", stringContains.value()))); + } else if (filter instanceof StringEndsWith) { + StringEndsWith stringEndsWith = (StringEndsWith) filter; + String fieldName = unquoteFieldName(stringEndsWith.attribute()); + return new FilterAndPipelineStage( + filter, Filters.regex(fieldName, format(".*%s$", stringEndsWith.value()))); + } else if (filter instanceof StringStartsWith) { + StringStartsWith stringStartsWith = (StringStartsWith) filter; + String fieldName = unquoteFieldName(stringStartsWith.attribute()); + return new FilterAndPipelineStage( + filter, Filters.regex(fieldName, format("^%s.*", stringStartsWith.value()))); + } + return new FilterAndPipelineStage(filter, null); + } + + @VisibleForTesting + static String unquoteFieldName(final String fieldName) { + // Spark automatically escapes hyphenated names using backticks + if (fieldName.contains("`")) { + return new Column(fieldName).toString(); + } + return fieldName; + } + + private Optional getBsonValue(final String fieldName, final Object value) { + try { + StructType localSchema = schema; + DataType localDataType = localSchema; + + for (String localFieldName : fieldName.split("\\.")) { + StructField localField = localSchema.apply(localFieldName); + localDataType = localField.dataType(); + if (localField.dataType() instanceof StructType) { + localSchema = (StructType) localField.dataType(); + } + } + RowToBsonDocumentConverter.ObjectToBsonValue objectToBsonValue = + createObjectToBsonValue(localDataType, WriteConfig.ConvertJson.FALSE, false); + return Optional.of(objectToBsonValue.apply(value)); + } catch (Exception e) { + // ignore + return Optional.empty(); + } + } + + /** FilterAndPipelineStage - contains an optional pipeline stage for the filter. */ + public static final class FilterAndPipelineStage { + + private final Filter filter; + private final Bson pipelineStage; + + private FilterAndPipelineStage(final Filter filter, @Nullable final Bson pipelineStage) { + this.filter = filter; + this.pipelineStage = pipelineStage; + } + + public Filter getFilter() { + return filter; + } + + public Bson getPipelineStage() { + return pipelineStage; + } + + public boolean hasPipelineStage() { + return pipelineStage != null; + } + } +} diff --git a/src/main/java/com/mongodb/spark/sql/connector/MongoCatalog.java b/src/main/java/com/mongodb/spark/sql/connector/MongoCatalog.java index 22feed44..9b4c0677 100644 --- a/src/main/java/com/mongodb/spark/sql/connector/MongoCatalog.java +++ b/src/main/java/com/mongodb/spark/sql/connector/MongoCatalog.java @@ -28,6 +28,7 @@ import com.mongodb.spark.sql.connector.config.ReadConfig; import com.mongodb.spark.sql.connector.config.WriteConfig; import com.mongodb.spark.sql.connector.exceptions.MongoSparkException; +import com.mongodb.spark.sql.connector.schema.InferSchema; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -239,7 +240,7 @@ public Table loadTable(final Identifier identifier) throws NoSuchTableException properties.put( MongoConfig.READ_PREFIX + MongoConfig.DATABASE_NAME_CONFIG, identifier.namespace()[0]); properties.put(MongoConfig.READ_PREFIX + MongoConfig.COLLECTION_NAME_CONFIG, identifier.name()); - return new MongoTable(MongoConfig.readConfig(properties)); + return new MongoTable(InferSchema.inferSchema(new CaseInsensitiveStringMap(properties)), MongoConfig.readConfig(properties)); } /** diff --git a/src/main/java/com/mongodb/spark/sql/connector/MongoTable.java b/src/main/java/com/mongodb/spark/sql/connector/MongoTable.java index 0dc7039d..76d77bc2 100644 --- a/src/main/java/com/mongodb/spark/sql/connector/MongoTable.java +++ b/src/main/java/com/mongodb/spark/sql/connector/MongoTable.java @@ -19,7 +19,12 @@ import static java.util.Arrays.asList; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.MongoDatabase; +import com.mongodb.client.model.Filters; import com.mongodb.spark.connector.Versions; +import com.mongodb.spark.sql.connector.ExpressionConverter.FilterAndPipelineStage; import com.mongodb.spark.sql.connector.config.MongoConfig; import com.mongodb.spark.sql.connector.config.ReadConfig; import com.mongodb.spark.sql.connector.config.WriteConfig; @@ -27,9 +32,12 @@ import com.mongodb.spark.sql.connector.write.MongoWriteBuilder; import java.util.Arrays; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.stream.Collectors; +import org.apache.spark.sql.connector.catalog.SupportsDelete; import org.apache.spark.sql.connector.catalog.SupportsRead; import org.apache.spark.sql.connector.catalog.SupportsWrite; import org.apache.spark.sql.connector.catalog.Table; @@ -38,13 +46,16 @@ import org.apache.spark.sql.connector.read.ScanBuilder; import org.apache.spark.sql.connector.write.LogicalWriteInfo; import org.apache.spark.sql.connector.write.WriteBuilder; +import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.bson.Document; +import org.bson.conversions.Bson; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** Represents a MongoDB Collection. */ -final class MongoTable implements Table, SupportsWrite, SupportsRead { +final class MongoTable implements Table, SupportsWrite, SupportsRead, SupportsDelete { private static final Logger LOGGER = LoggerFactory.getLogger(MongoTable.class); private static final Set TABLE_CAPABILITY_SET = new HashSet<>(asList( TableCapability.BATCH_WRITE, @@ -179,4 +190,22 @@ public int hashCode() { result = 31 * result + Arrays.hashCode(partitioning); return result; } + + @Override + public void deleteWhere(final Filter[] filters) { + ExpressionConverter converter = new ExpressionConverter(schema); + + List stages = Arrays.stream(filters) + .map(converter::processFilter) + .filter(FilterAndPipelineStage::hasPipelineStage) + .map(FilterAndPipelineStage::getPipelineStage) + .collect(Collectors.toList()); + Bson query = Filters.and(stages); + WriteConfig writeConfig = mongoConfig.toWriteConfig(); + + MongoClient mongoClient = writeConfig.getMongoClient(); + MongoDatabase database = mongoClient.getDatabase(writeConfig.getDatabaseName()); + MongoCollection collection = database.getCollection(writeConfig.getCollectionName()); + collection.deleteMany(query); + } } diff --git a/src/main/java/com/mongodb/spark/sql/connector/read/MongoScanBuilder.java b/src/main/java/com/mongodb/spark/sql/connector/read/MongoScanBuilder.java index f8c5c643..a557b2f4 100644 --- a/src/main/java/com/mongodb/spark/sql/connector/read/MongoScanBuilder.java +++ b/src/main/java/com/mongodb/spark/sql/connector/read/MongoScanBuilder.java @@ -17,23 +17,19 @@ package com.mongodb.spark.sql.connector.read; -import static com.mongodb.spark.sql.connector.schema.RowToBsonDocumentConverter.createObjectToBsonValue; -import static java.lang.String.format; import static java.util.Collections.emptyList; import static java.util.Collections.singletonList; import com.mongodb.client.model.Aggregates; import com.mongodb.client.model.Filters; -import com.mongodb.spark.sql.connector.assertions.Assertions; +import com.mongodb.spark.sql.connector.ExpressionConverter; +import com.mongodb.spark.sql.connector.ExpressionConverter.FilterAndPipelineStage; import com.mongodb.spark.sql.connector.config.MongoConfig; import com.mongodb.spark.sql.connector.config.ReadConfig; -import com.mongodb.spark.sql.connector.config.WriteConfig; -import com.mongodb.spark.sql.connector.schema.RowToBsonDocumentConverter; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Locale; -import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; import org.apache.spark.sql.Column; @@ -42,36 +38,20 @@ import org.apache.spark.sql.connector.read.ScanBuilder; import org.apache.spark.sql.connector.read.SupportsPushDownFilters; import org.apache.spark.sql.connector.read.SupportsPushDownRequiredColumns; -import org.apache.spark.sql.sources.And; -import org.apache.spark.sql.sources.EqualNullSafe; -import org.apache.spark.sql.sources.EqualTo; import org.apache.spark.sql.sources.Filter; -import org.apache.spark.sql.sources.GreaterThan; -import org.apache.spark.sql.sources.GreaterThanOrEqual; -import org.apache.spark.sql.sources.In; -import org.apache.spark.sql.sources.IsNotNull; -import org.apache.spark.sql.sources.IsNull; -import org.apache.spark.sql.sources.LessThan; -import org.apache.spark.sql.sources.LessThanOrEqual; -import org.apache.spark.sql.sources.Not; -import org.apache.spark.sql.sources.Or; -import org.apache.spark.sql.sources.StringContains; -import org.apache.spark.sql.sources.StringEndsWith; -import org.apache.spark.sql.sources.StringStartsWith; -import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.bson.BsonDocument; -import org.bson.BsonValue; -import org.bson.conversions.Bson; import org.jetbrains.annotations.ApiStatus; -import org.jetbrains.annotations.Nullable; import org.jetbrains.annotations.VisibleForTesting; /** A builder for a {@link MongoScan}. */ @ApiStatus.Internal public final class MongoScanBuilder - implements ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns { + implements + ScanBuilder, + SupportsPushDownFilters, + SupportsPushDownRequiredColumns { private final StructType schema; private final ReadConfig readConfig; private final boolean isCaseSensitive; @@ -121,8 +101,10 @@ public Scan build() { */ @Override public Filter[] pushFilters(final Filter[] filters) { + ExpressionConverter converter = new ExpressionConverter(schema); + List processed = - Arrays.stream(filters).map(this::processFilter).collect(Collectors.toList()); + Arrays.stream(filters).map(converter::processFilter).collect(Collectors.toList()); List withPipelines = processed.stream() .filter(FilterAndPipelineStage::hasPipelineStage) @@ -166,127 +148,6 @@ private String getColumnName(final StructField field) { return field.name(); } - /** - * Processes the Filter and if possible creates the equivalent aggregation pipeline stage. - * - * @param filter the filter to be applied - * @return the FilterAndPipelineStage which contains a pipeline stage if the filter is convertible - * into an aggregation pipeline. - */ - private FilterAndPipelineStage processFilter(final Filter filter) { - Assertions.ensureArgument(() -> filter != null, () -> "Invalid argument filter cannot be null"); - if (filter instanceof And) { - And andFilter = (And) filter; - FilterAndPipelineStage eitherLeft = processFilter(andFilter.left()); - FilterAndPipelineStage eitherRight = processFilter(andFilter.right()); - if (eitherLeft.hasPipelineStage() && eitherRight.hasPipelineStage()) { - return new FilterAndPipelineStage( - filter, Filters.and(eitherLeft.getPipelineStage(), eitherRight.getPipelineStage())); - } - } else if (filter instanceof EqualNullSafe) { - EqualNullSafe equalNullSafe = (EqualNullSafe) filter; - String fieldName = unquoteFieldName(equalNullSafe.attribute()); - return new FilterAndPipelineStage( - filter, - getBsonValue(fieldName, equalNullSafe.value()) - .map(bsonValue -> Filters.eq(fieldName, bsonValue)) - .orElse(null)); - } else if (filter instanceof EqualTo) { - EqualTo equalTo = (EqualTo) filter; - String fieldName = unquoteFieldName(equalTo.attribute()); - return new FilterAndPipelineStage( - filter, - getBsonValue(fieldName, equalTo.value()) - .map(bsonValue -> Filters.eq(fieldName, bsonValue)) - .orElse(null)); - } else if (filter instanceof GreaterThan) { - GreaterThan greaterThan = (GreaterThan) filter; - String fieldName = unquoteFieldName(greaterThan.attribute()); - return new FilterAndPipelineStage( - filter, - getBsonValue(fieldName, greaterThan.value()) - .map(bsonValue -> Filters.gt(fieldName, bsonValue)) - .orElse(null)); - } else if (filter instanceof GreaterThanOrEqual) { - GreaterThanOrEqual greaterThanOrEqual = (GreaterThanOrEqual) filter; - String fieldName = unquoteFieldName(greaterThanOrEqual.attribute()); - return new FilterAndPipelineStage( - filter, - getBsonValue(fieldName, greaterThanOrEqual.value()) - .map(bsonValue -> Filters.gte(fieldName, bsonValue)) - .orElse(null)); - } else if (filter instanceof In) { - In inFilter = (In) filter; - String fieldName = unquoteFieldName(inFilter.attribute()); - List values = Arrays.stream(inFilter.values()) - .map(v -> getBsonValue(fieldName, v)) - .filter(Optional::isPresent) - .map(Optional::get) - .collect(Collectors.toList()); - - // Ensure all values were matched otherwise leave to Spark to filter. - Bson pipelineStage = null; - if (values.size() == inFilter.values().length) { - pipelineStage = Filters.in(fieldName, values); - } - return new FilterAndPipelineStage(filter, pipelineStage); - } else if (filter instanceof IsNull) { - IsNull isNullFilter = (IsNull) filter; - String fieldName = unquoteFieldName(isNullFilter.attribute()); - return new FilterAndPipelineStage(filter, Filters.eq(fieldName, null)); - } else if (filter instanceof IsNotNull) { - IsNotNull isNotNullFilter = (IsNotNull) filter; - String fieldName = unquoteFieldName(isNotNullFilter.attribute()); - return new FilterAndPipelineStage(filter, Filters.ne(fieldName, null)); - } else if (filter instanceof LessThan) { - LessThan lessThan = (LessThan) filter; - String fieldName = unquoteFieldName(lessThan.attribute()); - return new FilterAndPipelineStage( - filter, - getBsonValue(fieldName, lessThan.value()) - .map(bsonValue -> Filters.lt(fieldName, bsonValue)) - .orElse(null)); - } else if (filter instanceof LessThanOrEqual) { - LessThanOrEqual lessThanOrEqual = (LessThanOrEqual) filter; - String fieldName = unquoteFieldName(lessThanOrEqual.attribute()); - return new FilterAndPipelineStage( - filter, - getBsonValue(fieldName, lessThanOrEqual.value()) - .map(bsonValue -> Filters.lte(fieldName, bsonValue)) - .orElse(null)); - } else if (filter instanceof Not) { - Not notFilter = (Not) filter; - FilterAndPipelineStage notChild = processFilter(notFilter.child()); - if (notChild.hasPipelineStage()) { - return new FilterAndPipelineStage(filter, Filters.not(notChild.pipelineStage)); - } - } else if (filter instanceof Or) { - Or or = (Or) filter; - FilterAndPipelineStage eitherLeft = processFilter(or.left()); - FilterAndPipelineStage eitherRight = processFilter(or.right()); - if (eitherLeft.hasPipelineStage() && eitherRight.hasPipelineStage()) { - return new FilterAndPipelineStage( - filter, Filters.or(eitherLeft.getPipelineStage(), eitherRight.getPipelineStage())); - } - } else if (filter instanceof StringContains) { - StringContains stringContains = (StringContains) filter; - String fieldName = unquoteFieldName(stringContains.attribute()); - return new FilterAndPipelineStage( - filter, Filters.regex(fieldName, format(".*%s.*", stringContains.value()))); - } else if (filter instanceof StringEndsWith) { - StringEndsWith stringEndsWith = (StringEndsWith) filter; - String fieldName = unquoteFieldName(stringEndsWith.attribute()); - return new FilterAndPipelineStage( - filter, Filters.regex(fieldName, format(".*%s$", stringEndsWith.value()))); - } else if (filter instanceof StringStartsWith) { - StringStartsWith stringStartsWith = (StringStartsWith) filter; - String fieldName = unquoteFieldName(stringStartsWith.attribute()); - return new FilterAndPipelineStage( - filter, Filters.regex(fieldName, format("^%s.*", stringStartsWith.value()))); - } - return new FilterAndPipelineStage(filter, null); - } - @VisibleForTesting static String unquoteFieldName(final String fieldName) { // Spark automatically escapes hyphenated names using backticks @@ -295,49 +156,4 @@ static String unquoteFieldName(final String fieldName) { } return fieldName; } - - private Optional getBsonValue(final String fieldName, final Object value) { - try { - StructType localSchema = schema; - DataType localDataType = localSchema; - - for (String localFieldName : fieldName.split("\\.")) { - StructField localField = localSchema.apply(localFieldName); - localDataType = localField.dataType(); - if (localField.dataType() instanceof StructType) { - localSchema = (StructType) localField.dataType(); - } - } - RowToBsonDocumentConverter.ObjectToBsonValue objectToBsonValue = - createObjectToBsonValue(localDataType, WriteConfig.ConvertJson.FALSE, false); - return Optional.of(objectToBsonValue.apply(value)); - } catch (Exception e) { - // ignore - return Optional.empty(); - } - } - - /** FilterAndPipelineStage - contains an optional pipeline stage for the filter. */ - private static final class FilterAndPipelineStage { - - private final Filter filter; - private final Bson pipelineStage; - - private FilterAndPipelineStage(final Filter filter, @Nullable final Bson pipelineStage) { - this.filter = filter; - this.pipelineStage = pipelineStage; - } - - public Filter getFilter() { - return filter; - } - - public Bson getPipelineStage() { - return pipelineStage; - } - - boolean hasPipelineStage() { - return pipelineStage != null; - } - } }