Skip to content

Commit 55bbdf8

Browse files
authored
Fix AV FL Issues (#2053)
* fix hb timeout; add retry logic for task result submit * increased timeout
1 parent 6aebf9c commit 55bbdf8

File tree

7 files changed

+156
-13
lines changed

7 files changed

+156
-13
lines changed

nvflare/apis/fl_constant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ class ReservedTopic(object):
163163
ABORT_ASK = "__abort_task__"
164164
AUX_COMMAND = "__aux_command__"
165165
JOB_HEART_BEAT = "__job_heartbeat__"
166+
TASK_CHECK = "__task_check__"
166167

167168

168169
class AdminCommandNames(object):

nvflare/apis/impl/controller.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,11 @@ def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
335335
if not self._dead_client_reports.get(client_name):
336336
self._dead_client_reports[client_name] = time.time()
337337

338+
def process_task_check(self, task_id: str, fl_ctx: FLContext):
339+
with self._task_lock:
340+
# task_id is the uuid associated with the client_task
341+
return self._client_task_map.get(task_id, None)
342+
338343
def process_submission(self, client: Client, task_name: str, task_id: str, result: Shareable, fl_ctx: FLContext):
339344
"""Called to process a submission from one client.
340345

nvflare/apis/responder.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,19 @@ def process_submission(self, client: Client, task_name: str, task_id: str, resul
6363
"""
6464
pass
6565

66+
@abstractmethod
67+
def process_task_check(self, task_id: str, fl_ctx: FLContext):
68+
"""Called by the Engine to check whether a specified task still exists.
69+
70+
Args:
71+
task_id: the id of the task
72+
fl_ctx: the FLContext
73+
74+
Returns: the ClientTask object if exists; None otherwise
75+
76+
"""
77+
pass
78+
6679
@abstractmethod
6780
def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
6881
"""Called by the Engine to handle the case that the job on the client is dead.

nvflare/private/fed/client/client_runner.py

Lines changed: 103 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
from nvflare.security.logging import secure_format_exception
3030
from nvflare.widgets.info_collector import GroupInfoCollector, InfoCollector
3131

32+
_TASK_CHECK_RESULT_OK = 0
33+
_TASK_CHECK_RESULT_TRY_AGAIN = 1
34+
_TASK_CHECK_RESULT_TASK_GONE = 2
35+
3236

3337
class ClientRunnerConfig(object):
3438
def __init__(
@@ -107,7 +111,8 @@ def __init__(
107111
self.task_lock = threading.Lock()
108112
self.end_run_fired = False
109113
self.end_run_lock = threading.Lock()
110-
114+
self.task_check_timeout = 5.0
115+
self.task_check_interval = 5.0
111116
self._register_aux_message_handler(engine)
112117

113118
def _register_aux_message_handler(self, engine):
@@ -473,19 +478,108 @@ def fetch_and_run_one_task(self, fl_ctx) -> (float, bool):
473478
if cookie_jar:
474479
task_reply.set_cookie_jar(cookie_jar)
475480

476-
reply_sent = self.engine.send_task_result(task_reply, fl_ctx)
477-
if reply_sent:
478-
self.log_info(fl_ctx, "result sent to server for task: name={}, id={}".format(task.name, task.task_id))
479-
else:
480-
self.log_error(
481-
fl_ctx,
482-
"failed to send result to server for task: name={}, id={}".format(task.name, task.task_id),
483-
)
481+
self._send_task_result(task_reply, task.task_id, fl_ctx)
484482
self.log_debug(fl_ctx, "firing event EventType.AFTER_SEND_TASK_RESULT")
485483
self.fire_event(EventType.AFTER_SEND_TASK_RESULT, fl_ctx)
486484

487485
return task_fetch_interval, True
488486

487+
def _send_task_result(self, result: Shareable, task_id: str, fl_ctx: FLContext):
488+
try_count = 1
489+
while True:
490+
self.log_info(fl_ctx, f"try #{try_count}: sending task result to server")
491+
492+
if self.asked_to_stop:
493+
self.log_info(fl_ctx, "job aborted: stopped trying to send result")
494+
return False
495+
496+
try_count += 1
497+
rc = self._try_send_result_once(result, task_id, fl_ctx)
498+
499+
if rc == _TASK_CHECK_RESULT_OK:
500+
return True
501+
elif rc == _TASK_CHECK_RESULT_TASK_GONE:
502+
return False
503+
else:
504+
# retry
505+
time.sleep(self.task_check_interval)
506+
507+
def _try_send_result_once(self, result: Shareable, task_id: str, fl_ctx: FLContext):
508+
# wait until server is ready to receive
509+
while True:
510+
if self.asked_to_stop:
511+
return _TASK_CHECK_RESULT_TASK_GONE
512+
513+
rc = self._check_task_once(task_id, fl_ctx)
514+
if rc == _TASK_CHECK_RESULT_OK:
515+
break
516+
elif rc == _TASK_CHECK_RESULT_TASK_GONE:
517+
return rc
518+
else:
519+
# try again
520+
time.sleep(self.task_check_interval)
521+
522+
# try to send the result
523+
self.log_info(fl_ctx, "start to send task result to server")
524+
reply_sent = self.engine.send_task_result(result, fl_ctx)
525+
if reply_sent:
526+
self.log_info(fl_ctx, "task result sent to server")
527+
return _TASK_CHECK_RESULT_OK
528+
else:
529+
self.log_error(fl_ctx, "failed to send task result to server - will try again")
530+
return _TASK_CHECK_RESULT_TRY_AGAIN
531+
532+
def _check_task_once(self, task_id: str, fl_ctx: FLContext) -> int:
533+
"""This method checks whether the server is still waiting for the specified task.
534+
The real reason for this method is to fight against unstable network connections.
535+
We try to make sure that when we send task result to the server, the connection is available.
536+
If the task check succeeds, then the network connection is likely to be available.
537+
Otherwise, we keep retrying until task check succeeds or the server tells us that the task is gone (timed out).
538+
539+
Args:
540+
task_id:
541+
fl_ctx:
542+
543+
Returns:
544+
545+
"""
546+
self.log_info(fl_ctx, "checking task ...")
547+
task_check_req = Shareable()
548+
task_check_req.set_header(ReservedKey.TASK_ID, task_id)
549+
resp = self.engine.send_aux_request(
550+
targets=[FQCN.ROOT_SERVER],
551+
topic=ReservedTopic.TASK_CHECK,
552+
request=task_check_req,
553+
timeout=self.task_check_timeout,
554+
fl_ctx=fl_ctx,
555+
optional=True,
556+
)
557+
if resp and isinstance(resp, dict):
558+
reply = resp.get(FQCN.ROOT_SERVER)
559+
if not isinstance(reply, Shareable):
560+
self.log_error(fl_ctx, f"bad task_check reply from server: expect Shareable but got {type(reply)}")
561+
return _TASK_CHECK_RESULT_TRY_AGAIN
562+
563+
rc = reply.get_return_code()
564+
if rc == ReturnCode.OK:
565+
return _TASK_CHECK_RESULT_OK
566+
elif rc == ReturnCode.COMMUNICATION_ERROR:
567+
self.log_error(fl_ctx, f"failed task_check: {rc}")
568+
return _TASK_CHECK_RESULT_TRY_AGAIN
569+
elif rc == ReturnCode.SERVER_NOT_READY:
570+
self.log_error(fl_ctx, f"server rejected task_check: {rc}")
571+
return _TASK_CHECK_RESULT_TRY_AGAIN
572+
elif rc == ReturnCode.TASK_UNKNOWN:
573+
self.log_error(fl_ctx, f"task no longer exists on server: {rc}")
574+
return _TASK_CHECK_RESULT_TASK_GONE
575+
else:
576+
# this should never happen
577+
self.log_error(fl_ctx, f"programming error: received {rc} from server")
578+
return _TASK_CHECK_RESULT_OK # try to push the result regardless
579+
else:
580+
self.log_error(fl_ctx, f"bad task_check reply from server: invalid resp {type(resp)}")
581+
return _TASK_CHECK_RESULT_TRY_AGAIN
582+
489583
def run(self, app_root, args):
490584
self.init_run(app_root, args)
491585

nvflare/private/fed/client/communicator.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(
6565
cell: Cell = None,
6666
client_register_interval=2,
6767
timeout=5.0,
68+
maint_msg_timeout=5.0,
6869
):
6970
"""To init the Communicator.
7071
@@ -85,7 +86,7 @@ def __init__(
8586
self.compression = compression
8687
self.client_register_interval = client_register_interval
8788
self.timeout = timeout
88-
89+
self.maint_msg_timeout = maint_msg_timeout
8990
self.logger = logging.getLogger(self.__class__.__name__)
9091

9192
def client_registration(self, client_name, servers, project_name):
@@ -129,7 +130,7 @@ def client_registration(self, client_name, servers, project_name):
129130
channel=CellChannel.SERVER_MAIN,
130131
topic=CellChannelTopic.Register,
131132
request=login_message,
132-
timeout=self.timeout,
133+
timeout=self.maint_msg_timeout,
133134
)
134135
return_code = result.get_header(MessageHeaderKey.RETURN_CODE)
135136
if return_code == ReturnCode.UNAUTHENTICATED:
@@ -297,7 +298,7 @@ def quit_remote(self, servers, task_name, token, ssid, fl_ctx: FLContext):
297298
channel=CellChannel.SERVER_MAIN,
298299
topic=CellChannelTopic.Quit,
299300
request=quit_message,
300-
timeout=self.timeout,
301+
timeout=self.maint_msg_timeout,
301302
)
302303
return_code = result.get_header(MessageHeaderKey.RETURN_CODE)
303304
if return_code == ReturnCode.UNAUTHENTICATED:
@@ -335,9 +336,13 @@ def send_heartbeat(self, servers, task_name, token, ssid, client_name, engine: C
335336
channel=CellChannel.SERVER_MAIN,
336337
topic=CellChannelTopic.HEART_BEAT,
337338
request=heartbeat_message,
338-
timeout=self.timeout,
339+
timeout=self.maint_msg_timeout,
339340
)
340341
return_code = result.get_header(MessageHeaderKey.RETURN_CODE)
342+
343+
if return_code != ReturnCode.OK:
344+
self.logger.error(f"heartbeat error: {return_code}")
345+
341346
if return_code == ReturnCode.UNAUTHENTICATED:
342347
unauthenticated = result.get_header(MessageHeaderKey.ERROR)
343348
raise FLCommunicationError("error:client_quit " + unauthenticated)

nvflare/private/fed/client/fed_client_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def __init__(
104104
cell=cell,
105105
client_register_interval=client_args.get("client_register_interval", 2.0),
106106
timeout=client_args.get("communication_timeout", 30.0),
107+
maint_msg_timeout=client_args.get("maint_msg_timeout", 5.0),
107108
)
108109

109110
self.secure_train = secure_train

nvflare/private/fed/server/server_runner.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ def _register_aux_message_handler(self, engine):
106106
topic=ReservedTopic.JOB_HEART_BEAT, message_handle_func=self._handle_job_heartbeat
107107
)
108108

109+
engine.register_aux_message_handler(topic=ReservedTopic.TASK_CHECK, message_handle_func=self._handle_task_check)
110+
109111
def _execute_run(self):
110112
while self.current_wf_index < len(self.config.workflows):
111113
wf = self.config.workflows[self.current_wf_index]
@@ -500,6 +502,28 @@ def _handle_job_heartbeat(self, topic: str, request: Shareable, fl_ctx: FLContex
500502
self.log_info(fl_ctx, "received client job_heartbeat aux request")
501503
return make_reply(ReturnCode.OK)
502504

505+
def _handle_task_check(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable:
506+
task_id = request.get_header(ReservedHeaderKey.TASK_ID)
507+
if not task_id:
508+
self.log_error(fl_ctx, f"missing {ReservedHeaderKey.TASK_ID} in task_check request")
509+
return make_reply(ReturnCode.BAD_REQUEST_DATA)
510+
511+
self.log_info(fl_ctx, f"received task_check on task {task_id}")
512+
513+
with self.wf_lock:
514+
if self.current_wf is None or self.current_wf.responder is None:
515+
self.log_info(fl_ctx, "no current workflow - dropped task_check.")
516+
return make_reply(ReturnCode.TASK_UNKNOWN)
517+
518+
# filter task result
519+
task = self.current_wf.responder.process_task_check(task_id=task_id, fl_ctx=fl_ctx)
520+
if task:
521+
self.log_info(fl_ctx, f"task {task_id} is still good")
522+
return make_reply(ReturnCode.OK)
523+
else:
524+
self.log_info(fl_ctx, f"task {task_id} is not found")
525+
return make_reply(ReturnCode.TASK_UNKNOWN)
526+
503527
def abort(self, fl_ctx: FLContext, turn_to_cold: bool = False):
504528
self.status = "done"
505529
self.abort_signal.trigger(value=True)

0 commit comments

Comments
 (0)