diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ebeeaa90..2c372826 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -3,18 +3,18 @@ name: HNSW CI on: [push, pull_request] jobs: - test_python: - runs-on: ${{ matrix.os }} + python: + runs-on: ${{ matrix.os }}-latest strategy: matrix: - os: [ubuntu-latest, windows-latest, macos-latest] - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + os: [ubuntu, windows, macos] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13.2"] steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - + - name: Build and install run: python -m pip install . @@ -36,11 +36,34 @@ jobs: python -m unittest discover -v --start-directory tests/python --pattern "bindings_test*.py" shell: bash - test_cpp: - runs-on: ${{ matrix.os }} + cpp: + runs-on: ${{ matrix.os }}-latest strategy: matrix: - os: [ubuntu-latest, windows-latest, macos-latest] + os: [ubuntu, windows, macos] + sanitizer: [no_sanitizers, asan_ubsan] + exceptions: [no_exceptions, with_exceptions] + compiler: [clang, gcc, msvc] + exclude: + # No sanitizers on Windows + - os: windows + sanitizer: asan_ubsan + # No sanitizers on macOS -- might not be well supported on arm64 + - os: macos + sanitizer: asan_ubsan + # No clang or gcc on Windows + - os: windows + compiler: clang + - os: windows + compiler: gcc + # No MSVC on Unix + - os: ubuntu + compiler: msvc + - os: macos + compiler: msvc + # No GCC on macOS + - os: macos + compiler: gcc steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 @@ -51,11 +74,48 @@ jobs: run: | mkdir build cd build - cmake .. - if [ "$RUNNER_OS" == "Windows" ]; then - cmake --build ./ --config Release + cmake_cmd=( cmake -S .. -B . ) + if [[ "${RUNNER_OS}" != "Windows" ]]; then + c_compiler="${{matrix.compiler}}" + if [[ "${c_compiler}" == "gcc" ]]; then + cxx_compiler=g++ + elif [[ "${c_compiler}" == "clang" ]]; then + cxx_compiler=clang++ + else + echo "Invalid compiler ${c_compiler} for OS ${RUNNER_OS}" >&2 + exit 1 + fi + cmake_cmd+=( + -DCMAKE_BUILD_TYPE=RelWithDebInfo + -DCMAKE_C_COMPILER=${c_compiler} + -DCMAKE_CXX_COMPILER=${cxx_compiler} + ) + fi + if [[ "${{ matrix.sanitizer }}" == "asan_ubsan" ]]; then + cmake_cmd+=( -DENABLE_ASAN=ON -DENABLE_UBSAN=ON ) + fi + if [[ "${{ matrix.exceptions }}" == "with_exceptions" ]]; then + cmake_cmd+=( -DHNSWLIB_ENABLE_EXCEPTIONS=ON ) + else + cmake_cmd+=( -DHNSWLIB_ENABLE_EXCEPTIONS=OFF ) + fi + if [[ "${RUNNER_OS}" != "Windows" ]]; then + cmake_cmd+=( -G Ninja ) + fi + "${cmake_cmd[@]}" + + # This is essential for debugging the the build, e.g. tracking down + # if exceptions are enabled or NDEBUG is specified when it should not + # be. + echo + echo "Contents of CMakeCache.txt:" + echo + cat CMakeCache.txt + echo + if [[ "${RUNNER_OS}" == "Windows" ]]; then + cmake --build ./ --config RelWithDebInfo --verbose else - make + ninja -v fi shell: bash @@ -67,26 +127,16 @@ jobs: shell: bash - name: Test - timeout-minutes: 15 + # Without sanitizers, 15 minutes might be sufficient. + timeout-minutes: 30 run: | cd build - if [ "$RUNNER_OS" == "Windows" ]; then - cp ./Release/* ./ + if [[ "${RUNNER_OS}" == "Windows" ]]; then + cp "./RelWithDebInfo/"* ./ fi - ./example_search - ./example_filter - ./example_replace_deleted - ./example_mt_search - ./example_mt_filter - ./example_mt_replace_deleted - ./example_multivector_search - ./example_epsilon_search - ./searchKnnCloserFirst_test - ./searchKnnWithFilter_test - ./multiThreadLoad_test - ./multiThread_replace_test + ctest --build-config "RelWithDebInfo" + # These tests could be ctest-enabled in CMakeLists.txt, but that + # requires auto-generating test data and installing numpy for that. ./test_updates ./test_updates update - ./multivector_search_test - ./epsilon_search_test shell: bash diff --git a/.gitignore b/.gitignore index d46c9890..9ca7d189 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,5 @@ var/ .vs/ **.DS_Store *.pyc +venv/ +tests/cpp/data/ diff --git a/CMakeLists.txt b/CMakeLists.txt index be0d40f0..a34a67f2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,6 +6,95 @@ project(hnswlib include(GNUInstallDirs) include(CheckCXXCompilerFlag) +# These example/test targets catch exceptions, so exceptions should always be +# enabled building these files even if they are disabled in other targets. +# We check that each target included in this list is a real target. +set(HNSWLIB_TARGETS_REQUIRING_EXCEPTIONS + example_mt_filter + example_mt_replace_deleted + example_mt_search + multiThread_replace_test + test_updates) + +set(EXAMPLE_NAMES + example_epsilon_search + example_filter + example_mt_filter + example_mt_replace_deleted + example_mt_search + example_multivector_search + example_replace_deleted + example_search + ) + +set(TEST_NAMES + epsilon_search_test + multiThread_replace_test + multiThreadLoad_test + multivector_search_test + resize_test + searchKnnCloserFirst_test + searchKnnWithFilter_test + ) + +function(add_cxx_flags) + foreach(flag IN LISTS ARGN) + string(APPEND CMAKE_CXX_FLAGS " ${flag}") + endforeach() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" PARENT_SCOPE) +endfunction() + +# Adds an example or test target. The target name parameter is followed by +# the list of source files. Automatically links with the hnswlib library. +# Also decides whether to enable exceptions when building the target. +# If HNSWLIB_ENABLE_EXCEPTIONS is ON, exceptions are always enabled. +# If HNSWLIB_ENABLE_EXCEPTIONS is OFF, exceptions are only enabled for the +# specific targets listed in HNSWLIB_TARGETS_REQUIRING_EXCEPTIONS. +function(add_example_or_test TARGET_NAME ...) + add_executable(${ARGV}) + target_link_libraries(${TARGET_NAME} hnswlib) + list(FIND HNSWLIB_TARGETS_REQUIRING_EXCEPTIONS "${TARGET_NAME}" found_at_index) + if(found_at_index GREATER -1) + if(NOT HNSWLIB_ENABLE_EXCEPTIONS) + message("Enabling exceptions for target ${TARGET_NAME} as a special case") + endif() + set(should_enable_exceptions ON) + else() + set(should_enable_exceptions "${HNSWLIB_ENABLE_EXCEPTIONS}") + endif() + if(should_enable_exceptions) + target_compile_options("${TARGET_NAME}" PUBLIC ${ENABLE_EXCEPTIONS_FLAGS}) + else() + target_compile_options("${TARGET_NAME}" PUBLIC ${DISABLE_EXCEPTIONS_FLAGS}) + endif() + if(NOT ${TARGET_NAME} MATCHES "^(main|test_updates)$") + # test_updates is not included here as a ctest-enabled test because it + # requires generating test data using update_gen_data.py, which requires + # installing numpy, which should probably be done in a virtual + # environment. Also test_updates needs to be invoked twice: without + # arguments, and with one "update" argument. This is currently handled + # in the GitHub Actions build.yml file. + add_test( + NAME ${TARGET_NAME} + COMMAND ${TARGET_NAME} + ) + endif() +endfunction() + + +option(HNSWLIB_ENABLE_EXCEPTIONS "Whether to enable exceptions in hnswlib" ON) +if(HNSWLIB_ENABLE_EXCEPTIONS) + message("Exceptions are enabled using HNSWLIB_ENABLE_EXCEPTIONS=ON (default)") +else() + message("Exceptions are disabled using HNSWLIB_ENABLE_EXCEPTIONS=OFF") +endif() +option(ENABLE_ASAN "Whether to enable AddressSanitizer" OFF) +option(ENABLE_UBSAN "Whether to enable UndefinedBehaviorSanitizer" OFF) +option(ENABLE_TSAN "Whether to enable ThreadSanitizer" OFF) +option(ENABLE_MSAN "Whether to enable MemorySanitizer" OFF) + +set(CMAKE_CXX_STANDARD 11) + add_library(hnswlib INTERFACE) add_library(hnswlib::hnswlib ALIAS hnswlib) @@ -31,75 +120,136 @@ if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME) else() option(HNSWLIB_EXAMPLES "Build examples and tests." OFF) endif() -if(HNSWLIB_EXAMPLES) - set(CMAKE_CXX_STANDARD 11) +if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") + set(ENABLE_EXCEPTIONS_FLAGS /EHsc) + set(DISABLE_EXCEPTIONS_FLAGS /GR- /D_HAS_EXCEPTIONS=0) +else() + set(ENABLE_EXCEPTIONS_FLAGS -fexceptions) + set(DISABLE_EXCEPTIONS_FLAGS -fno-exceptions) +endif() + +# Turn on assertions in the RelWithDebInfo build type. +foreach(NDEBUG_FLAG_STR IN ITEMS "/DNDEBUG" "/D NDEBUG" "-DNDEBUG") + string(REPLACE "${NDEBUG_FLAG_STR}" "" CMAKE_CXX_FLAGS_RELWITHDEBINFO + "${CMAKE_CXX_FLAGS_RELWITHDEBINFO}") +endforeach() +string(STRIP "${CMAKE_CXX_FLAGS_RELWITHDEBINFO}" CMAKE_CXX_FLAGS_RELWITHDEBINFO) +set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO}" + CACHE STRING "Flags for RelWithDebInfo configuration." FORCE) + +# Start with an empty value of CMAKE_CXX_FLAGS, not the value from the cache. +# It will not override any "default" flags -- those will come from +# per-build-type variables (CMAKE_CXX_FLAGS_${CMAKE_BUILD_TYPE}). +set(CMAKE_CXX_FLAGS "") + +if(HNSWLIB_EXAMPLES) + message("Building examples and tests") + message("System architecture: ${CMAKE_HOST_SYSTEM_PROCESSOR}") + enable_testing() if (CMAKE_CXX_COMPILER_ID MATCHES "Clang") - SET( CMAKE_CXX_FLAGS "-Ofast -std=c++11 -DHAVE_CXX0X -openmp -fpic -ftree-vectorize" ) - check_cxx_compiler_flag("-march=native" COMPILER_SUPPORT_NATIVE_FLAG) - if(COMPILER_SUPPORT_NATIVE_FLAG) - SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native" ) - message("set -march=native flag") - else() - check_cxx_compiler_flag("-mcpu=apple-m1" COMPILER_SUPPORT_M1_FLAG) - if(COMPILER_SUPPORT_M1_FLAG) - SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mcpu=apple-m1" ) - message("set -mcpu=apple-m1 flag") + # We enable optimizations in all build types, even in Debug. + add_cxx_flags(-O3 -ffast-math -openmp -ftree-vectorize) + check_cxx_compiler_flag("-march=native" COMPILER_SUPPORT_NATIVE_FLAG) + if(COMPILER_SUPPORT_NATIVE_FLAG) + add_cxx_flags(-march=native) + message("set -march=native flag") + else() + check_cxx_compiler_flag("-mcpu=apple-m1" COMPILER_SUPPORT_M1_FLAG) + if(COMPILER_SUPPORT_M1_FLAG) + add_cxx_flags(-mcpu=apple-m1) + message("set -mcpu=apple-m1 flag") + endif() endif() - endif() elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - SET( CMAKE_CXX_FLAGS "-Ofast -lrt -std=c++11 -DHAVE_CXX0X -march=native -fpic -w -fopenmp -ftree-vectorize -ftree-vectorizer-verbose=0" ) + # We enable optimizations in all build types, even in Debug. + add_cxx_flags( + -Ofast -lrt -march=native -w -fopenmp -ftree-vectorize + -ftree-vectorizer-verbose=0) elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") - SET( CMAKE_CXX_FLAGS "/O2 -DHAVE_CXX0X /W1 /openmp /EHsc" ) + if (NOT HNSWLIB_ENABLE_EXCEPTIONS) + # Do not enable exceptions by default. We will enable them on a + # case by case basis when needed. + foreach(config IN ITEMS Debug Release RelWithDebInfo MinSizeRel) + string(TOUPPER ${config} config_upper) + set(FLAGS_VAR "CMAKE_CXX_FLAGS_${config_upper}") + string(REPLACE "/EHsc" "" ${FLAGS_VAR} "${${FLAGS_VAR}}") + set(${FLAGS_VAR} "${${FLAGS_VAR}}" CACHE STRING + "Flags for ${config} configuration." FORCE) + endforeach() + endif() + add_cxx_flags(/O2 /W1 /openmp) endif() + add_cxx_flags(-DHAVE_CXX0X) - # examples - add_executable(example_search examples/cpp/example_search.cpp) - target_link_libraries(example_search hnswlib) - - add_executable(example_epsilon_search examples/cpp/example_epsilon_search.cpp) - target_link_libraries(example_epsilon_search hnswlib) - - add_executable(example_multivector_search examples/cpp/example_multivector_search.cpp) - target_link_libraries(example_multivector_search hnswlib) - - add_executable(example_filter examples/cpp/example_filter.cpp) - target_link_libraries(example_filter hnswlib) - - add_executable(example_replace_deleted examples/cpp/example_replace_deleted.cpp) - target_link_libraries(example_replace_deleted hnswlib) - - add_executable(example_mt_search examples/cpp/example_mt_search.cpp) - target_link_libraries(example_mt_search hnswlib) + if(NOT CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") + # Clang/GCC specific flags. + add_cxx_flags(-fpic) + if(ENABLE_ASAN) + add_cxx_flags(-fsanitize=address) + endif() + if(ENABLE_UBSAN) + add_cxx_flags(-fsanitize=undefined) + endif() + if(ENABLE_TSAN) + add_cxx_flags(-fsanitize=thread) + endif() + if(ENABLE_MSAN) + add_cxx_flags(-fsanitize=memory) + endif() + if(ENABLE_ASAN OR ENABLE_UBSAN) + add_cxx_flags(-DHNSWLIB_USE_PREFETCH=0) + endif() - add_executable(example_mt_filter examples/cpp/example_mt_filter.cpp) - target_link_libraries(example_mt_filter hnswlib) + add_cxx_flags(-Wall -Wextra -Wpedantic -Werror) - add_executable(example_mt_replace_deleted examples/cpp/example_mt_replace_deleted.cpp) - target_link_libraries(example_mt_replace_deleted hnswlib) + # Unused functions in header files might still be used by other code + # including those header files. + add_cxx_flags(-Wno-unused-function) - # tests - add_executable(multivector_search_test tests/cpp/multivector_search_test.cpp) - target_link_libraries(multivector_search_test hnswlib) + # Unused parameters are OK. + add_cxx_flags(-Wno-unused-parameter) - add_executable(epsilon_search_test tests/cpp/epsilon_search_test.cpp) - target_link_libraries(epsilon_search_test hnswlib) + # TODO: re-enable and fix comparisons of integers of different + # signedness. Not using -Wno-error=sign-compare here, because that will + # produce a lot of warnings. + add_cxx_flags(-Wno-sign-compare) - add_executable(test_updates tests/cpp/updates_test.cpp) - target_link_libraries(test_updates hnswlib) + if(CMAKE_BUILD_TYPE MATCHES "^(Release|MinSizeRel)$") + # For build types that disable assertions some variables might look + # like they are not being used. RelWithDebugInfo is not included in + # this list because we specifically enable assertions for that + # build type. + add_cxx_flags(-Wno-unused-variable) + add_cxx_flags(-Wno-unused-but-set-variable) + endif() + endif() - add_executable(searchKnnCloserFirst_test tests/cpp/searchKnnCloserFirst_test.cpp) - target_link_libraries(searchKnnCloserFirst_test hnswlib) + foreach(example_name IN LISTS EXAMPLE_NAMES) + add_example_or_test("${example_name}" "examples/cpp/${example_name}.cpp") + endforeach() - add_executable(searchKnnWithFilter_test tests/cpp/searchKnnWithFilter_test.cpp) - target_link_libraries(searchKnnWithFilter_test hnswlib) + foreach(test_name IN LISTS TEST_NAMES) + add_example_or_test("${test_name}" "tests/cpp/${test_name}.cpp") + endforeach() - add_executable(multiThreadLoad_test tests/cpp/multiThreadLoad_test.cpp) - target_link_libraries(multiThreadLoad_test hnswlib) + # This test deviates from the above pattern of naming test executables. + add_example_or_test(test_updates tests/cpp/updates_test.cpp) - add_executable(multiThread_replace_test tests/cpp/multiThread_replace_test.cpp) - target_link_libraries(multiThread_replace_test hnswlib) + # For historical reasons, the "main" program links with sift_1b.cpp. + add_example_or_test(main tests/cpp/main.cpp tests/cpp/sift_1b.cpp) - add_executable(main tests/cpp/main.cpp tests/cpp/sift_1b.cpp) - target_link_libraries(main hnswlib) + foreach(target_name IN LISTS HNSWLIB_TARGETS_REQUIRING_EXCEPTIONS) + if(NOT TARGET ${target_name}) + message(FATAL_ERROR + "Target '${target_name}' included in " + "HNSWLIB_TARGETS_REQUIRING_EXCEPTIONS does not exist. " + "Please check if this is a typo.") + endif() + endforeach() endif() + +# Persist CMAKE_CXX_FLAGS in the cache for debuggability. +string(STRIP "${CMAKE_CXX_FLAGS}" CMAKE_CXX_FLAGS) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" + CACHE STRING "Flags used by the CXX compiler during all build types." FORCE) diff --git a/examples/cpp/example_epsilon_search.cpp b/examples/cpp/example_epsilon_search.cpp index 49eec408..4d1d5347 100644 --- a/examples/cpp/example_epsilon_search.cpp +++ b/examples/cpp/example_epsilon_search.cpp @@ -53,7 +53,7 @@ int main() { } std::cout << "Query #" << i << "\n"; hnswlib::EpsilonSearchStopCondition stop_condition(epsilon2, min_num_candidates, max_elements); - std::vector> result = + std::vector> result = alg_hnsw->searchStopConditionClosest(query_data, stop_condition); size_t num_vectors = result.size(); std::cout << "Found " << num_vectors << " vectors\n"; diff --git a/examples/cpp/example_mt_replace_deleted.cpp b/examples/cpp/example_mt_replace_deleted.cpp index 40a94ce7..a4766ee8 100644 --- a/examples/cpp/example_mt_replace_deleted.cpp +++ b/examples/cpp/example_mt_replace_deleted.cpp @@ -69,7 +69,7 @@ int main() { int num_threads = 20; // Number of threads for operations with index // Initing index with allow_replace_deleted=true - int seed = 100; + int seed = 100; hnswlib::L2Space space(dim); hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction, seed, true); diff --git a/examples/cpp/example_multivector_search.cpp b/examples/cpp/example_multivector_search.cpp index 06aafe0b..94f9e540 100644 --- a/examples/cpp/example_multivector_search.cpp +++ b/examples/cpp/example_multivector_search.cpp @@ -63,7 +63,7 @@ int main() { } std::cout << "Query #" << i << "\n"; hnswlib::MultiVectorSearchStopCondition stop_condition(space, num_docs, ef_collection); - std::vector> result = + std::vector> result = alg_hnsw->searchStopConditionClosest(query_data, stop_condition); size_t num_vectors = result.size(); diff --git a/examples/cpp/example_replace_deleted.cpp b/examples/cpp/example_replace_deleted.cpp index 64c995bb..42974383 100644 --- a/examples/cpp/example_replace_deleted.cpp +++ b/examples/cpp/example_replace_deleted.cpp @@ -9,7 +9,7 @@ int main() { int ef_construction = 200; // Controls index search speed/build speed tradeoff // Initing index with allow_replace_deleted=true - int seed = 100; + int seed = 100; hnswlib::L2Space space(dim); hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction, seed, true); diff --git a/hnswlib/bruteforce.h b/hnswlib/bruteforce.h index 371847ad..cff0a67d 100644 --- a/hnswlib/bruteforce.h +++ b/hnswlib/bruteforce.h @@ -1,4 +1,7 @@ #pragma once + +#include "hnswlib.h" + #include #include #include @@ -51,7 +54,7 @@ class BruteforceSearch : public AlgorithmInterface { size_per_element_ = data_size_ + sizeof(labeltype); data_ = (char *) malloc(maxElements * size_per_element_); if (data_ == nullptr) - throw std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data"); + HNSWLIB_THROW_RUNTIME_ERROR("Not enough memory: BruteforceSearch failed to allocate data"); cur_element_count = 0; } @@ -61,7 +64,7 @@ class BruteforceSearch : public AlgorithmInterface { } - void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) { + Status addPointNoExceptions(const void *datapoint, labeltype label, bool replace_deleted = false) override { int idx; { std::unique_lock lock(index_lock); @@ -71,7 +74,7 @@ class BruteforceSearch : public AlgorithmInterface { idx = search->second; } else { if (cur_element_count >= maxelements_) { - throw std::runtime_error("The number of elements exceeds the specified limit\n"); + return Status("The number of elements exceeds the specified limit"); } idx = cur_element_count; dict_external_to_internal[label] = idx; @@ -80,6 +83,7 @@ class BruteforceSearch : public AlgorithmInterface { } memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype)); memcpy(data_ + size_per_element_ * idx, datapoint, data_size_); + return OkStatus(); } @@ -103,8 +107,9 @@ class BruteforceSearch : public AlgorithmInterface { } - std::priority_queue> - searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const { + using DistanceLabelPriorityQueue = typename AlgorithmInterface::DistanceLabelPriorityQueue; + StatusOr + searchKnnNoExceptions(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const override { assert(k <= cur_element_count); std::priority_queue> topResults; dist_t lastdist = std::numeric_limits::max(); @@ -125,7 +130,7 @@ class BruteforceSearch : public AlgorithmInterface { } - void saveIndex(const std::string &location) { + Status saveIndexNoExceptions(const std::string &location) override { std::ofstream output(location, std::ios::binary); std::streampos position; @@ -136,6 +141,7 @@ class BruteforceSearch : public AlgorithmInterface { output.write(data_, maxelements_ * size_per_element_); output.close(); + return OkStatus(); } @@ -153,7 +159,7 @@ class BruteforceSearch : public AlgorithmInterface { size_per_element_ = data_size_ + sizeof(labeltype); data_ = (char *) malloc(maxelements_ * size_per_element_); if (data_ == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate data"); + HNSWLIB_THROW_RUNTIME_ERROR("Not enough memory: loadIndex failed to allocate data"); input.read(data_, maxelements_ * size_per_element_); diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index f516df59..eb8dfe20 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -124,8 +124,9 @@ class HierarchicalNSW : public AlgorithmInterface { offsetLevel0_ = 0; data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_); - if (data_level0_memory_ == nullptr) - throw std::runtime_error("Not enough memory"); + if (data_level0_memory_ == nullptr) { + HNSWLIB_THROW_RUNTIME_ERROR("Not enough memory to allocate for level 0"); + } cur_element_count = 0; @@ -136,8 +137,11 @@ class HierarchicalNSW : public AlgorithmInterface { maxlevel_ = -1; linkLists_ = (char **) malloc(sizeof(void *) * max_elements_); - if (linkLists_ == nullptr) - throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists"); + if (linkLists_ == nullptr) { + HNSWLIB_THROW_RUNTIME_ERROR( + "Not enough memory: HierarchicalNSW failed to allocate linklists"); + } + size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); mult_ = 1 / log(1.0 * M_); revSize_ = 1.0 / mult_; @@ -155,7 +159,9 @@ class HierarchicalNSW : public AlgorithmInterface { if (element_levels_[i] > 0) free(linkLists_[i]); } - free(linkLists_); + if (linkLists_) { + free(linkLists_); + } linkLists_ = nullptr; cur_element_count = 0; visited_list_pool_.reset(nullptr); @@ -264,18 +270,22 @@ class HierarchicalNSW : public AlgorithmInterface { size_t size = getListCount((linklistsizeint*)data); tableint *datal = (tableint *) (data + 1); #ifdef USE_SSE +#if HNSWLIB_USE_PREFETCH _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0); +#endif #endif for (size_t j = 0; j < size; j++) { tableint candidate_id = *(datal + j); // if (candidate_id == 0) continue; #ifdef USE_SSE +#if HNSWLIB_USE_PREFETCH _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0); _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0); +#endif #endif if (visited_array[candidate_id] == visited_array_tag) continue; visited_array[candidate_id] = visited_array_tag; @@ -285,7 +295,9 @@ class HierarchicalNSW : public AlgorithmInterface { if (top_candidates.size() < ef_construction_ || lowerBound > dist1) { candidateSet.emplace(-dist1, candidate_id); #ifdef USE_SSE +#if HNSWLIB_USE_PREFETCH _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0); +#endif #endif if (!isMarkedDeleted(candidate_id)) @@ -322,7 +334,7 @@ class HierarchicalNSW : public AlgorithmInterface { std::priority_queue, std::vector>, CompareByFirst> candidate_set; dist_t lowerBound; - if (bare_bone_search || + if (bare_bone_search || (!isMarkedDeleted(ep_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id))))) { char* ep_data = getDataByInternalId(ep_id); dist_t dist = fstdistfunc_(data_point, ep_data, dist_func_param_); @@ -368,19 +380,23 @@ class HierarchicalNSW : public AlgorithmInterface { } #ifdef USE_SSE +#if HNSWLIB_USE_PREFETCH _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0); _mm_prefetch((char *) (data + 2), _MM_HINT_T0); +#endif #endif for (size_t j = 1; j <= size; j++) { int candidate_id = *(data + j); // if (candidate_id == 0) continue; #ifdef USE_SSE +#if HNSWLIB_USE_PREFETCH _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0); _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0); //////////// +#endif #endif if (!(visited_array[candidate_id] == visited_array_tag)) { visited_array[candidate_id] = visited_array_tag; @@ -398,12 +414,14 @@ class HierarchicalNSW : public AlgorithmInterface { if (flag_consider_candidate) { candidate_set.emplace(-dist, candidate_id); #ifdef USE_SSE +#if HNSWLIB_USE_PREFETCH _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ + offsetLevel0_, /////////// _MM_HINT_T0); //////////////////////// +#endif #endif - if (bare_bone_search || + if (bare_bone_search || (!isMarkedDeleted(candidate_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))) { top_candidates.emplace(dist, candidate_id); if (!bare_bone_search && stop_condition) { @@ -503,7 +521,7 @@ class HierarchicalNSW : public AlgorithmInterface { } - tableint mutuallyConnectNewElement( + StatusOr mutuallyConnectNewElement( const void *data_point, tableint cur_c, std::priority_queue, std::vector>, CompareByFirst> &top_candidates, @@ -512,7 +530,7 @@ class HierarchicalNSW : public AlgorithmInterface { size_t Mcurmax = level ? maxM_ : maxM0_; getNeighborsByHeuristic2(top_candidates, M_); if (top_candidates.size() > M_) - throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic"); + return Status("Should be not be more than M_ candidates returned by the heuristic"); std::vector selectedNeighbors; selectedNeighbors.reserve(M_); @@ -537,15 +555,15 @@ class HierarchicalNSW : public AlgorithmInterface { ll_cur = get_linklist(cur_c, level); if (*ll_cur && !isUpdate) { - throw std::runtime_error("The newly inserted element should have blank link list"); + return Status("The newly inserted element should have blank link list"); } setListCount(ll_cur, selectedNeighbors.size()); tableint *data = (tableint *) (ll_cur + 1); for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { if (data[idx] && !isUpdate) - throw std::runtime_error("Possible memory corruption"); + return Status("Possible memory corruption"); if (level > element_levels_[selectedNeighbors[idx]]) - throw std::runtime_error("Trying to make a link on a non-existent level"); + return Status("Trying to make a link on a non-existent level"); data[idx] = selectedNeighbors[idx]; } @@ -563,11 +581,11 @@ class HierarchicalNSW : public AlgorithmInterface { size_t sz_link_list_other = getListCount(ll_other); if (sz_link_list_other > Mcurmax) - throw std::runtime_error("Bad value of sz_link_list_other"); + return Status("Bad value of sz_link_list_other"); if (selectedNeighbors[idx] == cur_c) - throw std::runtime_error("Trying to connect an element to itself"); + return Status("Trying to connect an element to itself"); if (level > element_levels_[selectedNeighbors[idx]]) - throw std::runtime_error("Trying to make a link on a non-existent level"); + return Status("Trying to make a link on a non-existent level"); tableint *data = (tableint *) (ll_other + 1); @@ -629,10 +647,16 @@ class HierarchicalNSW : public AlgorithmInterface { return next_closest_entry_point; } - void resizeIndex(size_t new_max_elements) { + Status status = resizeIndexNoExceptions(new_max_elements); + if (!status.ok()) { + HNSWLIB_THROW_RUNTIME_ERROR(status.message()); + } + } + + Status resizeIndexNoExceptions(size_t new_max_elements) { if (new_max_elements < cur_element_count) - throw std::runtime_error("Cannot resize, max element is less than the current number of elements"); + return Status("Cannot resize, max element is less than the current number of elements"); visited_list_pool_.reset(new VisitedListPool(1, new_max_elements)); @@ -643,16 +667,17 @@ class HierarchicalNSW : public AlgorithmInterface { // Reallocate base layer char * data_level0_memory_new = (char *) realloc(data_level0_memory_, new_max_elements * size_data_per_element_); if (data_level0_memory_new == nullptr) - throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer"); + return Status("Not enough memory: resizeIndex failed to allocate base layer"); data_level0_memory_ = data_level0_memory_new; // Reallocate all other layers char ** linkLists_new = (char **) realloc(linkLists_, sizeof(void *) * new_max_elements); if (linkLists_new == nullptr) - throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers"); + return Status("Not enough memory: resizeIndex failed to allocate other layers"); linkLists_ = linkLists_new; max_elements_ = new_max_elements; + return OkStatus(); } size_t indexFileSize() const { @@ -682,7 +707,8 @@ class HierarchicalNSW : public AlgorithmInterface { return size; } - void saveIndex(const std::string &location) { + Status saveIndexNoExceptions(const std::string &location) override { + std::ofstream output(location, std::ios::binary); writeBinaryPOD(output, offsetLevel0_); @@ -709,14 +735,22 @@ class HierarchicalNSW : public AlgorithmInterface { output.write(linkLists_[i], linkListSize); } output.close(); + return OkStatus(); } void loadIndex(const std::string &location, SpaceInterface *s, size_t max_elements_i = 0) { + Status status = loadIndexNoExceptions(location, s, max_elements_i); + if (!status.ok()) { + HNSWLIB_THROW_RUNTIME_ERROR(status.message()); + } + } + + Status loadIndexNoExceptions(const std::string &location, SpaceInterface *s, size_t max_elements_i = 0) { std::ifstream input(location, std::ios::binary); if (!input.is_open()) - throw std::runtime_error("Cannot open file"); + return Status("Cannot open file"); clear(); // get file size: @@ -754,7 +788,7 @@ class HierarchicalNSW : public AlgorithmInterface { input.seekg(cur_element_count * size_data_per_element_, input.cur); for (size_t i = 0; i < cur_element_count; i++) { if (input.tellg() < 0 || input.tellg() >= total_filesize) { - throw std::runtime_error("Index seems to be corrupted or unsupported"); + return Status("Index seems to be corrupted or unsupported"); } unsigned int linkListSize; @@ -766,7 +800,7 @@ class HierarchicalNSW : public AlgorithmInterface { // throw exception if it either corrupted or old index if (input.tellg() != total_filesize) - throw std::runtime_error("Index seems to be corrupted or unsupported"); + return Status("Index seems to be corrupted or unsupported"); input.clear(); /// Optional check end @@ -775,7 +809,7 @@ class HierarchicalNSW : public AlgorithmInterface { data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_); if (data_level0_memory_ == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); + return Status("Not enough memory: loadIndex failed to allocate level0"); input.read(data_level0_memory_, cur_element_count * size_data_per_element_); size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); @@ -788,7 +822,7 @@ class HierarchicalNSW : public AlgorithmInterface { linkLists_ = (char **) malloc(sizeof(void *) * max_elements); if (linkLists_ == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists"); + return Status("Not enough memory: loadIndex failed to allocate linklists"); element_levels_ = std::vector(max_elements); revSize_ = 1.0 / mult_; ef_ = 10; @@ -803,7 +837,7 @@ class HierarchicalNSW : public AlgorithmInterface { element_levels_[i] = linkListSize / size_links_per_element_; linkLists_[i] = (char *) malloc(linkListSize); if (linkLists_[i] == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); + return Status("Not enough memory: loadIndex failed to allocate linklist"); input.read(linkLists_[i], linkListSize); } } @@ -817,19 +851,29 @@ class HierarchicalNSW : public AlgorithmInterface { input.close(); - return; + return OkStatus(); } + template + std::vector + getDataByLabel(labeltype label) const { + auto result = getDataByLabelNoExceptions(label); + if (!result.ok()) { + HNSWLIB_THROW_RUNTIME_ERROR(result.status().message()); + } + return std::move(result.value()); + } template - std::vector getDataByLabel(labeltype label) const { + StatusOr> + getDataByLabelNoExceptions(labeltype label) const { // lock all operations with element by label std::unique_lock lock_label(getLabelOpMutex(label)); - + std::unique_lock lock_table(label_lookup_lock); auto search = label_lookup_.find(label); if (search == label_lookup_.end() || isMarkedDeleted(search->second)) { - throw std::runtime_error("Label not found"); + return Status("Label not found"); } tableint internalId = search->second; lock_table.unlock(); @@ -845,88 +889,61 @@ class HierarchicalNSW : public AlgorithmInterface { return data; } + void markDelete(labeltype label) { + Status status = markDeleteNoExceptions(label); + if (!status.ok()) { + HNSWLIB_THROW_RUNTIME_ERROR(status.message()); + } + } /* * Marks an element with the given label deleted, does NOT really change the current graph. */ - void markDelete(labeltype label) { + Status markDeleteNoExceptions(labeltype label) { // lock all operations with element by label std::unique_lock lock_label(getLabelOpMutex(label)); std::unique_lock lock_table(label_lookup_lock); auto search = label_lookup_.find(label); if (search == label_lookup_.end()) { - throw std::runtime_error("Label not found"); + return Status("Label not found"); } tableint internalId = search->second; lock_table.unlock(); - markDeletedInternal(internalId); + return markDeletedInternal(internalId); } - - /* - * Uses the last 16 bits of the memory for the linked list size to store the mark, - * whereas maxM0_ has to be limited to the lower 16 bits, however, still large enough in almost all cases. - */ - void markDeletedInternal(tableint internalId) { - assert(internalId < cur_element_count); - if (!isMarkedDeleted(internalId)) { - unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; - *ll_cur |= DELETE_MARK; - num_deleted_ += 1; - if (allow_replace_deleted_) { - std::unique_lock lock_deleted_elements(deleted_elements_lock); - deleted_elements.insert(internalId); - } - } else { - throw std::runtime_error("The requested to delete element is already deleted"); + void unmarkDelete(labeltype label) { + auto status = unmarkDeleteNoExceptions(label); + if (!status.ok()) { + HNSWLIB_THROW_RUNTIME_ERROR(status.message()); } } - /* * Removes the deleted mark of the node, does NOT really change the current graph. - * + * * Note: the method is not safe to use when replacement of deleted elements is enabled, * because elements marked as deleted can be completely removed by addPoint */ - void unmarkDelete(labeltype label) { + Status unmarkDeleteNoExceptions(labeltype label) { // lock all operations with element by label std::unique_lock lock_label(getLabelOpMutex(label)); std::unique_lock lock_table(label_lookup_lock); auto search = label_lookup_.find(label); if (search == label_lookup_.end()) { - throw std::runtime_error("Label not found"); + return Status("Label not found"); } tableint internalId = search->second; lock_table.unlock(); - unmarkDeletedInternal(internalId); + return unmarkDeletedInternal(internalId); } - /* - * Remove the deleted mark of the node. - */ - void unmarkDeletedInternal(tableint internalId) { - assert(internalId < cur_element_count); - if (isMarkedDeleted(internalId)) { - unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId)) + 2; - *ll_cur &= ~DELETE_MARK; - num_deleted_ -= 1; - if (allow_replace_deleted_) { - std::unique_lock lock_deleted_elements(deleted_elements_lock); - deleted_elements.erase(internalId); - } - } else { - throw std::runtime_error("The requested to undelete element is not deleted"); - } - } - - /* * Checks the first 16 bits of the memory to see if the element is marked deleted. */ @@ -950,16 +967,21 @@ class HierarchicalNSW : public AlgorithmInterface { * Adds point. Updates the point if it is already in the index. * If replacement of deleted elements is enabled: replaces previously deleted point if any, updating it with new point */ - void addPoint(const void *data_point, labeltype label, bool replace_deleted = false) { - if ((allow_replace_deleted_ == false) && (replace_deleted == true)) { - throw std::runtime_error("Replacement of deleted elements is disabled in constructor"); + Status addPointNoExceptions( + const void *data_point, labeltype label, bool replace_deleted = false) + override { + if (!allow_replace_deleted_ && replace_deleted) { + return Status("Replacement of deleted elements is disabled in constructor"); } // lock all operations with element by label std::unique_lock lock_label(getLabelOpMutex(label)); if (!replace_deleted) { - addPoint(data_point, label, -1); - return; + auto status_or_new_point = addPointWithLevel(data_point, label, -1); + if (!status_or_new_point.ok()) { + return status_or_new_point.status(); + } + return OkStatus(); } // check if there is vacant place tableint internal_id_replaced; @@ -974,7 +996,7 @@ class HierarchicalNSW : public AlgorithmInterface { // if there is no vacant place then add or update point // else add point to vacant place if (!is_vacant_place) { - addPoint(data_point, label, -1); + addPointWithLevel(data_point, label, -1); } else { // we assume that there are no concurrent operations on deleted element labeltype label_replaced = getExternalLabel(internal_id_replaced); @@ -985,13 +1007,21 @@ class HierarchicalNSW : public AlgorithmInterface { label_lookup_[label] = internal_id_replaced; lock_table.unlock(); - unmarkDeletedInternal(internal_id_replaced); - updatePoint(data_point, internal_id_replaced, 1.0); + Status delete_status = unmarkDeletedInternal(internal_id_replaced); + if (!delete_status.ok()) { + return delete_status; + } + Status update_status = + updatePoint(data_point, internal_id_replaced, 1.0); + if (!update_status.ok()) { + return update_status; + } } + return OkStatus(); } - void updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability) { + Status updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability) { // update the feature vector associated with existing point with new vector memcpy(getDataByInternalId(internalId), dataPoint, data_size_); @@ -999,7 +1029,7 @@ class HierarchicalNSW : public AlgorithmInterface { tableint entryPointCopy = enterpoint_node_; // If point to be updated is entry point and graph just contains single element then just return. if (entryPointCopy == internalId && cur_element_count == 1) - return; + return OkStatus(); int elemLevel = element_levels_[internalId]; std::uniform_real_distribution distribution(0.0, 1.0); @@ -1066,11 +1096,11 @@ class HierarchicalNSW : public AlgorithmInterface { } } - repairConnectionsForUpdate(dataPoint, entryPointCopy, internalId, elemLevel, maxLevelCopy); + return repairConnectionsForUpdate(dataPoint, entryPointCopy, internalId, elemLevel, maxLevelCopy); } - void repairConnectionsForUpdate( + Status repairConnectionsForUpdate( const void *dataPoint, tableint entryPointInternalId, tableint dataPointInternalId, @@ -1089,11 +1119,15 @@ class HierarchicalNSW : public AlgorithmInterface { int size = getListCount(data); tableint *datal = (tableint *) (data + 1); #ifdef USE_SSE +#if HNSWLIB_USE_PREFETCH _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); +#endif #endif for (int i = 0; i < size; i++) { #ifdef USE_SSE +#if HNSWLIB_USE_PREFETCH _mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0); +#endif #endif tableint cand = datal[i]; dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_); @@ -1108,7 +1142,7 @@ class HierarchicalNSW : public AlgorithmInterface { } if (dataPointLevel > maxLevel) - throw std::runtime_error("Level of item to be updated cannot be bigger than max level"); + return Status("Level of item to be updated cannot be bigger than max level"); for (int level = dataPointLevel; level >= 0; level--) { std::priority_queue, std::vector>, CompareByFirst> topCandidates = searchBaseLayer( @@ -1132,9 +1166,14 @@ class HierarchicalNSW : public AlgorithmInterface { filteredTopCandidates.pop(); } - currObj = mutuallyConnectNewElement(dataPoint, dataPointInternalId, filteredTopCandidates, level, true); + auto statusOrCurObj = mutuallyConnectNewElement(dataPoint, dataPointInternalId, filteredTopCandidates, level, true); + if (!statusOrCurObj.ok()) { + return statusOrCurObj.status(); + } + currObj = statusOrCurObj.value(); } } + return OkStatus(); } @@ -1144,12 +1183,15 @@ class HierarchicalNSW : public AlgorithmInterface { int size = getListCount(data); std::vector result(size); tableint *ll = (tableint *) (data + 1); - memcpy(result.data(), ll, size * sizeof(tableint)); + if (size > 0) { + memcpy(result.data(), ll, size * sizeof(tableint)); + } return result; } - tableint addPoint(const void *data_point, labeltype label, int level) { + // This internal function adds a point at a specific level. If level is + StatusOr addPointWithLevel(const void *data_point, labeltype label, int level) { tableint cur_c = 0; { // Checking if the element with the same label already exists @@ -1160,21 +1202,32 @@ class HierarchicalNSW : public AlgorithmInterface { tableint existingInternalId = search->second; if (allow_replace_deleted_) { if (isMarkedDeleted(existingInternalId)) { - throw std::runtime_error("Can't use addPoint to update deleted elements if replacement of deleted elements is enabled."); + return Status( + "Can't use addPoint to update deleted elements if " + "replacement of deleted elements is enabled."); } } lock_table.unlock(); if (isMarkedDeleted(existingInternalId)) { - unmarkDeletedInternal(existingInternalId); + Status delete_status = + unmarkDeletedInternal(existingInternalId); + if (!delete_status.ok()) { + return delete_status; + } + } + Status update_status = + updatePoint(data_point, existingInternalId, 1.0); + if (!update_status.ok()) { + return update_status; } - updatePoint(data_point, existingInternalId, 1.0); return existingInternalId; } if (cur_element_count >= max_elements_) { - throw std::runtime_error("The number of elements exceeds the specified limit"); + return Status( + "The number of elements exceeds the specified limit"); } cur_c = cur_element_count; @@ -1204,8 +1257,9 @@ class HierarchicalNSW : public AlgorithmInterface { if (curlevel) { linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1); - if (linkLists_[cur_c] == nullptr) - throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist"); + if (linkLists_[cur_c] == nullptr) { + return Status("Not enough memory: addPoint failed to allocate linklist"); + } memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1); } @@ -1225,7 +1279,7 @@ class HierarchicalNSW : public AlgorithmInterface { for (int i = 0; i < size; i++) { tableint cand = datal[i]; if (cand < 0 || cand > max_elements_) - throw std::runtime_error("cand error"); + return Status("cand error"); dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_); if (d < curdist) { curdist = d; @@ -1240,7 +1294,7 @@ class HierarchicalNSW : public AlgorithmInterface { bool epDeleted = isMarkedDeleted(enterpoint_copy); for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) { if (level > maxlevelcopy || level < 0) // possible? - throw std::runtime_error("Level error"); + return Status("Level error"); std::priority_queue, std::vector>, CompareByFirst> top_candidates = searchBaseLayer( currObj, data_point, level); @@ -1249,7 +1303,13 @@ class HierarchicalNSW : public AlgorithmInterface { if (top_candidates.size() > ef_construction_) top_candidates.pop(); } - currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false); + + auto statusOrCurrObj = + mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false); + if (!statusOrCurrObj.ok()) { + return statusOrCurrObj.status(); + } + currObj = statusOrCurrObj.value(); } } else { // Do nothing for the first element @@ -1265,9 +1325,13 @@ class HierarchicalNSW : public AlgorithmInterface { return cur_c; } + using DistanceLabelPriorityQueue = typename AlgorithmInterface::DistanceLabelPriorityQueue; - std::priority_queue> - searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const { + virtual StatusOr + searchKnnNoExceptions( + const void *query_data, + size_t k, + BaseFilterFunctor* isIdAllowed = nullptr) const override { std::priority_queue> result; if (cur_element_count == 0) return result; @@ -1289,7 +1353,7 @@ class HierarchicalNSW : public AlgorithmInterface { for (int i = 0; i < size; i++) { tableint cand = datal[i]; if (cand < 0 || cand > max_elements_) - throw std::runtime_error("cand error"); + return Status("cand error"); dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); if (d < curdist) { @@ -1322,12 +1386,25 @@ class HierarchicalNSW : public AlgorithmInterface { return result; } + using DistanceLabelVector = typename AlgorithmInterface::DistanceLabelVector; + + DistanceLabelVector searchStopConditionClosest( + const void *query_data, + BaseSearchStopCondition& stop_condition, + BaseFilterFunctor* isIdAllowed = nullptr) const { + auto result = searchStopConditionClosestNoExceptions( + query_data, stop_condition, isIdAllowed); + if (!result.ok()) { + HNSWLIB_THROW_RUNTIME_ERROR(result.status().message()); + } + return std::move(result.value()); + } - std::vector> - searchStopConditionClosest( - const void *query_data, - BaseSearchStopCondition& stop_condition, - BaseFilterFunctor* isIdAllowed = nullptr) const { + StatusOr + searchStopConditionClosestNoExceptions( + const void *query_data, + BaseSearchStopCondition& stop_condition, + BaseFilterFunctor* isIdAllowed = nullptr) const { std::vector> result; if (cur_element_count == 0) return result; @@ -1349,7 +1426,7 @@ class HierarchicalNSW : public AlgorithmInterface { for (int i = 0; i < size; i++) { tableint cand = datal[i]; if (cand < 0 || cand > max_elements_) - throw std::runtime_error("cand error"); + return Status("cand error"); dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); if (d < curdist) { @@ -1407,5 +1484,48 @@ class HierarchicalNSW : public AlgorithmInterface { } std::cout << "integrity ok, checked " << connections_checked << " connections\n"; } + +private: + + /* + * Uses the last 16 bits of the memory for the linked list size to store the mark, + * whereas maxM0_ has to be limited to the lower 16 bits, however, still large enough in almost all cases. + */ + Status markDeletedInternal(tableint internalId) { + assert(internalId < cur_element_count); + if (!isMarkedDeleted(internalId)) { + unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; + *ll_cur |= DELETE_MARK; + num_deleted_ += 1; + if (allow_replace_deleted_) { + std::unique_lock lock_deleted_elements(deleted_elements_lock); + deleted_elements.insert(internalId); + } + } else { + return Status("The requested to delete element is already deleted"); + } + return OkStatus(); + } + + /* + * Remove the deleted mark of the node. + */ + Status unmarkDeletedInternal(tableint internalId) { + assert(internalId < cur_element_count); + if (isMarkedDeleted(internalId)) { + unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId)) + 2; + *ll_cur &= ~DELETE_MARK; + num_deleted_ -= 1; + if (allow_replace_deleted_) { + std::unique_lock lock_deleted_elements(deleted_elements_lock); + deleted_elements.erase(internalId); + } + } else { + return Status("The requested to undelete element is not deleted"); + } + return OkStatus(); + } + }; + } // namespace hnswlib diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index 7ccfbba5..a605ca5d 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -1,5 +1,11 @@ #pragma once +// We use prefetch instructions by default, but we allow their use to be +// disabled by setting HNSWLIB_USE_PREFETCH to 0. +#ifndef HNSWLIB_USE_PREFETCH +#define HNSWLIB_USE_PREFETCH 1 +#endif + // https://github.com/nmslib/hnswlib/pull/508 // This allows others to provide their own error stream (e.g. RcppHNSW) #ifndef HNSWLIB_ERR_OVERRIDE @@ -119,9 +125,88 @@ static bool AVX512Capable() { #include #include #include +#include #include +#include + +#if defined(__EXCEPTIONS) || _HAS_EXCEPTIONS == 1 +#define HNSWLIB_THROW_RUNTIME_ERROR(message) throw std::runtime_error(message) +#else +#define HNSWLIB_THROW_RUNTIME_ERROR(message) do { \ + fprintf(stderr, \ + "FATAL: hnswlib compiled without exception support. " \ + "Use ...NoExceptions functions. " \ + "Exception message: %s", \ + (message)); \ + abort(); \ +} while (false) +#endif + +#if __cplusplus >= 201703L +#define HNSWLIB_NODISCARD [[nodiscard]] +#else +#define HNSWLIB_NODISCARD +#endif namespace hnswlib { + +// A lightweight Status class inspired by Abseil's Status class. +class HNSWLIB_NODISCARD Status { +public: + Status() : message_(nullptr) {} + + // Constructor with an error message (nullptr is interpreted as OK status). + Status(const char* message) : message_(message) {} + + // Returns true if the status is OK. + bool ok() const { return !message_; } + + // Returns the error message, or nullptr if OK. + const char* message() const { return message_; } + +private: + // nullptr if OK, a message otherwise. + const char* message_; +}; + +Status OkStatus() { return Status(); } + +template +class StatusOr { +public: + // Default constructor + StatusOr() : status_(), value_() {} + + // Constructor with a value + StatusOr(T value) : status_(), value_(value) {} + + // Constructor with an error status + StatusOr(const char* error) : status_(error), value_() {} + StatusOr(Status status) : status_(status), value_() {} + + // Returns true if the status is OK. + bool ok() const { return status_.ok(); } + + // Returns the value if the status is OK, undefined behavior otherwise. + T&& value() { + return std::move(value_); + } + + const T& value() const { + return value_; + } + + T operator*() const { + return value(); + } + + Status status() const { return status_; } + +private: + Status status_; + T value_; +}; + typedef size_t labeltype; // This can be extended to store state for filtering (e.g. from a std::set) @@ -186,39 +271,91 @@ class SpaceInterface { template class AlgorithmInterface { public: - virtual void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) = 0; + virtual Status addPointNoExceptions( + const void *datapoint, labeltype label, bool replace_deleted = false) = 0; + + virtual void addPoint( + const void *datapoint, labeltype label, bool replace_deleted = false) { + auto status = addPointNoExceptions(datapoint, label, replace_deleted); + if (!status.ok()) { + HNSWLIB_THROW_RUNTIME_ERROR(status.message()); + } + } - virtual std::priority_queue> - searchKnn(const void*, size_t, BaseFilterFunctor* isIdAllowed = nullptr) const = 0; + using DistanceLabelPair = std::pair; - // Return k nearest neighbor in the order of closer fist - virtual std::vector> - searchKnnCloserFirst(const void* query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const; + // A priority queue of (distance, label) pairs. The largest element at the + // top corresponds to the element furthest from the query. + using DistanceLabelPriorityQueue = std::priority_queue; - virtual void saveIndex(const std::string &location) = 0; - virtual ~AlgorithmInterface(){ + // A vector of (distance, label) pairs. + using DistanceLabelVector = std::vector; + + virtual DistanceLabelPriorityQueue searchKnn( + const void* query_data, + size_t k, + BaseFilterFunctor* isIdAllowed = nullptr) const { + auto result = searchKnnNoExceptions(query_data, k, isIdAllowed); + if (!result.ok()) { + HNSWLIB_THROW_RUNTIME_ERROR(result.status().message()); + } + return std::move(result.value()); } -}; -template -std::vector> -AlgorithmInterface::searchKnnCloserFirst(const void* query_data, size_t k, - BaseFilterFunctor* isIdAllowed) const { - std::vector> result; - - // here searchKnn returns the result in the order of further first - auto ret = searchKnn(query_data, k, isIdAllowed); - { + virtual StatusOr searchKnnNoExceptions( + const void* query_data, + size_t k, + BaseFilterFunctor* isIdAllowed = nullptr) const = 0; + + // Return k nearest neighbor in the order of closest neighbor first. + virtual DistanceLabelVector searchKnnCloserFirst( + const void* query_data, + size_t k, + BaseFilterFunctor* isIdAllowed = nullptr) { + auto result = + searchKnnCloserFirstNoExceptions(query_data, k, isIdAllowed); + if (!result.ok()) { + HNSWLIB_THROW_RUNTIME_ERROR(result.status().message()); + } + return std::move(result.value()); + } + + virtual StatusOr searchKnnCloserFirstNoExceptions( + const void* query_data, + size_t k, + BaseFilterFunctor* isIdAllowed = nullptr) const { + + // Here searchKnn returns the result in the order of further first. + auto status_or_result = searchKnnNoExceptions(query_data, k, isIdAllowed); + if (!status_or_result.ok()) { + return status_or_result.status(); + } + auto ret = std::move(status_or_result.value()); + + DistanceLabelVector final_vector; size_t sz = ret.size(); - result.resize(sz); + final_vector.resize(sz); while (!ret.empty()) { - result[--sz] = ret.top(); + final_vector[--sz] = ret.top(); ret.pop(); } + + return final_vector; } - return result; -} + virtual void saveIndex(const std::string &location) { + Status status = saveIndexNoExceptions(location); + if (!status.ok()) { + HNSWLIB_THROW_RUNTIME_ERROR(status.message()); + } + } + + virtual Status saveIndexNoExceptions(const std::string &location) = 0; + + virtual ~AlgorithmInterface(){ + } +}; + } // namespace hnswlib #include "space_l2.h" diff --git a/hnswlib/space_ip.h b/hnswlib/space_ip.h index 0e6834c1..7547c5e6 100644 --- a/hnswlib/space_ip.h +++ b/hnswlib/space_ip.h @@ -158,7 +158,7 @@ InnerProductSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void __m512 sum512 = _mm512_set1_ps(0); size_t loop = qty16 / 4; - + while (loop--) { __m512 v1 = _mm512_loadu_ps(pVect1); __m512 v2 = _mm512_loadu_ps(pVect2); diff --git a/hnswlib/stop_condition.h b/hnswlib/stop_condition.h index acc80ebe..7d8d5a3b 100644 --- a/hnswlib/stop_condition.h +++ b/hnswlib/stop_condition.h @@ -257,7 +257,7 @@ class EpsilonSearchStopCondition : public BaseSearchStopCondition { return flag_consider_candidate; } - bool should_remove_extra() { + bool should_remove_extra() override { bool flag_remove_extra = curr_num_items_ > max_num_candidates_; return flag_remove_extra; } diff --git a/python_bindings/bindings.cpp b/python_bindings/bindings.cpp index 56ce9beb..babf9741 100644 --- a/python_bindings/bindings.cpp +++ b/python_bindings/bindings.cpp @@ -75,8 +75,10 @@ inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn inline void assert_true(bool expr, const std::string & msg) { - if (expr == false) throw std::runtime_error("Unpickle Error: " + msg); - return; + // Not using HNSWLIB_THROW_RUNTIME_ERROR here, because it expects a static + // string constant, and because we currently always compile the Python + // bindings with exceptions enabled. + if (!expr) throw std::runtime_error("Unpickle Error: " + msg); } @@ -100,7 +102,7 @@ inline void get_input_array_shapes(const py::buffer_info& buffer, size_t* rows, snprintf(msg, sizeof(msg), "Input vector data wrong shape. Number of dimensions %d. Data must be a 1D or 2D array.", buffer.ndim); - throw std::runtime_error(msg); + HNSWLIB_THROW_RUNTIME_ERROR(msg); } if (buffer.ndim == 2) { *rows = buffer.shape[0]; @@ -124,7 +126,7 @@ inline std::vector get_input_ids_and_check_shapes(const py::object& ids_ snprintf(msg, sizeof(msg), "The input label shape %d does not match the input data vector shape %d", ids_numpy.ndim, feature_rows); - throw std::runtime_error(msg); + HNSWLIB_THROW_RUNTIME_ERROR(msg); } // extract data if (ids_numpy.ndim == 1) { @@ -171,7 +173,7 @@ class Index { l2space = new hnswlib::InnerProductSpace(dim); normalize = true; } else { - throw std::runtime_error("Space name must be one of l2, ip, or cosine."); + HNSWLIB_THROW_RUNTIME_ERROR("Space name must be one of l2, ip, or cosine."); } appr_alg = NULL; ep_added = true; @@ -196,7 +198,7 @@ class Index { size_t random_seed, bool allow_replace_deleted) { if (appr_alg) { - throw std::runtime_error("The index is already initiated."); + HNSWLIB_THROW_RUNTIME_ERROR("The index is already initiated."); } cur_l = 0; appr_alg = new hnswlib::HierarchicalNSW(l2space, maxElements, M, efConstruction, random_seed, allow_replace_deleted); @@ -258,7 +260,7 @@ class Index { get_input_array_shapes(buffer, &rows, &features); if (features != dim) - throw std::runtime_error("Wrong dimensionality of the vectors"); + HNSWLIB_THROW_RUNTIME_ERROR("Wrong dimensionality of the vectors"); // avoid using threads when the number of additions is small: if (rows <= num_threads * 4) { @@ -556,7 +558,7 @@ class Index { for (size_t i = 0; i < appr_alg->cur_element_count; i++) { if (label_lookup_val_npy.data()[i] < 0) { - throw std::runtime_error("Internal id cannot be negative!"); + HNSWLIB_THROW_RUNTIME_ERROR("Internal id cannot be negative!"); } else { appr_alg->label_lookup_.insert(std::make_pair(label_lookup_key_npy.data()[i], label_lookup_val_npy.data()[i])); } @@ -583,7 +585,7 @@ class Index { } else { appr_alg->linkLists_[i] = (char*)malloc(linkListSize); if (appr_alg->linkLists_[i] == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); + HNSWLIB_THROW_RUNTIME_ERROR("Not enough memory: loadIndex failed to allocate linklist"); memcpy(appr_alg->linkLists_[i], link_list_npy.data() + link_npy_offsets[i], linkListSize); } @@ -644,7 +646,7 @@ class Index { std::priority_queue> result = appr_alg->searchKnn( (void*)items.data(row), k, p_idFilter); if (result.size() != k) - throw std::runtime_error( + HNSWLIB_THROW_RUNTIME_ERROR( "Cannot return the results in a contiguous 2D array. Probably ef or M is too small"); for (int i = k - 1; i >= 0; i--) { auto& result_tuple = result.top(); @@ -664,7 +666,7 @@ class Index { std::priority_queue> result = appr_alg->searchKnn( (void*)(norm_array.data() + start_idx), k, p_idFilter); if (result.size() != k) - throw std::runtime_error( + HNSWLIB_THROW_RUNTIME_ERROR( "Cannot return the results in a contiguous 2D array. Probably ef or M is too small"); for (int i = k - 1; i >= 0; i--) { auto& result_tuple = result.top(); @@ -748,7 +750,7 @@ class BFIndex { space = new hnswlib::InnerProductSpace(dim); normalize = true; } else { - throw std::runtime_error("Space name must be one of l2, ip, or cosine."); + HNSWLIB_THROW_RUNTIME_ERROR("Space name must be one of l2, ip, or cosine."); } alg = NULL; index_inited = false; @@ -781,7 +783,7 @@ class BFIndex { void init_new_index(const size_t maxElements) { if (alg) { - throw std::runtime_error("The index is already initiated."); + HNSWLIB_THROW_RUNTIME_ERROR("The index is already initiated."); } cur_l = 0; alg = new hnswlib::BruteforceSearch(space, maxElements); @@ -806,7 +808,7 @@ class BFIndex { get_input_array_shapes(buffer, &rows, &features); if (features != dim) - throw std::runtime_error("Wrong dimensionality of the vectors"); + HNSWLIB_THROW_RUNTIME_ERROR("Wrong dimensionality of the vectors"); std::vector ids = get_input_ids_and_check_shapes(ids_, rows); @@ -1004,7 +1006,7 @@ PYBIND11_PLUGIN(hnswlib) { }, [](py::tuple t) { // __setstate__ if (t.size() != 1) - throw std::runtime_error("Invalid state!"); + HNSWLIB_THROW_RUNTIME_ERROR("Invalid state!"); return Index::createFromParams(t[0].cast()); })) diff --git a/setup.py b/setup.py index d96aea49..0900adc6 100644 --- a/setup.py +++ b/setup.py @@ -118,9 +118,12 @@ def build_extensions(self): print(f'flag: {m1_flag} is not available') else: print(f'flag: {BuildExt.compiler_flag_native} is available') + # Enable exceptions. + opts.append("-fexceptions") elif ct == 'msvc': opts.append('/DVERSION_INFO=\\"%s\\"' % self.distribution.get_version()) - + # Enable exceptions. + opts.append('/EHsc') for ext in self.extensions: ext.extra_compile_args.extend(opts) ext.extra_link_args.extend(BuildExt.link_opts.get(ct, [])) diff --git a/test_all_build_types.sh b/test_all_build_types.sh new file mode 100755 index 00000000..4daeb837 --- /dev/null +++ b/test_all_build_types.sh @@ -0,0 +1,57 @@ +#!/bin/bash + +set -euo pipefail + +source_path="${PWD}" +for c_compiler in clang gcc; do + for build_type in Debug Release RelWithDebInfo MinSizeRel; do + # "tsan" is currently not present here because it has some unresolved + # issues. + for sanitizers in "no_sanitizers" "asan_ubsan"; do + if [[ "${c_compiler}" == "gcc" && + "${sanitizers}" == "tsan" ]]; then + continue + fi + echo ------------------------------------------------------------------- + echo "Starting ${c_compiler} ${build_type} ${sanitizers}" + echo ------------------------------------------------------------------- + echo + "${c_compiler}" --version + echo + + build_path="${source_path}/build/${build_type}_${c_compiler}_${sanitizers}" + rm -rf "${build_path}" + mkdir -p "${build_path}" + cd "${build_path}" + if [[ "${c_compiler}" == "gcc" ]]; then + cxx_compiler=g++ + else + cxx_compiler=clang++ + fi + cmake_cmd=( + cmake -G Ninja + "-DCMAKE_BUILD_TYPE=${build_type}" + "-DCMAKE_C_COMPILER=${c_compiler}" + "-DCMAKE_CXX_COMPILER=${cxx_compiler}" + -S "${source_path}" + -B "${build_path}" + ) + if [[ "${sanitizers}" == "asan_ubsan" ]]; then + cmake_cmd+=( -DENABLE_ASAN=ON -DENABLE_UBSAN=ON ) + fi + if [[ "${sanitizers}" == "tsan" ]]; then + cmake_cmd+=( -DENABLE_TSAN=ON ) + fi + ( set -x; "${cmake_cmd[@]}" ) + time ( set -x; ninja -j8 ) + time ( set -x; ctest ) + cd "${source_path}" + echo + echo ------------------------------------------------------------------- + echo "Finished ${c_compiler} ${build_type} ${sanitizers}" + echo ------------------------------------------------------------------- + echo + echo + done + done +done diff --git a/tests/cpp/epsilon_search_test.cpp b/tests/cpp/epsilon_search_test.cpp index 38df6246..e3aec706 100644 --- a/tests/cpp/epsilon_search_test.cpp +++ b/tests/cpp/epsilon_search_test.cpp @@ -51,9 +51,8 @@ int main() { hnswlib::EpsilonSearchStopCondition stop_condition(epsilon2, min_num_candidates, max_num_candidates); std::vector> result_hnsw = alg_hnsw->searchStopConditionClosest(query_data, stop_condition); - + // check that returned results are in epsilon region - size_t num_vectors = result_hnsw.size(); std::unordered_set hnsw_labels; for (auto pair: result_hnsw) { float dist = pair.first; @@ -63,7 +62,7 @@ int main() { } std::priority_queue> result_brute = alg_brute->searchKnn(query_data, max_elements); - + // check recall std::unordered_set gt_labels; while (!result_brute.empty()) { @@ -95,9 +94,8 @@ int main() { int min_candidates_small = 500; for (size_t i = 0; i < max_elements; i++) { hnswlib::EpsilonSearchStopCondition stop_condition(epsilon2_small, min_candidates_small, max_num_candidates); - std::vector> result = + std::vector> result = alg_hnsw->searchStopConditionClosest(alg_hnsw->getDataByInternalId(i), stop_condition); - size_t num_vectors = result.size(); // get closest distance float dist = -1; if (!result.empty()) { diff --git a/tests/cpp/multiThreadLoad_test.cpp b/tests/cpp/multiThreadLoad_test.cpp index 4d2b4aa2..d4f37d4b 100644 --- a/tests/cpp/multiThreadLoad_test.cpp +++ b/tests/cpp/multiThreadLoad_test.cpp @@ -1,6 +1,7 @@ #include "../../hnswlib/hnswlib.h" #include #include +#include int main() { @@ -8,15 +9,15 @@ int main() { int d = 16; int max_elements = 1000; - std::mt19937 rng; - rng.seed(47); - std::uniform_real_distribution<> distrib_real; - hnswlib::L2Space space(d); hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * max_elements); + std::unique_ptr> alg_hnsw_holder(alg_hnsw); - std::cout << "Building index" << std::endl; int num_threads = 40; + + std::mt19937 seeding_rng(314159265); + + std::cout << "Building index" << std::endl; int num_labels = 10; int num_iterations = 10; @@ -27,12 +28,16 @@ int main() { // will add/update element with the same label simultaneously while (true) { // add elements by batches - std::uniform_int_distribution<> distrib_int(start_label, start_label + num_labels - 1); std::vector threads; for (size_t thread_id = 0; thread_id < num_threads; thread_id++) { + unsigned int rng_seed = seeding_rng(); threads.push_back( std::thread( - [&] { + [=] { + std::uniform_int_distribution<> distrib_int( + start_label, start_label + num_labels - 1); + std::uniform_real_distribution<> distrib_real; + std::mt19937 rng(rng_seed); for (int iter = 0; iter < num_iterations; iter++) { std::vector data(d); hnswlib::labeltype label = distrib_int(rng); @@ -55,13 +60,15 @@ int main() { } // insert remaining elements if needed + std::uniform_real_distribution<> main_distrib_real; + std::mt19937 main_rng(seeding_rng()); for (hnswlib::labeltype label = 0; label < max_elements; label++) { auto search = alg_hnsw->label_lookup_.find(label); if (search == alg_hnsw->label_lookup_.end()) { std::cout << "Adding " << label << std::endl; std::vector data(d); for (int i = 0; i < d; i++) { - data[i] = distrib_real(rng); + data[i] = main_distrib_real(main_rng); } alg_hnsw->addPoint(data.data(), label); } @@ -69,7 +76,7 @@ int main() { std::cout << "Index is created" << std::endl; - bool stop_threads = false; + std::atomic stop_threads{false}; std::vector threads; // create threads that will do markDeleted and unmarkDeleted of random elements @@ -78,9 +85,11 @@ int main() { num_threads = 20; int chunk_size = max_elements / num_threads; for (size_t thread_id = 0; thread_id < num_threads; thread_id++) { + unsigned int rng_seed = seeding_rng(); threads.push_back( std::thread( - [&, thread_id] { + [=, &stop_threads] { + std::mt19937 rng(rng_seed); std::uniform_int_distribution<> distrib_int(0, chunk_size - 1); int start_id = thread_id * chunk_size; std::vector marked_deleted(chunk_size); @@ -103,11 +112,15 @@ int main() { // create threads that will add and update random elements std::cout << "Starting add and update elements threads" << std::endl; num_threads = 20; - std::uniform_int_distribution<> distrib_int_add(max_elements, 2 * max_elements - 1); for (size_t thread_id = 0; thread_id < num_threads; thread_id++) { + unsigned int rng_seed = seeding_rng(); threads.push_back( std::thread( - [&] { + [=, &stop_threads] { + std::mt19937 rng(rng_seed); + std::uniform_int_distribution<> distrib_int_add( + max_elements, 2 * max_elements - 1); + std::uniform_real_distribution<> distrib_real; std::vector data(d); while (!stop_threads) { hnswlib::labeltype label = distrib_int_add(rng); @@ -119,7 +132,7 @@ int main() { float max_val = *max_element(data.begin(), data.end()); // never happens but prevents compiler from deleting unused code if (max_val > 10) { - throw std::runtime_error("Unexpected value in data"); + HNSWLIB_THROW_RUNTIME_ERROR("Unexpected value in data"); } } } @@ -130,11 +143,12 @@ int main() { std::cout << "Sleep and continue operations with index" << std::endl; int sleep_ms = 60 * 1000; std::this_thread::sleep_for(std::chrono::milliseconds(sleep_ms)); + std::cout << "Stopping threads and waiting for them to join" << std::endl; stop_threads = true; for (auto &thread : threads) { thread.join(); } - + std::cout << "Finish" << std::endl; return 0; } diff --git a/tests/cpp/multiThread_replace_test.cpp b/tests/cpp/multiThread_replace_test.cpp index 203cdb0d..20571f1c 100644 --- a/tests/cpp/multiThread_replace_test.cpp +++ b/tests/cpp/multiThread_replace_test.cpp @@ -112,7 +112,7 @@ int main() { delete alg_hnsw; } - + std::cout << "Finish" << std::endl; delete[] batch1; diff --git a/tests/cpp/resize_test.cpp b/tests/cpp/resize_test.cpp new file mode 100644 index 00000000..d63cefba --- /dev/null +++ b/tests/cpp/resize_test.cpp @@ -0,0 +1,72 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "hnswlib/hnswlib.h" + +// Performs the index resize test. +void TestRandomSelf() { + + constexpr int kDim = 16; + constexpr int kNumElements = 10000; + constexpr int kM = 16; + constexpr int kEfConstruction = 100; + constexpr int kEfSearch = 20; + + // Set up a random number generator. + std::mt19937 rng; + std::uniform_real_distribution distrib_real; + + // Generate random data. + std::vector data(kNumElements * kDim); + for (int i = 0; i < kNumElements * kDim; ++i) { + data[i] = distrib_real(rng); + } + + // Initialize the HNSW index. + hnswlib::L2Space space(kDim); + // Initialize with half the maximum elements. + auto* alg_hnsw = new hnswlib::HierarchicalNSW( + &space, kNumElements / 2, kM, kEfConstruction); + std::unique_ptr> alg_hnsw_holder(alg_hnsw); + + alg_hnsw->setEf(kEfSearch); + + // Add the first half of the data to the index. + const int first_batch_size = kNumElements / 2; + std::cout << "Adding first batch of " << first_batch_size << " elements." + << std::endl; + for (int i = 0; i < first_batch_size; ++i) { + alg_hnsw->addPoint(data.data() + (i * kDim), i); + } + + // Resize the index and add the second batch + std::cout << "Resizing the index to " << kNumElements << "." << std::endl; + alg_hnsw->resizeIndex(kNumElements); + + const int second_batch_size = kNumElements - first_batch_size; + std::cout << "Adding the second batch of " << second_batch_size + << " elements." << std::endl; + for (int i = first_batch_size; i < kNumElements; ++i) { + alg_hnsw->addPoint(data.data() + (i * kDim), i); + } + + // Final validation - ensure all points are retrievable + std::cout << "Final validation of all elements..." << std::endl; + for (int i = 0; i < kNumElements; ++i) { + auto result = alg_hnsw->searchKnn(data.data() + (i * kDim), 1); + assert(!result.empty() && result.top().second == i); + } + + std::cout << "Resize test completed successfully!" << std::endl; +} + +int main() { + TestRandomSelf(); + std::cout << "\nAll test runs completed successfully!" << std::endl; + return 0; +} diff --git a/tests/cpp/sift_1b.cpp b/tests/cpp/sift_1b.cpp index c0f296c2..3ee33752 100644 --- a/tests/cpp/sift_1b.cpp +++ b/tests/cpp/sift_1b.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include "../../hnswlib/hnswlib.h" @@ -155,7 +156,6 @@ get_gt( vector>> &answers, size_t k) { (vector>>(qsize)).swap(answers); - DISTFUNC fstdistfunc_ = l2space.get_dist_func(); cout << qsize << "\n"; for (int i = 0; i < qsize; i++) { for (int j = 0; j < k; j++) { diff --git a/tests/cpp/update_gen_data.py b/tests/cpp/update_gen_data.py index 6f51bbbe..e2a4563c 100644 --- a/tests/cpp/update_gen_data.py +++ b/tests/cpp/update_gen_data.py @@ -34,4 +34,4 @@ def normalized(a, axis=-1, order=2): queries.tofile('data/queries.bin') np.int32(topk).tofile('data/gt.bin') with open("data/config.txt", "w") as file: - file.write("%d %d %d %d %d" %(N, dummy_data_multiplier, N_queries, d, K)) \ No newline at end of file + file.write("%d %d %d %d %d" %(N, dummy_data_multiplier, N_queries, d, K)) diff --git a/tests/cpp/updates_test.cpp b/tests/cpp/updates_test.cpp index 4dff2f85..e939b132 100644 --- a/tests/cpp/updates_test.cpp +++ b/tests/cpp/updates_test.cpp @@ -26,7 +26,7 @@ class StopW { * only handles a subset of functionality (no reductions etc) * Process ids from start (inclusive) to end (EXCLUSIVE) * - * The method is borrowed from nmslib + * The method is borrowed from nmslib */ template inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) { diff --git a/tests/python/bindings_test.py b/tests/python/bindings_test.py index f9b3092f..43d9bc03 100644 --- a/tests/python/bindings_test.py +++ b/tests/python/bindings_test.py @@ -64,5 +64,5 @@ def testRandomSelf(self): labels, distances = p.knn_query(data, k=1) self.assertAlmostEqual(np.mean(labels.reshape(-1) == np.arange(len(data))), 1.0, 3) - + os.remove(index_path) diff --git a/tests/python/bindings_test_spaces.py b/tests/python/bindings_test_spaces.py index c3cceb87..901cadcc 100644 --- a/tests/python/bindings_test_spaces.py +++ b/tests/python/bindings_test_spaces.py @@ -34,6 +34,6 @@ def testRandomSelf(self): # Query the elements for themselves and measure recall: labels, distances = p.knn_query(np.asarray(data2[-1:]), k=5) - - diff=np.mean(np.abs(distances-expected_distances)) + + diff=np.mean(np.abs(distances-expected_distances)) self.assertAlmostEqual(diff, 0, delta=1e-3) diff --git a/tests/python/draw_git_test_plots.py b/tests/python/draw_git_test_plots.py index c91c8f5d..d2bc11cd 100644 --- a/tests/python/draw_git_test_plots.py +++ b/tests/python/draw_git_test_plots.py @@ -14,11 +14,11 @@ def plot_data_from_file(file_path): # Create a subplot for each column fig, axes = plt.subplots(num_columns, 1, figsize=(10, 6 * num_columns)) - + # In case there is only one column, axes will not be an array, so we convert it if num_columns == 1: axes = [axes] - + for i, ax in enumerate(axes): idx=0 ax.scatter(np.asarray(data.index,dtype=np.int64)%rep_size, data[i], label=f'Column {i+1}') @@ -45,4 +45,4 @@ def scan_and_plot(directory): plot_data_from_file(file) print(f'Plot saved for {file}') # Replace 'your_folder_path' with the path to the folder containing the .txt files -scan_and_plot('./') \ No newline at end of file +scan_and_plot('./') diff --git a/tests/python/git_tester.py b/tests/python/git_tester.py index e7657fee..d4afa919 100644 --- a/tests/python/git_tester.py +++ b/tests/python/git_tester.py @@ -16,15 +16,15 @@ print(idx, commit.hash, name) for commit in commits: - commit_time = commit.author_date.strftime("%Y-%m-%d %H:%M:%S") + commit_time = commit.author_date.strftime("%Y-%m-%d %H:%M:%S") author_name = commit.author.name name = "auth:"+author_name+"_"+commit_time+"_msg:"+commit.msg.replace('\n', ' ').replace('\r', ' ').replace(",", ";") print("\nProcessing", commit.hash, name) - + if os.path.exists("build"): shutil.rmtree("build") os.system(f"git checkout {commit.hash}") - + # Checking we have actually switched the branch: current_commit=list(Repository('.').traverse_commits())[-1] if current_commit.hash != commit.hash: @@ -33,7 +33,7 @@ print("git checkout failed!!!!") print("git checkout failed!!!!") continue - + print("\n\n--------------------\n\n") ret = os.system("python -m pip install .") print("Install result:", ret) @@ -52,4 +52,3 @@ os.system(f'python {speedtest_copy_path} -n "{commit.hash[:4]}_{name}" -d 4 -t 64') os.system(f'python {speedtest_copy_path} -n "{commit.hash[:4]}_{name}" -d 128 -t 1') os.system(f'python {speedtest_copy_path} -n "{commit.hash[:4]}_{name}" -d 128 -t 64') - diff --git a/tests/python/speedtest.py b/tests/python/speedtest.py index 8d16cfc3..d5f81dd1 100644 --- a/tests/python/speedtest.py +++ b/tests/python/speedtest.py @@ -55,11 +55,10 @@ tt=time.time()-t0 times.append(tt) recall=np.sum(labels.reshape(-1)==np.arange(len(qdata)))/len(qdata) - print(f"{tt} seconds, recall= {recall}") - + print(f"{tt} seconds, recall= {recall}") + str_out=f"{np.mean(times)}, {np.median(times)}, {np.std(times)}, {construction_time}, {recall}, {name}" print(str_out) with open (f"log2_{dim}_t{threads}.txt","a") as f: f.write(str_out+"\n") f.flush() -