Skip to content

Commit 998ec6f

Browse files
committed
support insert_or_assign score for lru
1 parent 8c96d9c commit 998ec6f

File tree

1 file changed

+20
-4
lines changed

1 file changed

+20
-4
lines changed

include/merlin_hashtable.cuh

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,6 +1065,22 @@ class HashTable : public HashTableBase<K, V, S> {
10651065
const score_type* scores = nullptr, // (n)
10661066
cudaStream_t stream = 0, bool unique_key = true,
10671067
bool ignore_evict_strategy = false) {
1068+
if (ignore_evict_strategy) {
1069+
insert_or_assign_impl<EvictStrategy::kCustomized>(n, keys, values, scores,
1070+
stream, unique_key);
1071+
} else {
1072+
insert_or_assign_impl<evict_strategy>(n, keys, values, scores, stream,
1073+
unique_key);
1074+
}
1075+
}
1076+
1077+
template <int evict_strategy_>
1078+
void insert_or_assign_impl(const size_type n,
1079+
const key_type* keys, // (n)
1080+
const value_type* values, // (n, DIM)
1081+
const score_type* scores = nullptr, // (n)
1082+
cudaStream_t stream = 0, bool unique_key = true,
1083+
bool ignore_evict_strategy = false) {
10681084
if (n == 0) {
10691085
return;
10701086
}
@@ -1092,7 +1108,7 @@ class HashTable : public HashTableBase<K, V, S> {
10921108
}
10931109

10941110
using Selector = KernelSelector_Upsert<key_type, value_type, score_type,
1095-
evict_strategy, ArchTag>;
1111+
evict_strategy_, ArchTag>;
10961112
if (Selector::callable(unique_key,
10971113
static_cast<uint32_t>(options_.max_bucket_size),
10981114
static_cast<uint32_t>(options_.dim))) {
@@ -1105,7 +1121,7 @@ class HashTable : public HashTableBase<K, V, S> {
11051121
Selector::select_kernel(kernelParams, stream);
11061122
} else {
11071123
using Selector = SelectUpsertKernelWithIO<key_type, value_type,
1108-
score_type, evict_strategy>;
1124+
score_type, evict_strategy_>;
11091125
Selector::execute_kernel(
11101126
load_factor, options_.block_size, options_.max_bucket_size,
11111127
table_->buckets_num, options_.dim, stream, n, d_table_,
@@ -1142,7 +1158,7 @@ class HashTable : public HashTableBase<K, V, S> {
11421158
constexpr uint32_t BLOCK_SIZE = 128;
11431159

11441160
upsert_kernel_lock_key_hybrid<key_type, value_type, score_type,
1145-
BLOCK_SIZE, evict_strategy>
1161+
BLOCK_SIZE, evict_strategy_>
11461162
<<<(n + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, stream>>>(
11471163
table_->buckets, table_->buckets_size, table_->buckets_num,
11481164
options_.max_bucket_size, options_.dim, keys, d_dst, scores,
@@ -1153,7 +1169,7 @@ class HashTable : public HashTableBase<K, V, S> {
11531169
const size_t N = n * TILE_SIZE;
11541170
const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);
11551171

1156-
upsert_kernel<key_type, value_type, score_type, evict_strategy,
1172+
upsert_kernel<key_type, value_type, score_type, evict_strategy_,
11571173
TILE_SIZE><<<grid_size, block_size, 0, stream>>>(
11581174
d_table_, table_->buckets, options_.max_bucket_size,
11591175
table_->buckets_num, options_.dim, keys, d_dst, scores,

0 commit comments

Comments
 (0)