@@ -77,6 +77,12 @@ def _get_client_task(target, task: Task):
7777 return None
7878
7979
80+ class _DeadClientStatus :
81+ def __init__ (self ):
82+ self .report_time = time .time ()
83+ self .dead_time = None
84+
85+
8086class Controller (Responder , ControllerSpec , ABC ):
8187 def __init__ (self , task_check_period = 0.2 ):
8288 """Manage life cycles of tasks and their destinations.
@@ -92,7 +98,7 @@ def __init__(self, task_check_period=0.2):
9298 self ._task_lock = Lock ()
9399 self ._task_monitor = threading .Thread (target = self ._monitor_tasks , args = ())
94100 self ._task_check_period = task_check_period
95- self ._dead_client_reports = {} # clients that reported the job is dead on it: name => report time
101+ self ._dead_clients = {} # clients that reported the job is dead on it: name => _DeadClientStatus
96102 self ._dead_clients_lock = Lock () # need lock since dead_clients can be modified from different threads
97103 # make sure _check_tasks, process_task_request, process_submission does not interfere with each other
98104 self ._controller_lock = Lock ()
@@ -115,6 +121,28 @@ def initialize_run(self, fl_ctx: FLContext):
115121 self .start_controller (fl_ctx )
116122 self ._task_monitor .start ()
117123
124+ def client_is_dead (self , client_name : str ):
125+ """This method is called when a client is deemed dead.
126+
127+ Args:
128+ client_name: name of the client
129+
130+ Returns: None
131+
132+ """
133+ pass
134+
135+ def client_is_revived (self , client_name : str ):
136+ """This method is called when a client is revived.
137+
138+ Args:
139+ client_name: name of the client
140+
141+ Returns: None
142+
143+ """
144+ pass
145+
118146 def _try_again (self ) -> Tuple [str , str , Shareable ]:
119147 # TODO: how to tell client no shareable available now?
120148 return "" , "" , None
@@ -174,9 +202,6 @@ def _do_process_task_request(self, client: Client, fl_ctx: FLContext) -> Tuple[s
174202 if not isinstance (client , Client ):
175203 raise TypeError ("client must be an instance of Client, but got {}" .format (type (client )))
176204
177- with self ._dead_clients_lock :
178- self ._dead_client_reports .pop (client .name , None )
179-
180205 if not isinstance (fl_ctx , FLContext ):
181206 raise TypeError ("fl_ctx must be an instance of FLContext, but got {}" .format (type (fl_ctx )))
182207
@@ -331,9 +356,10 @@ def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
331356 """
332357 # record the report and to be used by the task monitor
333358 with self ._dead_clients_lock :
334- self .log_info (fl_ctx , f"received dead job report from client { client_name } " )
335- if not self ._dead_client_reports .get (client_name ):
336- self ._dead_client_reports [client_name ] = time .time ()
359+ self .log_warning (fl_ctx , f"received dead job report for client { client_name } " )
360+ if not self ._dead_clients .get (client_name ):
361+ self .log_warning (fl_ctx , f"client { client_name } is placed on dead client watch list" )
362+ self ._dead_clients [client_name ] = _DeadClientStatus ()
337363
338364 def process_task_check (self , task_id : str , fl_ctx : FLContext ):
339365 with self ._task_lock :
@@ -369,13 +395,6 @@ def _do_process_submission(
369395 if not isinstance (client , Client ):
370396 raise TypeError ("client must be an instance of Client, but got {}" .format (type (client )))
371397
372- # reset the dead job report!
373- # note that due to potential race conditions, a client may fail to include the job id in its
374- # heartbeat (since the job hasn't started at the time of heartbeat report), but then includes
375- # the job ID later.
376- with self ._dead_clients_lock :
377- self ._dead_client_reports .pop (client .name , None )
378-
379398 if not isinstance (fl_ctx , FLContext ):
380399 raise TypeError ("fl_ctx must be an instance of FLContext, but got {}" .format (type (fl_ctx )))
381400 if not isinstance (result , Shareable ):
@@ -651,10 +670,8 @@ def cancel_task(
651670 """Cancel the specified task.
652671
653672 Change the task completion_status, which will inform task monitor to clean up this task
654-
655- .. note::
656-
657- We only mark the task as completed and leave it to the task monitor to clean up. This is to avoid potential deadlock of task_lock.
673+ We only mark the task as completed and leave it to the task monitor to clean up.
674+ This is to avoid potential deadlock of task_lock.
658675
659676 Args:
660677 task (Task): the task to be cancelled
@@ -979,27 +996,64 @@ def wait_for_task(self, task: Task, abort_signal: Signal):
979996 def _check_dead_clients (self ):
980997 if self ._engine :
981998 clients = self ._engine .get_clients ()
999+ dead_clients = []
9821000 with self ._dead_clients_lock :
9831001 for client in clients :
984- if self ._client_still_alive (client .name ):
985- return False
1002+ if not self ._client_still_alive (client .name ):
1003+ dead_clients . append ( client . name )
9861004
987- # All the clients are dead, abort the job run.
1005+ if dead_clients and len ( clients ) == len ( dead_clients ):
9881006 with self ._engine .new_context () as fl_ctx :
9891007 self .system_panic ("All clients are dead" , fl_ctx )
990- return True
1008+ return True
9911009 return False
9921010
9931011 def _client_still_alive (self , client_name ):
9941012 now = time .time ()
995- report_time = self ._dead_client_reports .get (client_name , None )
996- grace_period = ConfigService .get_float_var (name = _CONFIG_VAR_DEAD_CLIENT_GRACE_PERIOD , default = 30 .0 )
1013+ status = self ._dead_clients .get (client_name , None )
1014+ grace_period = ConfigService .get_float_var (name = _CONFIG_VAR_DEAD_CLIENT_GRACE_PERIOD , default = 60 .0 )
9971015
998- if not report_time :
1016+ if not status :
9991017 # this client is still alive
10001018 return True
1001- elif now - report_time < grace_period :
1019+
1020+ assert isinstance (status , _DeadClientStatus )
1021+ if status .dead_time :
1022+ return False
1023+
1024+ if now - status .report_time < grace_period :
10021025 # this report is still fresh - consider the client to be still alive
10031026 return True
10041027
1028+ # consider client dead
1029+ status .dead_time = now
1030+ self .logger .error (f"Client { client_name } is deemed dead!" )
1031+ self .client_is_dead (client_name )
10051032 return False
1033+
1034+ def get_client_death_time (self , client_name : str ):
1035+ """Get the time that the client was deemed dead
1036+
1037+ Args:
1038+ client_name: name of the client
1039+
1040+ Returns: time at which the client was deemed dead; or None if the client is not dead
1041+
1042+ """
1043+ status = self ._dead_clients .get (client_name )
1044+ if status :
1045+ assert isinstance (status , _DeadClientStatus )
1046+ return status .dead_time
1047+ return None
1048+
1049+ def process_job_heartbeat (self , fl_ctx : FLContext ):
1050+ peer_ctx = fl_ctx .get_peer_context ()
1051+ assert isinstance (peer_ctx , FLContext )
1052+ client_name = peer_ctx .get_identity_name ()
1053+ with self ._dead_clients_lock :
1054+ if client_name in self ._dead_clients :
1055+ self .log_info (fl_ctx , f"Client { client_name } is removed from watch list" )
1056+ status = self ._dead_clients .pop (client_name )
1057+ if status .dead_time :
1058+ self .log_info (fl_ctx , f"Client { client_name } is revived" )
1059+ self .client_is_revived (client_name )
0 commit comments