Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@
*.jar
target
**/.DS_Store
bin
.project
5 changes: 5 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@
<version>10.2.0</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.lucene</groupId>
<artifactId>lucene-backward-codecs</artifactId>
<version>10.2.0</version>
</dependency>
<dependency>
<groupId>org.apache.lucene</groupId>
<artifactId>lucene-test-framework</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;

Expand Down Expand Up @@ -43,17 +42,25 @@ public class CuVS2510GPUVectorsFormat extends KnnVectorsFormat {
static final IndexType DEFAULT_INDEX_TYPE = IndexType.CAGRA;

static CuVSResources resources = cuVSResourcesOrNull();

/** The format for storing, reading, and merging raw vectors on disk. */
private static final FlatVectorsFormat flatVectorsFormat =
new Lucene99FlatVectorsFormat(DefaultFlatVectorScorer.INSTANCE);
static final LuceneProvider LUCENE_PROVIDER;
static final FlatVectorsFormat FLAT_VECTORS_FORMAT;

final int maxDimensions = 4096;
final int cuvsWriterThreads;
final int intGraphDegree;
final int graphDegree;
final CuVS2510GPUVectorsWriter.IndexType indexType; // the index type to build, when writing

static {
try {
LUCENE_PROVIDER = LuceneProvider.getInstance("99");
FLAT_VECTORS_FORMAT =
LUCENE_PROVIDER.getLuceneFlatVectorsFormatInstance(DefaultFlatVectorScorer.INSTANCE);
} catch (Exception e) {
throw new ExceptionInInitializerError(e.getMessage());
}
}

/**
* Initializes the {@link CuVS2510GPUVectorsFormat} with default parameter values.
*
Expand Down Expand Up @@ -92,7 +99,7 @@ public CuVS2510GPUVectorsFormat(
@Override
public CuVS2510GPUVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
checkSupported();
var flatWriter = flatVectorsFormat.fieldsWriter(state);
var flatWriter = FLAT_VECTORS_FORMAT.fieldsWriter(state);
return new CuVS2510GPUVectorsWriter(
state, cuvsWriterThreads, intGraphDegree, graphDegree, indexType, resources, flatWriter);
}
Expand All @@ -103,7 +110,7 @@ public CuVS2510GPUVectorsWriter fieldsWriter(SegmentWriteState state) throws IOE
@Override
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
checkSupported();
return new CuVS2510GPUVectorsReader(state, resources, flatVectorsFormat.fieldsReader(state));
return new CuVS2510GPUVectorsReader(state, resources, FLAT_VECTORS_FORMAT.fieldsReader(state));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import static com.nvidia.cuvs.lucene.CuVS2510GPUVectorsFormat.CUVS_META_CODEC_EXT;
import static com.nvidia.cuvs.lucene.CuVS2510GPUVectorsFormat.CUVS_META_CODEC_NAME;
import static com.nvidia.cuvs.lucene.CuVS2510GPUVectorsFormat.VERSION_CURRENT;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS;
import static org.apache.lucene.index.VectorEncoding.FLOAT32;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance;
Expand Down Expand Up @@ -68,6 +67,9 @@ public class CuVS2510GPUVectorsWriter extends KnnVectorsWriter {
/** The name of the CUVS component for the info-stream * */
private static final String CUVS_COMPONENT = "CUVS";

private static final LuceneProvider LUCENE_PROVIDER;
private static final List<VectorSimilarityFunction> VECTOR_SIMILARITY_FUNCTIONS;

// The minimum number of vectors in the dataset required before
// we attempt to build a Cagra index
static final int MIN_CAGRA_INDEX_SIZE = 2;
Expand All @@ -85,6 +87,15 @@ public class CuVS2510GPUVectorsWriter extends KnnVectorsWriter {
private final InfoStream infoStream;
private boolean finished;

static {
try {
LUCENE_PROVIDER = LuceneProvider.getInstance("99");
VECTOR_SIMILARITY_FUNCTIONS = LUCENE_PROVIDER.getSimilarityFunctions();
} catch (Exception e) {
throw new ExceptionInInitializerError(e.getMessage());
}
}

/**
* The cuVS index Types.
*/
Expand Down Expand Up @@ -442,8 +453,8 @@ private void writeMeta(
}

static int distFuncToOrd(VectorSimilarityFunction func) {
for (int i = 0; i < SIMILARITY_FUNCTIONS.size(); i++) {
if (SIMILARITY_FUNCTIONS.get(i).equals(func)) {
for (int i = 0; i < VECTOR_SIMILARITY_FUNCTIONS.size(); i++) {
if (VECTOR_SIMILARITY_FUNCTIONS.get(i).equals(func)) {
return (byte) i;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
package com.nvidia.cuvs.lucene;

import static com.nvidia.cuvs.lucene.Utils.cuVSResourcesOrNull;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER;

import com.nvidia.cuvs.CuVSResources;
import com.nvidia.cuvs.LibraryException;
Expand All @@ -18,9 +15,6 @@
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;

Expand All @@ -43,22 +37,36 @@ public class Lucene99AcceleratedHNSWVectorsFormat extends KnnVectorsFormat {
static final String HNSW_META_CODEC_EXT = "vem";
static final String HNSW_INDEX_CODEC_NAME = "Lucene99HnswVectorsFormatIndex";
static final String HNSW_INDEX_EXT = "vex";
static final LuceneProvider LUCENE_PROVIDER;

private static CuVSResources resources = cuVSResourcesOrNull();

/** The format for storing, reading, and merging raw vectors on disk. */
private static final FlatVectorsFormat flatVectorsFormat =
new Lucene99FlatVectorsFormat(DefaultFlatVectorScorer.INSTANCE);
private static final FlatVectorsFormat FLAT_VECTORS_FORMAT;
private static final Integer MAX_CONN;
private static final Integer BEAM_WIDTH;
private static final Integer NUM_MERGE_WORKERS;

private final int maxDimensions = 4096;
private final int cuvsWriterThreads;
private final int intGraphDegree;
private final int graphDegree;
private final int hnswLayers;

private final int maxConn;
private final int beamWidth;

static {
try {
LUCENE_PROVIDER = LuceneProvider.getInstance("99");
MAX_CONN = LUCENE_PROVIDER.getStaticIntParam("DEFAULT_MAX_CONN");
BEAM_WIDTH = LUCENE_PROVIDER.getStaticIntParam("DEFAULT_BEAM_WIDTH");
NUM_MERGE_WORKERS = LUCENE_PROVIDER.getStaticIntParam("DEFAULT_BEAM_WIDTH");
FLAT_VECTORS_FORMAT =
LUCENE_PROVIDER.getLuceneFlatVectorsFormatInstance(DefaultFlatVectorScorer.INSTANCE);
} catch (Exception e) {
throw new ExceptionInInitializerError(e.getMessage());
}
}

/**
* Initializes {@link Lucene99AcceleratedHNSWVectorsFormat} with default values.
*
Expand All @@ -70,8 +78,8 @@ public Lucene99AcceleratedHNSWVectorsFormat() {
DEFAULT_INTERMEDIATE_GRAPH_DEGREE,
DEFAULT_GRAPH_DEGREE,
DEFAULT_HNSW_GRAPH_LAYERS,
DEFAULT_MAX_CONN,
DEFAULT_BEAM_WIDTH);
MAX_CONN,
BEAM_WIDTH);
}

/**
Expand Down Expand Up @@ -105,7 +113,7 @@ public Lucene99AcceleratedHNSWVectorsFormat(
*/
@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
var flatWriter = flatVectorsFormat.fieldsWriter(state);
var flatWriter = FLAT_VECTORS_FORMAT.fieldsWriter(state);
if (supported()) {
log.info("cuVS is supported so using the Lucene99AcceleratedHNSWVectorsWriter");
return new Lucene99AcceleratedHNSWVectorsWriter(
Expand All @@ -114,8 +122,13 @@ public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException
log.warning(
"GPU based indexing not supported, falling back to using the Lucene99HnswVectorsWriter");
// TODO: Make num merge workers configurable.
return new Lucene99HnswVectorsWriter(
state, maxConn, beamWidth, flatWriter, DEFAULT_NUM_MERGE_WORKER, null);
try {
return LUCENE_PROVIDER.getLuceneHnswVectorsWriterInstance(
state, maxConn, beamWidth, flatWriter, NUM_MERGE_WORKERS, null);
} catch (Exception e) {
// maybe there is a better suited option to throwing RuntimeException? Need to explore.
throw new RuntimeException(e.getMessage());
}
}
}

Expand All @@ -124,7 +137,13 @@ public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException
*/
@Override
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
try {
return LUCENE_PROVIDER.getLuceneHnswVectorsReaderInstance(
state, FLAT_VECTORS_FORMAT.fieldsReader(state));
} catch (Exception e) {
// maybe there is a better suited option to throwing RuntimeException? Need to explore.
throw new RuntimeException(e.getMessage());
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import static com.nvidia.cuvs.lucene.Lucene99AcceleratedHNSWVectorsFormat.HNSW_INDEX_EXT;
import static com.nvidia.cuvs.lucene.Lucene99AcceleratedHNSWVectorsFormat.HNSW_META_CODEC_EXT;
import static com.nvidia.cuvs.lucene.Lucene99AcceleratedHNSWVectorsFormat.HNSW_META_CODEC_NAME;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS;
import static org.apache.lucene.index.VectorEncoding.FLOAT32;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance;
Expand Down Expand Up @@ -37,7 +36,6 @@
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
Expand Down Expand Up @@ -76,6 +74,10 @@ public class Lucene99AcceleratedHNSWVectorsWriter extends KnnVectorsWriter {
/** The name of the CUVS component for the info-stream * */
private static final String CUVS_COMPONENT = "CUVS";

private static final LuceneProvider LUCENE_PROVIDER;
private static final Integer VERSION_CURRENT;
private static final List<VectorSimilarityFunction> VECTOR_SIMILARITY_FUNCTIONS;

private final int cuvsWriterThreads;
private final int intGraphDegree;
private final int graphDegree;
Expand All @@ -90,6 +92,16 @@ public class Lucene99AcceleratedHNSWVectorsWriter extends KnnVectorsWriter {
private String vemFileName;
private String vexFileName;

static {
try {
LUCENE_PROVIDER = LuceneProvider.getInstance("99");
VERSION_CURRENT = LUCENE_PROVIDER.getStaticIntParam("VERSION_CURRENT");
VECTOR_SIMILARITY_FUNCTIONS = LUCENE_PROVIDER.getSimilarityFunctions();
} catch (Exception e) {
throw new ExceptionInInitializerError(e.getMessage());
}
}

/**
* Initializes {@link Lucene99AcceleratedHNSWVectorsWriter}
*
Expand Down Expand Up @@ -136,13 +148,13 @@ public Lucene99AcceleratedHNSWVectorsWriter(
CodecUtil.writeIndexHeader(
hnswMeta,
HNSW_META_CODEC_NAME,
Lucene99HnswVectorsFormat.VERSION_CURRENT,
VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix);
CodecUtil.writeIndexHeader(
hnswVectorIndex,
HNSW_INDEX_CODEC_NAME,
Lucene99HnswVectorsFormat.VERSION_CURRENT,
VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix);

Expand Down Expand Up @@ -662,8 +674,8 @@ private void writeEmpty(FieldInfo fieldInfo) throws IOException {
}

static int distFuncToOrd(VectorSimilarityFunction func) {
for (int i = 0; i < SIMILARITY_FUNCTIONS.size(); i++) {
if (SIMILARITY_FUNCTIONS.get(i).equals(func)) {
for (int i = 0; i < VECTOR_SIMILARITY_FUNCTIONS.size(); i++) {
if (VECTOR_SIMILARITY_FUNCTIONS.get(i).equals(func)) {
return (byte) i;
}
}
Expand Down
Loading