Skip to content

Commit 8c96d9c

Browse files
committed
[Fix] make find_and_update standalone & codes format
1 parent ec37b27 commit 8c96d9c

File tree

8 files changed

+192
-70
lines changed

8 files changed

+192
-70
lines changed

include/merlin/allocator.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ class BaseAllocator {
5959

6060
class DefaultAllocator : public virtual BaseAllocator {
6161
public:
62-
DefaultAllocator() {};
63-
~DefaultAllocator() override {};
62+
DefaultAllocator(){};
63+
~DefaultAllocator() override{};
6464

6565
void alloc(const MemoryType type, void** ptr, size_t size,
6666
unsigned int pinned_flags = cudaHostAllocDefault) override {

include/merlin_hashtable.cuh

Lines changed: 88 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -645,16 +645,42 @@ class HashTableBase {
645645
* @endparblock
646646
* @param stream The CUDA stream that is used to execute the operation.
647647
* @param unique_key If all keys in the same batch are unique.
648-
* @param update_score If true then update the found keys in the table, and
649-
* will use scores as input.
650648
*
651649
*/
652650
virtual void find(const size_type n, const key_type* keys, // (n)
653651
value_type** values, // (n)
654652
bool* founds, // (n)
655653
score_type* scores = nullptr, // (n)
656-
cudaStream_t stream = 0, bool unique_key = true,
657-
bool update_score = false) = 0;
654+
cudaStream_t stream = 0, bool unique_key = true) const = 0;
655+
656+
/**
657+
* @brief Searches the hash table for the specified keys and returns address
658+
* of the values, and will update the scores.
659+
*
660+
* @note When a key is missing, the data in @p values won't change.
661+
* @warning This API returns internal addresses for high-performance but
662+
* thread-unsafe. The caller is responsible for guaranteeing data consistency.
663+
*
664+
* @param n The number of key-value-score tuples to search.
665+
* @param keys The keys to search on GPU-accessible memory with shape (n).
666+
* @param values The addresses of values to search on GPU-accessible memory
667+
* with shape (n).
668+
* @param founds The status that indicates if the keys are found on
669+
* GPU-accessible memory with shape (n).
670+
* @param scores The scores to search on GPU-accessible memory with shape (n).
671+
* @parblock
672+
* If @p scores is `nullptr`, the score for each key will not be returned.
673+
* @endparblock
674+
* @param stream The CUDA stream that is used to execute the operation.
675+
* @param unique_key If all keys in the same batch are unique.
676+
*
677+
*/
678+
virtual void find_and_update(const size_type n, const key_type* keys, // (n)
679+
value_type** values, // (n)
680+
bool* founds, // (n)
681+
score_type* scores = nullptr, // (n)
682+
cudaStream_t stream = 0,
683+
bool unique_key = true) = 0;
658684

659685
/**
660686
* @brief Checks if there are elements with key equivalent to `keys` in the
@@ -2559,16 +2585,13 @@ class HashTable : public HashTableBase<K, V, S> {
25592585
* @endparblock
25602586
* @param stream The CUDA stream that is used to execute the operation.
25612587
* @param unique_key If all keys in the same batch are unique.
2562-
* @param update_score If true then update the found keys in the table, and
2563-
* will use scores as input.
25642588
*
25652589
*/
25662590
void find(const size_type n, const key_type* keys, // (n)
25672591
value_type** values, // (n)
25682592
bool* founds, // (n)
25692593
score_type* scores = nullptr, // (n)
2570-
cudaStream_t stream = 0, bool unique_key = true,
2571-
bool update_score = false) {
2594+
cudaStream_t stream = 0, bool unique_key = true) const {
25722595
if (n == 0) {
25732596
return;
25742597
}
@@ -2578,18 +2601,14 @@ class HashTable : public HashTableBase<K, V, S> {
25782601
lock_ptr = std::make_unique<read_shared_lock>(mutex_, stream);
25792602
}
25802603

2581-
if (update_score) {
2582-
check_evict_strategy(scores);
2583-
}
2584-
25852604
constexpr uint32_t MinBucketCapacityFilter = sizeof(VecD_Load) / sizeof(D);
25862605
if (unique_key && options_.max_bucket_size >= MinBucketCapacityFilter) {
25872606
constexpr uint32_t BLOCK_SIZE = 128U;
25882607
tlp_lookup_ptr_kernel_with_filter<key_type, value_type, score_type,
25892608
evict_strategy>
25902609
<<<(n + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, stream>>>(
25912610
table_->buckets, table_->buckets_num, options_.max_bucket_size,
2592-
options_.dim, keys, values, scores, founds, n, update_score,
2611+
options_.dim, keys, values, scores, founds, n, false,
25932612
global_epoch_);
25942613
} else {
25952614
using Selector = SelectLookupPtrKernel<key_type, value_type, score_type>;
@@ -2609,6 +2628,62 @@ class HashTable : public HashTableBase<K, V, S> {
26092628
CudaCheckError();
26102629
}
26112630

2631+
/**
2632+
* @brief Searches the hash table for the specified keys and returns address
2633+
* of the values, and will update the scores.
2634+
*
2635+
* @note When a key is missing, the data in @p values won't change.
2636+
* @warning This API returns internal addresses for high-performance but
2637+
* thread-unsafe. The caller is responsible for guaranteeing data consistency.
2638+
*
2639+
* @param n The number of key-value-score tuples to search.
2640+
* @param keys The keys to search on GPU-accessible memory with shape (n).
2641+
* @param values The addresses of values to search on GPU-accessible memory
2642+
* with shape (n).
2643+
* @param founds The status that indicates if the keys are found on
2644+
* GPU-accessible memory with shape (n).
2645+
* @param scores The scores to search on GPU-accessible memory with shape (n).
2646+
* @parblock
2647+
* If @p scores is `nullptr`, the score for each key will not be returned.
2648+
* @endparblock
2649+
* @param stream The CUDA stream that is used to execute the operation.
2650+
* @param unique_key If all keys in the same batch are unique.
2651+
*
2652+
*/
2653+
void find_and_update(const size_type n, const key_type* keys, // (n)
2654+
value_type** values, // (n)
2655+
bool* founds, // (n)
2656+
score_type* scores = nullptr, // (n)
2657+
cudaStream_t stream = 0, bool unique_key = true) {
2658+
if (n == 0) {
2659+
return;
2660+
}
2661+
2662+
std::unique_ptr<read_shared_lock> lock_ptr;
2663+
if (options_.api_lock) {
2664+
lock_ptr = std::make_unique<read_shared_lock>(mutex_, stream);
2665+
}
2666+
2667+
check_evict_strategy(scores);
2668+
2669+
constexpr uint32_t MinBucketCapacityFilter = sizeof(VecD_Load) / sizeof(D);
2670+
if (unique_key && options_.max_bucket_size >= MinBucketCapacityFilter) {
2671+
constexpr uint32_t BLOCK_SIZE = 128U;
2672+
tlp_lookup_ptr_kernel_with_filter<key_type, value_type, score_type,
2673+
evict_strategy>
2674+
<<<(n + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, stream>>>(
2675+
table_->buckets, table_->buckets_num, options_.max_bucket_size,
2676+
options_.dim, keys, values, scores, founds, n, true,
2677+
global_epoch_);
2678+
} else {
2679+
throw std::runtime_error(
2680+
"Not support update score when keys are not unique or bucket "
2681+
"capacity is small.");
2682+
}
2683+
2684+
CudaCheckError();
2685+
}
2686+
26122687
/**
26132688
* @brief Checks if there are elements with key equivalent to `keys` in the
26142689
* table.

tests/accum_or_assign_test.cc.cu

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -972,14 +972,16 @@ void test_evict_strategy_lru_basic(size_t max_hbm_for_vectors) {
972972

973973
CUDA_CHECK(cudaMalloc(&d_accum_or_assigns_temp, TEMP_KEY_NUM * sizeof(bool)));
974974

975-
test_util::create_random_bools<K>(reinterpret_cast<bool*>(h_accum_or_assigns_base.data()),
976-
BASE_KEY_NUM, true_ratio);
975+
test_util::create_random_bools<K>(
976+
reinterpret_cast<bool*>(h_accum_or_assigns_base.data()), BASE_KEY_NUM,
977+
true_ratio);
977978
test_util::create_keys_in_one_buckets<K, S, V, DIM>(
978979
h_keys_base.data(), h_scores_base.data(), h_vectors_base.data(),
979980
BASE_KEY_NUM, INIT_CAPACITY, BUCKET_MAX_SIZE, 1, 0, 0x3FFFFFFFFFFFFFFF);
980981

981-
test_util::create_random_bools<K>(reinterpret_cast<bool*>(h_accum_or_assigns_test.data()),
982-
TEST_KEY_NUM, true_ratio);
982+
test_util::create_random_bools<K>(
983+
reinterpret_cast<bool*>(h_accum_or_assigns_test.data()), TEST_KEY_NUM,
984+
true_ratio);
983985
test_util::create_keys_in_one_buckets<K, S, V, DIM>(
984986
h_keys_test.data(), h_scores_test.data(), h_vectors_test.data(),
985987
TEST_KEY_NUM, INIT_CAPACITY, BUCKET_MAX_SIZE, 1, 0x3FFFFFFFFFFFFFFF,
@@ -1189,11 +1191,13 @@ void test_evict_strategy_lfu_basic(size_t max_hbm_for_vectors, int key_start) {
11891191

11901192
CUDA_CHECK(cudaMalloc(&d_accum_or_assigns_temp, TEMP_KEY_NUM * sizeof(bool)));
11911193

1192-
test_util::create_random_bools<K>(reinterpret_cast<bool*>(h_accum_or_assigns_base.data()),
1193-
BASE_KEY_NUM, true_ratio);
1194+
test_util::create_random_bools<K>(
1195+
reinterpret_cast<bool*>(h_accum_or_assigns_base.data()), BASE_KEY_NUM,
1196+
true_ratio);
11941197

1195-
test_util::create_random_bools<K>(reinterpret_cast<bool*>(h_accum_or_assigns_test.data()),
1196-
TEST_KEY_NUM, true_ratio);
1198+
test_util::create_random_bools<K>(
1199+
reinterpret_cast<bool*>(h_accum_or_assigns_test.data()), TEST_KEY_NUM,
1200+
true_ratio);
11971201

11981202
for (int i = 0; i < TEST_TIMES; i++) {
11991203
test_util::create_keys_in_one_buckets_lfu<K, S, V, DIM>(
@@ -1416,14 +1420,16 @@ void test_evict_strategy_epochlru_basic(size_t max_hbm_for_vectors,
14161420

14171421
CUDA_CHECK(cudaMalloc(&d_accum_or_assigns_temp, TEMP_KEY_NUM * sizeof(bool)));
14181422

1419-
test_util::create_random_bools<K>(reinterpret_cast<bool*>(h_accum_or_assigns_base.data()),
1420-
BASE_KEY_NUM, true_ratio);
1423+
test_util::create_random_bools<K>(
1424+
reinterpret_cast<bool*>(h_accum_or_assigns_base.data()), BASE_KEY_NUM,
1425+
true_ratio);
14211426
test_util::create_keys_in_one_buckets<K, S, V, DIM>(
14221427
h_keys_base.data(), h_scores_base.data(), h_vectors_base.data(),
14231428
BASE_KEY_NUM, INIT_CAPACITY, BUCKET_MAX_SIZE, 1, 0, 0x3FFFFFFFFFFFFFFF);
14241429

1425-
test_util::create_random_bools<K>(reinterpret_cast<bool*>(h_accum_or_assigns_test.data()),
1426-
TEST_KEY_NUM, true_ratio);
1430+
test_util::create_random_bools<K>(
1431+
reinterpret_cast<bool*>(h_accum_or_assigns_test.data()), TEST_KEY_NUM,
1432+
true_ratio);
14271433
test_util::create_keys_in_one_buckets<K, S, V, DIM>(
14281434
h_keys_test.data(), h_scores_test.data(), h_vectors_test.data(),
14291435
TEST_KEY_NUM, INIT_CAPACITY, BUCKET_MAX_SIZE, 1, 0x3FFFFFFFFFFFFFFF,
@@ -1645,11 +1651,13 @@ void test_evict_strategy_epochlfu_basic(size_t max_hbm_for_vectors,
16451651

16461652
CUDA_CHECK(cudaMalloc(&d_accum_or_assigns_temp, TEMP_KEY_NUM * sizeof(bool)));
16471653

1648-
test_util::create_random_bools<K>(reinterpret_cast<bool*>(h_accum_or_assigns_base.data()),
1649-
BASE_KEY_NUM, true_ratio);
1654+
test_util::create_random_bools<K>(
1655+
reinterpret_cast<bool*>(h_accum_or_assigns_base.data()), BASE_KEY_NUM,
1656+
true_ratio);
16501657

1651-
test_util::create_random_bools<K>(reinterpret_cast<bool*>(h_accum_or_assigns_test.data()),
1652-
TEST_KEY_NUM, true_ratio);
1658+
test_util::create_random_bools<K>(
1659+
reinterpret_cast<bool*>(h_accum_or_assigns_test.data()), TEST_KEY_NUM,
1660+
true_ratio);
16531661

16541662
test_util::create_keys_in_one_buckets_lfu<K, S, V, DIM>(
16551663
h_keys_base.data(), h_scores_base.data(), h_vectors_base.data(),
@@ -1913,10 +1921,12 @@ void test_evict_strategy_customized_basic(size_t max_hbm_for_vectors,
19131921
CUDA_CHECK(cudaMalloc(&d_accum_or_assigns_temp, TEMP_KEY_NUM * sizeof(bool)));
19141922
CUDA_CHECK(cudaMalloc(&d_found_temp, TEMP_KEY_NUM * sizeof(bool)));
19151923

1916-
test_util::create_random_bools<K>(reinterpret_cast<bool*>(h_accum_or_assigns_base.data()),
1917-
BASE_KEY_NUM, true_ratio);
1918-
test_util::create_random_bools<K>(reinterpret_cast<bool*>(h_accum_or_assigns_test.data()),
1919-
TEST_KEY_NUM, true_ratio);
1924+
test_util::create_random_bools<K>(
1925+
reinterpret_cast<bool*>(h_accum_or_assigns_base.data()), BASE_KEY_NUM,
1926+
true_ratio);
1927+
test_util::create_random_bools<K>(
1928+
reinterpret_cast<bool*>(h_accum_or_assigns_test.data()), TEST_KEY_NUM,
1929+
true_ratio);
19201930

19211931
test_util::create_keys_in_one_buckets<K, S, V, DIM>(
19221932
h_keys_base.data(), h_scores_base.data(), h_vectors_base.data(),
@@ -2171,8 +2181,9 @@ void test_evict_strategy_customized_advanced(size_t max_hbm_for_vectors,
21712181
cudaMalloc(&d_vectors_temp, TEMP_KEY_NUM * sizeof(V) * options.dim));
21722182
CUDA_CHECK(cudaMalloc(&d_accum_or_assigns_temp, TEMP_KEY_NUM * sizeof(bool)));
21732183

2174-
test_util::create_random_bools<K>(reinterpret_cast<bool*>(h_accum_or_assigns_base.data()),
2175-
BASE_KEY_NUM, base_true_ratio);
2184+
test_util::create_random_bools<K>(
2185+
reinterpret_cast<bool*>(h_accum_or_assigns_base.data()), BASE_KEY_NUM,
2186+
base_true_ratio);
21762187
test_util::create_keys_in_one_buckets<K, S, V, DIM>(
21772188
h_keys_base.data(), h_scores_base.data(), h_vectors_base.data(),
21782189
BASE_KEY_NUM, INIT_CAPACITY, BUCKET_MAX_SIZE, 1, 0, 0x3FFFFFFFFFFFFFFF);
@@ -2182,8 +2193,9 @@ void test_evict_strategy_customized_advanced(size_t max_hbm_for_vectors,
21822193
h_scores_base[i] = base_score_start + i;
21832194
}
21842195

2185-
test_util::create_random_bools<K>(reinterpret_cast<bool*>(h_accum_or_assigns_test.data()),
2186-
TEST_KEY_NUM, test_true_ratio);
2196+
test_util::create_random_bools<K>(
2197+
reinterpret_cast<bool*>(h_accum_or_assigns_test.data()), TEST_KEY_NUM,
2198+
test_true_ratio);
21872199
test_util::create_keys_in_one_buckets<K, S, V, DIM>(
21882200
h_keys_test.data(), h_scores_test.data(), h_vectors_test.data(),
21892201
TEST_KEY_NUM, INIT_CAPACITY, BUCKET_MAX_SIZE, 1, 0x3FFFFFFFFFFFFFFF,

tests/assign_score_test.cc.cu

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -917,7 +917,9 @@ void test_evict_strategy_customized_basic(size_t max_hbm_for_vectors,
917917
std::sort(h_scores_temp_sorted.begin(), h_scores_temp_sorted.end());
918918

919919
auto expected_range = test_util::range<S, TEMP_KEY_NUM>(base_score_start);
920-
ASSERT_TRUE(std::equal(h_scores_temp_sorted.begin(), h_scores_temp_sorted.end(), expected_range.begin()));
920+
ASSERT_TRUE(std::equal(h_scores_temp_sorted.begin(),
921+
h_scores_temp_sorted.end(),
922+
expected_range.begin()));
921923
for (int i = 0; i < dump_counter; i++) {
922924
for (int j = 0; j < options.dim; j++) {
923925
ASSERT_EQ(h_vectors_temp[i * options.dim + j],
@@ -958,8 +960,11 @@ void test_evict_strategy_customized_basic(size_t max_hbm_for_vectors,
958960
std::vector<S> h_scores_temp_sorted(h_scores_temp);
959961
std::sort(h_scores_temp_sorted.begin(), h_scores_temp_sorted.end());
960962

961-
auto expected_range_test = test_util::range<S, TEST_KEY_NUM>(test_score_start);
962-
ASSERT_TRUE(std::equal(h_scores_temp_sorted.begin(), h_scores_temp_sorted.end(), expected_range_test.begin()));
963+
auto expected_range_test =
964+
test_util::range<S, TEST_KEY_NUM>(test_score_start);
965+
ASSERT_TRUE(std::equal(h_scores_temp_sorted.begin(),
966+
h_scores_temp_sorted.end(),
967+
expected_range_test.begin()));
963968
for (int i = 0; i < dump_counter; i++) {
964969
for (int j = 0; j < options.dim; j++) {
965970
ASSERT_EQ(h_vectors_temp[i * options.dim + j],
@@ -1104,7 +1109,9 @@ void test_evict_strategy_customized_advanced(size_t max_hbm_for_vectors,
11041109
std::sort(h_scores_temp_sorted.begin(), h_scores_temp_sorted.end());
11051110

11061111
auto expected_range = test_util::range<S, TEMP_KEY_NUM>(base_score_start);
1107-
ASSERT_TRUE(std::equal(h_scores_temp_sorted.begin(), h_scores_temp_sorted.end(), expected_range.begin()));
1112+
ASSERT_TRUE(std::equal(h_scores_temp_sorted.begin(),
1113+
h_scores_temp_sorted.end(),
1114+
expected_range.begin()));
11081115
for (int i = 0; i < dump_counter; i++) {
11091116
for (int j = 0; j < options.dim; j++) {
11101117
ASSERT_EQ(h_vectors_temp[i * options.dim + j],

tests/find_or_insert_ptr_lock_test.cc.cu

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2153,7 +2153,9 @@ void test_evict_strategy_customized_basic(size_t max_hbm_for_vectors,
21532153
std::sort(h_scores_temp_sorted.begin(), h_scores_temp_sorted.end());
21542154

21552155
auto expected_range = test_util::range<S, TEMP_KEY_NUM>(base_score_start);
2156-
ASSERT_TRUE(std::equal(h_scores_temp_sorted.begin(), h_scores_temp_sorted.end(), expected_range.begin()));
2156+
ASSERT_TRUE(std::equal(h_scores_temp_sorted.begin(),
2157+
h_scores_temp_sorted.end(),
2158+
expected_range.begin()));
21572159
for (int i = 0; i < dump_counter; i++) {
21582160
for (int j = 0; j < options.dim; j++) {
21592161
ASSERT_EQ(h_vectors_temp[i * options.dim + j],
@@ -2197,8 +2199,11 @@ void test_evict_strategy_customized_basic(size_t max_hbm_for_vectors,
21972199
std::vector<S> h_scores_temp_sorted(h_scores_temp);
21982200
std::sort(h_scores_temp_sorted.begin(), h_scores_temp_sorted.end());
21992201

2200-
auto expected_range_test = test_util::range<S, TEST_KEY_NUM>(test_score_start);
2201-
ASSERT_TRUE(std::equal(h_scores_temp_sorted.begin(), h_scores_temp_sorted.end(), expected_range_test.begin()));
2202+
auto expected_range_test =
2203+
test_util::range<S, TEST_KEY_NUM>(test_score_start);
2204+
ASSERT_TRUE(std::equal(h_scores_temp_sorted.begin(),
2205+
h_scores_temp_sorted.end(),
2206+
expected_range_test.begin()));
22022207
for (int i = 0; i < dump_counter; i++) {
22032208
for (int j = 0; j < options.dim; j++) {
22042209
ASSERT_EQ(h_vectors_temp[i * options.dim + j],
@@ -2344,7 +2349,9 @@ void test_evict_strategy_customized_advanced(size_t max_hbm_for_vectors,
23442349
std::sort(h_scores_temp_sorted.begin(), h_scores_temp_sorted.end());
23452350

23462351
auto expected_range = test_util::range<S, TEMP_KEY_NUM>(base_score_start);
2347-
ASSERT_TRUE(std::equal(h_scores_temp_sorted.begin(), h_scores_temp_sorted.end(), expected_range.begin()));
2352+
ASSERT_TRUE(std::equal(h_scores_temp_sorted.begin(),
2353+
h_scores_temp_sorted.end(),
2354+
expected_range.begin()));
23482355
for (int i = 0; i < dump_counter; i++) {
23492356
for (int j = 0; j < options.dim; j++) {
23502357
ASSERT_EQ(h_vectors_temp[i * options.dim + j],

0 commit comments

Comments
 (0)