Skip to content

Commit b7907b5

Browse files
committed
External storage API with updates from 2023-02-09 discussion with GDS.
1 parent 6101f96 commit b7907b5

File tree

3 files changed

+189
-9
lines changed

3 files changed

+189
-9
lines changed

include/merlin/core_kernels.cuh

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,8 +1029,9 @@ struct SelectUpsertKernelWithIO {
10291029
*/
10301030
template <class K, class V, class M, uint32_t TILE_SIZE = 4>
10311031
__global__ void upsert_kernel(const Table<K, V, M>* __restrict table,
1032-
const K* __restrict keys, V** __restrict vectors,
1033-
const M* __restrict metas,
1032+
const K* __restrict keys,
1033+
K* __restrict evicted_keys,
1034+
V** __restrict vectors, const M* __restrict metas,
10341035
int* __restrict src_offset, size_t N) {
10351036
Bucket<K, V, M>* buckets = table->buckets;
10361037
int* buckets_size = table->buckets_size;
@@ -1168,6 +1169,10 @@ __global__ void upsert_kernel(const Table<K, V, M>* __restrict table,
11681169
// override_result == OverrideResult::SUCCESS
11691170

11701171
if (rank == src_lane) {
1172+
if (evicted_keys) {
1173+
evicted_keys[key_idx] =
1174+
bucket->keys[key_pos].load(cuda::std::memory_order_relaxed);
1175+
}
11711176
bucket->keys[key_pos].store(insert_key,
11721177
cuda::std::memory_order_relaxed);
11731178
*(vectors + key_idx) = (bucket->vectors + key_pos * dim);
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/*
2+
* Copyright (c) 2022, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include <cstdint>
19+
#include <type_traits>
20+
#include "merlin/memory_pool.cuh"
21+
22+
namespace nv {
23+
namespace merlin {
24+
25+
template <class Key, class Value>
26+
class ExternalStorage {
27+
public:
28+
using size_type = size_t;
29+
using key_type = Key;
30+
using value_type = Value;
31+
32+
using dev_mem_pool_type = MemoryPool<DeviceAllocator<char>>;
33+
using host_mem_pool_type = MemoryPool<HostAllocator<char>>;
34+
35+
const size_type value_dim;
36+
37+
ExternalStorage() = delete;
38+
39+
/**
40+
* Constructs external storage object.
41+
*
42+
* @param value_dim The dimensionality of the values. In other words, each
43+
* value stored is exactly `value_dim * sizeof(value_type)` bytes large.
44+
*/
45+
ExternalStorage(const size_type value_dim) : value_dim{value_dim} {}
46+
47+
/**
48+
* @brief Inserts key/value pairs into the external storage that are about to
49+
* be evicted from the Merlin hashtable. If a key/value pair already exists,
50+
* overwrites the current value.
51+
*
52+
* @param dev_mem_pool Memory pool for temporarily allocating device memory.
53+
* @param host_mem_pool Memory pool for temporarily allocating host memory.
54+
* @param hkvs_is_pure_hbm True if the Merlin hashtable store is currently
55+
* operating in pure HBM mode, false otherwise. In pure HBM mode, all `values`
56+
* pointers are GUARANTEED to point to device memory.
57+
* @param n Number of key/value slots provided in other arguments.
58+
* @param d_masked_keys Device pointer to an (n)-sized array of keys.
59+
* Key-Value slots that should be ignored have the key set to `EMPTY_KEY`.
60+
* @param d_values Device pointer to an (n)-sized array containing pointers to
61+
* respectively a memory location where the current values for a key are
62+
* stored. Each pointer points to a vector of length `value_dim`. Pointers
63+
* *can* be set to `nullptr` for slots where the corresponding key equated to
64+
* the `EMPTY_KEY`. The memory locations can be device or host memory (see
65+
* also `hkvs_is_pure_hbm`).
66+
* @param stream Stream that MUST be used for queuing asynchronous CUDA
67+
* operations. If only the input arguments or resources obtained from
68+
* respectively `dev_mem_pool` and `host_mem_pool` are used for such
69+
* operations, it is not necessary to synchronize the stream prior to
70+
* returning from the function.
71+
*/
72+
virtual void insert_or_assign(dev_mem_pool_type& dev_mem_pool,
73+
host_mem_pool_type& host_mem_pool,
74+
bool hkvs_is_pure_hbm, size_type n,
75+
const key_type* d_masked_keys, // (n)
76+
const value_type* const* d_values, // (n)
77+
cudaStream_t stream) = 0;
78+
79+
/**
80+
* @brief Attempts to find the supplied `d_keys` if the corresponding
81+
* `d_founds`-flag is `false` and fills the stored into the supplied memory
82+
* locations (i.e. in `d_values`).
83+
*
84+
* @param dev_mem_pool Memory pool for temporarily allocating device memory.
85+
* @param host_mem_pool Memory pool for temporarily allocating host memory.
86+
* @param n Number of key/value slots provided in other arguments.
87+
* @param d_keys Device pointer to an (n)-sized array of keys.
88+
* @param d_values Device pointer to an (n * value_dim)-sized array to store
89+
* the retrieved `d_values`. For slots where the corresponding `d_founds`-flag
90+
* is not `false`, the value may already have been assigned and, thus, MUST
91+
* not be altered.
92+
* @param d_founds Device pointer to an (n)-sized array which indicates
93+
* whether the corresponding `d_values` slot is already filled or not. So, if
94+
* and only if `d_founds` is still false, the implementation shall attempt to
95+
* retrieve and fill in the value for the corresponding key. If a key/value
96+
* was retrieved successfully from external storage, the implementation MUST
97+
* also set `d_founds` to `true`.
98+
* @param stream Stream that MUST be used for queuing asynchronous CUDA
99+
* operations. If only the input arguments or resources obtained from
100+
* respectively `dev_mem_pool` and `host_mem_pool` are used for such
101+
* operations, it is not necessary to synchronize the stream prior to
102+
* returning from the function.
103+
*/
104+
virtual void find(dev_mem_pool_type& dev_mem_pool,
105+
host_mem_pool_type& host_mem_pool, size_type n,
106+
const key_type* d_keys, // (n)
107+
value_type* d_values, // (n * value_dim)
108+
bool* d_founds, // (n)
109+
cudaStream_t stream) = 0;
110+
111+
/**
112+
* @brief Attempts to erase the entries associated with the supplied `d_keys`.
113+
* For keys do not exist nothing happens. It is permissible for this function
114+
* to be implemented asynchronously (i.e., to return before the actual
115+
* deletion has happened).
116+
*
117+
* @param dev_mem_pool Memory pool for temporarily allocating device memory.
118+
* @param host_mem_pool Memory pool for temporarily allocating host memory.
119+
* @param n Number of keys provided in `d_keys` arguments.
120+
* @param d_keys Device pointer to an (n)-sized array of keys. This pointer is
121+
* only guarnteed to be valid for the duration of the call. If easure is
122+
* implemented asynchronously, you must make a copy and manage its lifetime
123+
* yourself.
124+
*/
125+
virtual void erase_async(dev_mem_pool_type& dev_mem_pool,
126+
host_mem_pool_type& host_mem_pool, size_type n,
127+
const key_type* d_keys, cudaStream_t stream) = 0;
128+
};
129+
130+
} // namespace merlin
131+
} // namespace nv

include/merlin_hashtable.cuh

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <shared_mutex>
2727
#include <type_traits>
2828
#include "merlin/core_kernels.cuh"
29+
#include "merlin/external_storage.cuh"
2930
#include "merlin/flexible_buffer.cuh"
3031
#include "merlin/memory_pool.cuh"
3132
#include "merlin/types.cuh"
@@ -152,6 +153,8 @@ class HashTable {
152153
using DeviceMemoryPool = MemoryPool<DeviceAllocator<char>>;
153154
using HostMemoryPool = MemoryPool<HostAllocator<char>>;
154155

156+
using external_storage_type = ExternalStorage<K, V>;
157+
155158
#if THRUST_VERSION >= 101600
156159
static constexpr auto thrust_par = thrust::cuda::par_nosync;
157160
#else
@@ -169,6 +172,8 @@ class HashTable {
169172
* table object.
170173
*/
171174
~HashTable() {
175+
unlink_external_storage();
176+
172177
if (initialized_) {
173178
CUDA_CHECK(cudaDeviceSynchronize());
174179

@@ -299,21 +304,33 @@ class HashTable {
299304
load_factor, options_.block_size, stream, n, c_table_index_, d_table_,
300305
keys, reinterpret_cast<const value_type*>(values), metas);
301306
} else {
302-
const size_type dev_ws_size{n * (sizeof(value_type*) + sizeof(int))};
307+
const size_type dev_ws_base_size{n * (sizeof(value_type*) + sizeof(int))};
308+
const size_type dev_ws_size{dev_ws_base_size +
309+
(ext_store_ ? n : 0) * sizeof(key_type)};
303310
auto dev_ws{dev_mem_pool_->get_workspace<1>(dev_ws_size, stream)};
304311
auto d_dst{dev_ws.get<value_type**>(0)};
305312
auto d_src_offset{reinterpret_cast<int*>(d_dst + n)};
313+
auto d_evicted_keys{reinterpret_cast<key_type*>(d_src_offset + n)};
306314

307-
CUDA_CHECK(cudaMemsetAsync(d_dst, 0, dev_ws_size, stream));
315+
CUDA_CHECK(cudaMemsetAsync(d_dst, 0, dev_ws_base_size, stream));
316+
CUDA_CHECK(cudaMemsetAsync(d_evicted_keys, 0xFF,
317+
dev_ws_size - dev_ws_base_size, stream));
308318

309319
{
310320
const size_t block_size = options_.block_size;
311321
const size_t N = n * TILE_SIZE;
312322
const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);
313323

314324
upsert_kernel<key_type, value_type, meta_type, TILE_SIZE>
315-
<<<grid_size, block_size, 0, stream>>>(d_table_, keys, d_dst, metas,
316-
d_src_offset, N);
325+
<<<grid_size, block_size, 0, stream>>>(
326+
d_table_, keys, ext_store_ ? d_evicted_keys : nullptr, d_dst,
327+
metas, d_src_offset, N);
328+
}
329+
330+
if (ext_store_) {
331+
ext_store_->insert_or_assign(
332+
*dev_mem_pool_, *host_mem_pool_, table_->is_pure_hbm, n,
333+
d_evicted_keys, reinterpret_cast<value_type**>(d_dst), stream);
317334
}
318335

319336
{
@@ -326,16 +343,17 @@ class HashTable {
326343
}
327344

328345
if (options_.io_by_cpu) {
329-
const size_type host_ws_size{dev_ws_size +
346+
const size_type host_ws_size{dev_ws_base_size +
330347
n * sizeof(value_type) * dim()};
331348
auto host_ws{host_mem_pool_->get_workspace<1>(host_ws_size, stream)};
332349
auto h_dst{host_ws.get<value_type**>(0)};
333350
auto h_src_offset{reinterpret_cast<int*>(h_dst + n)};
334351
auto h_values{reinterpret_cast<value_type*>(h_src_offset + n)};
335352

336-
CUDA_CHECK(cudaMemcpyAsync(h_dst, d_dst, dev_ws_size,
353+
CUDA_CHECK(cudaMemcpyAsync(h_dst, d_dst, dev_ws_base_size,
337354
cudaMemcpyDeviceToHost, stream));
338-
CUDA_CHECK(cudaMemcpyAsync(h_values, values, host_ws_size - dev_ws_size,
355+
CUDA_CHECK(cudaMemcpyAsync(h_values, values,
356+
host_ws_size - dev_ws_base_size,
339357
cudaMemcpyDeviceToHost, stream));
340358
CUDA_CHECK(cudaStreamSynchronize(stream));
341359

@@ -547,6 +565,11 @@ class HashTable {
547565
}
548566
}
549567

568+
if (ext_store_) {
569+
ext_store_->find(*dev_mem_pool_, *host_mem_pool_, n, keys, values, founds,
570+
stream);
571+
}
572+
550573
CudaCheckError();
551574
}
552575

@@ -576,6 +599,10 @@ class HashTable {
576599
table_->bucket_max_size, table_->buckets_num, N);
577600
}
578601

602+
if (ext_store_) {
603+
ext_store_->erase_async(*dev_mem_pool_, *host_mem_pool_, n, keys, stream);
604+
}
605+
579606
CudaCheckError();
580607
return;
581608
}
@@ -1097,6 +1124,21 @@ class HashTable {
10971124
return total_count;
10981125
}
10991126

1127+
void link_external_storage(
1128+
std::shared_ptr<external_storage_type>& ext_store) {
1129+
MERLIN_CHECK(
1130+
ext_store->value_dim == dim(),
1131+
"Provided external storage value dimension is not incompatible!");
1132+
1133+
std::unique_lock<std::shared_timed_mutex> lock(mutex_);
1134+
ext_store_ = ext_store;
1135+
}
1136+
1137+
void unlink_external_storage() {
1138+
std::unique_lock<std::shared_timed_mutex> lock(mutex_);
1139+
ext_store_.reset();
1140+
}
1141+
11001142
private:
11011143
inline bool is_fast_mode() const noexcept { return table_->is_pure_hbm; }
11021144

@@ -1173,6 +1215,8 @@ class HashTable {
11731215
int c_table_index_ = -1;
11741216
std::unique_ptr<DeviceMemoryPool> dev_mem_pool_;
11751217
std::unique_ptr<HostMemoryPool> host_mem_pool_;
1218+
1219+
std::shared_ptr<external_storage_type> ext_store_;
11761220
};
11771221

11781222
} // namespace merlin

0 commit comments

Comments
 (0)