diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java index 204b70385641e..59cd581f9dc0e 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java @@ -46,6 +46,8 @@ record CmdLineArgs( int indexThreads, boolean reindex, boolean forceMerge, + float filterSelectivity, + long seed, VectorSimilarityFunction vectorSpace, int quantizeBits, VectorEncoding vectorEncoding, @@ -75,6 +77,8 @@ record CmdLineArgs( static final ParseField VECTOR_ENCODING_FIELD = new ParseField("vector_encoding"); static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); static final ParseField EARLY_TERMINATION_FIELD = new ParseField("early_termination"); + static final ParseField FILTER_SELECTIVITY_FIELD = new ParseField("filter_selectivity"); + static final ParseField SEED_FIELD = new ParseField("seed"); static CmdLineArgs fromXContent(XContentParser parser) throws IOException { Builder builder = PARSER.apply(parser, null); @@ -106,6 +110,8 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException { PARSER.declareString(Builder::setVectorEncoding, VECTOR_ENCODING_FIELD); PARSER.declareInt(Builder::setDimensions, DIMENSIONS_FIELD); PARSER.declareBoolean(Builder::setEarlyTermination, EARLY_TERMINATION_FIELD); + PARSER.declareFloat(Builder::setFilterSelectivity, FILTER_SELECTIVITY_FIELD); + PARSER.declareLong(Builder::setSeed, SEED_FIELD); } @Override @@ -136,6 +142,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(QUANTIZE_BITS_FIELD.getPreferredName(), quantizeBits); builder.field(VECTOR_ENCODING_FIELD.getPreferredName(), vectorEncoding.name().toLowerCase(Locale.ROOT)); builder.field(DIMENSIONS_FIELD.getPreferredName(), dimensions); + builder.field(EARLY_TERMINATION_FIELD.getPreferredName(), earlyTermination); + builder.field(FILTER_SELECTIVITY_FIELD.getPreferredName(), filterSelectivity); + builder.field(SEED_FIELD.getPreferredName(), seed); return builder.endObject(); } @@ -167,6 +176,8 @@ static class Builder { private VectorEncoding vectorEncoding = VectorEncoding.FLOAT32; private int dimensions; private boolean earlyTermination; + private float filterSelectivity = 1f; + private long seed = 1751900822751L; public Builder setDocVectors(String docVectors) { this.docVectors = PathUtils.get(docVectors); @@ -278,6 +289,16 @@ public Builder setEarlyTermination(Boolean patience) { return this; } + public Builder setFilterSelectivity(float filterSelectivity) { + this.filterSelectivity = filterSelectivity; + return this; + } + + public Builder setSeed(long seed) { + this.seed = seed; + return this; + } + public CmdLineArgs build() { if (docVectors == null) { throw new IllegalArgumentException("Document vectors path must be provided"); @@ -305,6 +326,8 @@ public CmdLineArgs build() { indexThreads, reindex, forceMerge, + filterSelectivity, + seed, vectorSpace, quantizeBits, vectorEncoding, diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java index ee1ed65a4f97e..fe20f895d3ea9 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java @@ -178,10 +178,20 @@ public static void main(String[] args) throws Exception { ? cmdLineArgs.nProbes() : new int[] { 0 }; String indexType = cmdLineArgs.indexType().name().toLowerCase(Locale.ROOT); - Results indexResults = new Results(cmdLineArgs.docVectors().getFileName().toString(), indexType, cmdLineArgs.numDocs()); + Results indexResults = new Results( + cmdLineArgs.docVectors().getFileName().toString(), + indexType, + cmdLineArgs.numDocs(), + cmdLineArgs.filterSelectivity() + ); Results[] results = new Results[nProbes.length]; for (int i = 0; i < nProbes.length; i++) { - results[i] = new Results(cmdLineArgs.docVectors().getFileName().toString(), indexType, cmdLineArgs.numDocs()); + results[i] = new Results( + cmdLineArgs.docVectors().getFileName().toString(), + indexType, + cmdLineArgs.numDocs(), + cmdLineArgs.filterSelectivity() + ); } logger.info("Running KNN index tester with arguments: " + cmdLineArgs); Codec codec = createCodec(cmdLineArgs); @@ -244,7 +254,8 @@ public String toString() { "avg_cpu_count", "QPS", "recall", - "visited" }; + "visited", + "filter_selectivity" }; // Calculate appropriate column widths based on headers and data @@ -274,7 +285,8 @@ public String toString() { String.format(Locale.ROOT, "%.2f", queryResult.avgCpuCount), String.format(Locale.ROOT, "%.2f", queryResult.qps), String.format(Locale.ROOT, "%.2f", queryResult.avgRecall), - String.format(Locale.ROOT, "%.2f", queryResult.averageVisited) }; + String.format(Locale.ROOT, "%.2f", queryResult.averageVisited), + String.format(Locale.ROOT, "%.2f", queryResult.filterSelectivity), }; } printBlock(sb, searchHeaders, queryResultsArray); @@ -339,6 +351,7 @@ private int[] calculateColumnWidths(String[] headers, String[]... data) { static class Results { final String indexType, indexName; final int numDocs; + final float filterSelectivity; long indexTimeMS; long forceMergeTimeMS; int numSegments; @@ -350,10 +363,11 @@ static class Results { double netCpuTimeMS; double avgCpuCount; - Results(String indexName, String indexType, int numDocs) { + Results(String indexName, String indexType, int numDocs, float filterSelectivity) { this.indexName = indexName; this.indexType = indexType; this.numDocs = numDocs; + this.filterSelectivity = filterSelectivity; } } diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java index 7cf8a5846cba3..fb84df66b0138 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java @@ -22,6 +22,7 @@ import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.StoredFields; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; @@ -32,17 +33,29 @@ import org.apache.lucene.queries.function.valuesource.ConstKnnFloatValueSource; import org.apache.lucene.queries.function.valuesource.FloatKnnVectorFieldSource; import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.ConstantScoreScorer; +import org.apache.lucene.search.ConstantScoreWeight; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.KnnByteVectorQuery; import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.PatienceKnnVectorQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.Weight; import org.apache.lucene.store.Directory; import org.apache.lucene.store.FSDirectory; import org.apache.lucene.store.MMapDirectory; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.BitSetIterator; +import org.apache.lucene.util.FixedBitSet; import org.elasticsearch.common.io.Channels; import org.elasticsearch.core.PathUtils; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; @@ -64,9 +77,11 @@ import java.nio.file.Path; import java.nio.file.attribute.FileTime; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashSet; import java.util.List; import java.util.Objects; +import java.util.Random; import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; @@ -88,8 +103,8 @@ class KnnSearcher { private final Path queryPath; private final int numDocs; private final int numQueryVectors; - private final long randomSeed = 42; - private final float selectivity = 1f; + private final long randomSeed; + private final float selectivity; private final int topK; private final int efSearch; private final int nProbe; @@ -120,9 +135,12 @@ class KnnSearcher { this.indexType = cmdLineArgs.indexType(); this.searchThreads = cmdLineArgs.searchThreads(); this.numSearchers = cmdLineArgs.numSearchers(); + this.randomSeed = cmdLineArgs.seed(); + this.selectivity = cmdLineArgs.filterSelectivity(); } void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) throws IOException { + Query filterQuery = this.selectivity < 1f ? generateRandomQuery(new Random(randomSeed), indexPath, numDocs, selectivity) : null; TopDocs[] results = new TopDocs[numQueryVectors]; int[][] resultIds = new int[numQueryVectors][]; long elapsed, totalCpuTimeMS, totalVisited = 0; @@ -164,10 +182,10 @@ void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) th for (int i = 0; i < numQueryVectors; i++) { if (vectorEncoding.equals(VectorEncoding.BYTE)) { targetReader.next(targetBytes); - doVectorQuery(targetBytes, searcher, earlyTermination); + doVectorQuery(targetBytes, searcher, filterQuery, earlyTermination); } else { targetReader.next(target); - doVectorQuery(target, searcher, earlyTermination); + doVectorQuery(target, searcher, filterQuery, earlyTermination); } } targetReader.reset(); @@ -180,7 +198,7 @@ void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) th for (int s = 0; s < numSearchers; s++) { queryConsumers[s] = i -> { try { - results[i] = doVectorQuery(queries[i], searcher, earlyTermination); + results[i] = doVectorQuery(queries[i], searcher, filterQuery, earlyTermination); } catch (IOException e) { throw new UncheckedIOException(e); } @@ -194,7 +212,7 @@ void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) th for (int s = 0; s < numSearchers; s++) { queryConsumers[s] = i -> { try { - results[i] = doVectorQuery(queries[i], searcher, earlyTermination); + results[i] = doVectorQuery(queries[i], searcher, filterQuery, earlyTermination); } catch (IOException e) { throw new UncheckedIOException(e); } @@ -274,7 +292,7 @@ void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) th } } logger.info("checking results"); - int[][] nn = getOrCalculateExactNN(offsetByteSize); + int[][] nn = getOrCalculateExactNN(offsetByteSize, filterQuery); finalResults.nProbe = indexType == KnnIndexTester.IndexType.IVF ? nProbe : 0; finalResults.avgRecall = checkResults(resultIds, nn, topK); finalResults.qps = (1000f * numQueryVectors) / elapsed; @@ -284,7 +302,34 @@ void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) th finalResults.avgCpuCount = (double) totalCpuTimeMS / elapsed; } - private int[][] getOrCalculateExactNN(int vectorFileOffsetBytes) throws IOException { + private static Query generateRandomQuery(Random random, Path indexPath, int size, float selectivity) throws IOException { + FixedBitSet bitSet = new FixedBitSet(size); + for (int i = 0; i < size; i++) { + if (random.nextFloat() < selectivity) { + bitSet.set(i); + } else { + bitSet.clear(i); + } + } + + try (Directory dir = FSDirectory.open(indexPath); DirectoryReader reader = DirectoryReader.open(dir)) { + BitSet[] segmentDocs = new BitSet[reader.leaves().size()]; + for (var leafContext : reader.leaves()) { + var storedFields = leafContext.reader().storedFields(); + FixedBitSet segmentBitSet = new FixedBitSet(reader.maxDoc()); + for (int d = 0; d < leafContext.reader().maxDoc(); d++) { + int docID = Integer.parseInt(storedFields.document(d, Set.of(ID_FIELD)).get(ID_FIELD)); + if (bitSet.get(docID)) { + segmentBitSet.set(d); + } + } + segmentDocs[leafContext.ord] = segmentBitSet; + } + return new BitSetQuery(segmentDocs); + } + } + + private int[][] getOrCalculateExactNN(int vectorFileOffsetBytes, Query filterQuery) throws IOException { // look in working directory for cached nn file String hash = Integer.toString( Objects.hash( @@ -312,9 +357,9 @@ private int[][] getOrCalculateExactNN(int vectorFileOffsetBytes) throws IOExcept // checking low-precision recall int[][] nn; if (vectorEncoding.equals(VectorEncoding.BYTE)) { - nn = computeExactNNByte(queryPath, vectorFileOffsetBytes); + nn = computeExactNNByte(queryPath, filterQuery, vectorFileOffsetBytes); } else { - nn = computeExactNN(queryPath, vectorFileOffsetBytes); + nn = computeExactNN(queryPath, filterQuery, vectorFileOffsetBytes); } writeExactNN(nn, nnPath); long elapsedMS = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNS); // ns -> ms @@ -333,7 +378,7 @@ private boolean isNewer(Path path, Path... others) throws IOException { return true; } - TopDocs doVectorQuery(byte[] vector, IndexSearcher searcher, boolean earlyTermination) throws IOException { + TopDocs doVectorQuery(byte[] vector, IndexSearcher searcher, Query filterQuery, boolean earlyTermination) throws IOException { Query knnQuery; if (overSamplingFactor > 1f) { throw new IllegalArgumentException("oversampling factor > 1 is not supported for byte vectors"); @@ -346,7 +391,7 @@ TopDocs doVectorQuery(byte[] vector, IndexSearcher searcher, boolean earlyTermin vector, topK, efSearch, - null, + filterQuery, DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy() ); if (indexType == KnnIndexTester.IndexType.HNSW && earlyTermination) { @@ -360,7 +405,7 @@ TopDocs doVectorQuery(byte[] vector, IndexSearcher searcher, boolean earlyTermin return new TopDocs(new TotalHits(profiler.getVectorOpsCount(), docs.totalHits.relation()), docs.scoreDocs); } - TopDocs doVectorQuery(float[] vector, IndexSearcher searcher, boolean earlyTermination) throws IOException { + TopDocs doVectorQuery(float[] vector, IndexSearcher searcher, Query filterQuery, boolean earlyTermination) throws IOException { Query knnQuery; int topK = this.topK; if (overSamplingFactor > 1f) { @@ -369,14 +414,14 @@ TopDocs doVectorQuery(float[] vector, IndexSearcher searcher, boolean earlyTermi } int efSearch = Math.max(topK, this.efSearch); if (indexType == KnnIndexTester.IndexType.IVF) { - knnQuery = new IVFKnnFloatVectorQuery(VECTOR_FIELD, vector, topK, efSearch, null, nProbe); + knnQuery = new IVFKnnFloatVectorQuery(VECTOR_FIELD, vector, topK, efSearch, filterQuery, nProbe); } else { knnQuery = new ESKnnFloatVectorQuery( VECTOR_FIELD, vector, topK, efSearch, - null, + filterQuery, DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy() ); if (indexType == KnnIndexTester.IndexType.HNSW && earlyTermination) { @@ -449,7 +494,7 @@ private void writeExactNN(int[][] nn, Path nnPath) throws IOException { } } - private int[][] computeExactNN(Path queryPath, int vectorFileOffsetBytes) throws IOException { + private int[][] computeExactNN(Path queryPath, Query filterQuery, int vectorFileOffsetBytes) throws IOException { int[][] result = new int[numQueryVectors][]; try (Directory dir = FSDirectory.open(indexPath); DirectoryReader reader = DirectoryReader.open(dir)) { List> tasks = new ArrayList<>(); @@ -463,7 +508,7 @@ private int[][] computeExactNN(Path queryPath, int vectorFileOffsetBytes) throws for (int i = 0; i < numQueryVectors; i++) { float[] queryVector = new float[dim]; queryReader.next(queryVector); - tasks.add(new ComputeNNFloatTask(i, topK, queryVector, result, reader, similarityFunction)); + tasks.add(new ComputeNNFloatTask(i, topK, queryVector, result, reader, filterQuery, similarityFunction)); } ForkJoinPool.commonPool().invokeAll(tasks); } @@ -471,7 +516,7 @@ private int[][] computeExactNN(Path queryPath, int vectorFileOffsetBytes) throws } } - private int[][] computeExactNNByte(Path queryPath, int vectorFileOffsetBytes) throws IOException { + private int[][] computeExactNNByte(Path queryPath, Query filterQuery, int vectorFileOffsetBytes) throws IOException { int[][] result = new int[numQueryVectors][]; try (Directory dir = FSDirectory.open(indexPath); DirectoryReader reader = DirectoryReader.open(dir)) { List> tasks = new ArrayList<>(); @@ -480,7 +525,7 @@ private int[][] computeExactNNByte(Path queryPath, int vectorFileOffsetBytes) th for (int i = 0; i < numQueryVectors; i++) { byte[] queryVector = new byte[dim]; queryReader.next(queryVector); - tasks.add(new ComputeNNByteTask(i, queryVector, result, reader, similarityFunction)); + tasks.add(new ComputeNNByteTask(i, queryVector, result, reader, filterQuery, similarityFunction)); } ForkJoinPool.commonPool().invokeAll(tasks); } @@ -495,6 +540,7 @@ static class ComputeNNFloatTask implements Callable { private final int[][] result; private final IndexReader reader; private final VectorSimilarityFunction similarityFunction; + private final Query filterQuery; private final int topK; ComputeNNFloatTask( @@ -503,6 +549,7 @@ static class ComputeNNFloatTask implements Callable { float[] query, int[][] result, IndexReader reader, + Query filterQuery, VectorSimilarityFunction similarityFunction ) { this.queryOrd = queryOrd; @@ -510,6 +557,7 @@ static class ComputeNNFloatTask implements Callable { this.result = result; this.reader = reader; this.similarityFunction = similarityFunction; + this.filterQuery = filterQuery; this.topK = topK; } @@ -520,6 +568,11 @@ public Void call() { var queryVector = new ConstKnnFloatValueSource(query); var docVectors = new FloatKnnVectorFieldSource(VECTOR_FIELD); Query query = new FunctionQuery(new FloatVectorSimilarityFunction(similarityFunction, queryVector, docVectors)); + if (filterQuery != null) { + query = new BooleanQuery.Builder().add(query, BooleanClause.Occur.SHOULD) + .add(filterQuery, BooleanClause.Occur.FILTER) + .build(); + } var topDocs = searcher.search(query, topK); result[queryOrd] = getResultIds(topDocs, reader.storedFields()); if ((queryOrd + 1) % 10 == 0) { @@ -538,13 +591,22 @@ static class ComputeNNByteTask implements Callable { private final byte[] query; private final int[][] result; private final IndexReader reader; + private final Query filterQuery; private final VectorSimilarityFunction similarityFunction; - ComputeNNByteTask(int queryOrd, byte[] query, int[][] result, IndexReader reader, VectorSimilarityFunction similarityFunction) { + ComputeNNByteTask( + int queryOrd, + byte[] query, + int[][] result, + IndexReader reader, + Query filterQuery, + VectorSimilarityFunction similarityFunction + ) { this.queryOrd = queryOrd; this.query = query; this.result = result; this.reader = reader; + this.filterQuery = filterQuery; this.similarityFunction = similarityFunction; } @@ -556,6 +618,11 @@ public Void call() { var queryVector = new ConstKnnByteVectorValueSource(query); var docVectors = new ByteKnnVectorFieldSource(VECTOR_FIELD); Query query = new FunctionQuery(new ByteVectorSimilarityFunction(similarityFunction, queryVector, docVectors)); + if (filterQuery != null) { + query = new BooleanQuery.Builder().add(query, BooleanClause.Occur.SHOULD) + .add(filterQuery, BooleanClause.Occur.FILTER) + .build(); + } var topDocs = searcher.search(query, topK); result[queryOrd] = getResultIds(topDocs, reader.storedFields()); if ((queryOrd + 1) % 10 == 0) { @@ -582,4 +649,57 @@ static int[] getResultIds(TopDocs topDocs, StoredFields storedFields) throws IOE return resultIds; } + private static class BitSetQuery extends Query { + private final BitSet[] segmentDocs; + + BitSetQuery(BitSet[] segmentDocs) { + this.segmentDocs = segmentDocs; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + return new ConstantScoreWeight(this, boost) { + public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { + var bitSet = segmentDocs[context.ord]; + var cardinality = bitSet.cardinality(); + var scorer = new ConstantScoreScorer(score(), scoreMode, new BitSetIterator(bitSet, cardinality)); + return new ScorerSupplier() { + @Override + public Scorer get(long leadCost) throws IOException { + return scorer; + } + + @Override + public long cost() { + return cardinality; + } + }; + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return false; + } + }; + } + + @Override + public void visit(QueryVisitor visitor) {} + + @Override + public String toString(String field) { + return "BitSetQuery"; + } + + @Override + public boolean equals(Object other) { + return sameClassAs(other) && Arrays.equals(segmentDocs, ((BitSetQuery) other).segmentDocs); + } + + @Override + public int hashCode() { + return 31 * classHash() + Arrays.hashCode(segmentDocs); + } + } + }