2626#include < shared_mutex>
2727#include < type_traits>
2828#include " merlin/core_kernels.cuh"
29+ #include " merlin/external_storage.cuh"
2930#include " merlin/flexible_buffer.cuh"
3031#include " merlin/memory_pool.cuh"
3132#include " merlin/types.cuh"
@@ -152,6 +153,8 @@ class HashTable {
152153 using DeviceMemoryPool = MemoryPool<DeviceAllocator<char >>;
153154 using HostMemoryPool = MemoryPool<HostAllocator<char >>;
154155
156+ using external_storage_type = ExternalStorage<K, V>;
157+
155158#if THRUST_VERSION >= 101600
156159 static constexpr auto thrust_par = thrust::cuda::par_nosync;
157160#else
@@ -169,6 +172,8 @@ class HashTable {
169172 * table object.
170173 */
171174 ~HashTable () {
175+ unlink_external_storage ();
176+
172177 if (initialized_) {
173178 CUDA_CHECK (cudaDeviceSynchronize ());
174179
@@ -299,21 +304,33 @@ class HashTable {
299304 load_factor, options_.block_size , stream, n, c_table_index_, d_table_,
300305 keys, reinterpret_cast <const value_type*>(values), metas);
301306 } else {
302- const size_type dev_ws_size{n * (sizeof (value_type*) + sizeof (int ))};
307+ const size_type dev_ws_base_size{n * (sizeof (value_type*) + sizeof (int ))};
308+ const size_type dev_ws_size{dev_ws_base_size +
309+ (ext_store_ ? n : 0 ) * sizeof (key_type)};
303310 auto dev_ws{dev_mem_pool_->get_workspace <1 >(dev_ws_size, stream)};
304311 auto d_dst{dev_ws.get <value_type**>(0 )};
305312 auto d_src_offset{reinterpret_cast <int *>(d_dst + n)};
313+ auto d_evicted_keys{reinterpret_cast <key_type*>(d_src_offset + n)};
306314
307- CUDA_CHECK (cudaMemsetAsync (d_dst, 0 , dev_ws_size, stream));
315+ CUDA_CHECK (cudaMemsetAsync (d_dst, 0 , dev_ws_base_size, stream));
316+ CUDA_CHECK (cudaMemsetAsync (d_evicted_keys, 0xFF ,
317+ dev_ws_size - dev_ws_base_size, stream));
308318
309319 {
310320 const size_t block_size = options_.block_size ;
311321 const size_t N = n * TILE_SIZE;
312322 const size_t grid_size = SAFE_GET_GRID_SIZE (N, block_size);
313323
314324 upsert_kernel<key_type, value_type, meta_type, TILE_SIZE>
315- <<<grid_size, block_size, 0 , stream>>> (d_table_, keys, d_dst, metas,
316- d_src_offset, N);
325+ <<<grid_size, block_size, 0 , stream>>> (
326+ d_table_, keys, ext_store_ ? d_evicted_keys : nullptr , d_dst,
327+ metas, d_src_offset, N);
328+ }
329+
330+ if (ext_store_) {
331+ ext_store_->insert_or_assign (
332+ *dev_mem_pool_, *host_mem_pool_, table_->is_pure_hbm , n,
333+ d_evicted_keys, reinterpret_cast <value_type**>(d_dst), stream);
317334 }
318335
319336 {
@@ -326,16 +343,17 @@ class HashTable {
326343 }
327344
328345 if (options_.io_by_cpu ) {
329- const size_type host_ws_size{dev_ws_size +
346+ const size_type host_ws_size{dev_ws_base_size +
330347 n * sizeof (value_type) * dim ()};
331348 auto host_ws{host_mem_pool_->get_workspace <1 >(host_ws_size, stream)};
332349 auto h_dst{host_ws.get <value_type**>(0 )};
333350 auto h_src_offset{reinterpret_cast <int *>(h_dst + n)};
334351 auto h_values{reinterpret_cast <value_type*>(h_src_offset + n)};
335352
336- CUDA_CHECK (cudaMemcpyAsync (h_dst, d_dst, dev_ws_size ,
353+ CUDA_CHECK (cudaMemcpyAsync (h_dst, d_dst, dev_ws_base_size ,
337354 cudaMemcpyDeviceToHost, stream));
338- CUDA_CHECK (cudaMemcpyAsync (h_values, values, host_ws_size - dev_ws_size,
355+ CUDA_CHECK (cudaMemcpyAsync (h_values, values,
356+ host_ws_size - dev_ws_base_size,
339357 cudaMemcpyDeviceToHost, stream));
340358 CUDA_CHECK (cudaStreamSynchronize (stream));
341359
@@ -547,6 +565,11 @@ class HashTable {
547565 }
548566 }
549567
568+ if (ext_store_) {
569+ ext_store_->find (*dev_mem_pool_, *host_mem_pool_, n, keys, values, founds,
570+ stream);
571+ }
572+
550573 CudaCheckError ();
551574 }
552575
@@ -576,6 +599,10 @@ class HashTable {
576599 table_->bucket_max_size , table_->buckets_num , N);
577600 }
578601
602+ if (ext_store_) {
603+ ext_store_->erase_async (*dev_mem_pool_, *host_mem_pool_, n, keys, stream);
604+ }
605+
579606 CudaCheckError ();
580607 return ;
581608 }
@@ -1097,6 +1124,21 @@ class HashTable {
10971124 return total_count;
10981125 }
10991126
1127+ void link_external_storage (
1128+ std::shared_ptr<external_storage_type>& ext_store) {
1129+ MERLIN_CHECK (
1130+ ext_store->value_dim == dim (),
1131+ " Provided external storage value dimension is not incompatible!" );
1132+
1133+ std::unique_lock<std::shared_timed_mutex> lock (mutex_);
1134+ ext_store_ = ext_store;
1135+ }
1136+
1137+ void unlink_external_storage () {
1138+ std::unique_lock<std::shared_timed_mutex> lock (mutex_);
1139+ ext_store_.reset ();
1140+ }
1141+
11001142 private:
11011143 inline bool is_fast_mode () const noexcept { return table_->is_pure_hbm ; }
11021144
@@ -1173,6 +1215,8 @@ class HashTable {
11731215 int c_table_index_ = -1 ;
11741216 std::unique_ptr<DeviceMemoryPool> dev_mem_pool_;
11751217 std::unique_ptr<HostMemoryPool> host_mem_pool_;
1218+
1219+ std::shared_ptr<external_storage_type> ext_store_;
11761220};
11771221
11781222} // namespace merlin
0 commit comments