Skip to content

Commit c9d565b

Browse files
committed
DEVICE/API: Add tests
Signed-off-by: Michal Shalev <[email protected]>
1 parent d865179 commit c9d565b

File tree

13 files changed

+1628
-829
lines changed

13 files changed

+1628
-829
lines changed

test/gtest/device_api/cuda_ptr.cuh

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
#ifndef _CUDA_PTR_CUH
19+
#define _CUDA_PTR_CUH
20+
21+
#include <cuda_runtime.h>
22+
23+
template<typename T> class CudaPtr {
24+
public:
25+
explicit CudaPtr(T **ptr) : ptr_(ptr) {
26+
cudaMalloc(reinterpret_cast<void **>(ptr_), sizeof(T));
27+
cudaMemset(*ptr_, 0, sizeof(T));
28+
}
29+
30+
~CudaPtr() {
31+
if (ptr_ && *ptr_) {
32+
cudaFree(*ptr_);
33+
*ptr_ = nullptr;
34+
}
35+
}
36+
37+
CudaPtr(const CudaPtr &) = delete;
38+
CudaPtr &
39+
operator=(const CudaPtr &) = delete;
40+
41+
CudaPtr(CudaPtr &&other) noexcept : ptr_(other.ptr_) {
42+
other.ptr_ = nullptr;
43+
}
44+
45+
CudaPtr &
46+
operator=(CudaPtr &&other) noexcept {
47+
if (this != &other) {
48+
if (ptr_ && *ptr_) {
49+
cudaFree(*ptr_);
50+
}
51+
ptr_ = other.ptr_;
52+
other.ptr_ = nullptr;
53+
}
54+
return *this;
55+
}
56+
57+
T *
58+
get() const {
59+
return ptr_ ? *ptr_ : nullptr;
60+
}
61+
62+
private:
63+
T **ptr_;
64+
};
65+
66+
#endif // _CUDA_PTR_CUH
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
#include "device_test_base.cuh"
19+
20+
nixlAgentConfig
21+
DeviceApiTestBase::getConfig() {
22+
return nixlAgentConfig(true, false, 0, nixl_thread_sync_t::NIXL_THREAD_SYNC_RW, 0, 100000);
23+
}
24+
25+
nixl_b_params_t
26+
DeviceApiTestBase::getBackendParams() {
27+
nixl_b_params_t params;
28+
params["num_workers"] = "2";
29+
return params;
30+
}
31+
32+
void
33+
DeviceApiTestBase::SetUp() {
34+
cudaError_t cuda_error = cudaSetDevice(0);
35+
if (cuda_error != cudaSuccess) {
36+
FAIL() << "Failed to set CUDA device 0: " << cudaGetErrorString(cuda_error);
37+
}
38+
39+
for (size_t i = 0; i < 2; i++) {
40+
agents.emplace_back(std::make_unique<nixlAgent>(getAgentName(i), getConfig()));
41+
nixlBackendH *backend_handle = nullptr;
42+
nixl_status_t status =
43+
agents.back()->createBackend("UCX", getBackendParams(), backend_handle);
44+
ASSERT_EQ(status, NIXL_SUCCESS);
45+
EXPECT_NE(backend_handle, nullptr);
46+
backend_handles.push_back(backend_handle);
47+
}
48+
}
49+
50+
void
51+
DeviceApiTestBase::TearDown() {
52+
agents.clear();
53+
}
54+
55+
void
56+
DeviceApiTestBase::registerMem(nixlAgent &agent,
57+
const std::vector<MemBuffer> &buffers,
58+
nixl_mem_t mem_type) {
59+
auto reg_list = makeDescList<nixlBlobDesc>(buffers, mem_type);
60+
agent.registerMem(reg_list);
61+
}
62+
63+
void
64+
DeviceApiTestBase::completeWireup(size_t from_agent, size_t to_agent) {
65+
nixl_notifs_t notifs;
66+
nixl_status_t status = getAgent(from_agent).genNotif(getAgentName(to_agent), notifMsg);
67+
ASSERT_EQ(status, NIXL_SUCCESS) << "Failed to complete wireup";
68+
69+
do {
70+
nixl_status_t ret = getAgent(to_agent).getNotifs(notifs);
71+
ASSERT_EQ(ret, NIXL_SUCCESS) << "Failed to get notifications during wireup";
72+
std::this_thread::sleep_for(std::chrono::milliseconds(10));
73+
} while (notifs.size() == 0);
74+
}
75+
76+
void
77+
DeviceApiTestBase::exchangeMD(size_t from_agent, size_t to_agent) {
78+
for (size_t i = 0; i < agents.size(); i++) {
79+
nixl_blob_t md;
80+
nixl_status_t status = agents[i]->getLocalMD(md);
81+
ASSERT_EQ(status, NIXL_SUCCESS);
82+
83+
for (size_t j = 0; j < agents.size(); j++) {
84+
if (i == j) continue;
85+
std::string remote_agent_name;
86+
status = agents[j]->loadRemoteMD(md, remote_agent_name);
87+
ASSERT_EQ(status, NIXL_SUCCESS);
88+
EXPECT_EQ(remote_agent_name, getAgentName(i));
89+
}
90+
}
91+
92+
completeWireup(from_agent, to_agent);
93+
}
94+
95+
void
96+
DeviceApiTestBase::invalidateMD() {
97+
for (size_t i = 0; i < agents.size(); i++) {
98+
for (size_t j = 0; j < agents.size(); j++) {
99+
if (i == j) continue;
100+
nixl_status_t status = agents[j]->invalidateRemoteMD(getAgentName(i));
101+
ASSERT_EQ(status, NIXL_SUCCESS);
102+
}
103+
}
104+
}
105+
106+
void
107+
DeviceApiTestBase::createRegisteredMem(nixlAgent &agent,
108+
size_t size,
109+
size_t count,
110+
nixl_mem_t mem_type,
111+
std::vector<MemBuffer> &out) {
112+
while (count-- != 0) {
113+
out.emplace_back(size, mem_type);
114+
}
115+
116+
registerMem(agent, out, mem_type);
117+
}
118+
119+
nixlAgent &
120+
DeviceApiTestBase::getAgent(size_t idx) {
121+
return *agents[idx];
122+
}
123+
124+
std::string
125+
DeviceApiTestBase::getAgentName(size_t idx) {
126+
return absl::StrFormat("agent_%d", idx);
127+
}
128+
129+
void
130+
DeviceApiTestBase::initTiming(unsigned long long **start_time_ptr,
131+
unsigned long long **end_time_ptr) {
132+
cudaMalloc(start_time_ptr, sizeof(unsigned long long));
133+
cudaMalloc(end_time_ptr, sizeof(unsigned long long));
134+
cudaMemset(*start_time_ptr, 0, sizeof(unsigned long long));
135+
cudaMemset(*end_time_ptr, 0, sizeof(unsigned long long));
136+
}
137+
138+
void
139+
DeviceApiTestBase::getTiming(unsigned long long *start_time_ptr,
140+
unsigned long long *end_time_ptr,
141+
unsigned long long &start_time_cpu,
142+
unsigned long long &end_time_cpu) {
143+
cudaMemcpy(&start_time_cpu, start_time_ptr, sizeof(unsigned long long), cudaMemcpyDeviceToHost);
144+
cudaMemcpy(&end_time_cpu, end_time_ptr, sizeof(unsigned long long), cudaMemcpyDeviceToHost);
145+
}
146+
147+
const char *
148+
DeviceApiTestBase::getGpuXferLevelStr(nixl_gpu_level_t level) {
149+
switch (level) {
150+
case nixl_gpu_level_t::WARP:
151+
return "WARP";
152+
case nixl_gpu_level_t::BLOCK:
153+
return "BLOCK";
154+
case nixl_gpu_level_t::THREAD:
155+
return "THREAD";
156+
default:
157+
return "UNKNOWN";
158+
}
159+
}
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
#ifndef _DEVICE_TEST_BASE_CUH
19+
#define _DEVICE_TEST_BASE_CUH
20+
21+
#include <gtest/gtest.h>
22+
#include "nixl.h"
23+
#include "common.h"
24+
#include <nixl_device.cuh>
25+
#include <cuda_runtime.h>
26+
#include <memory>
27+
#include <vector>
28+
#include <string>
29+
#include <thread>
30+
#include <chrono>
31+
#include <functional>
32+
#include <type_traits>
33+
#include <iomanip>
34+
#include <iostream>
35+
#include <cstdlib>
36+
#include <ctime>
37+
#include <absl/strings/str_format.h>
38+
39+
#include "mem_buffer.cuh"
40+
#include "cuda_ptr.cuh"
41+
#include "device_utils.cuh"
42+
43+
class DeviceApiTestBase : public testing::TestWithParam<nixl_gpu_level_t> {
44+
public:
45+
static const char *
46+
getGpuXferLevelStr(nixl_gpu_level_t level);
47+
48+
static const std::vector<nixl_gpu_level_t>
49+
getTestLevels() {
50+
static const std::vector<nixl_gpu_level_t> testLevels = {
51+
nixl_gpu_level_t::BLOCK,
52+
nixl_gpu_level_t::WARP,
53+
nixl_gpu_level_t::THREAD,
54+
};
55+
return testLevels;
56+
}
57+
58+
static constexpr const char *notifMsg = "notification";
59+
60+
protected:
61+
static nixlAgentConfig
62+
getConfig();
63+
nixl_b_params_t
64+
getBackendParams();
65+
void
66+
SetUp() override;
67+
void
68+
TearDown() override;
69+
70+
template<typename Desc>
71+
nixlDescList<Desc>
72+
makeDescList(const std::vector<MemBuffer> &buffers, nixl_mem_t mem_type) {
73+
nixlDescList<Desc> desc_list(mem_type);
74+
for (const auto &buffer : buffers) {
75+
desc_list.addDesc(Desc(buffer, buffer.getSize(), uint64_t(devId)));
76+
}
77+
return desc_list;
78+
}
79+
80+
void
81+
registerMem(nixlAgent &agent, const std::vector<MemBuffer> &buffers, nixl_mem_t mem_type);
82+
void
83+
completeWireup(size_t from_agent, size_t to_agent);
84+
void
85+
exchangeMD(size_t from_agent, size_t to_agent);
86+
void
87+
invalidateMD();
88+
89+
void
90+
createRegisteredMem(nixlAgent &agent,
91+
size_t size,
92+
size_t count,
93+
nixl_mem_t mem_type,
94+
std::vector<MemBuffer> &out);
95+
96+
nixlAgent &
97+
getAgent(size_t idx);
98+
std::string
99+
getAgentName(size_t idx);
100+
101+
void
102+
initTiming(unsigned long long **start_time_ptr, unsigned long long **end_time_ptr);
103+
void
104+
getTiming(unsigned long long *start_time_ptr,
105+
unsigned long long *end_time_ptr,
106+
unsigned long long &start_time_cpu,
107+
unsigned long long &end_time_cpu);
108+
109+
template<typename KernelFunc>
110+
nixl_status_t
111+
dispatchKernelByLevel(nixl_gpu_level_t level, KernelFunc kernel_func) {
112+
switch (level) {
113+
case nixl_gpu_level_t::BLOCK:
114+
return kernel_func(std::integral_constant<nixl_gpu_level_t, nixl_gpu_level_t::BLOCK>{});
115+
case nixl_gpu_level_t::WARP:
116+
return kernel_func(std::integral_constant<nixl_gpu_level_t, nixl_gpu_level_t::WARP>{});
117+
case nixl_gpu_level_t::THREAD:
118+
return kernel_func(
119+
std::integral_constant<nixl_gpu_level_t, nixl_gpu_level_t::THREAD>{});
120+
default:
121+
ADD_FAILURE() << "Unknown level: " << static_cast<int>(level);
122+
return NIXL_ERR_INVALID_PARAM;
123+
}
124+
}
125+
126+
protected:
127+
static constexpr size_t senderAgent = 0;
128+
static constexpr size_t receiverAgent = 1;
129+
130+
std::vector<nixlBackendH *> backend_handles;
131+
132+
private:
133+
static constexpr uint64_t devId = 0;
134+
135+
std::vector<std::unique_ptr<nixlAgent>> agents;
136+
};
137+
138+
#endif // _DEVICE_TEST_BASE_CUH

0 commit comments

Comments
 (0)