1+ /*
2+ * Copyright (c) 2025, NVIDIA CORPORATION.
3+ *
4+ * Licensed under the Apache License, Version 2.0 (the "License");
5+ * you may not use this file except in compliance with the License.
6+ * You may obtain a copy of the License at
7+ *
8+ * http://www.apache.org/licenses/LICENSE-2.0
9+ *
10+ * Unless required by applicable law or agreed to in writing, software
11+ * distributed under the License is distributed on an "AS IS" BASIS,
12+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ * See the License for the specific language governing permissions and
14+ * limitations under the License.
15+ */
16+
17+ #include < test_utils.hpp>
18+
19+ #include < cuco/static_map.cuh>
20+
21+ #include < cuda/functional>
22+ #include < thrust/device_vector.h>
23+ #include < thrust/host_vector.h>
24+ #include < thrust/sequence.h>
25+
26+ #include < catch2/catch_template_test_macros.hpp>
27+
28+ using size_type = std::size_t ;
29+
30+ template <class Container >
31+ __global__ void test_retrieve_if_kernel (
32+ Container container_ref,
33+ typename Container::key_type* keys_begin,
34+ std::size_t num_keys,
35+ typename Container::key_type* stencil_begin,
36+ typename Container::key_type* output_probe,
37+ typename Container::value_type* output_match,
38+ cuda::atomic<int , cuda::thread_scope_device>* atomic_counter)
39+ {
40+ using key_type = typename Container::key_type;
41+ namespace cg = cooperative_groups;
42+
43+ auto const block = cg::this_thread_block ();
44+ auto const pred = [] __device__ (key_type k) { return k % 2 == 0 ; };
45+
46+ container_ref.retrieve_if <128 >(block,
47+ keys_begin,
48+ keys_begin + num_keys,
49+ stencil_begin,
50+ pred,
51+ output_probe,
52+ output_match,
53+ *atomic_counter);
54+ }
55+
56+ template <class Container >
57+ __global__ void test_retrieve_if_all_false_kernel (
58+ Container container_ref,
59+ typename Container::key_type* keys_begin,
60+ std::size_t num_keys,
61+ typename Container::key_type* stencil_begin,
62+ typename Container::key_type* output_probe,
63+ typename Container::value_type* output_match,
64+ cuda::atomic<int , cuda::thread_scope_device>* atomic_counter)
65+ {
66+ using key_type = typename Container::key_type;
67+ namespace cg = cooperative_groups;
68+
69+ auto const block = cg::this_thread_block ();
70+ auto const always_false = [] __device__ (key_type) { return false ; };
71+
72+ container_ref.retrieve_if <128 >(block,
73+ keys_begin,
74+ keys_begin + num_keys,
75+ stencil_begin,
76+ always_false,
77+ output_probe,
78+ output_match,
79+ *atomic_counter);
80+ }
81+
82+ template <class Container >
83+ __global__ void test_retrieve_if_all_true_kernel (
84+ Container container_ref,
85+ typename Container::key_type* keys_begin,
86+ std::size_t num_keys,
87+ typename Container::key_type* stencil_begin,
88+ typename Container::key_type* output_probe,
89+ typename Container::value_type* output_match,
90+ cuda::atomic<int , cuda::thread_scope_device>* atomic_counter)
91+ {
92+ using key_type = typename Container::key_type;
93+ namespace cg = cooperative_groups;
94+
95+ auto const block = cg::this_thread_block ();
96+ auto const always_true = [] __device__ (key_type) { return true ; };
97+
98+ container_ref.retrieve_if <128 >(block,
99+ keys_begin,
100+ keys_begin + num_keys,
101+ stencil_begin,
102+ always_true,
103+ output_probe,
104+ output_match,
105+ *atomic_counter);
106+ }
107+
108+ TEMPLATE_TEST_CASE_SIG (" static_map retrieve_if" ,
109+ " " ,
110+ ((typename Key, typename T), Key, T),
111+ (int32_t , int32_t ),
112+ (int64_t , int64_t ))
113+ {
114+ constexpr size_type num_keys{400 };
115+
116+ using container_type = cuco::static_map<Key, T>;
117+ using value_type = typename container_type::value_type;
118+
119+ container_type container{num_keys * 2 , cuco::empty_key<Key>{-1 }, cuco::empty_value<T>{-1 }};
120+
121+ auto keys_begin = thrust::counting_iterator<Key>(1 );
122+ auto vals_begin = thrust::counting_iterator<T>(1 );
123+ auto pairs_begin = thrust::make_zip_iterator (thrust::make_tuple (keys_begin, vals_begin));
124+
125+ container.insert (pairs_begin, pairs_begin + num_keys);
126+
127+ SECTION (" Testing retrieve_if with even predicate" )
128+ {
129+ thrust::device_vector<Key> input_keys (keys_begin, keys_begin + num_keys);
130+ thrust::device_vector<Key> stencil_values (keys_begin, keys_begin + num_keys);
131+ thrust::device_vector<Key> probed_keys (num_keys);
132+ thrust::device_vector<value_type> matched_pairs (num_keys);
133+
134+ cuda::atomic<int , cuda::thread_scope_device>* d_atomic_counter;
135+ CUCO_CUDA_TRY (
136+ cudaMalloc (&d_atomic_counter, sizeof (cuda::atomic<int , cuda::thread_scope_device>)));
137+ CUCO_CUDA_TRY (
138+ cudaMemset (d_atomic_counter, 0 , sizeof (cuda::atomic<int , cuda::thread_scope_device>)));
139+
140+ auto const container_ref = container.ref (cuco::op::retrieve);
141+
142+ test_retrieve_if_kernel<<<1 , 128 >>> (container_ref,
143+ thrust::raw_pointer_cast (input_keys.data ()),
144+ num_keys,
145+ thrust::raw_pointer_cast (stencil_values.data ()),
146+ thrust::raw_pointer_cast (probed_keys.data ()),
147+ thrust::raw_pointer_cast (matched_pairs.data ()),
148+ d_atomic_counter);
149+ CUCO_CUDA_TRY (cudaDeviceSynchronize ());
150+
151+ int h_counter;
152+ CUCO_CUDA_TRY (cudaMemcpy (&h_counter, d_atomic_counter, sizeof (int ), cudaMemcpyDeviceToHost));
153+
154+ // Should retrieve even numbers only
155+ REQUIRE (h_counter > 0 );
156+ REQUIRE (h_counter <= static_cast <int >(num_keys));
157+
158+ CUCO_CUDA_TRY (cudaFree (d_atomic_counter));
159+ }
160+
161+ SECTION (" Testing retrieve_if with always false predicate" )
162+ {
163+ thrust::device_vector<Key> input_keys (keys_begin, keys_begin + num_keys);
164+ thrust::device_vector<Key> stencil_values (keys_begin, keys_begin + num_keys);
165+ thrust::device_vector<Key> probed_keys (num_keys);
166+ thrust::device_vector<value_type> matched_pairs (num_keys);
167+
168+ cuda::atomic<int , cuda::thread_scope_device>* d_atomic_counter;
169+ CUCO_CUDA_TRY (
170+ cudaMalloc (&d_atomic_counter, sizeof (cuda::atomic<int , cuda::thread_scope_device>)));
171+ CUCO_CUDA_TRY (
172+ cudaMemset (d_atomic_counter, 0 , sizeof (cuda::atomic<int , cuda::thread_scope_device>)));
173+
174+ auto const container_ref = container.ref (cuco::op::retrieve);
175+
176+ test_retrieve_if_all_false_kernel<<<1 , 128 >>> (container_ref,
177+ thrust::raw_pointer_cast (input_keys.data ()),
178+ num_keys,
179+ thrust::raw_pointer_cast (stencil_values.data ()),
180+ thrust::raw_pointer_cast (probed_keys.data ()),
181+ thrust::raw_pointer_cast (matched_pairs.data ()),
182+ d_atomic_counter);
183+ CUCO_CUDA_TRY (cudaDeviceSynchronize ());
184+
185+ int h_counter;
186+ CUCO_CUDA_TRY (cudaMemcpy (&h_counter, d_atomic_counter, sizeof (int ), cudaMemcpyDeviceToHost));
187+
188+ // Should retrieve nothing
189+ REQUIRE (h_counter == 0 );
190+
191+ CUCO_CUDA_TRY (cudaFree (d_atomic_counter));
192+ }
193+
194+ SECTION (" Testing retrieve_if with always true predicate" )
195+ {
196+ thrust::device_vector<Key> input_keys (keys_begin, keys_begin + num_keys);
197+ thrust::device_vector<Key> stencil_values (keys_begin, keys_begin + num_keys);
198+ thrust::device_vector<Key> probed_keys (num_keys);
199+ thrust::device_vector<value_type> matched_pairs (num_keys);
200+
201+ cuda::atomic<int , cuda::thread_scope_device>* d_atomic_counter;
202+ CUCO_CUDA_TRY (
203+ cudaMalloc (&d_atomic_counter, sizeof (cuda::atomic<int , cuda::thread_scope_device>)));
204+ CUCO_CUDA_TRY (
205+ cudaMemset (d_atomic_counter, 0 , sizeof (cuda::atomic<int , cuda::thread_scope_device>)));
206+
207+ auto const container_ref = container.ref (cuco::op::retrieve);
208+
209+ test_retrieve_if_all_true_kernel<<<1 , 128 >>> (container_ref,
210+ thrust::raw_pointer_cast (input_keys.data ()),
211+ num_keys,
212+ thrust::raw_pointer_cast (stencil_values.data ()),
213+ thrust::raw_pointer_cast (probed_keys.data ()),
214+ thrust::raw_pointer_cast (matched_pairs.data ()),
215+ d_atomic_counter);
216+ CUCO_CUDA_TRY (cudaDeviceSynchronize ());
217+
218+ int h_counter;
219+ CUCO_CUDA_TRY (cudaMemcpy (&h_counter, d_atomic_counter, sizeof (int ), cudaMemcpyDeviceToHost));
220+
221+ // Should retrieve all keys that exist in the container
222+ REQUIRE (h_counter == static_cast <int >(num_keys));
223+
224+ CUCO_CUDA_TRY (cudaFree (d_atomic_counter));
225+ }
226+ }
0 commit comments