Skip to content

Commit 0966722

Browse files
authored
[2.4] Integrated ReliableMessage with XGBoost (#2399)
* Integrated ReliableMessage with XGBoost * Addressed comments in PR * Removed RequestSender and moved Sender to nvflare.apis.util * Removed _check_reply * Fixed a test error caused by not having a mock engine * Removed the extra self.engine initialization * Removed duplicate assignment of self.engine
1 parent 08c202f commit 0966722

File tree

10 files changed

+250
-104
lines changed

10 files changed

+250
-104
lines changed

nvflare/apis/utils/reliable_message.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def _send_request(
304304
# keep sending the request until a positive ack or result is received
305305
num_tries = 0
306306
while True:
307-
if abort_signal.triggered:
307+
if abort_signal and abort_signal.triggered:
308308
return make_reply(ReturnCode.TASK_ABORTED)
309309

310310
ack = engine.send_aux_request(
@@ -338,7 +338,7 @@ def _send_request(
338338
return _error_reply(ReturnCode.COMMUNICATION_ERROR, f"Max send retries ({cls._max_retries}) reached")
339339
start = time.time()
340340
while time.time() - start < cls._query_interval:
341-
if abort_signal.triggered:
341+
if abort_signal and abort_signal.triggered:
342342
return make_reply(ReturnCode.TASK_ABORTED)
343343
time.sleep(0.1)
344344

@@ -355,6 +355,7 @@ def _query_result(
355355
) -> Shareable:
356356

357357
# Querying phase - try to get result
358+
engine = fl_ctx.get_engine()
358359
query = Shareable()
359360
query.set_header(HEADER_TX, receiver.tx_id)
360361
query.set_header(HEADER_OP, OP_QUERY)
@@ -367,7 +368,7 @@ def _query_result(
367368
# check other condition and/or send query to ask for result.
368369
return receiver.result
369370

370-
if abort_signal.triggered:
371+
if abort_signal and abort_signal.triggered:
371372
return make_reply(ReturnCode.TASK_ABORTED)
372373

373374
# send a query. The ack of the query could be the result itself, or a status report.
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from nvflare.apis.fl_context import FLContext
15+
from nvflare.apis.shareable import Shareable
16+
from nvflare.apis.signal import Signal
17+
from nvflare.apis.utils.reliable_message import ReliableMessage
18+
from nvflare.apis.utils.sender import Sender
19+
from nvflare.fuel.f3.cellnet.fqcn import FQCN
20+
21+
22+
class ReliableSender(Sender):
23+
def __init__(self, max_request_workers=20, query_interval=5, max_retries=5, max_tx_time=300.0):
24+
"""Constructor
25+
26+
Args:
27+
max_request_workers: Number of concurrent request worker threads
28+
query_interval: Retry/query interval
29+
max_retries: Number of retries
30+
max_tx_time: Max transmitting time
31+
"""
32+
33+
super().__init__()
34+
self.max_request_workers = max_request_workers
35+
self.query_interval = query_interval
36+
self.max_retries = max_retries
37+
self.max_tx_time = max_tx_time
38+
self.enabled = False
39+
40+
def send_request(
41+
self, target: str, topic: str, req: Shareable, timeout: float, fl_ctx: FLContext, abort_signal: Signal
42+
) -> Shareable:
43+
44+
if not self.enabled:
45+
ReliableMessage.enable(
46+
fl_ctx,
47+
max_request_workers=self.max_request_workers,
48+
query_interval=self.query_interval,
49+
max_retries=self.max_retries,
50+
max_tx_time=self.max_tx_time,
51+
)
52+
self.enabled = True
53+
54+
return ReliableMessage.send_request(FQCN.ROOT_SERVER, topic, req, timeout, abort_signal, fl_ctx)

nvflare/apis/utils/sender.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from abc import ABC, abstractmethod
15+
from typing import Optional
16+
17+
from nvflare.apis.fl_component import FLComponent
18+
from nvflare.apis.fl_context import FLContext
19+
from nvflare.apis.shareable import Shareable
20+
from nvflare.apis.signal import Signal
21+
from nvflare.fuel.f3.cellnet.fqcn import FQCN
22+
23+
24+
class Sender(FLComponent, ABC):
25+
"""An abstract class to send request"""
26+
27+
@abstractmethod
28+
def send_request(
29+
self, target: str, topic: str, req: Shareable, timeout: float, fl_ctx: FLContext, abort_signal: Signal
30+
) -> Optional[Shareable]:
31+
"""Send a request to target. This is an abstract method. Derived class must implement this method
32+
33+
Args:
34+
target: The destination
35+
topic: Topic for the request
36+
req: the request Shareable
37+
timeout: Timeout of the request in seconds
38+
fl_ctx: FLContext for the transaction
39+
abort_signal: used for checking whether the job is aborted.
40+
41+
Returns:
42+
The reply in Shareable
43+
44+
"""
45+
pass
46+
47+
def send_to_server(
48+
self, topic: str, req: Shareable, timeout: float, fl_ctx: FLContext, abort_signal: Signal
49+
) -> Optional[Shareable]:
50+
"""Send an XGB request to the server.
51+
52+
Args:
53+
topic: The topic of the request
54+
req: the request Shareable
55+
timeout: The timeout value for the request
56+
fl_ctx: The FLContext for the request
57+
abort_signal: used for checking whether the job is aborted.
58+
59+
Returns: reply from the server
60+
"""
61+
62+
return self.send_request(FQCN.ROOT_SERVER, topic, req, timeout, fl_ctx, abort_signal)
63+
64+
65+
class SimpleSender(Sender):
66+
def __init__(self):
67+
super().__init__()
68+
69+
def send_request(
70+
self, target: str, topic: str, req: Shareable, timeout: float, fl_ctx: FLContext, abort_signal: Signal
71+
) -> Optional[Shareable]:
72+
73+
engine = fl_ctx.get_engine()
74+
reply = engine.send_aux_request(
75+
targets=[target],
76+
topic=topic,
77+
request=req,
78+
timeout=timeout,
79+
fl_ctx=fl_ctx,
80+
)
81+
82+
# send_aux_request returns multiple replies in a dict
83+
if reply:
84+
return reply.get(target)
85+
else:
86+
return None

nvflare/app_opt/xgboost/histogram_based_v2/adaptor.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
from typing import Tuple
1919

2020
from nvflare.apis.fl_component import FLComponent
21+
from nvflare.apis.fl_constant import ReturnCode
2122
from nvflare.apis.fl_context import FLContext
2223
from nvflare.apis.shareable import Shareable
2324
from nvflare.apis.signal import Signal
25+
from nvflare.apis.utils.sender import Sender
2426
from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant
2527
from nvflare.app_opt.xgboost.histogram_based_v2.runner import XGBRunner
26-
from nvflare.app_opt.xgboost.histogram_based_v2.sender import Sender
2728
from nvflare.fuel.utils.validation_utils import check_non_negative_int, check_object_type, check_positive_int
2829

2930

@@ -279,7 +280,7 @@ class XGBClientAdaptor(XGBAdaptor, ABC):
279280
XGBClientAdaptor specifies commonly required methods for client adaptor implementations.
280281
"""
281282

282-
def __init__(self):
283+
def __init__(self, req_timeout: float):
283284
"""Constructor of XGBClientAdaptor"""
284285
XGBAdaptor.__init__(self)
285286
self.engine = None
@@ -288,6 +289,7 @@ def __init__(self):
288289
self.rank = None
289290
self.num_rounds = None
290291
self.world_size = None
292+
self.req_timeout = req_timeout
291293

292294
def set_sender(self, sender: Sender):
293295
"""Set the sender to be used to send XGB operation requests to the server.
@@ -314,6 +316,8 @@ def configure(self, config: dict, fl_ctx: FLContext):
314316
Returns:
315317
None
316318
"""
319+
self.engine = fl_ctx.get_engine()
320+
317321
ws = config.get(Constant.CONF_KEY_WORLD_SIZE)
318322
if not ws:
319323
raise RuntimeError("world_size is not configured")
@@ -345,8 +349,22 @@ def _send_request(self, op: str, req: Shareable) -> bytes:
345349
Returns:
346350
operation result
347351
"""
348-
reply = self.sender.send_to_server(op, req, self.abort_signal)
352+
req.set_header(Constant.MSG_KEY_XGB_OP, op)
353+
354+
with self.engine.new_context() as fl_ctx:
355+
reply = self.sender.send_to_server(
356+
Constant.TOPIC_XGB_REQUEST, req, self.req_timeout, fl_ctx, self.abort_signal
357+
)
358+
349359
if isinstance(reply, Shareable):
360+
rc = reply.get_return_code()
361+
if rc != ReturnCode.OK:
362+
raise RuntimeError(f"received error return code: {rc}")
363+
364+
reply_op = reply.get_header(Constant.MSG_KEY_XGB_OP)
365+
if reply_op != op:
366+
raise RuntimeError(f"received op {reply_op} != expected op {op}")
367+
350368
rcv_buf = reply.get(Constant.PARAM_KEY_RCV_BUF)
351369
if not isinstance(rcv_buf, bytes):
352370
raise RuntimeError(f"invalid rcv_buf for {op=}: expect bytes but got {type(rcv_buf)}")

nvflare/app_opt/xgboost/histogram_based_v2/adaptor_controller.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from nvflare.apis.impl.controller import ClientTask, Controller, Task
2323
from nvflare.apis.shareable import ReturnCode, Shareable, make_reply
2424
from nvflare.apis.signal import Signal
25+
from nvflare.apis.utils.reliable_message import ReliableMessage
2526
from nvflare.app_opt.xgboost.histogram_based_v2.adaptor import XGBServerAdaptor
2627
from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant
2728
from nvflare.fuel.utils.validation_utils import check_number_range, check_object_type, check_positive_number, check_str
@@ -218,6 +219,16 @@ def start_controller(self, fl_ctx: FLContext):
218219
message_handle_func=self._process_client_done,
219220
)
220221

222+
ReliableMessage.enable(fl_ctx)
223+
ReliableMessage.register_request_handler(
224+
topic=Constant.TOPIC_XGB_REQUEST,
225+
handler_f=self._process_xgb_request,
226+
)
227+
ReliableMessage.register_request_handler(
228+
topic=Constant.TOPIC_CLIENT_DONE,
229+
handler_f=self._process_client_done,
230+
)
231+
221232
def _trigger_stop(self, fl_ctx: FLContext, error=None):
222233
# first trigger the abort_signal to tell all components (mainly the controller's control_flow and adaptor)
223234
# that check this signal to abort.

nvflare/app_opt/xgboost/histogram_based_v2/adaptor_executor.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,25 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
from nvflare.apis.event_type import EventType
1615
from nvflare.apis.executor import Executor
1716
from nvflare.apis.fl_constant import ReturnCode
1817
from nvflare.apis.fl_context import FLContext
1918
from nvflare.apis.shareable import Shareable, make_reply
2019
from nvflare.apis.signal import Signal
20+
from nvflare.apis.utils.sender import Sender, SimpleSender
2121
from nvflare.app_opt.xgboost.histogram_based_v2.adaptor import XGBClientAdaptor
2222
from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant
23-
from nvflare.app_opt.xgboost.histogram_based_v2.sender import Sender
2423
from nvflare.fuel.f3.cellnet.fqcn import FQCN
24+
from nvflare.fuel.utils.validation_utils import check_str
2525
from nvflare.security.logging import secure_format_exception
2626

2727

2828
class XGBExecutor(Executor):
2929
def __init__(
3030
self,
3131
adaptor_component_id: str,
32+
sender_id: str = None,
3233
configure_task_name=Constant.CONFIG_TASK_NAME,
3334
start_task_name=Constant.START_TASK_NAME,
3435
req_timeout=100.0,
@@ -37,11 +38,17 @@ def __init__(
3738
3839
Args:
3940
adaptor_component_id: the component ID of client target adaptor
41+
sender_id: The sender component id
4042
configure_task_name: name of the config task
4143
start_task_name: name of the start task
4244
"""
4345
Executor.__init__(self)
4446
self.adaptor_component_id = adaptor_component_id
47+
48+
if sender_id:
49+
check_str("sender_id", sender_id)
50+
self.sender_id = sender_id
51+
4552
self.req_timeout = req_timeout
4653
self.configure_task_name = configure_task_name
4754
self.start_task_name = start_task_name
@@ -78,9 +85,12 @@ def handle_event(self, event_type: str, fl_ctx: FLContext):
7885
)
7986
return
8087

88+
sender = self._get_sender(fl_ctx)
89+
if not sender:
90+
return
91+
8192
adaptor.set_abort_signal(self.abort_signal)
82-
engine = fl_ctx.get_engine()
83-
adaptor.set_sender(Sender(engine, self.req_timeout))
93+
adaptor.set_sender(sender)
8494
adaptor.initialize(fl_ctx)
8595
self.adaptor = adaptor
8696
elif event_type == EventType.END_RUN:
@@ -168,3 +178,31 @@ def _notify_client_done(self, rc, fl_ctx: FLContext):
168178
fl_ctx=fl_ctx,
169179
optional=True,
170180
)
181+
182+
def _get_sender(self, fl_ctx: FLContext) -> Sender:
183+
"""Get request sender to be used by this executor.
184+
185+
Args:
186+
fl_ctx: the FL context
187+
188+
Returns:
189+
A sender object
190+
"""
191+
192+
if self.sender_id:
193+
engine = fl_ctx.get_engine()
194+
sender = engine.get_component(self.sender_id)
195+
if not sender:
196+
self.system_panic(f"cannot get component for {self.sender_id}", fl_ctx)
197+
else:
198+
if not isinstance(sender, Sender):
199+
self.system_panic(
200+
f"invalid component '{self.sender_id}': expect {Sender.__name__} but got {type(sender)}",
201+
fl_ctx,
202+
)
203+
sender = None
204+
205+
else:
206+
sender = SimpleSender()
207+
208+
return sender

nvflare/app_opt/xgboost/histogram_based_v2/adaptors/grpc_client_adaptor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,17 @@ def __init__(
8383
self,
8484
int_server_grpc_options=None,
8585
in_process=False,
86+
req_timeout=100,
8687
):
8788
"""Constructor method to initialize the object.
8889
8990
Args:
9091
int_server_grpc_options: An optional list of key-value pairs (`channel_arguments`
9192
in gRPC Core runtime) to configure the gRPC channel of internal `GrpcServer`.
9293
in_process (bool): Specifies whether to start the `XGBRunner` in the same process or not.
94+
req_timeout: Request timeout
9395
"""
94-
XGBClientAdaptor.__init__(self)
96+
XGBClientAdaptor.__init__(self, req_timeout)
9597
self.int_server_grpc_options = int_server_grpc_options
9698
self.in_process = in_process
9799
self.internal_xgb_server = None

0 commit comments

Comments
 (0)