@@ -119,12 +119,6 @@ using EraseIfPredict = bool (*)(
119119 const S& threshold // /< The threshold to compare with the `score` argument.
120120);
121121
122- #if THRUST_VERSION >= 101600
123- static constexpr auto & thrust_par = thrust::cuda::par_nosync;
124- #else
125- static constexpr auto & thrust_par = thrust::cuda::par;
126- #endif
127-
128122/* *
129123 * A HierarchicalKV hash table is a concurrent and hierarchical hash table that
130124 * is powered by GPUs and can use HBM and host memory as storage for key-value
@@ -327,7 +321,7 @@ class HashTable {
327321 reinterpret_cast <uintptr_t *>(d_dst));
328322 thrust::device_ptr<int > d_src_offset_ptr (d_src_offset);
329323
330- thrust::sort_by_key (thrust_par .on (stream), d_dst_ptr, d_dst_ptr + n,
324+ thrust::sort_by_key (thrust::cuda::par_nosync .on (stream), d_dst_ptr, d_dst_ptr + n,
331325 d_src_offset_ptr, thrust::less<uintptr_t >());
332326 }
333327
@@ -561,7 +555,7 @@ class HashTable {
561555 thrust::device_ptr<uintptr_t > dst_ptr (reinterpret_cast <uintptr_t *>(dst));
562556 thrust::device_ptr<int > src_offset_ptr (src_offset);
563557
564- thrust::sort_by_key (thrust_par .on (stream), dst_ptr, dst_ptr + n,
558+ thrust::sort_by_key (thrust::cuda::par_nosync .on (stream), dst_ptr, dst_ptr + n,
565559 src_offset_ptr, thrust::less<uintptr_t >());
566560 }
567561
@@ -655,7 +649,7 @@ class HashTable {
655649 reinterpret_cast <uintptr_t *>(d_table_value_addrs));
656650 thrust::device_ptr<int > param_key_index_ptr (param_key_index);
657651
658- thrust::sort_by_key (thrust_par .on (stream), table_value_ptr,
652+ thrust::sort_by_key (thrust::cuda::par_nosync .on (stream), table_value_ptr,
659653 table_value_ptr + n, param_key_index_ptr,
660654 thrust::less<uintptr_t >());
661655 }
@@ -825,7 +819,7 @@ class HashTable {
825819 reinterpret_cast <uintptr_t *>(d_dst));
826820 thrust::device_ptr<int > d_src_offset_ptr (d_src_offset);
827821
828- thrust::sort_by_key (thrust_par .on (stream), d_dst_ptr, d_dst_ptr + n,
822+ thrust::sort_by_key (thrust::cuda::par_nosync .on (stream), d_dst_ptr, d_dst_ptr + n,
829823 d_src_offset_ptr, thrust::less<uintptr_t >());
830824 }
831825
@@ -926,7 +920,7 @@ class HashTable {
926920 reinterpret_cast <uintptr_t *>(src));
927921 thrust::device_ptr<int > dst_offset_ptr (dst_offset);
928922
929- thrust::sort_by_key (thrust_par .on (stream), src_ptr, src_ptr + n,
923+ thrust::sort_by_key (thrust::cuda::par_nosync .on (stream), src_ptr, src_ptr + n,
930924 dst_offset_ptr, thrust::less<uintptr_t >());
931925 }
932926
@@ -1278,7 +1272,7 @@ class HashTable {
12781272
12791273 for (size_type start_i = 0 ; start_i < N; start_i += step) {
12801274 size_type end_i = std::min (start_i + step, N);
1281- h_size += thrust::reduce (thrust_par .on (stream), size_ptr + start_i,
1275+ h_size += thrust::reduce (thrust::cuda::par_nosync .on (stream), size_ptr + start_i,
12821276 size_ptr + end_i, 0 , thrust::plus<int >());
12831277 }
12841278
@@ -1594,7 +1588,7 @@ class HashTable {
15941588
15951589 thrust::device_ptr<int > size_ptr (table_->buckets_size );
15961590
1597- int size = thrust::reduce (thrust_par .on (stream), size_ptr, size_ptr + N, 0 ,
1591+ int size = thrust::reduce (thrust::cuda::par_nosync .on (stream), size_ptr, size_ptr + N, 0 ,
15981592 thrust::plus<int >());
15991593
16001594 CudaCheckError ();
0 commit comments