11/*
22 * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33 * SPDX-License-Identifier: Apache-2.0
4+ *
5+ * Licensed under the Apache License, Version 2.0 (the "License");
6+ * you may not use this file except in compliance with the License.
7+ * You may obtain a copy of the License at
8+ *
9+ * http://www.apache.org/licenses/LICENSE-2.0
10+ *
11+ * Unless required by applicable law or agreed to in writing, software
12+ * distributed under the License is distributed on an "AS IS" BASIS,
13+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+ * See the License for the specific language governing permissions and
15+ * limitations under the License.
416 */
517
6- #include " utils .cuh"
18+ #include " device_test_base .cuh"
719
8- namespace gtest {
9- namespace gpu {
10-
11- const char *GetGpuXferLevelStr (nixl_gpu_level_t level) {
12- switch (level) {
13- case nixl_gpu_level_t ::WARP:
14- return " WARP" ;
15- case nixl_gpu_level_t ::BLOCK:
16- return " BLOCK" ;
17- case nixl_gpu_level_t ::THREAD:
18- return " THREAD" ;
19- default :
20- return " UNKNOWN" ;
21- }
22- }
23-
24- void initTiming (unsigned long long **start_time_ptr, unsigned long long **end_time_ptr) {
25- cudaMalloc (start_time_ptr, sizeof (unsigned long long ));
26- cudaMalloc (end_time_ptr, sizeof (unsigned long long ));
27- cudaMemset (*start_time_ptr, 0 , sizeof (unsigned long long ));
28- cudaMemset (*end_time_ptr, 0 , sizeof (unsigned long long ));
29- }
30-
31- void getTiming (unsigned long long *start_time_ptr,
32- unsigned long long *end_time_ptr,
33- unsigned long long &start_time_cpu,
34- unsigned long long &end_time_cpu) {
35- cudaMemcpy (&start_time_cpu, start_time_ptr, sizeof (unsigned long long ), cudaMemcpyDeviceToHost);
36- cudaMemcpy (&end_time_cpu, end_time_ptr, sizeof (unsigned long long ), cudaMemcpyDeviceToHost);
37- }
38-
39- void logResults (size_t size,
40- size_t count,
41- size_t num_iters,
42- unsigned long long start_time_cpu,
43- unsigned long long end_time_cpu) {
44- auto total_time = NS_TO_SEC(end_time_cpu - start_time_cpu);
45- double total_size = size * count * num_iters;
46- auto bandwidth = total_size / total_time / (1024 * 1024 );
47- printf (" Device API Results: %zux%zux%zu=%.0f bytes in %f seconds (%.2f MB/s)\n " ,
48- size, count, num_iters, total_size, total_time, bandwidth);
49- }
50-
51- } // namespace gpu
52- } // namespace gtest
53-
54- nixlAgentConfig DeviceApiTestBase::getConfig () {
55- return nixlAgentConfig (true ,
56- false ,
57- 0 ,
58- nixl_thread_sync_t ::NIXL_THREAD_SYNC_RW,
59- 0 ,
60- 100000 );
20+ nixlAgentConfig
21+ DeviceApiTestBase::getConfig () {
22+ return nixlAgentConfig (true , false , 0 , nixl_thread_sync_t ::NIXL_THREAD_SYNC_RW, 0 , 100000 );
6123}
6224
63- nixl_b_params_t DeviceApiTestBase::getBackendParams () {
25+ nixl_b_params_t
26+ DeviceApiTestBase::getBackendParams () {
6427 nixl_b_params_t params;
6528 params[" num_workers" ] = " 2" ;
6629 return params;
6730}
6831
69- void DeviceApiTestBase::SetUp () {
32+ void
33+ DeviceApiTestBase::SetUp () {
7034 if (cudaSetDevice (0 ) != cudaSuccess) {
7135 FAIL () << " Failed to set CUDA device 0" ;
7236 }
7337
7438 for (size_t i = 0 ; i < 2 ; i++) {
7539 agents.emplace_back (std::make_unique<nixlAgent>(getAgentName (i), getConfig ()));
7640 nixlBackendH *backend_handle = nullptr ;
77- nixl_status_t status = agents.back ()->createBackend (" UCX" , getBackendParams (), backend_handle);
41+ nixl_status_t status =
42+ agents.back ()->createBackend (" UCX" , getBackendParams (), backend_handle);
7843 ASSERT_EQ (status, NIXL_SUCCESS);
7944 EXPECT_NE (backend_handle, nullptr );
8045 backend_handles.push_back (backend_handle);
8146 }
8247}
8348
84- void DeviceApiTestBase::TearDown () {
49+ void
50+ DeviceApiTestBase::TearDown () {
8551 agents.clear ();
86- backend_handles.clear ();
87- }
88-
89- template <typename Desc>
90- nixlDescList<Desc> DeviceApiTestBase::makeDescList (const std::vector<MemBuffer> &buffers, nixl_mem_t mem_type) {
91- nixlDescList<Desc> desc_list (mem_type);
92- for (const auto &buffer : buffers) {
93- desc_list.addDesc (Desc (buffer, buffer.getSize (), uint64_t (DEV_ID)));
94- }
95- return desc_list;
9652}
9753
98- void DeviceApiTestBase::registerMem (nixlAgent &agent, const std::vector<MemBuffer> &buffers, nixl_mem_t mem_type) {
54+ void
55+ DeviceApiTestBase::registerMem (nixlAgent &agent,
56+ const std::vector<MemBuffer> &buffers,
57+ nixl_mem_t mem_type) {
9958 auto reg_list = makeDescList<nixlBlobDesc>(buffers, mem_type);
10059 agent.registerMem (reg_list);
10160}
10261
103- void DeviceApiTestBase::completeWireup (size_t from_agent, size_t to_agent) {
62+ void
63+ DeviceApiTestBase::completeWireup (size_t from_agent, size_t to_agent) {
10464 nixl_notifs_t notifs;
105- nixl_status_t status = getAgent (from_agent).genNotif (getAgentName (to_agent), NOTIF_MSG );
65+ nixl_status_t status = getAgent (from_agent).genNotif (getAgentName (to_agent), notifMsg );
10666 ASSERT_EQ (status, NIXL_SUCCESS) << " Failed to complete wireup" ;
10767
10868 do {
@@ -112,7 +72,8 @@ void DeviceApiTestBase::completeWireup(size_t from_agent, size_t to_agent) {
11272 } while (notifs.size () == 0 );
11373}
11474
115- void DeviceApiTestBase::exchangeMD (size_t from_agent, size_t to_agent) {
75+ void
76+ DeviceApiTestBase::exchangeMD (size_t from_agent, size_t to_agent) {
11677 for (size_t i = 0 ; i < agents.size (); i++) {
11778 nixl_blob_t md;
11879 nixl_status_t status = agents[i]->getLocalMD (md);
@@ -130,7 +91,8 @@ void DeviceApiTestBase::exchangeMD(size_t from_agent, size_t to_agent) {
13091 completeWireup (from_agent, to_agent);
13192}
13293
133- void DeviceApiTestBase::invalidateMD () {
94+ void
95+ DeviceApiTestBase::invalidateMD () {
13496 for (size_t i = 0 ; i < agents.size (); i++) {
13597 for (size_t j = 0 ; j < agents.size (); j++) {
13698 if (i == j) continue ;
@@ -140,24 +102,57 @@ void DeviceApiTestBase::invalidateMD() {
140102 }
141103}
142104
143- void DeviceApiTestBase::createRegisteredMem (nixlAgent &agent,
144- size_t size,
145- size_t count,
146- nixl_mem_t mem_type,
147- std::vector<MemBuffer> &out) {
105+ void
106+ DeviceApiTestBase::createRegisteredMem (nixlAgent &agent,
107+ size_t size,
108+ size_t count,
109+ nixl_mem_t mem_type,
110+ std::vector<MemBuffer> &out) {
148111 while (count-- != 0 ) {
149112 out.emplace_back (size, mem_type);
150113 }
151114
152115 registerMem (agent, out, mem_type);
153116}
154117
155- nixlAgent &DeviceApiTestBase::getAgent (size_t idx) {
118+ nixlAgent &
119+ DeviceApiTestBase::getAgent (size_t idx) {
156120 return *agents[idx];
157121}
158122
159- std::string DeviceApiTestBase::getAgentName (size_t idx) {
123+ std::string
124+ DeviceApiTestBase::getAgentName (size_t idx) {
160125 return absl::StrFormat (" agent_%d" , idx);
161126}
162127
163- template nixlDescList<nixlBasicDesc> DeviceApiTestBase::makeDescList<nixlBasicDesc>(const std::vector<MemBuffer> &buffers, nixl_mem_t mem_type);
128+ void
129+ DeviceApiTestBase::initTiming (unsigned long long **start_time_ptr,
130+ unsigned long long **end_time_ptr) {
131+ cudaMalloc (start_time_ptr, sizeof (unsigned long long ));
132+ cudaMalloc (end_time_ptr, sizeof (unsigned long long ));
133+ cudaMemset (*start_time_ptr, 0 , sizeof (unsigned long long ));
134+ cudaMemset (*end_time_ptr, 0 , sizeof (unsigned long long ));
135+ }
136+
137+ void
138+ DeviceApiTestBase::getTiming (unsigned long long *start_time_ptr,
139+ unsigned long long *end_time_ptr,
140+ unsigned long long &start_time_cpu,
141+ unsigned long long &end_time_cpu) {
142+ cudaMemcpy (&start_time_cpu, start_time_ptr, sizeof (unsigned long long ), cudaMemcpyDeviceToHost);
143+ cudaMemcpy (&end_time_cpu, end_time_ptr, sizeof (unsigned long long ), cudaMemcpyDeviceToHost);
144+ }
145+
146+ const char *
147+ DeviceApiTestBase::GetGpuXferLevelStr (nixl_gpu_level_t level) {
148+ switch (level) {
149+ case nixl_gpu_level_t ::WARP:
150+ return " WARP" ;
151+ case nixl_gpu_level_t ::BLOCK:
152+ return " BLOCK" ;
153+ case nixl_gpu_level_t ::THREAD:
154+ return " THREAD" ;
155+ default :
156+ return " UNKNOWN" ;
157+ }
158+ }
0 commit comments