@@ -80,7 +80,7 @@ def _get_client_task(target, task: Task):
8080class _DeadClientStatus :
8181 def __init__ (self ):
8282 self .report_time = time .time ()
83- self .dead_time = None
83+ self .death_time = None
8484
8585
8686class Controller (Responder , ControllerSpec , ABC ):
@@ -356,8 +356,10 @@ def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
356356 """
357357 # record the report and to be used by the task monitor
358358 with self ._dead_clients_lock :
359- self .log_warning (fl_ctx , f"received dead job report for client { client_name } " )
360- if not self ._dead_clients .get (client_name ):
359+ if self ._dead_clients .get (client_name ):
360+ # already on watch list
361+ self .log_warning (fl_ctx , f"discarded dead job report for client { client_name } : already on watch list" )
362+ else :
361363 self .log_warning (fl_ctx , f"client { client_name } is placed on dead client watch list" )
362364 self ._dead_clients [client_name ] = _DeadClientStatus ()
363365
@@ -856,12 +858,36 @@ def relay_and_wait(
856858 self .wait_for_task (task , abort_signal )
857859
858860 def _monitor_tasks (self ):
861+ grace_period = ConfigService .get_float_var (name = _CONFIG_VAR_DEAD_CLIENT_GRACE_PERIOD , default = 60.0 )
859862 while not self ._all_done :
860- clients_all_dead = self ._check_dead_clients ()
861- if not clients_all_dead :
863+ self ._determine_dead_clients (grace_period )
864+ if self ._all_clients_dead ():
865+ with self ._engine .new_context () as fl_ctx :
866+ self .system_panic ("All clients are dead" , fl_ctx )
867+ return
868+ else :
862869 self ._check_tasks ()
863870 time .sleep (self ._task_check_period )
864871
872+ def _determine_dead_clients (self , grace_period ):
873+ if not self ._dead_clients :
874+ return
875+
876+ now = time .time ()
877+ with self ._dead_clients_lock :
878+ for client_name , status in self ._dead_clients .items ():
879+ if status .death_time :
880+ # already dead
881+ continue
882+
883+ if now - status .report_time < grace_period :
884+ # this report is still fresh - consider the client to be still alive
885+ continue
886+
887+ # consider client dead
888+ status .death_time = now
889+ self .logger .error (f"Client { client_name } is deemed dead!" )
890+
865891 def _check_tasks (self ):
866892 with self ._controller_lock :
867893 self ._do_check_tasks ()
@@ -897,7 +923,7 @@ def _do_check_tasks(self):
897923 # check whether clients that the task is waiting are all dead
898924 dead_clients = self ._get_task_dead_clients (task )
899925 if dead_clients :
900- self .logger .info (f"client { dead_clients } is dead - set task { task .name } to TIMEOUT" )
926+ self .logger .info (f"clients { dead_clients } dead - set task { task .name } to TIMEOUT" )
901927 task .completion_status = TaskCompletionStatus .CLIENT_DEAD
902928 exit_tasks .append (task )
903929 continue
@@ -949,22 +975,21 @@ def _get_task_dead_clients(self, task: Task):
949975 return None
950976
951977 dead_clients = []
952- with self ._dead_clients_lock :
953- for target in task .targets :
954- ct = _get_client_task (target , task )
955- if ct is not None and ct .result_received_time :
956- # response has been received from this client
957- continue
958-
959- # either we have not sent the task to this client or we have not received response
960- # is the client already dead?
961- if self ._client_still_alive (target ):
962- # this client is still alive
963- # we let the task continue its course since we still have live clients
964- return None
965- else :
966- # this client is dead - remember it
967- dead_clients .append (target )
978+ for target in task .targets :
979+ ct = _get_client_task (target , task )
980+ if ct is not None and ct .result_received_time :
981+ # response has been received from this client
982+ continue
983+
984+ # either we have not sent the task to this client or we have not received response
985+ # is the client already dead?
986+ if self .get_client_death_time (target ):
987+ # this client is dead - remember it
988+ dead_clients .append (target )
989+ else :
990+ # this client is still alive
991+ # we let the task continue its course since we still have live clients
992+ return None
968993
969994 return dead_clients
970995
@@ -993,46 +1018,18 @@ def wait_for_task(self, task: Task, abort_signal: Signal):
9931018 break
9941019 time .sleep (self ._task_check_period )
9951020
996- def _check_dead_clients (self ):
1021+ def _all_clients_dead (self ):
9971022 if self ._engine :
9981023 clients = self ._engine .get_clients ()
999- dead_clients = []
1000- with self ._dead_clients_lock :
1001- for client in clients :
1002- if not self ._client_still_alive (client .name ):
1003- dead_clients .append (client .name )
1004-
1005- if dead_clients and len (clients ) == len (dead_clients ):
1006- with self ._engine .new_context () as fl_ctx :
1007- self .system_panic ("All clients are dead" , fl_ctx )
1008- return True
1009- return False
1010-
1011- def _client_still_alive (self , client_name ):
1012- now = time .time ()
1013- status = self ._dead_clients .get (client_name , None )
1014- grace_period = ConfigService .get_float_var (name = _CONFIG_VAR_DEAD_CLIENT_GRACE_PERIOD , default = 60.0 )
1015-
1016- if not status :
1017- # this client is still alive
1018- return True
1019-
1020- assert isinstance (status , _DeadClientStatus )
1021- if status .dead_time :
1022- return False
1023-
1024- if now - status .report_time < grace_period :
1025- # this report is still fresh - consider the client to be still alive
1024+ for client in clients :
1025+ if not self .get_client_death_time (client .name ):
1026+ # this client is still alive
1027+ return False
10261028 return True
1027-
1028- # consider client dead
1029- status .dead_time = now
1030- self .logger .error (f"Client { client_name } is deemed dead!" )
1031- self .client_is_dead (client_name )
10321029 return False
10331030
10341031 def get_client_death_time (self , client_name : str ):
1035- """Get the time that the client was deemed dead
1032+ """Get the time that the client was deemed dead/disconnected
10361033
10371034 Args:
10381035 client_name: name of the client
@@ -1042,18 +1039,17 @@ def get_client_death_time(self, client_name: str):
10421039 """
10431040 status = self ._dead_clients .get (client_name )
10441041 if status :
1045- assert isinstance (status , _DeadClientStatus )
1046- return status .dead_time
1042+ return status .death_time
10471043 return None
10481044
1049- def process_job_heartbeat (self , fl_ctx : FLContext ):
1045+ def process_job_heartbeat (self , fl_ctx : FLContext , reason : str ):
10501046 peer_ctx = fl_ctx .get_peer_context ()
10511047 assert isinstance (peer_ctx , FLContext )
10521048 client_name = peer_ctx .get_identity_name ()
10531049 with self ._dead_clients_lock :
10541050 if client_name in self ._dead_clients :
1055- self .log_info (fl_ctx , f"Client { client_name } is removed from watch list" )
1051+ self .log_info (fl_ctx , f"Client { client_name } is removed from watch list: { reason = } " )
10561052 status = self ._dead_clients .pop (client_name )
1057- if status .dead_time :
1058- self .log_info (fl_ctx , f"Client { client_name } is revived" )
1053+ if status .death_time :
1054+ self .log_info (fl_ctx , f"Client { client_name } is revived: { reason = } " )
10591055 self .client_is_revived (client_name )
0 commit comments