Skip to content

Commit 9144cfb

Browse files
committed
Add missing files
1 parent fe5f47c commit 9144cfb

File tree

2 files changed

+446
-0
lines changed

2 files changed

+446
-0
lines changed
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <test_utils.hpp>
18+
19+
#include <cuco/static_map.cuh>
20+
21+
#include <cuda/functional>
22+
#include <thrust/device_vector.h>
23+
#include <thrust/host_vector.h>
24+
#include <thrust/sequence.h>
25+
26+
#include <catch2/catch_template_test_macros.hpp>
27+
28+
using size_type = std::size_t;
29+
30+
template <class Container>
31+
__global__ void test_retrieve_if_kernel(
32+
Container container_ref,
33+
typename Container::key_type* keys_begin,
34+
std::size_t num_keys,
35+
typename Container::key_type* stencil_begin,
36+
typename Container::key_type* output_probe,
37+
typename Container::value_type* output_match,
38+
cuda::atomic<int, cuda::thread_scope_device>* atomic_counter)
39+
{
40+
using key_type = typename Container::key_type;
41+
namespace cg = cooperative_groups;
42+
43+
auto const block = cg::this_thread_block();
44+
auto const pred = [] __device__(key_type k) { return k % 2 == 0; };
45+
46+
container_ref.retrieve_if<128>(block,
47+
keys_begin,
48+
keys_begin + num_keys,
49+
stencil_begin,
50+
pred,
51+
output_probe,
52+
output_match,
53+
*atomic_counter);
54+
}
55+
56+
template <class Container>
57+
__global__ void test_retrieve_if_all_false_kernel(
58+
Container container_ref,
59+
typename Container::key_type* keys_begin,
60+
std::size_t num_keys,
61+
typename Container::key_type* stencil_begin,
62+
typename Container::key_type* output_probe,
63+
typename Container::value_type* output_match,
64+
cuda::atomic<int, cuda::thread_scope_device>* atomic_counter)
65+
{
66+
using key_type = typename Container::key_type;
67+
namespace cg = cooperative_groups;
68+
69+
auto const block = cg::this_thread_block();
70+
auto const always_false = [] __device__(key_type) { return false; };
71+
72+
container_ref.retrieve_if<128>(block,
73+
keys_begin,
74+
keys_begin + num_keys,
75+
stencil_begin,
76+
always_false,
77+
output_probe,
78+
output_match,
79+
*atomic_counter);
80+
}
81+
82+
template <class Container>
83+
__global__ void test_retrieve_if_all_true_kernel(
84+
Container container_ref,
85+
typename Container::key_type* keys_begin,
86+
std::size_t num_keys,
87+
typename Container::key_type* stencil_begin,
88+
typename Container::key_type* output_probe,
89+
typename Container::value_type* output_match,
90+
cuda::atomic<int, cuda::thread_scope_device>* atomic_counter)
91+
{
92+
using key_type = typename Container::key_type;
93+
namespace cg = cooperative_groups;
94+
95+
auto const block = cg::this_thread_block();
96+
auto const always_true = [] __device__(key_type) { return true; };
97+
98+
container_ref.retrieve_if<128>(block,
99+
keys_begin,
100+
keys_begin + num_keys,
101+
stencil_begin,
102+
always_true,
103+
output_probe,
104+
output_match,
105+
*atomic_counter);
106+
}
107+
108+
TEMPLATE_TEST_CASE_SIG("static_map retrieve_if",
109+
"",
110+
((typename Key, typename T), Key, T),
111+
(int32_t, int32_t),
112+
(int64_t, int64_t))
113+
{
114+
constexpr size_type num_keys{400};
115+
116+
using container_type = cuco::static_map<Key, T>;
117+
using value_type = typename container_type::value_type;
118+
119+
container_type container{num_keys * 2, cuco::empty_key<Key>{-1}, cuco::empty_value<T>{-1}};
120+
121+
auto keys_begin = thrust::counting_iterator<Key>(1);
122+
auto vals_begin = thrust::counting_iterator<T>(1);
123+
auto pairs_begin = thrust::make_zip_iterator(thrust::make_tuple(keys_begin, vals_begin));
124+
125+
container.insert(pairs_begin, pairs_begin + num_keys);
126+
127+
SECTION("Testing retrieve_if with even predicate")
128+
{
129+
thrust::device_vector<Key> input_keys(keys_begin, keys_begin + num_keys);
130+
thrust::device_vector<Key> stencil_values(keys_begin, keys_begin + num_keys);
131+
thrust::device_vector<Key> probed_keys(num_keys);
132+
thrust::device_vector<value_type> matched_pairs(num_keys);
133+
134+
cuda::atomic<int, cuda::thread_scope_device>* d_atomic_counter;
135+
CUCO_CUDA_TRY(
136+
cudaMalloc(&d_atomic_counter, sizeof(cuda::atomic<int, cuda::thread_scope_device>)));
137+
CUCO_CUDA_TRY(
138+
cudaMemset(d_atomic_counter, 0, sizeof(cuda::atomic<int, cuda::thread_scope_device>)));
139+
140+
auto const container_ref = container.ref(cuco::op::retrieve);
141+
142+
test_retrieve_if_kernel<<<1, 128>>>(container_ref,
143+
thrust::raw_pointer_cast(input_keys.data()),
144+
num_keys,
145+
thrust::raw_pointer_cast(stencil_values.data()),
146+
thrust::raw_pointer_cast(probed_keys.data()),
147+
thrust::raw_pointer_cast(matched_pairs.data()),
148+
d_atomic_counter);
149+
CUCO_CUDA_TRY(cudaDeviceSynchronize());
150+
151+
int h_counter;
152+
CUCO_CUDA_TRY(cudaMemcpy(&h_counter, d_atomic_counter, sizeof(int), cudaMemcpyDeviceToHost));
153+
154+
// Should retrieve even numbers only
155+
REQUIRE(h_counter > 0);
156+
REQUIRE(h_counter <= static_cast<int>(num_keys));
157+
158+
CUCO_CUDA_TRY(cudaFree(d_atomic_counter));
159+
}
160+
161+
SECTION("Testing retrieve_if with always false predicate")
162+
{
163+
thrust::device_vector<Key> input_keys(keys_begin, keys_begin + num_keys);
164+
thrust::device_vector<Key> stencil_values(keys_begin, keys_begin + num_keys);
165+
thrust::device_vector<Key> probed_keys(num_keys);
166+
thrust::device_vector<value_type> matched_pairs(num_keys);
167+
168+
cuda::atomic<int, cuda::thread_scope_device>* d_atomic_counter;
169+
CUCO_CUDA_TRY(
170+
cudaMalloc(&d_atomic_counter, sizeof(cuda::atomic<int, cuda::thread_scope_device>)));
171+
CUCO_CUDA_TRY(
172+
cudaMemset(d_atomic_counter, 0, sizeof(cuda::atomic<int, cuda::thread_scope_device>)));
173+
174+
auto const container_ref = container.ref(cuco::op::retrieve);
175+
176+
test_retrieve_if_all_false_kernel<<<1, 128>>>(container_ref,
177+
thrust::raw_pointer_cast(input_keys.data()),
178+
num_keys,
179+
thrust::raw_pointer_cast(stencil_values.data()),
180+
thrust::raw_pointer_cast(probed_keys.data()),
181+
thrust::raw_pointer_cast(matched_pairs.data()),
182+
d_atomic_counter);
183+
CUCO_CUDA_TRY(cudaDeviceSynchronize());
184+
185+
int h_counter;
186+
CUCO_CUDA_TRY(cudaMemcpy(&h_counter, d_atomic_counter, sizeof(int), cudaMemcpyDeviceToHost));
187+
188+
// Should retrieve nothing
189+
REQUIRE(h_counter == 0);
190+
191+
CUCO_CUDA_TRY(cudaFree(d_atomic_counter));
192+
}
193+
194+
SECTION("Testing retrieve_if with always true predicate")
195+
{
196+
thrust::device_vector<Key> input_keys(keys_begin, keys_begin + num_keys);
197+
thrust::device_vector<Key> stencil_values(keys_begin, keys_begin + num_keys);
198+
thrust::device_vector<Key> probed_keys(num_keys);
199+
thrust::device_vector<value_type> matched_pairs(num_keys);
200+
201+
cuda::atomic<int, cuda::thread_scope_device>* d_atomic_counter;
202+
CUCO_CUDA_TRY(
203+
cudaMalloc(&d_atomic_counter, sizeof(cuda::atomic<int, cuda::thread_scope_device>)));
204+
CUCO_CUDA_TRY(
205+
cudaMemset(d_atomic_counter, 0, sizeof(cuda::atomic<int, cuda::thread_scope_device>)));
206+
207+
auto const container_ref = container.ref(cuco::op::retrieve);
208+
209+
test_retrieve_if_all_true_kernel<<<1, 128>>>(container_ref,
210+
thrust::raw_pointer_cast(input_keys.data()),
211+
num_keys,
212+
thrust::raw_pointer_cast(stencil_values.data()),
213+
thrust::raw_pointer_cast(probed_keys.data()),
214+
thrust::raw_pointer_cast(matched_pairs.data()),
215+
d_atomic_counter);
216+
CUCO_CUDA_TRY(cudaDeviceSynchronize());
217+
218+
int h_counter;
219+
CUCO_CUDA_TRY(cudaMemcpy(&h_counter, d_atomic_counter, sizeof(int), cudaMemcpyDeviceToHost));
220+
221+
// Should retrieve all keys that exist in the container
222+
REQUIRE(h_counter == static_cast<int>(num_keys));
223+
224+
CUCO_CUDA_TRY(cudaFree(d_atomic_counter));
225+
}
226+
}

0 commit comments

Comments
 (0)