Skip to content

Commit 16730ec

Browse files
yanchengnvYuanTingHsiehchesterxgchen
authored
[2.3] Improve dead client handling (#2501)
* improve dead client handling * remove unused import * fix docstring --------- Co-authored-by: Yuan-Ting Hsieh (謝沅廷) <[email protected]> Co-authored-by: Chester Chen <[email protected]>
1 parent ed670ac commit 16730ec

File tree

9 files changed

+112
-92
lines changed

9 files changed

+112
-92
lines changed

nvflare/apis/fl_constant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ class ServerCommandKey(object):
236236
CLIENTS = "clients"
237237
COLLECTOR = "collector"
238238
TURN_TO_COLD = "__turn_to_cold__"
239+
REASON = "reason"
239240

240241

241242
class FedEventHeader(object):

nvflare/apis/impl/controller.py

Lines changed: 58 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _get_client_task(target, task: Task):
8080
class _DeadClientStatus:
8181
def __init__(self):
8282
self.report_time = time.time()
83-
self.dead_time = None
83+
self.death_time = None
8484

8585

8686
class Controller(Responder, ControllerSpec, ABC):
@@ -356,8 +356,10 @@ def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
356356
"""
357357
# record the report and to be used by the task monitor
358358
with self._dead_clients_lock:
359-
self.log_warning(fl_ctx, f"received dead job report for client {client_name}")
360-
if not self._dead_clients.get(client_name):
359+
if self._dead_clients.get(client_name):
360+
# already on watch list
361+
self.log_warning(fl_ctx, f"discarded dead job report for client {client_name}: already on watch list")
362+
else:
361363
self.log_warning(fl_ctx, f"client {client_name} is placed on dead client watch list")
362364
self._dead_clients[client_name] = _DeadClientStatus()
363365

@@ -856,12 +858,36 @@ def relay_and_wait(
856858
self.wait_for_task(task, abort_signal)
857859

858860
def _monitor_tasks(self):
861+
grace_period = ConfigService.get_float_var(name=_CONFIG_VAR_DEAD_CLIENT_GRACE_PERIOD, default=60.0)
859862
while not self._all_done:
860-
clients_all_dead = self._check_dead_clients()
861-
if not clients_all_dead:
863+
self._determine_dead_clients(grace_period)
864+
if self._all_clients_dead():
865+
with self._engine.new_context() as fl_ctx:
866+
self.system_panic("All clients are dead", fl_ctx)
867+
return
868+
else:
862869
self._check_tasks()
863870
time.sleep(self._task_check_period)
864871

872+
def _determine_dead_clients(self, grace_period):
873+
if not self._dead_clients:
874+
return
875+
876+
now = time.time()
877+
with self._dead_clients_lock:
878+
for client_name, status in self._dead_clients.items():
879+
if status.death_time:
880+
# already dead
881+
continue
882+
883+
if now - status.report_time < grace_period:
884+
# this report is still fresh - consider the client to be still alive
885+
continue
886+
887+
# consider client dead
888+
status.death_time = now
889+
self.logger.error(f"Client {client_name} is deemed dead!")
890+
865891
def _check_tasks(self):
866892
with self._controller_lock:
867893
self._do_check_tasks()
@@ -897,7 +923,7 @@ def _do_check_tasks(self):
897923
# check whether clients that the task is waiting are all dead
898924
dead_clients = self._get_task_dead_clients(task)
899925
if dead_clients:
900-
self.logger.info(f"client {dead_clients} is dead - set task {task.name} to TIMEOUT")
926+
self.logger.info(f"clients {dead_clients} dead - set task {task.name} to TIMEOUT")
901927
task.completion_status = TaskCompletionStatus.CLIENT_DEAD
902928
exit_tasks.append(task)
903929
continue
@@ -949,22 +975,21 @@ def _get_task_dead_clients(self, task: Task):
949975
return None
950976

951977
dead_clients = []
952-
with self._dead_clients_lock:
953-
for target in task.targets:
954-
ct = _get_client_task(target, task)
955-
if ct is not None and ct.result_received_time:
956-
# response has been received from this client
957-
continue
958-
959-
# either we have not sent the task to this client or we have not received response
960-
# is the client already dead?
961-
if self._client_still_alive(target):
962-
# this client is still alive
963-
# we let the task continue its course since we still have live clients
964-
return None
965-
else:
966-
# this client is dead - remember it
967-
dead_clients.append(target)
978+
for target in task.targets:
979+
ct = _get_client_task(target, task)
980+
if ct is not None and ct.result_received_time:
981+
# response has been received from this client
982+
continue
983+
984+
# either we have not sent the task to this client or we have not received response
985+
# is the client already dead?
986+
if self.get_client_death_time(target):
987+
# this client is dead - remember it
988+
dead_clients.append(target)
989+
else:
990+
# this client is still alive
991+
# we let the task continue its course since we still have live clients
992+
return None
968993

969994
return dead_clients
970995

@@ -993,46 +1018,18 @@ def wait_for_task(self, task: Task, abort_signal: Signal):
9931018
break
9941019
time.sleep(self._task_check_period)
9951020

996-
def _check_dead_clients(self):
1021+
def _all_clients_dead(self):
9971022
if self._engine:
9981023
clients = self._engine.get_clients()
999-
dead_clients = []
1000-
with self._dead_clients_lock:
1001-
for client in clients:
1002-
if not self._client_still_alive(client.name):
1003-
dead_clients.append(client.name)
1004-
1005-
if dead_clients and len(clients) == len(dead_clients):
1006-
with self._engine.new_context() as fl_ctx:
1007-
self.system_panic("All clients are dead", fl_ctx)
1008-
return True
1009-
return False
1010-
1011-
def _client_still_alive(self, client_name):
1012-
now = time.time()
1013-
status = self._dead_clients.get(client_name, None)
1014-
grace_period = ConfigService.get_float_var(name=_CONFIG_VAR_DEAD_CLIENT_GRACE_PERIOD, default=60.0)
1015-
1016-
if not status:
1017-
# this client is still alive
1018-
return True
1019-
1020-
assert isinstance(status, _DeadClientStatus)
1021-
if status.dead_time:
1022-
return False
1023-
1024-
if now - status.report_time < grace_period:
1025-
# this report is still fresh - consider the client to be still alive
1024+
for client in clients:
1025+
if not self.get_client_death_time(client.name):
1026+
# this client is still alive
1027+
return False
10261028
return True
1027-
1028-
# consider client dead
1029-
status.dead_time = now
1030-
self.logger.error(f"Client {client_name} is deemed dead!")
1031-
self.client_is_dead(client_name)
10321029
return False
10331030

10341031
def get_client_death_time(self, client_name: str):
1035-
"""Get the time that the client was deemed dead
1032+
"""Get the time that the client was deemed dead/disconnected
10361033
10371034
Args:
10381035
client_name: name of the client
@@ -1042,18 +1039,17 @@ def get_client_death_time(self, client_name: str):
10421039
"""
10431040
status = self._dead_clients.get(client_name)
10441041
if status:
1045-
assert isinstance(status, _DeadClientStatus)
1046-
return status.dead_time
1042+
return status.death_time
10471043
return None
10481044

1049-
def process_job_heartbeat(self, fl_ctx: FLContext):
1045+
def process_job_heartbeat(self, fl_ctx: FLContext, reason: str):
10501046
peer_ctx = fl_ctx.get_peer_context()
10511047
assert isinstance(peer_ctx, FLContext)
10521048
client_name = peer_ctx.get_identity_name()
10531049
with self._dead_clients_lock:
10541050
if client_name in self._dead_clients:
1055-
self.log_info(fl_ctx, f"Client {client_name} is removed from watch list")
1051+
self.log_info(fl_ctx, f"Client {client_name} is removed from watch list: {reason=}")
10561052
status = self._dead_clients.pop(client_name)
1057-
if status.dead_time:
1058-
self.log_info(fl_ctx, f"Client {client_name} is revived")
1053+
if status.death_time:
1054+
self.log_info(fl_ctx, f"Client {client_name} is revived: {reason=}")
10591055
self.client_is_revived(client_name)

nvflare/apis/responder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,12 @@ def process_task_check(self, task_id: str, fl_ctx: FLContext):
7777
pass
7878

7979
@abstractmethod
80-
def process_job_heartbeat(self, fl_ctx: FLContext):
80+
def process_job_heartbeat(self, fl_ctx: FLContext, reason: str):
8181
"""Called by the Engine to handle heartbeat received from clients.
8282
8383
Args:
8484
fl_ctx: the FLContext
85+
reason: reason of the HB
8586
8687
Returns: None
8788

nvflare/app_common/workflows/cyclic_ctl.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,15 +139,15 @@ def start_controller(self, fl_ctx: FLContext):
139139
self._last_client = None
140140

141141
def _get_relay_orders(self, fl_ctx: FLContext) -> Union[List[Client], None]:
142-
if len(self._participating_clients) <= 1:
143-
self.system_panic(f"Not enough client sites ({len(self._participating_clients)}).", fl_ctx)
144-
return None
145-
146142
active_clients_map = {}
147143
for t in self._participating_clients:
148144
if not self.get_client_death_time(t.name):
149145
active_clients_map[t.name] = t
150146

147+
if len(active_clients_map) <= 1:
148+
self.system_panic(f"Not enough client sites (active_clients={len(active_clients_map)}).", fl_ctx)
149+
return None
150+
151151
if isinstance(self._order, list):
152152
targets = []
153153
for c_name in self._order:

nvflare/private/fed/server/fed_server.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
WorkspaceConstants,
3636
)
3737
from nvflare.apis.fl_context import FLContext
38-
from nvflare.apis.shareable import Shareable
3938
from nvflare.apis.workspace import Workspace
4039
from nvflare.fuel.common.exit_codes import ProcessExitCode
4140
from nvflare.fuel.f3.cellnet.cell import Cell, Message
@@ -234,7 +233,7 @@ def fl_shutdown(self):
234233
self.shutdown = True
235234
start = time.time()
236235
while self.client_manager.clients:
237-
# Wait for the clients to shutdown and quite first.
236+
# Wait for the clients to shut down and quite first.
238237
time.sleep(0.1)
239238
if time.time() - start > self.shutdown_period:
240239
self.logger.info("There are still clients connected. But shutdown the server after timeout.")
@@ -580,26 +579,17 @@ def _sync_client_jobs(self, request, client_token):
580579
# this is a dict: token => nvflare.apis.client.Client
581580
client = participating_clients.get(client_token, None)
582581
if client:
583-
self._notify_dead_job(client, job_id)
582+
self._notify_dead_job(client, job_id, "missing job on client")
584583

585584
return jobs_need_abort
586585

587-
def _notify_dead_job(self, client, job_id: str):
586+
def _notify_dead_job(self, client, job_id: str, reason: str):
588587
try:
589-
with self.engine.lock:
590-
shareable = Shareable()
591-
shareable.set_header(ServerCommandKey.FL_CLIENT, client.name)
592-
fqcn = FQCN.join([FQCN.ROOT_SERVER, job_id])
593-
request = new_cell_message({}, fobs.dumps(shareable))
594-
self.cell.fire_and_forget(
595-
targets=fqcn,
596-
channel=CellChannel.SERVER_COMMAND,
597-
topic=ServerCommandNames.HANDLE_DEAD_JOB,
598-
message=request,
599-
optional=True,
600-
)
601-
except Exception:
602-
self.logger.info("Could not connect to server runner process")
588+
self.engine.notify_dead_job(job_id, client.name, reason)
589+
except Exception as ex:
590+
self.logger.info(
591+
f"Failed to notify_dead_job to runner process of job {job_id}: {secure_format_exception(ex)}"
592+
)
603593

604594
def notify_dead_client(self, client):
605595
"""Called to do further processing of the dead client
@@ -618,7 +608,7 @@ def notify_dead_client(self, client):
618608
assert isinstance(process_info, dict)
619609
participating_clients = process_info.get(RunProcessKey.PARTICIPANTS, None)
620610
if participating_clients and client.token in participating_clients:
621-
self._notify_dead_job(client, job_id)
611+
self._notify_dead_job(client, job_id, "client dead")
622612

623613
def start_run(self, job_id, run_root, conf, args, snapshot):
624614
# Create the FL Engine

nvflare/private/fed/server/server_commands.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,8 @@ def process(self, data: Shareable, fl_ctx: FLContext):
264264
265265
"""
266266
client_name = data.get_header(ServerCommandKey.FL_CLIENT)
267+
reason = data.get_header(ServerCommandKey.REASON)
268+
self.logger.warning(f"received dead job notification: {client_name=}; {reason=}")
267269
server_runner = fl_ctx.get_prop(FLContextKey.RUNNER)
268270
if server_runner:
269271
server_runner.handle_dead_job(client_name, fl_ctx)

nvflare/private/fed/server/server_engine.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,19 @@ def update_job_run_status(self):
583583
message=request,
584584
)
585585

586+
def notify_dead_job(self, job_id: str, client_name: str, reason: str):
587+
shareable = Shareable()
588+
shareable.set_header(ServerCommandKey.FL_CLIENT, client_name)
589+
shareable.set_header(ServerCommandKey.REASON, reason)
590+
self.send_command_to_child_runner_process(
591+
job_id=job_id,
592+
command_name=ServerCommandNames.HANDLE_DEAD_JOB,
593+
command_data=shareable,
594+
timeout=0.0,
595+
optional=True,
596+
)
597+
self.logger.warning(f"notified SJ of dead-job: {job_id=}; {client_name=}; {reason=}")
598+
586599
def send_command_to_child_runner_process(
587600
self, job_id: str, command_name: str, command_data, timeout=5.0, optional=False
588601
):
@@ -594,7 +607,7 @@ def send_command_to_child_runner_process(
594607
targets=fqcn,
595608
channel=CellChannel.SERVER_COMMAND,
596609
topic=command_name,
597-
request=request,
610+
message=request,
598611
optional=optional,
599612
)
600613
return None

nvflare/private/fed/server/server_runner.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def _try_to_get_task(self, client, fl_ctx, timeout=None, retry_interval=0.005):
329329
return "", "", None
330330

331331
if self.current_wf.responder:
332-
self.current_wf.responder.process_job_heartbeat(fl_ctx=fl_ctx)
332+
self.current_wf.responder.process_job_heartbeat(fl_ctx=fl_ctx, reason="getTask")
333333

334334
task_name, task_id, task_data = self.current_wf.responder.process_task_request(client, fl_ctx)
335335

@@ -371,7 +371,6 @@ def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
371371
try:
372372
if self.current_wf is None:
373373
return
374-
375374
self.current_wf.responder.handle_dead_job(client_name=client_name, fl_ctx=fl_ctx)
376375
except Exception as e:
377376
self.log_exception(
@@ -445,7 +444,7 @@ def process_submission(self, client: Client, task_name: str, task_id: str, resul
445444
return
446445

447446
if self.current_wf.responder:
448-
self.current_wf.responder.process_job_heartbeat(fl_ctx)
447+
self.current_wf.responder.process_job_heartbeat(fl_ctx, "submitTask")
449448

450449
wf_id = result.get_cookie(ReservedHeaderKey.WORKFLOW, None)
451450
if wf_id is not None and wf_id != self.current_wf.id:
@@ -508,7 +507,7 @@ def _handle_job_heartbeat(self, topic: str, request: Shareable, fl_ctx: FLContex
508507
self.log_debug(fl_ctx, "received client job_heartbeat")
509508
with self.wf_lock:
510509
if self.current_wf and self.current_wf.responder:
511-
self.current_wf.responder.process_job_heartbeat(fl_ctx=fl_ctx)
510+
self.current_wf.responder.process_job_heartbeat(fl_ctx=fl_ctx, reason="jobHeartbeat")
512511
return make_reply(ReturnCode.OK)
513512

514513
def _handle_task_check(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable:
@@ -525,7 +524,7 @@ def _handle_task_check(self, topic: str, request: Shareable, fl_ctx: FLContext)
525524
return make_reply(ReturnCode.TASK_UNKNOWN)
526525

527526
if self.current_wf.responder:
528-
self.current_wf.responder.process_job_heartbeat(fl_ctx)
527+
self.current_wf.responder.process_job_heartbeat(fl_ctx, "taskCheck")
529528

530529
# filter task result
531530
task = self.current_wf.responder.process_task_check(task_id=task_id, fl_ctx=fl_ctx)

nvflare/private/fed/server/sys_cmd.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,14 @@ def get_spec(self):
6868
authz_func=self.authorize_server_operation,
6969
visible=True,
7070
),
71+
CommandSpec(
72+
name="dead",
73+
description="send dead client msg to SJ",
74+
usage="dead <client-name>",
75+
handler_func=self.dead_client,
76+
authz_func=self.must_be_project_admin,
77+
visible=False,
78+
),
7179
],
7280
)
7381

@@ -156,3 +164,13 @@ def report_resources(self, conn: Connection, args: List[str]):
156164
table = conn.append_table(["Sites", "Resources"])
157165
for k, v in site_resources.items():
158166
table.add_row([str(k), str(v)])
167+
168+
def dead_client(self, conn: Connection, args: List[str]):
169+
if len(args) != 3:
170+
conn.append_error(f"Usage: {args[0]} client_name job_id")
171+
return
172+
client_name = args[1]
173+
job_id = args[2]
174+
engine = conn.app_ctx
175+
engine.notify_dead_job(job_id, client_name, f"AdminCommand: {args[0]}")
176+
conn.append_string(f"called notify_dead_job for client {client_name=} {job_id=}")

0 commit comments

Comments
 (0)