Skip to content

Commit 40cd796

Browse files
committed
Simplify thrust calls.
1 parent 40fb02b commit 40cd796

File tree

1 file changed

+7
-13
lines changed

1 file changed

+7
-13
lines changed

include/merlin_hashtable.cuh

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,6 @@ using EraseIfPredict = bool (*)(
119119
const S& threshold ///< The threshold to compare with the `score` argument.
120120
);
121121

122-
#if THRUST_VERSION >= 101600
123-
static constexpr auto& thrust_par = thrust::cuda::par_nosync;
124-
#else
125-
static constexpr auto& thrust_par = thrust::cuda::par;
126-
#endif
127-
128122
/**
129123
* A HierarchicalKV hash table is a concurrent and hierarchical hash table that
130124
* is powered by GPUs and can use HBM and host memory as storage for key-value
@@ -327,7 +321,7 @@ class HashTable {
327321
reinterpret_cast<uintptr_t*>(d_dst));
328322
thrust::device_ptr<int> d_src_offset_ptr(d_src_offset);
329323

330-
thrust::sort_by_key(thrust_par.on(stream), d_dst_ptr, d_dst_ptr + n,
324+
thrust::sort_by_key(thrust::cuda::par_nosync.on(stream), d_dst_ptr, d_dst_ptr + n,
331325
d_src_offset_ptr, thrust::less<uintptr_t>());
332326
}
333327

@@ -561,7 +555,7 @@ class HashTable {
561555
thrust::device_ptr<uintptr_t> dst_ptr(reinterpret_cast<uintptr_t*>(dst));
562556
thrust::device_ptr<int> src_offset_ptr(src_offset);
563557

564-
thrust::sort_by_key(thrust_par.on(stream), dst_ptr, dst_ptr + n,
558+
thrust::sort_by_key(thrust::cuda::par_nosync.on(stream), dst_ptr, dst_ptr + n,
565559
src_offset_ptr, thrust::less<uintptr_t>());
566560
}
567561

@@ -655,7 +649,7 @@ class HashTable {
655649
reinterpret_cast<uintptr_t*>(d_table_value_addrs));
656650
thrust::device_ptr<int> param_key_index_ptr(param_key_index);
657651

658-
thrust::sort_by_key(thrust_par.on(stream), table_value_ptr,
652+
thrust::sort_by_key(thrust::cuda::par_nosync.on(stream), table_value_ptr,
659653
table_value_ptr + n, param_key_index_ptr,
660654
thrust::less<uintptr_t>());
661655
}
@@ -825,7 +819,7 @@ class HashTable {
825819
reinterpret_cast<uintptr_t*>(d_dst));
826820
thrust::device_ptr<int> d_src_offset_ptr(d_src_offset);
827821

828-
thrust::sort_by_key(thrust_par.on(stream), d_dst_ptr, d_dst_ptr + n,
822+
thrust::sort_by_key(thrust::cuda::par_nosync.on(stream), d_dst_ptr, d_dst_ptr + n,
829823
d_src_offset_ptr, thrust::less<uintptr_t>());
830824
}
831825

@@ -926,7 +920,7 @@ class HashTable {
926920
reinterpret_cast<uintptr_t*>(src));
927921
thrust::device_ptr<int> dst_offset_ptr(dst_offset);
928922

929-
thrust::sort_by_key(thrust_par.on(stream), src_ptr, src_ptr + n,
923+
thrust::sort_by_key(thrust::cuda::par_nosync.on(stream), src_ptr, src_ptr + n,
930924
dst_offset_ptr, thrust::less<uintptr_t>());
931925
}
932926

@@ -1278,7 +1272,7 @@ class HashTable {
12781272

12791273
for (size_type start_i = 0; start_i < N; start_i += step) {
12801274
size_type end_i = std::min(start_i + step, N);
1281-
h_size += thrust::reduce(thrust_par.on(stream), size_ptr + start_i,
1275+
h_size += thrust::reduce(thrust::cuda::par_nosync.on(stream), size_ptr + start_i,
12821276
size_ptr + end_i, 0, thrust::plus<int>());
12831277
}
12841278

@@ -1594,7 +1588,7 @@ class HashTable {
15941588

15951589
thrust::device_ptr<int> size_ptr(table_->buckets_size);
15961590

1597-
int size = thrust::reduce(thrust_par.on(stream), size_ptr, size_ptr + N, 0,
1591+
int size = thrust::reduce(thrust::cuda::par_nosync.on(stream), size_ptr, size_ptr + N, 0,
15981592
thrust::plus<int>());
15991593

16001594
CudaCheckError();

0 commit comments

Comments
 (0)