Skip to content

Commit 45a16d5

Browse files
XGBoost plugin with new API (#2725)
* Updated FOBS readme to add DatumManager, added agrpcs as secure scheme * Implemented LocalPlugin * Refactoring plugin * Fixed formats * Fixed horizontal secure isses with mismatching algather-v sizes * Added padding to the buffer so it's big enough for histograms * Format fix * Changed log level for tenseal exceptions * Fixed a typo * Added debug statements * Fixed LocalPlugin horizontal bug * Added #include <chrono> * Added docstring to BasePlugin --------- Co-authored-by: Yuan-Ting Hsieh (謝沅廷) <[email protected]>
1 parent 4b32f27 commit 45a16d5

38 files changed

+2209
-758
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
root = true
2+
3+
[*]
4+
charset=utf-8
5+
indent_style = space
6+
indent_size = 2
7+
insert_final_newline = true
8+
9+
[*.py]
10+
indent_style = space
11+
indent_size = 4
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
cmake_minimum_required(VERSION 3.19)
2+
project(xgb_nvflare LANGUAGES CXX C VERSION 1.0)
3+
set(CMAKE_CXX_STANDARD 17)
4+
set(CMAKE_BUILD_TYPE Debug)
5+
6+
option(GOOGLE_TEST "Build google tests" OFF)
7+
8+
file(GLOB_RECURSE LIB_SRC "src/*.cc")
9+
10+
add_library(nvflare SHARED ${LIB_SRC})
11+
set_target_properties(nvflare PROPERTIES
12+
CXX_STANDARD 17
13+
CXX_STANDARD_REQUIRED ON
14+
POSITION_INDEPENDENT_CODE ON
15+
ENABLE_EXPORTS ON
16+
)
17+
target_include_directories(nvflare PRIVATE ${xgb_nvflare_SOURCE_DIR}/src/include)
18+
19+
if (APPLE)
20+
add_link_options("LINKER:-object_path_lto,$<TARGET_PROPERTY:NAME>_lto.o")
21+
add_link_options("LINKER:-cache_path_lto,${CMAKE_BINARY_DIR}/LTOCache")
22+
endif ()
23+
24+
#-- Unit Tests
25+
if(GOOGLE_TEST)
26+
find_package(GTest REQUIRED)
27+
enable_testing()
28+
add_executable(nvflare_test)
29+
target_link_libraries(nvflare_test PRIVATE nvflare)
30+
31+
32+
target_include_directories(nvflare_test PRIVATE ${xgb_nvflare_SOURCE_DIR}/src/include)
33+
34+
add_subdirectory(${xgb_nvflare_SOURCE_DIR}/tests)
35+
36+
add_test(
37+
NAME TestNvflarePlugins
38+
COMMAND nvflare_test
39+
WORKING_DIRECTORY ${xgb_nvflare_BINARY_DIR})
40+
41+
endif()
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Build Instruction
2+
3+
cd NVFlare/integration/xgboost/encryption_plugins
4+
mkdir build
5+
cd build
6+
cmake ..
7+
make
8+
9+
The library is libxgb_nvflare.so
File renamed without changes.
File renamed without changes.
Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
/**
2+
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include <iostream>
17+
#include <cstring>
18+
#include "dam.h"
19+
20+
21+
void print_hex(const uint8_t *buffer, std::size_t size) {
22+
std::cout << std::hex;
23+
for (int i = 0; i < size; i++) {
24+
int c = buffer[i];
25+
std::cout << c << " ";
26+
}
27+
std::cout << std::endl << std::dec;
28+
}
29+
30+
void print_buffer(const uint8_t *buffer, std::size_t size) {
31+
if (size <= 64) {
32+
std::cout << "Whole buffer: " << size << " bytes" << std::endl;
33+
print_hex(buffer, size);
34+
return;
35+
}
36+
37+
std::cout << "First chunk, Total: " << size << " bytes" << std::endl;
38+
print_hex(buffer, 32);
39+
std::cout << "Last chunk, Offset: " << size-16 << " bytes" << std::endl;
40+
print_hex(buffer+size-32, 32);
41+
}
42+
43+
size_t align(const size_t length) {
44+
return ((length + 7)/8)*8;
45+
}
46+
47+
// DamEncoder ======
48+
void DamEncoder::AddBuffer(const Buffer &buffer) {
49+
if (debug_) {
50+
std::cout << "AddBuffer called, size: " << buffer.buf_size << std::endl;
51+
}
52+
if (encoded_) {
53+
std::cout << "Buffer is already encoded" << std::endl;
54+
return;
55+
}
56+
// print_buffer(buffer, buf_size);
57+
entries_.emplace_back(kDataTypeBuffer, static_cast<const uint8_t *>(buffer.buffer), buffer.buf_size);
58+
}
59+
60+
void DamEncoder::AddFloatArray(const std::vector<double> &value) {
61+
if (debug_) {
62+
std::cout << "AddFloatArray called, size: " << value.size() << std::endl;
63+
}
64+
65+
if (encoded_) {
66+
std::cout << "Buffer is already encoded" << std::endl;
67+
return;
68+
}
69+
// print_buffer(reinterpret_cast<uint8_t *>(value.data()), value.size() * 8);
70+
entries_.emplace_back(kDataTypeFloatArray, reinterpret_cast<const uint8_t *>(value.data()), value.size());
71+
}
72+
73+
void DamEncoder::AddIntArray(const std::vector<int64_t> &value) {
74+
if (debug_) {
75+
std::cout << "AddIntArray called, size: " << value.size() << std::endl;
76+
}
77+
78+
if (encoded_) {
79+
std::cout << "Buffer is already encoded" << std::endl;
80+
return;
81+
}
82+
// print_buffer(buffer, buf_size);
83+
entries_.emplace_back(kDataTypeIntArray, reinterpret_cast<const uint8_t *>(value.data()), value.size());
84+
}
85+
86+
void DamEncoder::AddBufferArray(const std::vector<Buffer> &value) {
87+
if (debug_) {
88+
std::cout << "AddBufferArray called, size: " << value.size() << std::endl;
89+
}
90+
91+
if (encoded_) {
92+
std::cout << "Buffer is already encoded" << std::endl;
93+
return;
94+
}
95+
size_t size = 0;
96+
for (auto &buf: value) {
97+
size += buf.buf_size;
98+
}
99+
size += 8*value.size();
100+
entries_.emplace_back(kDataTypeBufferArray, reinterpret_cast<const uint8_t *>(&value), size);
101+
}
102+
103+
104+
std::uint8_t * DamEncoder::Finish(size_t &size) {
105+
encoded_ = true;
106+
107+
size = CalculateSize();
108+
auto buf = static_cast<uint8_t *>(calloc(size, 1));
109+
auto pointer = buf;
110+
auto sig = local_version_ ? kSignatureLocal : kSignature;
111+
memcpy(pointer, sig, strlen(sig));
112+
memcpy(pointer+8, &size, 8);
113+
memcpy(pointer+16, &data_set_id_, 8);
114+
115+
pointer += kPrefixLen;
116+
for (auto& entry : entries_) {
117+
std::size_t len;
118+
if (entry.data_type == kDataTypeBufferArray) {
119+
auto buffers = reinterpret_cast<const std::vector<Buffer> *>(entry.pointer);
120+
memcpy(pointer, &entry.data_type, 8);
121+
pointer += 8;
122+
auto array_size = static_cast<int64_t>(buffers->size());
123+
memcpy(pointer, &array_size, 8);
124+
pointer += 8;
125+
auto sizes = reinterpret_cast<int64_t *>(pointer);
126+
for (auto &item : *buffers) {
127+
*sizes = static_cast<int64_t>(item.buf_size);
128+
sizes++;
129+
}
130+
len = 8*buffers->size();
131+
auto buf_ptr = pointer + len;
132+
for (auto &item : *buffers) {
133+
if (item.buf_size > 0) {
134+
memcpy(buf_ptr, item.buffer, item.buf_size);
135+
}
136+
buf_ptr += item.buf_size;
137+
len += item.buf_size;
138+
}
139+
} else {
140+
memcpy(pointer, &entry.data_type, 8);
141+
pointer += 8;
142+
memcpy(pointer, &entry.size, 8);
143+
pointer += 8;
144+
len = entry.size * entry.ItemSize();
145+
if (len) {
146+
memcpy(pointer, entry.pointer, len);
147+
}
148+
}
149+
pointer += align(len);
150+
}
151+
152+
if ((pointer - buf) != size) {
153+
std::cout << "Invalid encoded size: " << (pointer - buf) << std::endl;
154+
return nullptr;
155+
}
156+
157+
return buf;
158+
}
159+
160+
std::size_t DamEncoder::CalculateSize() {
161+
std::size_t size = kPrefixLen;
162+
163+
for (auto& entry : entries_) {
164+
size += 16; // The Type and Len
165+
auto len = entry.size * entry.ItemSize();
166+
size += align(len);
167+
}
168+
169+
return size;
170+
}
171+
172+
173+
// DamDecoder ======
174+
175+
DamDecoder::DamDecoder(std::uint8_t *buffer, std::size_t size, bool local_version, bool debug) {
176+
local_version_ = local_version;
177+
buffer_ = buffer;
178+
buf_size_ = size;
179+
pos_ = buffer + kPrefixLen;
180+
debug_ = debug;
181+
182+
if (size >= kPrefixLen) {
183+
memcpy(&len_, buffer + 8, 8);
184+
memcpy(&data_set_id_, buffer + 16, 8);
185+
} else {
186+
len_ = 0;
187+
data_set_id_ = 0;
188+
}
189+
}
190+
191+
bool DamDecoder::IsValid() const {
192+
auto sig = local_version_ ? kSignatureLocal : kSignature;
193+
return buf_size_ >= kPrefixLen && memcmp(buffer_, sig, strlen(sig)) == 0;
194+
}
195+
196+
Buffer DamDecoder::DecodeBuffer() {
197+
auto type = *reinterpret_cast<int64_t *>(pos_);
198+
if (type != kDataTypeBuffer) {
199+
std::cout << "Data type " << type << " doesn't match bytes" << std::endl;
200+
return {};
201+
}
202+
pos_ += 8;
203+
204+
auto size = *reinterpret_cast<int64_t *>(pos_);
205+
pos_ += 8;
206+
207+
if (size == 0) {
208+
return {};
209+
}
210+
211+
auto ptr = reinterpret_cast<void *>(pos_);
212+
pos_ += align(size);
213+
return{ ptr, static_cast<std::size_t>(size)};
214+
}
215+
216+
std::vector<int64_t> DamDecoder::DecodeIntArray() {
217+
auto type = *reinterpret_cast<int64_t *>(pos_);
218+
if (type != kDataTypeIntArray) {
219+
std::cout << "Data type " << type << " doesn't match Int Array" << std::endl;
220+
return {};
221+
}
222+
pos_ += 8;
223+
224+
auto array_size = *reinterpret_cast<int64_t *>(pos_);
225+
pos_ += 8;
226+
auto ptr = reinterpret_cast<int64_t *>(pos_);
227+
pos_ += align(8 * array_size);
228+
return {ptr, ptr + array_size};
229+
}
230+
231+
std::vector<double> DamDecoder::DecodeFloatArray() {
232+
auto type = *reinterpret_cast<int64_t *>(pos_);
233+
if (type != kDataTypeFloatArray) {
234+
std::cout << "Data type " << type << " doesn't match Float Array" << std::endl;
235+
return {};
236+
}
237+
pos_ += 8;
238+
239+
auto array_size = *reinterpret_cast<int64_t *>(pos_);
240+
pos_ += 8;
241+
242+
auto ptr = reinterpret_cast<double *>(pos_);
243+
pos_ += align(8 * array_size);
244+
return {ptr, ptr + array_size};
245+
}
246+
247+
std::vector<Buffer> DamDecoder::DecodeBufferArray() {
248+
auto type = *reinterpret_cast<int64_t *>(pos_);
249+
if (type != kDataTypeBufferArray) {
250+
std::cout << "Data type " << type << " doesn't match Bytes Array" << std::endl;
251+
return {};
252+
}
253+
pos_ += 8;
254+
255+
auto num = *reinterpret_cast<int64_t *>(pos_);
256+
pos_ += 8;
257+
258+
auto size_ptr = reinterpret_cast<int64_t *>(pos_);
259+
auto buf_ptr = pos_ + 8 * num;
260+
size_t total_size = 8 * num;
261+
auto result = std::vector<Buffer>(num);
262+
for (int i = 0; i < num; i++) {
263+
auto size = size_ptr[i];
264+
if (buf_size_ > 0) {
265+
result[i].buf_size = size;
266+
result[i].buffer = buf_ptr;
267+
buf_ptr += size;
268+
}
269+
total_size += size;
270+
}
271+
272+
pos_ += align(total_size);
273+
return result;
274+
}

0 commit comments

Comments
 (0)