Skip to content

Commit eaa7570

Browse files
authored
[2.3] fix client death handling (#2484)
* fix client death handling * remove unused imports * improve HB logic
1 parent 2874b12 commit eaa7570

File tree

4 files changed

+111
-34
lines changed

4 files changed

+111
-34
lines changed

nvflare/apis/impl/controller.py

Lines changed: 80 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ def _get_client_task(target, task: Task):
7777
return None
7878

7979

80+
class _DeadClientStatus:
81+
def __init__(self):
82+
self.report_time = time.time()
83+
self.dead_time = None
84+
85+
8086
class Controller(Responder, ControllerSpec, ABC):
8187
def __init__(self, task_check_period=0.2):
8288
"""Manage life cycles of tasks and their destinations.
@@ -92,7 +98,7 @@ def __init__(self, task_check_period=0.2):
9298
self._task_lock = Lock()
9399
self._task_monitor = threading.Thread(target=self._monitor_tasks, args=())
94100
self._task_check_period = task_check_period
95-
self._dead_client_reports = {} # clients that reported the job is dead on it: name => report time
101+
self._dead_clients = {} # clients that reported the job is dead on it: name => _DeadClientStatus
96102
self._dead_clients_lock = Lock() # need lock since dead_clients can be modified from different threads
97103
# make sure _check_tasks, process_task_request, process_submission does not interfere with each other
98104
self._controller_lock = Lock()
@@ -115,6 +121,28 @@ def initialize_run(self, fl_ctx: FLContext):
115121
self.start_controller(fl_ctx)
116122
self._task_monitor.start()
117123

124+
def client_is_dead(self, client_name: str):
125+
"""This method is called when a client is deemed dead.
126+
127+
Args:
128+
client_name: name of the client
129+
130+
Returns: None
131+
132+
"""
133+
pass
134+
135+
def client_is_revived(self, client_name: str):
136+
"""This method is called when a client is revived.
137+
138+
Args:
139+
client_name: name of the client
140+
141+
Returns: None
142+
143+
"""
144+
pass
145+
118146
def _try_again(self) -> Tuple[str, str, Shareable]:
119147
# TODO: how to tell client no shareable available now?
120148
return "", "", None
@@ -174,9 +202,6 @@ def _do_process_task_request(self, client: Client, fl_ctx: FLContext) -> Tuple[s
174202
if not isinstance(client, Client):
175203
raise TypeError("client must be an instance of Client, but got {}".format(type(client)))
176204

177-
with self._dead_clients_lock:
178-
self._dead_client_reports.pop(client.name, None)
179-
180205
if not isinstance(fl_ctx, FLContext):
181206
raise TypeError("fl_ctx must be an instance of FLContext, but got {}".format(type(fl_ctx)))
182207

@@ -331,9 +356,10 @@ def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
331356
"""
332357
# record the report and to be used by the task monitor
333358
with self._dead_clients_lock:
334-
self.log_info(fl_ctx, f"received dead job report from client {client_name}")
335-
if not self._dead_client_reports.get(client_name):
336-
self._dead_client_reports[client_name] = time.time()
359+
self.log_warning(fl_ctx, f"received dead job report for client {client_name}")
360+
if not self._dead_clients.get(client_name):
361+
self.log_warning(fl_ctx, f"client {client_name} is placed on dead client watch list")
362+
self._dead_clients[client_name] = _DeadClientStatus()
337363

338364
def process_task_check(self, task_id: str, fl_ctx: FLContext):
339365
with self._task_lock:
@@ -369,13 +395,6 @@ def _do_process_submission(
369395
if not isinstance(client, Client):
370396
raise TypeError("client must be an instance of Client, but got {}".format(type(client)))
371397

372-
# reset the dead job report!
373-
# note that due to potential race conditions, a client may fail to include the job id in its
374-
# heartbeat (since the job hasn't started at the time of heartbeat report), but then includes
375-
# the job ID later.
376-
with self._dead_clients_lock:
377-
self._dead_client_reports.pop(client.name, None)
378-
379398
if not isinstance(fl_ctx, FLContext):
380399
raise TypeError("fl_ctx must be an instance of FLContext, but got {}".format(type(fl_ctx)))
381400
if not isinstance(result, Shareable):
@@ -651,10 +670,8 @@ def cancel_task(
651670
"""Cancel the specified task.
652671
653672
Change the task completion_status, which will inform task monitor to clean up this task
654-
655-
.. note::
656-
657-
We only mark the task as completed and leave it to the task monitor to clean up. This is to avoid potential deadlock of task_lock.
673+
We only mark the task as completed and leave it to the task monitor to clean up.
674+
This is to avoid potential deadlock of task_lock.
658675
659676
Args:
660677
task (Task): the task to be cancelled
@@ -979,27 +996,64 @@ def wait_for_task(self, task: Task, abort_signal: Signal):
979996
def _check_dead_clients(self):
980997
if self._engine:
981998
clients = self._engine.get_clients()
999+
dead_clients = []
9821000
with self._dead_clients_lock:
9831001
for client in clients:
984-
if self._client_still_alive(client.name):
985-
return False
1002+
if not self._client_still_alive(client.name):
1003+
dead_clients.append(client.name)
9861004

987-
# All the clients are dead, abort the job run.
1005+
if dead_clients and len(clients) == len(dead_clients):
9881006
with self._engine.new_context() as fl_ctx:
9891007
self.system_panic("All clients are dead", fl_ctx)
990-
return True
1008+
return True
9911009
return False
9921010

9931011
def _client_still_alive(self, client_name):
9941012
now = time.time()
995-
report_time = self._dead_client_reports.get(client_name, None)
996-
grace_period = ConfigService.get_float_var(name=_CONFIG_VAR_DEAD_CLIENT_GRACE_PERIOD, default=30.0)
1013+
status = self._dead_clients.get(client_name, None)
1014+
grace_period = ConfigService.get_float_var(name=_CONFIG_VAR_DEAD_CLIENT_GRACE_PERIOD, default=60.0)
9971015

998-
if not report_time:
1016+
if not status:
9991017
# this client is still alive
10001018
return True
1001-
elif now - report_time < grace_period:
1019+
1020+
assert isinstance(status, _DeadClientStatus)
1021+
if status.dead_time:
1022+
return False
1023+
1024+
if now - status.report_time < grace_period:
10021025
# this report is still fresh - consider the client to be still alive
10031026
return True
10041027

1028+
# consider client dead
1029+
status.dead_time = now
1030+
self.logger.error(f"Client {client_name} is deemed dead!")
1031+
self.client_is_dead(client_name)
10051032
return False
1033+
1034+
def get_client_death_time(self, client_name: str):
1035+
"""Get the time that the client was deemed dead
1036+
1037+
Args:
1038+
client_name: name of the client
1039+
1040+
Returns: time at which the client was deemed dead; or None if the client is not dead
1041+
1042+
"""
1043+
status = self._dead_clients.get(client_name)
1044+
if status:
1045+
assert isinstance(status, _DeadClientStatus)
1046+
return status.dead_time
1047+
return None
1048+
1049+
def process_job_heartbeat(self, fl_ctx: FLContext):
1050+
peer_ctx = fl_ctx.get_peer_context()
1051+
assert isinstance(peer_ctx, FLContext)
1052+
client_name = peer_ctx.get_identity_name()
1053+
with self._dead_clients_lock:
1054+
if client_name in self._dead_clients:
1055+
self.log_info(fl_ctx, f"Client {client_name} is removed from watch list")
1056+
status = self._dead_clients.pop(client_name)
1057+
if status.dead_time:
1058+
self.log_info(fl_ctx, f"Client {client_name} is revived")
1059+
self.client_is_revived(client_name)

nvflare/apis/responder.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,18 @@ def process_task_check(self, task_id: str, fl_ctx: FLContext):
7676
"""
7777
pass
7878

79+
@abstractmethod
80+
def process_job_heartbeat(self, fl_ctx: FLContext):
81+
"""Called by the Engine to handle heartbeat received from clients.
82+
83+
Args:
84+
fl_ctx: the FLContext
85+
86+
Returns: None
87+
88+
"""
89+
pass
90+
7991
@abstractmethod
8092
def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
8193
"""Called by the Engine to handle the case that the job on the client is dead.

nvflare/private/fed/client/client_runner.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -413,11 +413,13 @@ def _try_run(self):
413413

414414
time.sleep(task_fetch_interval)
415415

416-
def send_job_heartbeat(self, interval=30.0):
417-
wait_times = int(interval / 2)
416+
def send_job_heartbeat(self, interval=10.0):
418417
request = Shareable()
418+
last_hb_time = 0.0
419+
short_sleep = 1.0
420+
fl_ctx = self.engine.new_context()
419421
while not self.asked_to_stop:
420-
with self.engine.new_context() as fl_ctx:
422+
if time.time() - last_hb_time >= interval:
421423
self.engine.send_aux_request(
422424
targets=[FQCN.ROOT_SERVER],
423425
topic=ReservedTopic.JOB_HEART_BEAT,
@@ -426,11 +428,8 @@ def send_job_heartbeat(self, interval=30.0):
426428
fl_ctx=fl_ctx,
427429
optional=True,
428430
)
429-
430-
for i in range(wait_times):
431-
time.sleep(2)
432-
if self.asked_to_stop:
433-
break
431+
last_hb_time = time.time()
432+
time.sleep(short_sleep)
434433

435434
def fetch_and_run_one_task(self, fl_ctx) -> (float, bool):
436435
"""Fetches and runs a task.

nvflare/private/fed/server/server_runner.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,9 @@ def _try_to_get_task(self, client, fl_ctx, timeout=None, retry_interval=0.005):
328328
self.log_info(fl_ctx, "no current workflow - asked client to try again later")
329329
return "", "", None
330330

331+
if self.current_wf.responder:
332+
self.current_wf.responder.process_job_heartbeat(fl_ctx=fl_ctx)
333+
331334
task_name, task_id, task_data = self.current_wf.responder.process_task_request(client, fl_ctx)
332335

333336
if task_name and task_name != SpecialTaskName.TRY_AGAIN:
@@ -441,6 +444,9 @@ def process_submission(self, client: Client, task_name: str, task_id: str, resul
441444
self.log_info(fl_ctx, "no current workflow - dropped submission.")
442445
return
443446

447+
if self.current_wf.responder:
448+
self.current_wf.responder.process_job_heartbeat(fl_ctx)
449+
444450
wf_id = result.get_cookie(ReservedHeaderKey.WORKFLOW, None)
445451
if wf_id is not None and wf_id != self.current_wf.id:
446452
self.log_info(
@@ -500,6 +506,9 @@ def process_submission(self, client: Client, task_name: str, task_id: str, resul
500506

501507
def _handle_job_heartbeat(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable:
502508
self.log_debug(fl_ctx, "received client job_heartbeat")
509+
with self.wf_lock:
510+
if self.current_wf and self.current_wf.responder:
511+
self.current_wf.responder.process_job_heartbeat(fl_ctx=fl_ctx)
503512
return make_reply(ReturnCode.OK)
504513

505514
def _handle_task_check(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable:
@@ -515,6 +524,9 @@ def _handle_task_check(self, topic: str, request: Shareable, fl_ctx: FLContext)
515524
self.log_info(fl_ctx, "no current workflow - dropped task_check.")
516525
return make_reply(ReturnCode.TASK_UNKNOWN)
517526

527+
if self.current_wf.responder:
528+
self.current_wf.responder.process_job_heartbeat(fl_ctx)
529+
518530
# filter task result
519531
task = self.current_wf.responder.process_task_check(task_id=task_id, fl_ctx=fl_ctx)
520532
if task:

0 commit comments

Comments
 (0)