From adf8d49f130eb2b95916d8ce51539fc5e7d0d38a Mon Sep 17 00:00:00 2001 From: GitHub Actions <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 20 May 2025 23:15:48 -0700 Subject: [PATCH 1/8] feature: add logging in orchestrator: experimental_calibrate --- include/svs/index/vamana/index.h | 8 ++-- include/svs/orchestrators/vamana.h | 39 ++++++++++++------ tests/integration/vamana/index_search.cpp | 50 ++++++++++++++++++----- 3 files changed, 70 insertions(+), 27 deletions(-) diff --git a/include/svs/index/vamana/index.h b/include/svs/index/vamana/index.h index 3a60132b..dbd8d44b 100644 --- a/include/svs/index/vamana/index.h +++ b/include/svs/index/vamana/index.h @@ -828,7 +828,8 @@ class VamanaIndex { const GroundTruth& groundtruth, size_t num_neighbors, double target_recall, - const CalibrationParameters& calibration_parameters = {} + const CalibrationParameters& calibration_parameters = {}, + logging::logger_ptr logger = svs::logging::get() ) { // Preallocate the destination for search. // Further, reference the search lambda in the recall lambda. @@ -850,7 +851,8 @@ class VamanaIndex { num_neighbors, target_recall, compute_recall, - do_search + do_search, + logger ); set_search_parameters(p); return p; @@ -997,7 +999,7 @@ auto auto_assemble( I{}, std::move(distance), std::move(threadpool), - std::move(logger)}; + logger}; auto config = lib::load_from_disk(config_path); index.apply(config); return index; diff --git a/include/svs/orchestrators/vamana.h b/include/svs/orchestrators/vamana.h index d41bd0cf..e9858eef 100644 --- a/include/svs/orchestrators/vamana.h +++ b/include/svs/orchestrators/vamana.h @@ -92,7 +92,8 @@ class VamanaInterface { size_t groundtruth_size_1, size_t num_neighbors, double target_recall, - const index::vamana::CalibrationParameters& calibration_parameters + const index::vamana::CalibrationParameters& calibration_parameters, + svs::logging::logger_ptr logger = svs::logging::get() ) = 0; virtual void reset_performance_parameters() = 0; @@ -206,7 +207,8 @@ class VamanaImpl : public manager::ManagerImpl { size_t groundtruth_size_1, size_t num_neighbors, double target_recall, - const index::vamana::CalibrationParameters& calibration_parameters + const index::vamana::CalibrationParameters& calibration_parameters, + svs::logging::logger_ptr logger = svs::logging::get() ) override { if (!lib::in(queries.type(), QueryTypes{})) { throw ANNEXCEPTION( @@ -237,7 +239,8 @@ class VamanaImpl : public manager::ManagerImpl { ), num_neighbors, target_recall, - calibration_parameters + calibration_parameters, + logger ); } ); @@ -409,7 +412,8 @@ class Vamana : public manager::IndexManager { const GraphLoaderType& graph_loader, DataLoader&& data_loader, const Distance& distance, - ThreadPoolProto threadpool_proto + ThreadPoolProto threadpool_proto, + svs::logging::logger_ptr logger = svs::logging::get() ) { // If given an `enum` for the distance type, than we need to dispatch over that // enum. @@ -425,7 +429,8 @@ class Vamana : public manager::IndexManager { graph_loader, std::forward(data_loader), distance_function, - std::move(threadpool) + std::move(threadpool), + logger ); }); } else { @@ -435,7 +440,8 @@ class Vamana : public manager::IndexManager { graph_loader, std::forward(data_loader), distance, - std::move(threadpool) + std::move(threadpool), + logger ); } } @@ -478,7 +484,8 @@ class Vamana : public manager::IndexManager { DataLoader&& data_loader, Distance distance, ThreadPoolProto threadpool_proto = 1, - const Allocator& graph_allocator = {} + const Allocator& graph_allocator = {}, + svs::logging::logger_ptr logger = svs::logging::get() ) { auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); if constexpr (std::is_same_v, DistanceType>) { @@ -490,7 +497,8 @@ class Vamana : public manager::IndexManager { std::forward(data_loader), std::move(distance_function), std::move(threadpool), - graph_allocator + graph_allocator, + logger ); }); } else { @@ -500,7 +508,8 @@ class Vamana : public manager::IndexManager { std::forward(data_loader), distance, std::move(threadpool), - graph_allocator + graph_allocator, + logger ); } } @@ -537,14 +546,16 @@ class Vamana : public manager::IndexManager { const GroundTruth& groundtruth, size_t num_neighbors, double target_recall, - const index::vamana::CalibrationParameters calibration_parameters = {} + const index::vamana::CalibrationParameters calibration_parameters = {}, + svs::logging::logger_ptr logger = svs::logging::get() ) { return experimental_calibrate_impl( queries.cview(), groundtruth.cview(), num_neighbors, target_recall, - calibration_parameters + calibration_parameters, + logger ); } @@ -554,7 +565,8 @@ class Vamana : public manager::IndexManager { data::ConstSimpleDataView groundtruth, size_t num_neighbors, double target_recall, - const index::vamana::CalibrationParameters calibration_parameters + const index::vamana::CalibrationParameters calibration_parameters, + svs::logging::logger_ptr logger = svs::logging::get() ) { return impl_->experimental_calibrate( ConstErasedPointer{queries.data()}, @@ -565,7 +577,8 @@ class Vamana : public manager::IndexManager { groundtruth.dimensions(), num_neighbors, target_recall, - calibration_parameters + calibration_parameters, + logger ); } diff --git a/tests/integration/vamana/index_search.cpp b/tests/integration/vamana/index_search.cpp index 3b96d599..671589e7 100644 --- a/tests/integration/vamana/index_search.cpp +++ b/tests/integration/vamana/index_search.cpp @@ -41,6 +41,7 @@ #include "tests/utils/test_dataset.h" #include "tests/utils/utils.h" #include "tests/utils/vamana_reference.h" +#include "spdlog/sinks/callback_sink.h" namespace { @@ -130,6 +131,7 @@ void run_tests( const svs::data::SimpleData& queries_all, const svs::data::SimpleData& groundtruth_all, const std::vector& expected_results, + svs::logging::logger_ptr logger, bool test_calibration = false ) { // If we make a change that somehow improves accuracy, we'll want to know. @@ -222,7 +224,7 @@ void run_tests( c.train_prefetchers_ = false; index.experimental_calibrate( - queries, groundtruth, first_result.num_neighbors_, first_result.recall_, c + queries, groundtruth, first_result.num_neighbors_, first_result.recall_, c, logger ); auto recall = svs::k_recall_at_n( groundtruth, @@ -235,6 +237,28 @@ void run_tests( } // namespace CATCH_TEST_CASE("Uncompressed Vamana Search", "[integration][search][vamana]") { + // Set up log + std::vector captured_logs; + auto callback_sink = std::make_shared( + [&captured_logs](const spdlog::details::log_msg& msg) { + captured_logs.emplace_back(msg.payload.data(), msg.payload.size()); + } + ); + callback_sink->set_level(spdlog::level::trace); + auto test_logger = std::make_shared("test_logger", callback_sink); + test_logger->set_level(spdlog::level::trace); + + std::vector global_captured_logs; + auto global_callback_sink = std::make_shared( + [&global_captured_logs](const spdlog::details::log_msg& msg) { + global_captured_logs.emplace_back(msg.payload.data(), msg.payload.size()); + } + ); + global_callback_sink->set_level(spdlog::level::trace); + auto original_logger = std::make_shared("original_logger", global_callback_sink); + original_logger->set_level(spdlog::level::trace); + svs::logging::set(original_logger); + auto distances = std::to_array({svs::L2, svs::MIP, svs::Cosine}); const auto queries = test_dataset::queries(); auto temp_dir = svs_test::temp_directory(); @@ -253,7 +277,8 @@ CATCH_TEST_CASE("Uncompressed Vamana Search", "[integration][search][vamana]") { svs::GraphLoader(test_dataset::graph_file()), svs::VectorDataLoader(test_dataset::data_svs_file()), distance_type, - 2 + 2, + test_logger ); CATCH_REQUIRE(index.size() == test_dataset::VECTORS_IN_DATA_SET); @@ -267,19 +292,19 @@ CATCH_TEST_CASE("Uncompressed Vamana Search", "[integration][search][vamana]") { verify_reconstruction(index, original_data); first = false; } - - run_tests(index, queries, groundtruth, expected_results.config_and_recall_, true); + run_tests(index, queries, groundtruth, expected_results.config_and_recall_, test_logger, true); index = svs::Vamana::assemble>( test_dataset::vamana_config_file(), svs::GraphLoader(test_dataset::graph_file()), svs::VectorDataLoader(test_dataset::data_svs_file()), distance_type, - svs::threads::CppAsyncThreadPool(2) + svs::threads::CppAsyncThreadPool(2), + test_logger ); run_tests( - index, queries, groundtruth, expected_results.config_and_recall_, true + index, queries, groundtruth, expected_results.config_and_recall_, test_logger, true ); index = svs::Vamana::assemble>( @@ -287,11 +312,12 @@ CATCH_TEST_CASE("Uncompressed Vamana Search", "[integration][search][vamana]") { svs::GraphLoader(test_dataset::graph_file()), svs::VectorDataLoader(test_dataset::data_svs_file()), distance_type, - svs::threads::QueueThreadPoolWrapper(2) + svs::threads::QueueThreadPoolWrapper(2), + test_logger ); run_tests( - index, queries, groundtruth, expected_results.config_and_recall_, true + index, queries, groundtruth, expected_results.config_and_recall_, test_logger, true ); // Save and reload. svs_test::prepare_temp_directory(); @@ -316,7 +342,8 @@ CATCH_TEST_CASE("Uncompressed Vamana Search", "[integration][search][vamana]") { svs::GraphLoader(graph_dir), svs::VectorDataLoader(data_dir), distance_type, - 1 + 1, + test_logger ); // Data Properties CATCH_REQUIRE(index.size() == test_dataset::VECTORS_IN_DATA_SET); @@ -337,12 +364,13 @@ CATCH_TEST_CASE("Uncompressed Vamana Search", "[integration][search][vamana]") { threadpool.resize(2); CATCH_REQUIRE(index.get_num_threads() == 2); run_tests( - index, queries, groundtruth, expected_results.config_and_recall_ + index, queries, groundtruth, expected_results.config_and_recall_, test_logger ); index.set_threadpool(svs::threads::SwitchNativeThreadPool(2)); run_tests( - index, queries, groundtruth, expected_results.config_and_recall_ + index, queries, groundtruth, expected_results.config_and_recall_, test_logger ); } + CATCH_REQUIRE(global_captured_logs.empty()); } From 28947c2971bd8bf440aed6b13f288fbb43ec2063 Mon Sep 17 00:00:00 2001 From: GitHub Actions <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 22 May 2025 22:44:47 -0700 Subject: [PATCH 2/8] fix: add search in orchestrator level --- include/svs/index/flat/flat.h | 15 +++-- include/svs/index/index.h | 19 +++--- include/svs/index/inverted/memory_based.h | 3 +- include/svs/index/vamana/dynamic_index.h | 15 +++-- include/svs/index/vamana/extensions.h | 22 ++++--- include/svs/index/vamana/greedy_search.h | 10 +++- include/svs/index/vamana/index.h | 21 +++++-- include/svs/index/vamana/iterator.h | 71 ++++++++++++----------- include/svs/index/vamana/vamana_build.h | 17 ++++-- include/svs/orchestrators/manager.h | 23 +++++--- tests/integration/vamana/index_search.cpp | 30 ++++++++-- tests/svs/index/index.cpp | 3 +- 12 files changed, 163 insertions(+), 86 deletions(-) diff --git a/include/svs/index/flat/flat.h b/include/svs/index/flat/flat.h index ee987d19..6878d769 100644 --- a/include/svs/index/flat/flat.h +++ b/include/svs/index/flat/flat.h @@ -319,7 +319,8 @@ class FlatIndex { const data::ConstSimpleDataView& queries, const search_parameters_type& search_parameters, const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), - Pred predicate = lib::Returns(lib::Const()) + Pred predicate = lib::Returns(lib::Const()), + svs::logging::logger_ptr logger = svs::logging::get() ) { const size_t data_max_size = data_.size(); @@ -347,7 +348,8 @@ class FlatIndex { scratch, search_parameters, cancel, - predicate + predicate, + logger ); start = stop; } @@ -377,7 +379,8 @@ class FlatIndex { sorter_type& scratch, const search_parameters_type& search_parameters, const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), - Pred predicate = lib::Returns(lib::Const()) + Pred predicate = lib::Returns(lib::Const()), + svs::logging::logger_ptr logger = svs::logging::get() ) { // Process all queries. threads::parallel_for( @@ -398,7 +401,8 @@ class FlatIndex { scratch, distances, cancel, - predicate + predicate, + logger ); } ); @@ -420,7 +424,8 @@ class FlatIndex { sorter_type& scratch, distance::BroadcastDistance& distance_functors, const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), - Pred predicate = lib::Returns(lib::Const()) + Pred predicate = lib::Returns(lib::Const()), + logging::logger_ptr SVS_UNUSED(logger) = svs::logging::get() ) { assert(distance_functors.size() >= query_indices.size()); auto accessor = extensions::accessor(data_); diff --git a/include/svs/index/index.h b/include/svs/index/index.h index d19bd68a..b57bb6ff 100644 --- a/include/svs/index/index.h +++ b/include/svs/index/index.h @@ -18,6 +18,7 @@ // svs #include "svs/concepts/data.h" +#include "svs/core/logging.h" #include "svs/core/query_result.h" // stl @@ -47,7 +48,8 @@ void search_batch_into_with( svs::QueryResultView result, const Queries& queries, const search_parameters_t& search_parameters, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), + svs::logging::logger_ptr SVS_UNUSED(logger) = svs::logging::get() ) { // Assert pre-conditions. assert(result.n_queries() == queries.size()); @@ -60,10 +62,11 @@ void search_batch_into( Index& index, svs::QueryResultView result, const Queries& queries, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), + svs::logging::logger_ptr logger = svs::logging::get() ) { svs::index::search_batch_into_with( - index, result, queries, index.get_search_parameters(), cancel + index, result, queries, index.get_search_parameters(), cancel, logger ); } @@ -74,11 +77,12 @@ svs::QueryResult search_batch_with( const Queries& queries, size_t num_neighbors, const search_parameters_t& search_parameters, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), + svs::logging::logger_ptr logger = svs::logging::get() ) { auto result = svs::QueryResult{queries.size(), num_neighbors}; svs::index::search_batch_into_with( - index, result.view(), queries, search_parameters, cancel + index, result.view(), queries, search_parameters, cancel, logger ); return result; } @@ -89,10 +93,11 @@ svs::QueryResult search_batch( Index& index, const Queries& queries, size_t num_neighbors, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), + svs::logging::logger_ptr logger = svs::logging::get() ) { return svs::index::search_batch_with( - index, queries, num_neighbors, index.get_search_parameters(), cancel + index, queries, num_neighbors, index.get_search_parameters(), cancel, logger ); } } // namespace svs::index diff --git a/include/svs/index/inverted/memory_based.h b/include/svs/index/inverted/memory_based.h index 3d3fc24c..b0a2c363 100644 --- a/include/svs/index/inverted/memory_based.h +++ b/include/svs/index/inverted/memory_based.h @@ -408,7 +408,8 @@ template class InvertedIndex { QueryResultView results, const Queries& queries, const search_parameters_type& search_parameters, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), + svs::logging::logger_ptr SVS_UNUSED(logger) = svs::logging::get() ) { threads::parallel_for( threadpool_, diff --git a/include/svs/index/vamana/dynamic_index.h b/include/svs/index/vamana/dynamic_index.h index 9dafe705..6e79af4e 100644 --- a/include/svs/index/vamana/dynamic_index.h +++ b/include/svs/index/vamana/dynamic_index.h @@ -447,7 +447,11 @@ class MutableVamanaIndex { const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) ) const { return [&, prefetch_parameters]( - const auto& query, auto& accessor, auto& distance, auto& buffer + const auto& query, + auto& accessor, + auto& distance, + auto& buffer, + auto& logger ) { // Perform the greedy search using the provided resources. greedy_search( @@ -460,7 +464,8 @@ class MutableVamanaIndex { vamana::EntryPointInitializer{lib::as_const_span(entry_point_)}, internal_search_builder(), prefetch_parameters, - cancel + cancel, + logger ); // Take a pass over the search buffer to remove any deleted elements that // might remain. @@ -473,7 +478,8 @@ class MutableVamanaIndex { void search( const Query& query, scratchspace_type& scratch, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), + svs::logging::logger_ptr SVS_UNUSED(logger) = svs::logging::get() ) const { extensions::single_search( data_, @@ -489,7 +495,8 @@ class MutableVamanaIndex { QueryResultView results, const Queries& queries, const search_parameters_type& sp, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), + svs::logging::logger_ptr SVS_UNUSED(logger) = svs::logging::get() ) { threads::parallel_for( threadpool_, diff --git a/include/svs/index/vamana/extensions.h b/include/svs/index/vamana/extensions.h index 58923567..0b4631d9 100644 --- a/include/svs/index/vamana/extensions.h +++ b/include/svs/index/vamana/extensions.h @@ -19,6 +19,7 @@ #include "svs/concepts/distance.h" #include "svs/core/data.h" #include "svs/core/distance.h" +#include "svs/core/logging.h" #include "svs/core/medioid.h" #include "svs/core/query_result.h" #include "svs/index/vamana/greedy_search.h" @@ -417,9 +418,10 @@ struct VamanaSingleSearchType { Scratch& scratch, const Query& query, const Search& search, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), + svs::logging::logger_ptr logger = svs::logging::get() ) const { - svs::svs_invoke(*this, data, search_buffer, scratch, query, search, cancel); + svs::svs_invoke(*this, data, search_buffer, scratch, query, search, cancel, logger); } }; @@ -442,7 +444,8 @@ SVS_FORCE_INLINE void svs_invoke( Distance& distance, const Query& query, const Search& search, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), + svs::logging::logger_ptr logger = svs::logging::get() ) { // Check if request to cancel the search if (cancel()) { @@ -450,7 +453,7 @@ SVS_FORCE_INLINE void svs_invoke( } // Perform graph search. auto accessor = data::GetDatumAccessor(); - search(query, accessor, distance, search_buffer); + search(query, accessor, distance, search_buffer, logger); } /// @@ -497,7 +500,8 @@ struct VamanaPerThreadBatchSearchType { QueryResultView& result, threads::UnitRange thread_indices, const Search& search, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), + svs::logging::logger_ptr logger = svs::logging::get() ) const { svs::svs_invoke( *this, @@ -508,7 +512,8 @@ struct VamanaPerThreadBatchSearchType { result, thread_indices, search, - cancel + cancel, + logger ); } }; @@ -533,7 +538,8 @@ void svs_invoke( QueryResultView& result, threads::UnitRange thread_indices, const Search& search, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), + svs::logging::logger_ptr logger = svs::logging::get() ) { // Fallback implementation size_t num_neighbors = result.n_neighbors(); @@ -544,7 +550,7 @@ void svs_invoke( } // Perform search - results will be queued in the search buffer. single_search( - dataset, search_buffer, distance, queries.get_datum(i), search, cancel + dataset, search_buffer, distance, queries.get_datum(i), search, cancel, logger ); // Copy back results. diff --git a/include/svs/index/vamana/greedy_search.h b/include/svs/index/vamana/greedy_search.h index f12c0129..ce8c0174 100644 --- a/include/svs/index/vamana/greedy_search.h +++ b/include/svs/index/vamana/greedy_search.h @@ -19,6 +19,7 @@ #include "svs/concepts/data.h" #include "svs/concepts/distance.h" #include "svs/concepts/graph.h" +#include "svs/core/logging.h" #include "svs/index/vamana/search_buffer.h" #include @@ -132,7 +133,8 @@ void greedy_search( const Builder& builder, Tracker& search_tracker, GreedySearchPrefetchParameters prefetch_parameters = {}, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), + logging::logger_ptr SVS_UNUSED(logger) = svs::logging::get() ) { using I = typename Graph::index_type; @@ -223,7 +225,8 @@ void greedy_search( const Initializer& initializer, const Builder& builder = NeighborBuilder(), GreedySearchPrefetchParameters prefetch_parameters = {}, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), + svs::logging::logger_ptr logger = svs::logging::get() ) { auto null_tracker = NullTracker{}; greedy_search( @@ -237,7 +240,8 @@ void greedy_search( builder, null_tracker, prefetch_parameters, - cancel + cancel, + logger ); } } // namespace svs::index::vamana diff --git a/include/svs/index/vamana/index.h b/include/svs/index/vamana/index.h index dbd8d44b..5774aed7 100644 --- a/include/svs/index/vamana/index.h +++ b/include/svs/index/vamana/index.h @@ -467,7 +467,11 @@ class VamanaIndex { const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) ) const { return [&, prefetch_parameters]( - const auto& query, auto& accessor, auto& distance, auto& buffer + const auto& query, + auto& accessor, + auto& distance, + auto& buffer, + auto& logger ) { greedy_search( graph_, @@ -479,7 +483,8 @@ class VamanaIndex { vamana::EntryPointInitializer{lib::as_const_span(entry_point_)}, internal_search_builder(), prefetch_parameters, - cancel + cancel, + logger ); }; } @@ -502,7 +507,8 @@ class VamanaIndex { void search( const Query& query, scratchspace_type& scratch, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), + svs::logging::logger_ptr SVS_UNUSED(logger) = svs::logging::get() ) const { extensions::single_search( data_, @@ -554,7 +560,8 @@ class VamanaIndex { QueryResultView result, const Queries& queries, const search_parameters_type& search_parameters, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), + svs::logging::logger_ptr SVS_UNUSED(logger) = svs::logging::get() ) { threads::parallel_for( threadpool_, @@ -829,14 +836,16 @@ class VamanaIndex { size_t num_neighbors, double target_recall, const CalibrationParameters& calibration_parameters = {}, - logging::logger_ptr logger = svs::logging::get() + svs::logging::logger_ptr logger = svs::logging::get() ) { // Preallocate the destination for search. // Further, reference the search lambda in the recall lambda. auto results = svs::QueryResult{queries.size(), num_neighbors}; auto do_search = [&](const search_parameters_type& p) { - this->search(results.view(), queries, p); + this->search( + results.view(), queries, p, lib::Returns(lib::Const()), logger + ); }; auto compute_recall = [&](const search_parameters_type& p) { diff --git a/include/svs/index/vamana/iterator.h b/include/svs/index/vamana/iterator.h index 7dc52a6b..bc417124 100644 --- a/include/svs/index/vamana/iterator.h +++ b/include/svs/index/vamana/iterator.h @@ -259,40 +259,45 @@ template class BatchIterator { const auto& SVS_UNUSED(distance), std::span entry_points ) { - auto search_closure = - [&](const auto& query, const auto& accessor, auto& d, auto& buffer) { - constexpr vamana::extensions::UsesReranking< - std::remove_const_t>> - uses_reranking{}; - if constexpr (uses_reranking()) { - distance::maybe_fix_argument(d, query); - for (size_t j = 0, jmax = buffer.size(); j < jmax; ++j) { - auto& neighbor = buffer[j]; - auto id = neighbor.id(); - auto new_distance = - distance::compute(d, query, data.get_primary(id)); - neighbor.set_distance(new_distance); - } - buffer.sort(); + auto search_closure = [&](const auto& query, + const auto& accessor, + auto& d, + auto& buffer, + svs::logging::logger_ptr logger = + svs::logging::get()) { + constexpr vamana::extensions::UsesReranking< + std::remove_const_t>> + uses_reranking{}; + if constexpr (uses_reranking()) { + distance::maybe_fix_argument(d, query); + for (size_t j = 0, jmax = buffer.size(); j < jmax; ++j) { + auto& neighbor = buffer[j]; + auto id = neighbor.id(); + auto new_distance = + distance::compute(d, query, data.get_primary(id)); + neighbor.set_distance(new_distance); } - - vamana::greedy_search( - graph, - data, - accessor, - query, - d, - buffer, - RestartInitializer{entry_points, restart_search_copy}, - parent_->internal_search_builder(), - scratchspace_.prefetch_parameters, - cancel - ); - - if constexpr (Index::needs_id_translation) { - buffer.cleanup(); - } - }; + buffer.sort(); + } + + vamana::greedy_search( + graph, + data, + accessor, + query, + d, + buffer, + RestartInitializer{entry_points, restart_search_copy}, + parent_->internal_search_builder(), + scratchspace_.prefetch_parameters, + cancel, + logger + ); + + if constexpr (Index::needs_id_translation) { + buffer.cleanup(); + } + }; extensions::single_search( data, diff --git a/include/svs/index/vamana/vamana_build.h b/include/svs/index/vamana/vamana_build.h index b20f7bc5..cbaeb5b3 100644 --- a/include/svs/index/vamana/vamana_build.h +++ b/include/svs/index/vamana/vamana_build.h @@ -206,7 +206,7 @@ class VamanaBuilder { float alpha, Idx entry_point, logging::Level level = logging::Level::Info, - logging::logger_ptr logger = svs::logging::get() + svs::logging::logger_ptr logger = svs::logging::get() ) { construct( alpha, entry_point, threads::UnitRange{0, data_.size()}, level, logger @@ -219,7 +219,7 @@ class VamanaBuilder { Idx entry_point, const R& range, logging::Level level = logging::Level::Info, - logging::logger_ptr logger = svs::logging::get() + svs::logging::logger_ptr logger = svs::logging::get() ) { size_t num_nodes = range.size(); size_t num_batches = std::max( @@ -256,7 +256,11 @@ class VamanaBuilder { // because it seems to generally yield better results. auto x = timer.push_back("generate neighbors"); generate_neighbors( - threads::IteratorPair{start, stop}, params_.alpha, entry_points, timer + threads::IteratorPair{start, stop}, + params_.alpha, + entry_points, + timer, + logger ); search_time += lib::as_seconds(x.finish()); @@ -313,7 +317,8 @@ class VamanaBuilder { const R& indices, float alpha, const std::vector& entry_points, - lib::Timer& timer + lib::Timer& timer, + svs::logging::logger_ptr logger = svs::logging::get() ) { auto range = threads::StaticPartition{indices}; @@ -365,7 +370,9 @@ class VamanaBuilder { vamana::EntryPointInitializer{lib::as_const_span(entry_points)}, NeighborBuilder(), tracker, - prefetch_hint_ + prefetch_hint_, + lib::Returns(lib::Const()), + logger ); } diff --git a/include/svs/orchestrators/manager.h b/include/svs/orchestrators/manager.h index 6b08c266..54c8f2ba 100644 --- a/include/svs/orchestrators/manager.h +++ b/include/svs/orchestrators/manager.h @@ -18,6 +18,7 @@ #include "svs/core/data/simple.h" #include "svs/core/distance.h" +#include "svs/core/logging.h" #include "svs/core/query_result.h" #include "svs/lib/datatype.h" #include "svs/lib/threads/threadpool.h" @@ -79,7 +80,8 @@ template class ManagerInterface : public IFace { svs::QueryResultView results, AnonymousArray<2> data, const search_parameters_type& search_parameters, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), + svs::logging::logger_ptr logger = svs::logging::get() ) = 0; // Data Interface @@ -142,7 +144,8 @@ class ManagerImpl : public ManagerInterface { QueryResultView result, AnonymousArray<2> data, const search_parameters_type& search_parameters, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), + svs::logging::logger_ptr logger = svs::logging::get() ) override { // See if we have a specialization for this particular query type. // If so, invoke that specialization, otherwise throw @@ -152,7 +155,7 @@ class ManagerImpl : public ManagerInterface { [&](lib::Type SVS_UNUSED(type)) { const auto view = data::ConstSimpleDataView(data); svs::index::search_batch_into_with( - implementation_, result, view, search_parameters, cancel + implementation_, result, view, search_parameters, cancel, logger ); }, [&](svs::DataType data_type) { @@ -218,9 +221,12 @@ template class IndexManager { QueryResultView result, data::ConstSimpleDataView queries, const search_parameters_type& search_parameters, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), + svs::logging::logger_ptr logger = svs::logging::get() ) { - impl_->search(result, AnonymousArray<2>(queries), search_parameters, cancel); + impl_->search( + result, AnonymousArray<2>(queries), search_parameters, cancel, logger + ); } // This is an API compatibility trick. @@ -230,9 +236,12 @@ template class IndexManager { QueryResult search( const Queries& queries, size_t num_neighbors, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), + svs::logging::logger_ptr logger = svs::logging::get() ) { - return svs::index::search_batch(*this, queries.cview(), num_neighbors, cancel); + return svs::index::search_batch( + *this, queries.cview(), num_neighbors, cancel, logger + ); } ///// Data Interface diff --git a/tests/integration/vamana/index_search.cpp b/tests/integration/vamana/index_search.cpp index 671589e7..cae36dae 100644 --- a/tests/integration/vamana/index_search.cpp +++ b/tests/integration/vamana/index_search.cpp @@ -38,10 +38,10 @@ #include "fmt/core.h" // tests +#include "spdlog/sinks/callback_sink.h" #include "tests/utils/test_dataset.h" #include "tests/utils/utils.h" #include "tests/utils/vamana_reference.h" -#include "spdlog/sinks/callback_sink.h" namespace { @@ -236,7 +236,7 @@ void run_tests( } } // namespace -CATCH_TEST_CASE("Uncompressed Vamana Search", "[integration][search][vamana]") { +CATCH_TEST_CASE("Uncompressed Vamana Search", "[integration1][search][vamana]") { // Set up log std::vector captured_logs; auto callback_sink = std::make_shared( @@ -255,7 +255,8 @@ CATCH_TEST_CASE("Uncompressed Vamana Search", "[integration][search][vamana]") { } ); global_callback_sink->set_level(spdlog::level::trace); - auto original_logger = std::make_shared("original_logger", global_callback_sink); + auto original_logger = + std::make_shared("original_logger", global_callback_sink); original_logger->set_level(spdlog::level::trace); svs::logging::set(original_logger); @@ -292,7 +293,14 @@ CATCH_TEST_CASE("Uncompressed Vamana Search", "[integration][search][vamana]") { verify_reconstruction(index, original_data); first = false; } - run_tests(index, queries, groundtruth, expected_results.config_and_recall_, test_logger, true); + run_tests( + index, + queries, + groundtruth, + expected_results.config_and_recall_, + test_logger, + true + ); index = svs::Vamana::assemble>( test_dataset::vamana_config_file(), @@ -304,7 +312,12 @@ CATCH_TEST_CASE("Uncompressed Vamana Search", "[integration][search][vamana]") { ); run_tests( - index, queries, groundtruth, expected_results.config_and_recall_, test_logger, true + index, + queries, + groundtruth, + expected_results.config_and_recall_, + test_logger, + true ); index = svs::Vamana::assemble>( @@ -317,7 +330,12 @@ CATCH_TEST_CASE("Uncompressed Vamana Search", "[integration][search][vamana]") { ); run_tests( - index, queries, groundtruth, expected_results.config_and_recall_, test_logger, true + index, + queries, + groundtruth, + expected_results.config_and_recall_, + test_logger, + true ); // Save and reload. svs_test::prepare_temp_directory(); diff --git a/tests/svs/index/index.cpp b/tests/svs/index/index.cpp index 8e45110d..76dc4f78 100644 --- a/tests/svs/index/index.cpp +++ b/tests/svs/index/index.cpp @@ -48,7 +48,8 @@ struct TestIndex { svs::data::ConstSimpleDataView queries, SearchParameters p, const svs::lib::DefaultPredicate& cancel = - svs::lib::Returns(svs::lib::Const()) + svs::lib::Returns(svs::lib::Const()), + svs::logging::logger_ptr SVS_UNUSED(logger) = svs::logging::get() ) const { CATCH_REQUIRE(result.n_neighbors() == expected_num_neighbors_); CATCH_REQUIRE(result.n_queries() == expected_num_queries_); From 5ba032a37d814f42aab75b52a0508feacedba950 Mon Sep 17 00:00:00 2001 From: GitHub Actions <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 23 May 2025 16:21:51 -0700 Subject: [PATCH 3/8] fix: fix search more tests --- include/svs/index/flat/flat.h | 20 +++++------ include/svs/index/index.h | 24 ++++++------- include/svs/index/inverted/memory_based.h | 6 ++-- include/svs/index/vamana/dynamic_index.h | 14 ++++---- include/svs/index/vamana/extensions.h | 24 ++++++------- include/svs/index/vamana/greedy_search.h | 12 +++---- include/svs/index/vamana/index.h | 20 +++++------ include/svs/index/vamana/iterator.h | 4 +-- include/svs/index/vamana/vamana_build.h | 1 - include/svs/orchestrators/manager.h | 22 ++++++------ tests/integration/cancel.cpp | 15 ++++++--- tests/integration/exhaustive.cpp | 35 +++++++++++++++++-- tests/integration/vamana/index_search.cpp | 41 +++++++++++++++++++---- tests/svs/index/index.cpp | 4 +-- 14 files changed, 154 insertions(+), 88 deletions(-) diff --git a/include/svs/index/flat/flat.h b/include/svs/index/flat/flat.h index 6878d769..567b2de5 100644 --- a/include/svs/index/flat/flat.h +++ b/include/svs/index/flat/flat.h @@ -318,9 +318,9 @@ class FlatIndex { QueryResultView result, const data::ConstSimpleDataView& queries, const search_parameters_type& search_parameters, + svs::logging::logger_ptr logger = svs::logging::get(), const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), - Pred predicate = lib::Returns(lib::Const()), - svs::logging::logger_ptr logger = svs::logging::get() + Pred predicate = lib::Returns(lib::Const()) ) { const size_t data_max_size = data_.size(); @@ -347,9 +347,9 @@ class FlatIndex { threads::UnitRange(start, stop), scratch, search_parameters, + logger, cancel, - predicate, - logger + predicate ); start = stop; } @@ -378,9 +378,9 @@ class FlatIndex { const threads::UnitRange& data_indices, sorter_type& scratch, const search_parameters_type& search_parameters, + svs::logging::logger_ptr logger = svs::logging::get(), const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), - Pred predicate = lib::Returns(lib::Const()), - svs::logging::logger_ptr logger = svs::logging::get() + Pred predicate = lib::Returns(lib::Const()) ) { // Process all queries. threads::parallel_for( @@ -400,9 +400,9 @@ class FlatIndex { threads::UnitRange(query_indices), scratch, distances, + logger, cancel, - predicate, - logger + predicate ); } ); @@ -423,9 +423,9 @@ class FlatIndex { const threads::UnitRange& query_indices, sorter_type& scratch, distance::BroadcastDistance& distance_functors, + logging::logger_ptr SVS_UNUSED(logger) = svs::logging::get(), const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), - Pred predicate = lib::Returns(lib::Const()), - logging::logger_ptr SVS_UNUSED(logger) = svs::logging::get() + Pred predicate = lib::Returns(lib::Const()) ) { assert(distance_functors.size() >= query_indices.size()); auto accessor = extensions::accessor(data_); diff --git a/include/svs/index/index.h b/include/svs/index/index.h index b57bb6ff..a3037a43 100644 --- a/include/svs/index/index.h +++ b/include/svs/index/index.h @@ -48,12 +48,12 @@ void search_batch_into_with( svs::QueryResultView result, const Queries& queries, const search_parameters_t& search_parameters, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), - svs::logging::logger_ptr SVS_UNUSED(logger) = svs::logging::get() + svs::logging::logger_ptr logger = svs::logging::get(), + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) ) { // Assert pre-conditions. assert(result.n_queries() == queries.size()); - index.search(result, queries, search_parameters, cancel); + index.search(result, queries, search_parameters, logger, cancel); } // Apply default search parameters @@ -62,11 +62,11 @@ void search_batch_into( Index& index, svs::QueryResultView result, const Queries& queries, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), - svs::logging::logger_ptr logger = svs::logging::get() + svs::logging::logger_ptr logger = svs::logging::get(), + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) ) { svs::index::search_batch_into_with( - index, result, queries, index.get_search_parameters(), cancel, logger + index, result, queries, index.get_search_parameters(), logger, cancel ); } @@ -77,12 +77,12 @@ svs::QueryResult search_batch_with( const Queries& queries, size_t num_neighbors, const search_parameters_t& search_parameters, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), - svs::logging::logger_ptr logger = svs::logging::get() + svs::logging::logger_ptr logger = svs::logging::get(), + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) ) { auto result = svs::QueryResult{queries.size(), num_neighbors}; svs::index::search_batch_into_with( - index, result.view(), queries, search_parameters, cancel, logger + index, result.view(), queries, search_parameters, logger, cancel ); return result; } @@ -93,11 +93,11 @@ svs::QueryResult search_batch( Index& index, const Queries& queries, size_t num_neighbors, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), - svs::logging::logger_ptr logger = svs::logging::get() + svs::logging::logger_ptr logger = svs::logging::get(), + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) ) { return svs::index::search_batch_with( - index, queries, num_neighbors, index.get_search_parameters(), cancel, logger + index, queries, num_neighbors, index.get_search_parameters(), logger, cancel ); } } // namespace svs::index diff --git a/include/svs/index/inverted/memory_based.h b/include/svs/index/inverted/memory_based.h index b0a2c363..de358c69 100644 --- a/include/svs/index/inverted/memory_based.h +++ b/include/svs/index/inverted/memory_based.h @@ -408,8 +408,8 @@ template class InvertedIndex { QueryResultView results, const Queries& queries, const search_parameters_type& search_parameters, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), - svs::logging::logger_ptr SVS_UNUSED(logger) = svs::logging::get() + svs::logging::logger_ptr logger = svs::logging::get(), + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) ) { threads::parallel_for( threadpool_, @@ -432,7 +432,7 @@ template class InvertedIndex { auto&& query = queries.get_datum(i); // Primary Index Search - index_.search(query, scratch, cancel); + index_.search(query, scratch, logger, cancel); auto& d = scratch.scratch; auto compare = distance::comparator(d); diff --git a/include/svs/index/vamana/dynamic_index.h b/include/svs/index/vamana/dynamic_index.h index 6e79af4e..91e9be02 100644 --- a/include/svs/index/vamana/dynamic_index.h +++ b/include/svs/index/vamana/dynamic_index.h @@ -464,8 +464,8 @@ class MutableVamanaIndex { vamana::EntryPointInitializer{lib::as_const_span(entry_point_)}, internal_search_builder(), prefetch_parameters, - cancel, - logger + logger, + cancel ); // Take a pass over the search buffer to remove any deleted elements that // might remain. @@ -479,14 +479,15 @@ class MutableVamanaIndex { const Query& query, scratchspace_type& scratch, const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), - svs::logging::logger_ptr SVS_UNUSED(logger) = svs::logging::get() + svs::logging::logger_ptr logger = svs::logging::get() ) const { extensions::single_search( data_, scratch.buffer, scratch.scratch, query, - greedy_search_closure(scratch.prefetch_parameters, cancel) + greedy_search_closure(scratch.prefetch_parameters, cancel), + logger ); } @@ -495,8 +496,8 @@ class MutableVamanaIndex { QueryResultView results, const Queries& queries, const search_parameters_type& sp, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), - svs::logging::logger_ptr SVS_UNUSED(logger) = svs::logging::get() + svs::logging::logger_ptr logger = svs::logging::get(), + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) ) { threads::parallel_for( threadpool_, @@ -523,6 +524,7 @@ class MutableVamanaIndex { results, threads::UnitRange{is}, greedy_search_closure(prefetch_parameters, cancel), + logger, cancel ); } diff --git a/include/svs/index/vamana/extensions.h b/include/svs/index/vamana/extensions.h index 0b4631d9..3c781106 100644 --- a/include/svs/index/vamana/extensions.h +++ b/include/svs/index/vamana/extensions.h @@ -418,10 +418,10 @@ struct VamanaSingleSearchType { Scratch& scratch, const Query& query, const Search& search, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), - svs::logging::logger_ptr logger = svs::logging::get() + svs::logging::logger_ptr logger = svs::logging::get(), + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) ) const { - svs::svs_invoke(*this, data, search_buffer, scratch, query, search, cancel, logger); + svs::svs_invoke(*this, data, search_buffer, scratch, query, search, logger, cancel); } }; @@ -444,8 +444,8 @@ SVS_FORCE_INLINE void svs_invoke( Distance& distance, const Query& query, const Search& search, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), - svs::logging::logger_ptr logger = svs::logging::get() + svs::logging::logger_ptr logger = svs::logging::get(), + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) ) { // Check if request to cancel the search if (cancel()) { @@ -500,8 +500,8 @@ struct VamanaPerThreadBatchSearchType { QueryResultView& result, threads::UnitRange thread_indices, const Search& search, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), - svs::logging::logger_ptr logger = svs::logging::get() + svs::logging::logger_ptr logger = svs::logging::get(), + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) ) const { svs::svs_invoke( *this, @@ -512,8 +512,8 @@ struct VamanaPerThreadBatchSearchType { result, thread_indices, search, - cancel, - logger + logger, + cancel ); } }; @@ -538,8 +538,8 @@ void svs_invoke( QueryResultView& result, threads::UnitRange thread_indices, const Search& search, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), - svs::logging::logger_ptr logger = svs::logging::get() + svs::logging::logger_ptr logger = svs::logging::get(), + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) ) { // Fallback implementation size_t num_neighbors = result.n_neighbors(); @@ -550,7 +550,7 @@ void svs_invoke( } // Perform search - results will be queued in the search buffer. single_search( - dataset, search_buffer, distance, queries.get_datum(i), search, cancel, logger + dataset, search_buffer, distance, queries.get_datum(i), search, logger, cancel ); // Copy back results. diff --git a/include/svs/index/vamana/greedy_search.h b/include/svs/index/vamana/greedy_search.h index ce8c0174..bcc817f9 100644 --- a/include/svs/index/vamana/greedy_search.h +++ b/include/svs/index/vamana/greedy_search.h @@ -133,8 +133,8 @@ void greedy_search( const Builder& builder, Tracker& search_tracker, GreedySearchPrefetchParameters prefetch_parameters = {}, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), - logging::logger_ptr SVS_UNUSED(logger) = svs::logging::get() + logging::logger_ptr SVS_UNUSED(logger) = svs::logging::get(), + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) ) { using I = typename Graph::index_type; @@ -225,8 +225,8 @@ void greedy_search( const Initializer& initializer, const Builder& builder = NeighborBuilder(), GreedySearchPrefetchParameters prefetch_parameters = {}, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), - svs::logging::logger_ptr logger = svs::logging::get() + svs::logging::logger_ptr logger = svs::logging::get(), + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) ) { auto null_tracker = NullTracker{}; greedy_search( @@ -240,8 +240,8 @@ void greedy_search( builder, null_tracker, prefetch_parameters, - cancel, - logger + logger, + cancel ); } } // namespace svs::index::vamana diff --git a/include/svs/index/vamana/index.h b/include/svs/index/vamana/index.h index 5774aed7..9ec7201b 100644 --- a/include/svs/index/vamana/index.h +++ b/include/svs/index/vamana/index.h @@ -483,8 +483,8 @@ class VamanaIndex { vamana::EntryPointInitializer{lib::as_const_span(entry_point_)}, internal_search_builder(), prefetch_parameters, - cancel, - logger + logger, + cancel ); }; } @@ -507,15 +507,16 @@ class VamanaIndex { void search( const Query& query, scratchspace_type& scratch, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), - svs::logging::logger_ptr SVS_UNUSED(logger) = svs::logging::get() + svs::logging::logger_ptr logger = svs::logging::get(), + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) ) const { extensions::single_search( data_, scratch.buffer, scratch.scratch, query, - greedy_search_closure(scratch.prefetch_parameters, cancel) + greedy_search_closure(scratch.prefetch_parameters, cancel), + logger ); } @@ -560,8 +561,8 @@ class VamanaIndex { QueryResultView result, const Queries& queries, const search_parameters_type& search_parameters, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), - svs::logging::logger_ptr SVS_UNUSED(logger) = svs::logging::get() + svs::logging::logger_ptr logger = svs::logging::get(), + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) ) { threads::parallel_for( threadpool_, @@ -598,6 +599,7 @@ class VamanaIndex { result, threads::UnitRange{is}, greedy_search_closure(prefetch_parameters, cancel), + logger, cancel ); } @@ -843,9 +845,7 @@ class VamanaIndex { auto results = svs::QueryResult{queries.size(), num_neighbors}; auto do_search = [&](const search_parameters_type& p) { - this->search( - results.view(), queries, p, lib::Returns(lib::Const()), logger - ); + this->search(results.view(), queries, p, logger); }; auto compute_recall = [&](const search_parameters_type& p) { diff --git a/include/svs/index/vamana/iterator.h b/include/svs/index/vamana/iterator.h index bc417124..8e6d9507 100644 --- a/include/svs/index/vamana/iterator.h +++ b/include/svs/index/vamana/iterator.h @@ -290,8 +290,8 @@ template class BatchIterator { RestartInitializer{entry_points, restart_search_copy}, parent_->internal_search_builder(), scratchspace_.prefetch_parameters, - cancel, - logger + logger, + cancel ); if constexpr (Index::needs_id_translation) { diff --git a/include/svs/index/vamana/vamana_build.h b/include/svs/index/vamana/vamana_build.h index cbaeb5b3..b7a076b0 100644 --- a/include/svs/index/vamana/vamana_build.h +++ b/include/svs/index/vamana/vamana_build.h @@ -371,7 +371,6 @@ class VamanaBuilder { NeighborBuilder(), tracker, prefetch_hint_, - lib::Returns(lib::Const()), logger ); } diff --git a/include/svs/orchestrators/manager.h b/include/svs/orchestrators/manager.h index 54c8f2ba..04643b5b 100644 --- a/include/svs/orchestrators/manager.h +++ b/include/svs/orchestrators/manager.h @@ -80,8 +80,8 @@ template class ManagerInterface : public IFace { svs::QueryResultView results, AnonymousArray<2> data, const search_parameters_type& search_parameters, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), - svs::logging::logger_ptr logger = svs::logging::get() + svs::logging::logger_ptr logger = svs::logging::get(), + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) ) = 0; // Data Interface @@ -144,8 +144,8 @@ class ManagerImpl : public ManagerInterface { QueryResultView result, AnonymousArray<2> data, const search_parameters_type& search_parameters, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), - svs::logging::logger_ptr logger = svs::logging::get() + svs::logging::logger_ptr logger = svs::logging::get(), + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) ) override { // See if we have a specialization for this particular query type. // If so, invoke that specialization, otherwise throw @@ -155,7 +155,7 @@ class ManagerImpl : public ManagerInterface { [&](lib::Type SVS_UNUSED(type)) { const auto view = data::ConstSimpleDataView(data); svs::index::search_batch_into_with( - implementation_, result, view, search_parameters, cancel, logger + implementation_, result, view, search_parameters, logger, cancel ); }, [&](svs::DataType data_type) { @@ -221,11 +221,11 @@ template class IndexManager { QueryResultView result, data::ConstSimpleDataView queries, const search_parameters_type& search_parameters, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), - svs::logging::logger_ptr logger = svs::logging::get() + svs::logging::logger_ptr logger = svs::logging::get(), + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) ) { impl_->search( - result, AnonymousArray<2>(queries), search_parameters, cancel, logger + result, AnonymousArray<2>(queries), search_parameters, logger, cancel ); } @@ -236,11 +236,11 @@ template class IndexManager { QueryResult search( const Queries& queries, size_t num_neighbors, - const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()), - svs::logging::logger_ptr logger = svs::logging::get() + svs::logging::logger_ptr logger = svs::logging::get(), + const lib::DefaultPredicate& cancel = lib::Returns(lib::Const()) ) { return svs::index::search_batch( - *this, queries.cview(), num_neighbors, cancel, logger + *this, queries.cview(), num_neighbors, logger, cancel ); } diff --git a/tests/integration/cancel.cpp b/tests/integration/cancel.cpp index 18a0c09f..4e98af1f 100644 --- a/tests/integration/cancel.cpp +++ b/tests/integration/cancel.cpp @@ -66,7 +66,9 @@ CATCH_TEST_CASE("Cancel", "[integration][cancel]") { auto these_groundtruth = test_dataset::get_test_set(groundtruth, expected.num_queries_); index.set_search_parameters(sp); - auto results = index.search(these_queries, expected.num_neighbors_, timeout); + auto results = index.search( + these_queries, expected.num_neighbors_, svs::logging::get(), timeout + ); auto recall = svs::k_recall_at_n( these_groundtruth, results, expected.num_neighbors_, expected.recall_k_ ); @@ -102,7 +104,8 @@ CATCH_TEST_CASE("Cancel", "[integration][cancel]") { auto groundtruth = test_dataset::get_test_set(groundtruth_all, queries_in_test_set); index.set_search_parameters(expected.search_parameters_); index.set_threadpool(svs::threads::DefaultThreadPool(num_threads)); - auto results = index.search(queries, expected.num_neighbors_, timeout); + auto results = + index.search(queries, expected.num_neighbors_, svs::logging::get(), timeout); auto recall = svs::k_recall_at_n( groundtruth, results, expected.num_neighbors_, expected.recall_k_ ); @@ -123,7 +126,9 @@ CATCH_TEST_CASE("Cancel", "[integration][cancel]") { auto timeout = [&]() { return ++counter >= 2; }; auto index = svs::index::flat::FlatIndex(std::move(data), svs::distance::DistanceL2{}, 1); - svs::index::search_batch_into(index, result.view(), queries.cview(), timeout); + svs::index::search_batch_into( + index, result.view(), queries.cview(), svs::logging::get(), timeout + ); // recall should be very bad due to timeout CATCH_REQUIRE(svs::k_recall_at_n(groundtruth, result) < 0.5); @@ -138,7 +143,9 @@ CATCH_TEST_CASE("Cancel", "[integration][cancel]") { svs::Flat index = svs::Flat::assemble>( svs::VectorDataLoader(test_dataset::data_svs_file()), svs::L2, 2 ); - svs::index::search_batch_into(index, result.view(), queries.cview(), timeout); + svs::index::search_batch_into( + index, result.view(), queries.cview(), svs::logging::get(), timeout + ); // recall should be very bad due to timeout CATCH_REQUIRE(svs::k_recall_at_n(groundtruth, result) < 0.5); diff --git a/tests/integration/exhaustive.cpp b/tests/integration/exhaustive.cpp index d461cc63..4d297850 100644 --- a/tests/integration/exhaustive.cpp +++ b/tests/integration/exhaustive.cpp @@ -21,6 +21,7 @@ #include "catch2/catch_test_macros.hpp" // svs +#include "spdlog/sinks/callback_sink.h" #include "svs/core/distance.h" #include "svs/core/recall.h" #include "svs/index/flat/flat.h" @@ -43,7 +44,11 @@ inline constexpr bool is_flat_index_v> = tr // In this test, we predicate out the even indices and only return odd indices. // The test checks that no even indices occur in the result. template -void test_predicate(Index& index, const Queries& queries) { +void test_predicate( + Index& index, + const Queries& queries, + svs::logging::logger_ptr logger = svs::logging::get() +) { const size_t num_neighbors = 10; auto result = svs::QueryResult(queries.size(), num_neighbors); @@ -55,6 +60,7 @@ void test_predicate(Index& index, const Queries& queries) { result.view(), queries.cview(), index.get_search_parameters(), + logger, []() { return false; }, predicate ); @@ -75,7 +81,8 @@ void test_flat( Index& index, const Queries& queries, const GroundTruth& groundtruth, - svs::DistanceType distance_type + svs::DistanceType distance_type, + svs::logging::logger_ptr logger = svs::logging::get() ) { // Test get distance auto dataset = svs::load_data(test_dataset::data_svs_file()); @@ -129,7 +136,7 @@ void test_flat( // Test predicated search. if constexpr (is_flat_index_v) { - test_predicate(index, queries); + test_predicate(index, queries, logger); } } } // namespace @@ -140,6 +147,28 @@ void test_flat( // Test the single-threaded implementation. CATCH_TEST_CASE("Flat Index Search", "[integration][exhaustive][index]") { + // Set up log + std::vector captured_logs; + auto callback_sink = std::make_shared( + [&captured_logs](const spdlog::details::log_msg& msg) { + captured_logs.emplace_back(msg.payload.data(), msg.payload.size()); + } + ); + callback_sink->set_level(spdlog::level::trace); + auto test_logger = std::make_shared("test_logger", callback_sink); + test_logger->set_level(spdlog::level::trace); + + std::vector global_captured_logs; + auto global_callback_sink = std::make_shared( + [&global_captured_logs](const spdlog::details::log_msg& msg) { + global_captured_logs.emplace_back(msg.payload.data(), msg.payload.size()); + } + ); + global_callback_sink->set_level(spdlog::level::trace); + auto original_logger = + std::make_shared("original_logger", global_callback_sink); + original_logger->set_level(spdlog::level::trace); + svs::logging::set(original_logger); auto queries = test_dataset::queries(); auto data = svs::load_data(test_dataset::data_svs_file()); diff --git a/tests/integration/vamana/index_search.cpp b/tests/integration/vamana/index_search.cpp index cae36dae..8013eac2 100644 --- a/tests/integration/vamana/index_search.cpp +++ b/tests/integration/vamana/index_search.cpp @@ -132,6 +132,7 @@ void run_tests( const svs::data::SimpleData& groundtruth_all, const std::vector& expected_results, svs::logging::logger_ptr logger, + std::vector& global_captured_logs, bool test_calibration = false ) { // If we make a change that somehow improves accuracy, we'll want to know. @@ -184,7 +185,8 @@ void run_tests( for (auto num_threads : {1, 2}) { index.set_threadpool(Pool(num_threads)); // Float32 - auto results = index.search(queries, expected.num_neighbors_); + auto results = index.search(queries, expected.num_neighbors_, logger); + CATCH_REQUIRE(global_captured_logs.empty()); auto recall = svs::k_recall_at_n( groundtruth, results, expected.num_neighbors_, expected.recall_k_ ); @@ -194,7 +196,8 @@ void run_tests( // Test Float16 results, but only on the first iteration. // Otherwise, skip it to keep run times down. if (first) { - results = index.search(queries_f16, expected.num_neighbors_); + results = index.search(queries_f16, expected.num_neighbors_, logger); + CATCH_REQUIRE(global_captured_logs.empty()); recall = svs::k_recall_at_n( groundtruth, results, expected.num_neighbors_, expected.recall_k_ ); @@ -203,7 +206,9 @@ void run_tests( first = false; } } + CATCH_REQUIRE(global_captured_logs.empty()); } + CATCH_REQUIRE(global_captured_logs.empty()); // Make sure calibration works. if (!test_calibration) { @@ -226,17 +231,20 @@ void run_tests( index.experimental_calibrate( queries, groundtruth, first_result.num_neighbors_, first_result.recall_, c, logger ); + CATCH_REQUIRE(global_captured_logs.empty()); + auto recall = svs::k_recall_at_n( groundtruth, - index.search(queries, first_result.num_neighbors_), + index.search(queries, first_result.num_neighbors_, logger), first_result.num_neighbors_, first_result.recall_k_ ); CATCH_REQUIRE(recall >= first_result.recall_); + CATCH_REQUIRE(global_captured_logs.empty()); } } // namespace -CATCH_TEST_CASE("Uncompressed Vamana Search", "[integration1][search][vamana]") { +CATCH_TEST_CASE("Uncompressed Vamana Search", "[integration][search][vamana]") { // Set up log std::vector captured_logs; auto callback_sink = std::make_shared( @@ -281,6 +289,7 @@ CATCH_TEST_CASE("Uncompressed Vamana Search", "[integration1][search][vamana]") 2, test_logger ); + CATCH_REQUIRE(global_captured_logs.empty()); CATCH_REQUIRE(index.size() == test_dataset::VECTORS_IN_DATA_SET); CATCH_REQUIRE(index.dimensions() == test_dataset::NUM_DIMENSIONS); @@ -293,14 +302,17 @@ CATCH_TEST_CASE("Uncompressed Vamana Search", "[integration1][search][vamana]") verify_reconstruction(index, original_data); first = false; } + CATCH_REQUIRE(global_captured_logs.empty()); run_tests( index, queries, groundtruth, expected_results.config_and_recall_, test_logger, + global_captured_logs, true ); + CATCH_REQUIRE(global_captured_logs.empty()); index = svs::Vamana::assemble>( test_dataset::vamana_config_file(), @@ -310,6 +322,7 @@ CATCH_TEST_CASE("Uncompressed Vamana Search", "[integration1][search][vamana]") svs::threads::CppAsyncThreadPool(2), test_logger ); + CATCH_REQUIRE(global_captured_logs.empty()); run_tests( index, @@ -317,8 +330,10 @@ CATCH_TEST_CASE("Uncompressed Vamana Search", "[integration1][search][vamana]") groundtruth, expected_results.config_and_recall_, test_logger, + global_captured_logs, true ); + CATCH_REQUIRE(global_captured_logs.empty()); index = svs::Vamana::assemble>( test_dataset::vamana_config_file(), @@ -328,6 +343,7 @@ CATCH_TEST_CASE("Uncompressed Vamana Search", "[integration1][search][vamana]") svs::threads::QueueThreadPoolWrapper(2), test_logger ); + CATCH_REQUIRE(global_captured_logs.empty()); run_tests( index, @@ -335,8 +351,11 @@ CATCH_TEST_CASE("Uncompressed Vamana Search", "[integration1][search][vamana]") groundtruth, expected_results.config_and_recall_, test_logger, + global_captured_logs, true ); + CATCH_REQUIRE(global_captured_logs.empty()); + // Save and reload. svs_test::prepare_temp_directory(); @@ -382,12 +401,22 @@ CATCH_TEST_CASE("Uncompressed Vamana Search", "[integration1][search][vamana]") threadpool.resize(2); CATCH_REQUIRE(index.get_num_threads() == 2); run_tests( - index, queries, groundtruth, expected_results.config_and_recall_, test_logger + index, + queries, + groundtruth, + expected_results.config_and_recall_, + test_logger, + global_captured_logs ); index.set_threadpool(svs::threads::SwitchNativeThreadPool(2)); run_tests( - index, queries, groundtruth, expected_results.config_and_recall_, test_logger + index, + queries, + groundtruth, + expected_results.config_and_recall_, + test_logger, + global_captured_logs ); } CATCH_REQUIRE(global_captured_logs.empty()); diff --git a/tests/svs/index/index.cpp b/tests/svs/index/index.cpp index 76dc4f78..79d1fe9d 100644 --- a/tests/svs/index/index.cpp +++ b/tests/svs/index/index.cpp @@ -47,9 +47,9 @@ struct TestIndex { svs::QueryResultView result, svs::data::ConstSimpleDataView queries, SearchParameters p, + svs::logging::logger_ptr SVS_UNUSED(logger) = svs::logging::get(), const svs::lib::DefaultPredicate& cancel = - svs::lib::Returns(svs::lib::Const()), - svs::logging::logger_ptr SVS_UNUSED(logger) = svs::logging::get() + svs::lib::Returns(svs::lib::Const()) ) const { CATCH_REQUIRE(result.n_neighbors() == expected_num_neighbors_); CATCH_REQUIRE(result.n_queries() == expected_num_queries_); From 6e123bbaf826b946d675af257841a8fb512e5c14 Mon Sep 17 00:00:00 2001 From: GitHub Actions <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 23 May 2025 16:26:45 -0700 Subject: [PATCH 4/8] fix: remove unused test_logger --- tests/integration/exhaustive.cpp | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/tests/integration/exhaustive.cpp b/tests/integration/exhaustive.cpp index 4d297850..d4009299 100644 --- a/tests/integration/exhaustive.cpp +++ b/tests/integration/exhaustive.cpp @@ -147,28 +147,6 @@ void test_flat( // Test the single-threaded implementation. CATCH_TEST_CASE("Flat Index Search", "[integration][exhaustive][index]") { - // Set up log - std::vector captured_logs; - auto callback_sink = std::make_shared( - [&captured_logs](const spdlog::details::log_msg& msg) { - captured_logs.emplace_back(msg.payload.data(), msg.payload.size()); - } - ); - callback_sink->set_level(spdlog::level::trace); - auto test_logger = std::make_shared("test_logger", callback_sink); - test_logger->set_level(spdlog::level::trace); - - std::vector global_captured_logs; - auto global_callback_sink = std::make_shared( - [&global_captured_logs](const spdlog::details::log_msg& msg) { - global_captured_logs.emplace_back(msg.payload.data(), msg.payload.size()); - } - ); - global_callback_sink->set_level(spdlog::level::trace); - auto original_logger = - std::make_shared("original_logger", global_callback_sink); - original_logger->set_level(spdlog::level::trace); - svs::logging::set(original_logger); auto queries = test_dataset::queries(); auto data = svs::load_data(test_dataset::data_svs_file()); From 848151ed1de9552d7d4cc482f02becbaceed78a5 Mon Sep 17 00:00:00 2001 From: GitHub Actions <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 23 May 2025 16:36:47 -0700 Subject: [PATCH 5/8] fix: add consolidate --- include/svs/index/vamana/consolidate.h | 3 ++- include/svs/index/vamana/dynamic_index.h | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/include/svs/index/vamana/consolidate.h b/include/svs/index/vamana/consolidate.h index 73352e56..1a7ed8fe 100644 --- a/include/svs/index/vamana/consolidate.h +++ b/include/svs/index/vamana/consolidate.h @@ -362,7 +362,8 @@ void consolidate( size_t max_candidate_pool_size, float alpha, const Distance& distance, - Deleted&& is_deleted + Deleted&& is_deleted, + svs::logging::logger_ptr logger = svs::logging::get() ) { ConsolidationParameters params{200'000, prune_to, max_candidate_pool_size, alpha}; auto consolidator = GraphConsolidator{graph, data, threadpool, distance, params}; diff --git a/include/svs/index/vamana/dynamic_index.h b/include/svs/index/vamana/dynamic_index.h index 91e9be02..426b4fc3 100644 --- a/include/svs/index/vamana/dynamic_index.h +++ b/include/svs/index/vamana/dynamic_index.h @@ -964,7 +964,8 @@ class MutableVamanaIndex { max_candidates_, alpha_, distance_, - check_is_deleted + check_is_deleted, + logger_ ); // After consolidation - set all `Deleted` slots to `Empty`. From 12c3b42d02cfa65d15ddeda9049bb26a02a3094d Mon Sep 17 00:00:00 2001 From: GitHub Actions <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 23 May 2025 16:40:54 -0700 Subject: [PATCH 6/8] fix: consolidator --- include/svs/index/vamana/consolidate.h | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/include/svs/index/vamana/consolidate.h b/include/svs/index/vamana/consolidate.h index 1a7ed8fe..ff15ef75 100644 --- a/include/svs/index/vamana/consolidate.h +++ b/include/svs/index/vamana/consolidate.h @@ -161,6 +161,7 @@ class GraphConsolidator { Pool& threadpool_; const Distance& distance_; ConsolidationParameters params_; + svs::logging::logger_ptr logger_; public: // Constructor @@ -169,13 +170,15 @@ class GraphConsolidator { const Data& data, Pool& threadpool, const Distance& distance, - const ConsolidationParameters& params + const ConsolidationParameters& params, + svs::logging::logger_ptr logger = svs::logging::get() ) : graph_{graph} , data_{data} , threadpool_{threadpool} , distance_{distance} - , params_{params} { + , params_{params} + , logger_{logger} { assert(graph.n_nodes() == data.size()); } @@ -366,7 +369,7 @@ void consolidate( svs::logging::logger_ptr logger = svs::logging::get() ) { ConsolidationParameters params{200'000, prune_to, max_candidate_pool_size, alpha}; - auto consolidator = GraphConsolidator{graph, data, threadpool, distance, params}; + auto consolidator = GraphConsolidator{graph, data, threadpool, distance, params, logger}; consolidator(is_deleted); } From 193f1b3474dd7f244cbc5cc4d5eb5aa5a9210233 Mon Sep 17 00:00:00 2001 From: GitHub Actions <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 23 May 2025 17:44:58 -0700 Subject: [PATCH 7/8] fix: caliborate --- include/svs/index/vamana/dynamic_index.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/include/svs/index/vamana/dynamic_index.h b/include/svs/index/vamana/dynamic_index.h index 426b4fc3..694704d1 100644 --- a/include/svs/index/vamana/dynamic_index.h +++ b/include/svs/index/vamana/dynamic_index.h @@ -1041,7 +1041,8 @@ class MutableVamanaIndex { const GroundTruth& groundtruth, size_t num_neighbors, double target_recall, - const CalibrationParameters& calibration_parameters = {} + const CalibrationParameters& calibration_parameters = {}, + svs::logging::logger_ptr logger = svs::logging::get() ) { // Preallocate the destination for search. // Further, reference the search lambda in the recall lambda. @@ -1064,7 +1065,7 @@ class MutableVamanaIndex { target_recall, compute_recall, do_search, - logger_ + logger ); set_search_parameters(p); From a6459a88f11cdcc96b3b824076539c0094d802b7 Mon Sep 17 00:00:00 2001 From: GitHub Actions <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 27 May 2025 17:15:04 -0700 Subject: [PATCH 8/8] fix: add points and delete points --- include/svs/index/vamana/consolidate.h | 3 ++- include/svs/index/vamana/dynamic_index.h | 30 +++++++++++++++------- include/svs/index/vamana/index.h | 10 ++++---- include/svs/misc/dynamic_helper.h | 16 +++++++++--- include/svs/orchestrators/dynamic_vamana.h | 29 ++++++++++++++------- tests/svs/index/vamana/dynamic_index_2.cpp | 17 +++++++----- 6 files changed, 71 insertions(+), 34 deletions(-) diff --git a/include/svs/index/vamana/consolidate.h b/include/svs/index/vamana/consolidate.h index ff15ef75..1e33b5c5 100644 --- a/include/svs/index/vamana/consolidate.h +++ b/include/svs/index/vamana/consolidate.h @@ -369,7 +369,8 @@ void consolidate( svs::logging::logger_ptr logger = svs::logging::get() ) { ConsolidationParameters params{200'000, prune_to, max_candidate_pool_size, alpha}; - auto consolidator = GraphConsolidator{graph, data, threadpool, distance, params, logger}; + auto consolidator = + GraphConsolidator{graph, data, threadpool, distance, params, logger}; consolidator(is_deleted); } diff --git a/include/svs/index/vamana/dynamic_index.h b/include/svs/index/vamana/dynamic_index.h index 694704d1..b78a4c7c 100644 --- a/include/svs/index/vamana/dynamic_index.h +++ b/include/svs/index/vamana/dynamic_index.h @@ -464,7 +464,7 @@ class MutableVamanaIndex { vamana::EntryPointInitializer{lib::as_const_span(entry_point_)}, internal_search_builder(), prefetch_parameters, - logger, + logger_ ? logger_ : logger, cancel ); // Take a pass over the search buffer to remove any deleted elements that @@ -487,7 +487,7 @@ class MutableVamanaIndex { scratch.scratch, query, greedy_search_closure(scratch.prefetch_parameters, cancel), - logger + logger_ ? logger_ : logger ); } @@ -625,7 +625,10 @@ class MutableVamanaIndex { /// template std::vector add_points( - const Points& points, const ExternalIds& external_ids, bool reuse_empty = false + const Points& points, + const ExternalIds& external_ids, + bool reuse_empty = false, + svs::logging::logger_ptr logger = svs::logging::get() ) { const size_t num_points = points.size(); const size_t num_ids = external_ids.size(); @@ -699,7 +702,9 @@ class MutableVamanaIndex { GreedySearchPrefetchParameters{sp.prefetch_lookahead_, sp.prefetch_step_}; VamanaBuilder builder{ graph_, data_, distance_, parameters, threadpool_, prefetch_parameters}; - builder.construct(alpha_, entry_point(), slots, logging::Level::Trace, logger_); + builder.construct( + alpha_, entry_point(), slots, logging::Level::Trace, logger_ ? logger_ : logger + ); // Mark all added entries as valid. for (const auto& i : slots) { status_[i] = SlotMetadata::Valid; @@ -733,16 +738,20 @@ class MutableVamanaIndex { /// Delete consolidation performs the actual removal of deleted entries from the /// graph. /// - template size_t delete_entries(const T& ids) { + template + size_t + delete_entries(const T& ids, svs::logging::logger_ptr logger = svs::logging::get()) { translator_.check_external_exist(ids.begin(), ids.end()); for (auto i : ids) { - delete_entry(translator_.get_internal(i)); + delete_entry(translator_.get_internal(i), logger_ ? logger_ : logger); } translator_.delete_external(ids); return ids.size(); } - void delete_entry(size_t i) { + void delete_entry( + size_t i, svs::logging::logger_ptr SVS_UNUSED(logger) = svs::logging::get() + ) { SlotMetadata& meta = getindex(status_, i); assert(meta == SlotMetadata::Valid); meta = SlotMetadata::Deleted; @@ -777,7 +786,10 @@ class MutableVamanaIndex { /// @param batch_size Granularity at which points are shuffled. Setting this higher can /// improve performance but requires more working memory. /// - void compact(Idx batch_size = 1'000) { + void compact( + Idx batch_size = 1'000, + svs::logging::logger_ptr SVS_UNUSED(logger) = svs::logging::get() + ) { // Step 1: Compute a prefix-sum matching each valid internal index to its new // internal index. // @@ -965,7 +977,7 @@ class MutableVamanaIndex { alpha_, distance_, check_is_deleted, - logger_ + logger_ ? logger_ : logger_ ); // After consolidation - set all `Deleted` slots to `Empty`. diff --git a/include/svs/index/vamana/index.h b/include/svs/index/vamana/index.h index 9ec7201b..4812e1ce 100644 --- a/include/svs/index/vamana/index.h +++ b/include/svs/index/vamana/index.h @@ -483,7 +483,7 @@ class VamanaIndex { vamana::EntryPointInitializer{lib::as_const_span(entry_point_)}, internal_search_builder(), prefetch_parameters, - logger, + logger_ ? logger_ : logger, cancel ); }; @@ -516,7 +516,7 @@ class VamanaIndex { scratch.scratch, query, greedy_search_closure(scratch.prefetch_parameters, cancel), - logger + logger_ ? logger_ : logger ); } @@ -599,7 +599,7 @@ class VamanaIndex { result, threads::UnitRange{is}, greedy_search_closure(prefetch_parameters, cancel), - logger, + logger_ ? logger_ : logger, cancel ); } @@ -845,7 +845,7 @@ class VamanaIndex { auto results = svs::QueryResult{queries.size(), num_neighbors}; auto do_search = [&](const search_parameters_type& p) { - this->search(results.view(), queries, p, logger); + this->search(results.view(), queries, p, logger_ ? logger_ : logger); }; auto compute_recall = [&](const search_parameters_type& p) { @@ -861,7 +861,7 @@ class VamanaIndex { target_recall, compute_recall, do_search, - logger + logger_ ? logger_ : logger ); set_search_parameters(p); return p; diff --git a/include/svs/misc/dynamic_helper.h b/include/svs/misc/dynamic_helper.h index 7e541529..d8da4383 100644 --- a/include/svs/misc/dynamic_helper.h +++ b/include/svs/misc/dynamic_helper.h @@ -371,11 +371,15 @@ template class Referenc /// @returns The number of points added and the time spend adding those points. /// template - std::pair add_points(MutableIndex& index, size_t num_points) { + std::pair add_points( + MutableIndex& index, + size_t num_points, + svs::logging::logger_ptr logger = svs::logging::get() + ) { auto [vectors, indices] = generate(num_points); // Add the points to the index. auto tic = lib::now(); - index.add_points(vectors, indices); + index.add_points(vectors, indices, false, logger); double time = lib::time_difference(tic); return std::make_pair(indices.size(), time); } @@ -412,10 +416,14 @@ template class Referenc } template - std::pair delete_points(MutableIndex& index, size_t num_points) { + std::pair delete_points( + MutableIndex& index, + size_t num_points, + svs::logging::logger_ptr logger = svs::logging::get() + ) { auto points = get_delete_points(num_points); auto tic = svs::lib::now(); - index.delete_entries(points); + index.delete_entries(points, logger); double time = svs::lib::time_difference(tic); return std::make_pair(num_points, time); } diff --git a/include/svs/orchestrators/dynamic_vamana.h b/include/svs/orchestrators/dynamic_vamana.h index a0d72578..cde7a552 100644 --- a/include/svs/orchestrators/dynamic_vamana.h +++ b/include/svs/orchestrators/dynamic_vamana.h @@ -37,10 +37,13 @@ class DynamicVamanaInterface : public VamanaInterface { size_t dim0, size_t dim1, std::span ids, - bool reuse_empty = false + bool reuse_empty = false, + svs::logging::logger_ptr logger = svs::logging::get() ) = 0; - virtual void delete_points(std::span ids) = 0; + virtual void delete_points( + std::span ids, svs::logging::logger_ptr logger = svs::logging::get() + ) = 0; virtual void consolidate() = 0; virtual void compact(size_t batchsize = 1'000'000) = 0; @@ -71,13 +74,18 @@ class DynamicVamanaImpl : public VamanaImpl ids, - bool reuse_empty = false + bool reuse_empty = false, + svs::logging::logger_ptr logger = svs::logging::get() ) override { auto points = data::ConstSimpleDataView(data, dim0, dim1); - impl().add_points(points, ids, reuse_empty); + impl().add_points(points, ids, reuse_empty, logger); } - void delete_points(std::span ids) override { impl().delete_entries(ids); } + void delete_points( + std::span ids, svs::logging::logger_ptr logger = svs::logging::get() + ) override { + impl().delete_entries(ids, logger); + } void consolidate() override { impl().consolidate(); } void compact(size_t batchsize) override { impl().compact(batchsize); } @@ -174,16 +182,19 @@ class DynamicVamana : public manager::IndexManager { DynamicVamana& add_points( data::ConstSimpleDataView points, std::span ids, - bool reuse_empty = false + bool reuse_empty = false, + svs::logging::logger_ptr logger = svs::logging::get() ) { impl_->add_points( - points.data(), points.size(), points.dimensions(), ids, reuse_empty + points.data(), points.size(), points.dimensions(), ids, reuse_empty, logger ); return *this; } - DynamicVamana& delete_points(std::span ids) { - impl_->delete_points(ids); + DynamicVamana& delete_points( + std::span ids, svs::logging::logger_ptr logger = svs::logging::get() + ) { + impl_->delete_points(ids, logger); return *this; } diff --git a/tests/svs/index/vamana/dynamic_index_2.cpp b/tests/svs/index/vamana/dynamic_index_2.cpp index da509a10..07ae6550 100644 --- a/tests/svs/index/vamana/dynamic_index_2.cpp +++ b/tests/svs/index/vamana/dynamic_index_2.cpp @@ -202,13 +202,14 @@ void test_loop( const Queries& queries, size_t num_points, size_t consolidate_every, - size_t iterations + size_t iterations, + svs::logging::logger_ptr logger ) { size_t consolidate_count = 0; for (size_t i = 0; i < iterations; ++i) { // Add Points { - auto [points, time] = reference.add_points(index, num_points); + auto [points, time] = reference.add_points(index, num_points, logger); CATCH_REQUIRE(points <= num_points); CATCH_REQUIRE(points > num_points - reference.bucket_size()); index.debug_check_invariants(true); @@ -217,7 +218,7 @@ void test_loop( // Delete Points { - auto [points, time] = reference.delete_points(index, num_points); + auto [points, time] = reference.delete_points(index, num_points, logger); CATCH_REQUIRE(points <= num_points); CATCH_REQUIRE(points > num_points - reference.bucket_size()); index.debug_check_invariants(true); @@ -278,8 +279,10 @@ CATCH_TEST_CASE("Testing Graph Index", "[graph_index][dynamic_index]") { } ); global_callback_sink->set_level(spdlog::level::trace); - auto original_logger = svs::logging::get(); - original_logger->sinks().push_back(global_callback_sink); + auto original_logger = + std::make_shared("original_logger", global_callback_sink); + original_logger->set_level(spdlog::level::trace); + svs::logging::set(original_logger); // Load the base dataset and queries. auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); @@ -386,7 +389,9 @@ CATCH_TEST_CASE("Testing Graph Index", "[graph_index][dynamic_index]") { true ); - test_loop(index, reference, queries, div(reference.size(), modify_fraction), 2, 6); + test_loop( + index, reference, queries, div(reference.size(), modify_fraction), 2, 6, test_logger + ); // Try saving the index. svs_test::prepare_temp_directory();