Skip to content

Commit 53ecfd5

Browse files
authored
[2.4] Ported newest ReliableMessage to 2.4 (#2815)
* Ported 2.5 ReliableMessage to 2.4 * Fixed base_v2 to work with XGBoost 2.11 * Updated copyright year * Updated doc to mention v2.11 * Corrected version number to 2.1.1
1 parent b9d01cb commit 53ecfd5

File tree

12 files changed

+337
-163
lines changed

12 files changed

+337
-163
lines changed

examples/advanced/xgboost/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ These examples show how to use [NVIDIA FLARE](https://nvflare.readthedocs.io/en/
1111
They use [XGBoost](https://github.com/dmlc/xgboost),
1212
which is an optimized distributed gradient boosting library.
1313

14+
The code was tested with XGBoost V2.1.1. It may not work with other versions of XGBoost.
15+
1416
### HIGGS
1517
The examples illustrate a binary classification task based on [HIGGS dataset](https://archive.ics.uci.edu/dataset/280/higgs).
1618
This dataset contains 11 million instances, each with 28 attributes.

nvflare/apis/utils/reliable_message.py

Lines changed: 176 additions & 79 deletions
Large diffs are not rendered by default.

nvflare/app_opt/xgboost/histogram_based/executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def train(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -
260260
self.log_info(fl_ctx, f"server address is {self._server_address}")
261261

262262
communicator_env = {
263-
"xgboost_communicator": "federated",
263+
"dmlc_communicator": "federated",
264264
"federated_server_address": f"{self._server_address}:{xgb_fl_server_port}",
265265
"federated_world_size": self.world_size,
266266
"federated_rank": self.rank,

nvflare/app_opt/xgboost/histogram_based_v2/adaptor_controller.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -209,23 +209,15 @@ def start_controller(self, fl_ctx: FLContext):
209209
adaptor.initialize(fl_ctx)
210210
self.adaptor = adaptor
211211

212-
engine = fl_ctx.get_engine()
213-
engine.register_aux_message_handler(
214-
topic=Constant.TOPIC_XGB_REQUEST,
215-
message_handle_func=self._process_xgb_request,
216-
)
217-
engine.register_aux_message_handler(
218-
topic=Constant.TOPIC_CLIENT_DONE,
219-
message_handle_func=self._process_client_done,
220-
)
221-
222212
ReliableMessage.register_request_handler(
223213
topic=Constant.TOPIC_XGB_REQUEST,
224214
handler_f=self._process_xgb_request,
215+
fl_ctx=fl_ctx,
225216
)
226217
ReliableMessage.register_request_handler(
227218
topic=Constant.TOPIC_CLIENT_DONE,
228219
handler_f=self._process_client_done,
220+
fl_ctx=fl_ctx,
229221
)
230222

231223
def _trigger_stop(self, fl_ctx: FLContext, error=None):

nvflare/app_opt/xgboost/histogram_based_v2/defs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class Constant:
2727
CONF_KEY_NUM_ROUNDS = "num_rounds"
2828

2929
# default component config values
30-
CONFIG_TASK_TIMEOUT = 10
30+
CONFIG_TASK_TIMEOUT = 60
3131
START_TASK_TIMEOUT = 10
3232
XGB_SERVER_READY_TIMEOUT = 5.0
3333

nvflare/app_opt/xgboost/histogram_based_v2/proto/federated.proto

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,30 @@
11
/*!
2-
* Copyright 2022 XGBoost contributors
3-
* needs to match file in https://github.com/dmlc/xgboost/blob/v2.0.3/plugin/federated/federated.proto
2+
* Copyright 2022-2023 XGBoost contributors
43
*/
54
syntax = "proto3";
65

7-
package xgboost.federated;
6+
package xgboost.collective.federated;
87

98
service Federated {
109
rpc Allgather(AllgatherRequest) returns (AllgatherReply) {}
10+
rpc AllgatherV(AllgatherVRequest) returns (AllgatherVReply) {}
1111
rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {}
1212
rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {}
1313
}
1414

1515
enum DataType {
16-
INT8 = 0;
17-
UINT8 = 1;
18-
INT32 = 2;
19-
UINT32 = 3;
20-
INT64 = 4;
21-
UINT64 = 5;
22-
FLOAT = 6;
23-
DOUBLE = 7;
16+
HALF = 0;
17+
FLOAT = 1;
18+
DOUBLE = 2;
19+
LONG_DOUBLE = 3;
20+
INT8 = 4;
21+
INT16 = 5;
22+
INT32 = 6;
23+
INT64 = 7;
24+
UINT8 = 8;
25+
UINT16 = 9;
26+
UINT32 = 10;
27+
UINT64 = 11;
2428
}
2529

2630
enum ReduceOperation {
@@ -43,6 +47,17 @@ message AllgatherReply {
4347
bytes receive_buffer = 1;
4448
}
4549

50+
message AllgatherVRequest {
51+
// An incrementing counter that is unique to each round to operations.
52+
uint64 sequence_number = 1;
53+
int32 rank = 2;
54+
bytes send_buffer = 3;
55+
}
56+
57+
message AllgatherVReply {
58+
bytes receive_buffer = 1;
59+
}
60+
4661
message AllreduceRequest {
4762
// An incrementing counter that is unique to each round to operations.
4863
uint64 sequence_number = 1;
@@ -67,4 +82,4 @@ message BroadcastRequest {
6782

6883
message BroadcastReply {
6984
bytes receive_buffer = 1;
70-
}
85+
}

nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
1415
# -*- coding: utf-8 -*-
1516
# Generated by the protocol buffer compiler. DO NOT EDIT!
1617
# source: federated.proto
17-
# Protobuf Python Version: 4.25.0
18+
# Protobuf Python Version: 4.25.1
1819
"""Generated protocol buffer code."""
1920
from google.protobuf import descriptor as _descriptor
2021
from google.protobuf import descriptor_pool as _descriptor_pool
@@ -27,29 +28,33 @@
2728

2829

2930

30-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x66\x65\x64\x65rated.proto\x12\x11xgboost.federated\"N\n\x10\x41llgatherRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\"(\n\x0e\x41llgatherReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c\"\xbc\x01\n\x10\x41llreduceRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\x12.\n\tdata_type\x18\x04 \x01(\x0e\x32\x1b.xgboost.federated.DataType\x12<\n\x10reduce_operation\x18\x05 \x01(\x0e\x32\".xgboost.federated.ReduceOperation\"(\n\x0e\x41llreduceReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c\"\\\n\x10\x42roadcastRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\x12\x0c\n\x04root\x18\x04 \x01(\x05\"(\n\x0e\x42roadcastReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c*d\n\x08\x44\x61taType\x12\x08\n\x04INT8\x10\x00\x12\t\n\x05UINT8\x10\x01\x12\t\n\x05INT32\x10\x02\x12\n\n\x06UINT32\x10\x03\x12\t\n\x05INT64\x10\x04\x12\n\n\x06UINT64\x10\x05\x12\t\n\x05\x46LOAT\x10\x06\x12\n\n\x06\x44OUBLE\x10\x07*^\n\x0fReduceOperation\x12\x07\n\x03MAX\x10\x00\x12\x07\n\x03MIN\x10\x01\x12\x07\n\x03SUM\x10\x02\x12\x0f\n\x0b\x42ITWISE_AND\x10\x03\x12\x0e\n\nBITWISE_OR\x10\x04\x12\x0f\n\x0b\x42ITWISE_XOR\x10\x05\x32\x90\x02\n\tFederated\x12U\n\tAllgather\x12#.xgboost.federated.AllgatherRequest\x1a!.xgboost.federated.AllgatherReply\"\x00\x12U\n\tAllreduce\x12#.xgboost.federated.AllreduceRequest\x1a!.xgboost.federated.AllreduceReply\"\x00\x12U\n\tBroadcast\x12#.xgboost.federated.BroadcastRequest\x1a!.xgboost.federated.BroadcastReply\"\x00\x62\x06proto3')
31+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x66\x65\x64\x65rated.proto\x12\x1cxgboost.collective.federated\"N\n\x10\x41llgatherRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\"(\n\x0e\x41llgatherReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c\"O\n\x11\x41llgatherVRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\")\n\x0f\x41llgatherVReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c\"\xd2\x01\n\x10\x41llreduceRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\x12\x39\n\tdata_type\x18\x04 \x01(\x0e\x32&.xgboost.collective.federated.DataType\x12G\n\x10reduce_operation\x18\x05 \x01(\x0e\x32-.xgboost.collective.federated.ReduceOperation\"(\n\x0e\x41llreduceReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c\"\\\n\x10\x42roadcastRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\x12\x0c\n\x04root\x18\x04 \x01(\x05\"(\n\x0e\x42roadcastReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c*\x96\x01\n\x08\x44\x61taType\x12\x08\n\x04HALF\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\n\n\x06\x44OUBLE\x10\x02\x12\x0f\n\x0bLONG_DOUBLE\x10\x03\x12\x08\n\x04INT8\x10\x04\x12\t\n\x05INT16\x10\x05\x12\t\n\x05INT32\x10\x06\x12\t\n\x05INT64\x10\x07\x12\t\n\x05UINT8\x10\x08\x12\n\n\x06UINT16\x10\t\x12\n\n\x06UINT32\x10\n\x12\n\n\x06UINT64\x10\x0b*^\n\x0fReduceOperation\x12\x07\n\x03MAX\x10\x00\x12\x07\n\x03MIN\x10\x01\x12\x07\n\x03SUM\x10\x02\x12\x0f\n\x0b\x42ITWISE_AND\x10\x03\x12\x0e\n\nBITWISE_OR\x10\x04\x12\x0f\n\x0b\x42ITWISE_XOR\x10\x05\x32\xc2\x03\n\tFederated\x12k\n\tAllgather\x12..xgboost.collective.federated.AllgatherRequest\x1a,.xgboost.collective.federated.AllgatherReply\"\x00\x12n\n\nAllgatherV\x12/.xgboost.collective.federated.AllgatherVRequest\x1a-.xgboost.collective.federated.AllgatherVReply\"\x00\x12k\n\tAllreduce\x12..xgboost.collective.federated.AllreduceRequest\x1a,.xgboost.collective.federated.AllreduceReply\"\x00\x12k\n\tBroadcast\x12..xgboost.collective.federated.BroadcastRequest\x1a,.xgboost.collective.federated.BroadcastReply\"\x00\x62\x06proto3')
3132

3233
_globals = globals()
3334
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
3435
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'federated_pb2', _globals)
3536
if _descriptor._USE_C_DESCRIPTORS == False:
3637
DESCRIPTOR._options = None
37-
_globals['_DATATYPE']._serialized_start=529
38-
_globals['_DATATYPE']._serialized_end=629
39-
_globals['_REDUCEOPERATION']._serialized_start=631
40-
_globals['_REDUCEOPERATION']._serialized_end=725
41-
_globals['_ALLGATHERREQUEST']._serialized_start=38
42-
_globals['_ALLGATHERREQUEST']._serialized_end=116
43-
_globals['_ALLGATHERREPLY']._serialized_start=118
44-
_globals['_ALLGATHERREPLY']._serialized_end=158
45-
_globals['_ALLREDUCEREQUEST']._serialized_start=161
46-
_globals['_ALLREDUCEREQUEST']._serialized_end=349
47-
_globals['_ALLREDUCEREPLY']._serialized_start=351
48-
_globals['_ALLREDUCEREPLY']._serialized_end=391
49-
_globals['_BROADCASTREQUEST']._serialized_start=393
50-
_globals['_BROADCASTREQUEST']._serialized_end=485
51-
_globals['_BROADCASTREPLY']._serialized_start=487
52-
_globals['_BROADCASTREPLY']._serialized_end=527
53-
_globals['_FEDERATED']._serialized_start=728
54-
_globals['_FEDERATED']._serialized_end=1000
38+
_globals['_DATATYPE']._serialized_start=687
39+
_globals['_DATATYPE']._serialized_end=837
40+
_globals['_REDUCEOPERATION']._serialized_start=839
41+
_globals['_REDUCEOPERATION']._serialized_end=933
42+
_globals['_ALLGATHERREQUEST']._serialized_start=49
43+
_globals['_ALLGATHERREQUEST']._serialized_end=127
44+
_globals['_ALLGATHERREPLY']._serialized_start=129
45+
_globals['_ALLGATHERREPLY']._serialized_end=169
46+
_globals['_ALLGATHERVREQUEST']._serialized_start=171
47+
_globals['_ALLGATHERVREQUEST']._serialized_end=250
48+
_globals['_ALLGATHERVREPLY']._serialized_start=252
49+
_globals['_ALLGATHERVREPLY']._serialized_end=293
50+
_globals['_ALLREDUCEREQUEST']._serialized_start=296
51+
_globals['_ALLREDUCEREQUEST']._serialized_end=506
52+
_globals['_ALLREDUCEREPLY']._serialized_start=508
53+
_globals['_ALLREDUCEREPLY']._serialized_end=548
54+
_globals['_BROADCASTREQUEST']._serialized_start=550
55+
_globals['_BROADCASTREQUEST']._serialized_end=642
56+
_globals['_BROADCASTREPLY']._serialized_start=644
57+
_globals['_BROADCASTREPLY']._serialized_end=684
58+
_globals['_FEDERATED']._serialized_start=936
59+
_globals['_FEDERATED']._serialized_end=1386
5560
# @@protoc_insertion_point(module_scope)

nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.pyi

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,18 @@ DESCRIPTOR: _descriptor.FileDescriptor
77

88
class DataType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
99
__slots__ = ()
10+
HALF: _ClassVar[DataType]
11+
FLOAT: _ClassVar[DataType]
12+
DOUBLE: _ClassVar[DataType]
13+
LONG_DOUBLE: _ClassVar[DataType]
1014
INT8: _ClassVar[DataType]
11-
UINT8: _ClassVar[DataType]
15+
INT16: _ClassVar[DataType]
1216
INT32: _ClassVar[DataType]
13-
UINT32: _ClassVar[DataType]
1417
INT64: _ClassVar[DataType]
18+
UINT8: _ClassVar[DataType]
19+
UINT16: _ClassVar[DataType]
20+
UINT32: _ClassVar[DataType]
1521
UINT64: _ClassVar[DataType]
16-
FLOAT: _ClassVar[DataType]
17-
DOUBLE: _ClassVar[DataType]
1822

1923
class ReduceOperation(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
2024
__slots__ = ()
@@ -24,14 +28,18 @@ class ReduceOperation(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
2428
BITWISE_AND: _ClassVar[ReduceOperation]
2529
BITWISE_OR: _ClassVar[ReduceOperation]
2630
BITWISE_XOR: _ClassVar[ReduceOperation]
31+
HALF: DataType
32+
FLOAT: DataType
33+
DOUBLE: DataType
34+
LONG_DOUBLE: DataType
2735
INT8: DataType
28-
UINT8: DataType
36+
INT16: DataType
2937
INT32: DataType
30-
UINT32: DataType
3138
INT64: DataType
39+
UINT8: DataType
40+
UINT16: DataType
41+
UINT32: DataType
3242
UINT64: DataType
33-
FLOAT: DataType
34-
DOUBLE: DataType
3543
MAX: ReduceOperation
3644
MIN: ReduceOperation
3745
SUM: ReduceOperation
@@ -55,6 +63,22 @@ class AllgatherReply(_message.Message):
5563
receive_buffer: bytes
5664
def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ...
5765

66+
class AllgatherVRequest(_message.Message):
67+
__slots__ = ("sequence_number", "rank", "send_buffer")
68+
SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int]
69+
RANK_FIELD_NUMBER: _ClassVar[int]
70+
SEND_BUFFER_FIELD_NUMBER: _ClassVar[int]
71+
sequence_number: int
72+
rank: int
73+
send_buffer: bytes
74+
def __init__(self, sequence_number: _Optional[int] = ..., rank: _Optional[int] = ..., send_buffer: _Optional[bytes] = ...) -> None: ...
75+
76+
class AllgatherVReply(_message.Message):
77+
__slots__ = ("receive_buffer",)
78+
RECEIVE_BUFFER_FIELD_NUMBER: _ClassVar[int]
79+
receive_buffer: bytes
80+
def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ...
81+
5882
class AllreduceRequest(_message.Message):
5983
__slots__ = ("sequence_number", "rank", "send_buffer", "data_type", "reduce_operation")
6084
SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int]

0 commit comments

Comments
 (0)