From 537283572fa38cb7cc35fdba2fe187c18515be6d Mon Sep 17 00:00:00 2001 From: aleliu Date: Thu, 23 Oct 2025 00:53:07 -0700 Subject: [PATCH] support insert_or_assign score for lru --- include/merlin_hashtable.cuh | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/include/merlin_hashtable.cuh b/include/merlin_hashtable.cuh index fd6be65e..791224f2 100644 --- a/include/merlin_hashtable.cuh +++ b/include/merlin_hashtable.cuh @@ -1065,6 +1065,22 @@ class HashTable : public HashTableBase { const score_type* scores = nullptr, // (n) cudaStream_t stream = 0, bool unique_key = true, bool ignore_evict_strategy = false) { + if (ignore_evict_strategy) { + insert_or_assign_impl( + n, keys, values, scores, stream, unique_key, ignore_evict_strategy); + } else { + insert_or_assign_impl(n, keys, values, scores, stream, + unique_key, ignore_evict_strategy); + } + } + + template + void insert_or_assign_impl(const size_type n, + const key_type* keys, // (n) + const value_type* values, // (n, DIM) + const score_type* scores, // (n) + cudaStream_t stream, bool unique_key, + bool ignore_evict_strategy) { if (n == 0) { return; } @@ -1092,7 +1108,7 @@ class HashTable : public HashTableBase { } using Selector = KernelSelector_Upsert; + evict_strategy_, ArchTag>; if (Selector::callable(unique_key, static_cast(options_.max_bucket_size), static_cast(options_.dim))) { @@ -1105,7 +1121,7 @@ class HashTable : public HashTableBase { Selector::select_kernel(kernelParams, stream); } else { using Selector = SelectUpsertKernelWithIO; + score_type, evict_strategy_>; Selector::execute_kernel( load_factor, options_.block_size, options_.max_bucket_size, table_->buckets_num, options_.dim, stream, n, d_table_, @@ -1142,7 +1158,7 @@ class HashTable : public HashTableBase { constexpr uint32_t BLOCK_SIZE = 128; upsert_kernel_lock_key_hybrid + BLOCK_SIZE, evict_strategy_> <<<(n + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, stream>>>( table_->buckets, table_->buckets_size, table_->buckets_num, options_.max_bucket_size, options_.dim, keys, d_dst, scores, @@ -1153,7 +1169,7 @@ class HashTable : public HashTableBase { const size_t N = n * TILE_SIZE; const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size); - upsert_kernel<<>>( d_table_, table_->buckets, options_.max_bucket_size, table_->buckets_num, options_.dim, keys, d_dst, scores,