1414
1515import gc
1616import random
17+ from typing import List , Union
1718
1819from nvflare .apis .client import Client
1920from nvflare .apis .fl_constant import ReturnCode
@@ -48,7 +49,7 @@ def __init__(
4849 task_check_period : float = 0.5 ,
4950 persist_every_n_rounds : int = 1 ,
5051 snapshot_every_n_rounds : int = 1 ,
51- order : str = RelayOrder .FIXED ,
52+ order : Union [ str , List [ str ]] = RelayOrder .FIXED ,
5253 allow_early_termination = False ,
5354 ):
5455 """A sample implementation to demonstrate how to use relay method for Cyclic Federated Learning.
@@ -66,11 +67,12 @@ def __init__(
6667 If n is 0 then no persist.
6768 snapshot_every_n_rounds (int, optional): persist the server state every n rounds. Defaults to 1.
6869 If n is 0 then no persist.
69- order (str, optional): the order of relay.
70- If FIXED means the same order for every round.
71- If RANDOM means random order for every round.
72- If RANDOM_WITHOUT_SAME_IN_A_ROW means every round the order gets shuffled but a client will never be
73- run twice in a row (in different round).
70+ order (Union[str, List[str]], optional): The order of relay.
71+ - If a string is provided:
72+ - "FIXED": Same order for every round.
73+ - "RANDOM": Random order for every round.
74+ - "RANDOM_WITHOUT_SAME_IN_A_ROW": Shuffled order, no repetition in consecutive rounds.
75+ - If a list of strings is provided, it represents a custom order for relay.
7476 allow_early_termination: whether to allow early workflow termination from clients
7577
7678 Raises:
@@ -90,8 +92,8 @@ def __init__(
9092 if not isinstance (task_name , str ):
9193 raise TypeError ("task_name must be a string but got {}" .format (type (task_name )))
9294
93- if order not in SUPPORTED_ORDERS :
94- raise ValueError (f"order must be in { SUPPORTED_ORDERS } " )
95+ if order not in SUPPORTED_ORDERS and not isinstance ( order , list ) :
96+ raise ValueError (f"order must be in { SUPPORTED_ORDERS } or a list " )
9597
9698 self ._num_rounds = num_rounds
9799 self ._start_round = 0
@@ -131,21 +133,34 @@ def start_controller(self, fl_ctx: FLContext):
131133 fl_ctx .set_prop (AppConstants .NUM_ROUNDS , self ._num_rounds , private = True , sticky = True )
132134 self .fire_event (AppEventType .INITIAL_MODEL_LOADED , fl_ctx )
133135
134- self ._participating_clients = self ._engine .get_clients ()
136+ self ._participating_clients : List [ Client ] = self ._engine .get_clients ()
135137 if len (self ._participating_clients ) <= 1 :
136- self .system_panic ("Not enough client sites." , fl_ctx )
138+ self .system_panic (f "Not enough client sites ( { len ( self . _participating_clients ) } ) ." , fl_ctx )
137139 self ._last_client = None
138140
139- def _get_relay_orders (self , fl_ctx : FLContext ):
140- targets = list (self ._participating_clients )
141- if len (targets ) <= 1 :
142- self .system_panic ("Not enough client sites." , fl_ctx )
143- if self ._order == RelayOrder .RANDOM :
144- random .shuffle (targets )
145- elif self ._order == RelayOrder .RANDOM_WITHOUT_SAME_IN_A_ROW :
146- random .shuffle (targets )
147- if self ._last_client == targets [0 ]:
148- targets = targets .append (targets .pop (0 ))
141+ def _get_relay_orders (self , fl_ctx : FLContext ) -> Union [List [Client ], None ]:
142+ if len (self ._participating_clients ) <= 1 :
143+ self .system_panic (f"Not enough client sites ({ len (self ._participating_clients )} )." , fl_ctx )
144+ return None
145+
146+ active_clients_map = {}
147+ for t in self ._participating_clients :
148+ if not self .get_client_death_time (t .name ):
149+ active_clients_map [t .name ] = t
150+
151+ if isinstance (self ._order , list ):
152+ targets = []
153+ for c_name in self ._order :
154+ if c_name not in active_clients_map :
155+ self .system_panic (f"Required client site ({ c_name } ) is not in active clients." , fl_ctx )
156+ return None
157+ targets .append (active_clients_map [c_name ])
158+ else :
159+ targets = list (active_clients_map .values ())
160+ if self ._order == RelayOrder .RANDOM or self ._order == RelayOrder .RANDOM_WITHOUT_SAME_IN_A_ROW :
161+ random .shuffle (targets )
162+ if self ._order == RelayOrder .RANDOM_WITHOUT_SAME_IN_A_ROW and self ._last_client == targets [0 ]:
163+ targets .append (targets .pop (0 ))
149164 self ._last_client = targets [- 1 ]
150165 return targets
151166
@@ -191,22 +206,27 @@ def _process_result(self, client_task: ClientTask, fl_ctx: FLContext):
191206
192207 def control_flow (self , abort_signal : Signal , fl_ctx : FLContext ):
193208 try :
194- self .log_debug (fl_ctx , "Cyclic starting." )
209+ self .log_info (fl_ctx , "Cyclic starting." )
195210
196211 for self ._current_round in range (self ._start_round , self ._end_round ):
197212 if self ._is_done :
213+ self .log_info (fl_ctx , "Cyclic is done." )
198214 return
199215
200216 if abort_signal .triggered :
217+ self .log_info (fl_ctx , "abort signal triggered." )
201218 return
202219
203- self .log_debug (fl_ctx , "Starting current round={}." .format (self ._current_round ))
220+ self .log_info (fl_ctx , "Starting current round={}." .format (self ._current_round ))
204221 fl_ctx .set_prop (AppConstants .CURRENT_ROUND , self ._current_round , private = True , sticky = True )
205222
206223 # Task for one cyclic
207224 targets = self ._get_relay_orders (fl_ctx )
225+ if not targets :
226+ self .log_info (fl_ctx , "No cyclic targets." )
227+ return
208228 targets_names = [t .name for t in targets ]
209- self .log_debug (fl_ctx , f"Relay on { targets_names } " )
229+ self .log_info (fl_ctx , f"Relay on { targets_names } " )
210230
211231 shareable = self .shareable_generator .learnable_to_shareable (self ._last_learnable , fl_ctx )
212232 shareable .set_header (AppConstants .CURRENT_ROUND , self ._current_round )
@@ -241,10 +261,10 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
241261 # Call the self._engine to persist the snapshot of all the FLComponents
242262 self ._engine .persist_components (fl_ctx , completed = False )
243263
244- self .log_debug (fl_ctx , "Ending current round={}." .format (self ._current_round ))
264+ self .log_info (fl_ctx , "Ending current round={}." .format (self ._current_round ))
245265 gc .collect ()
246266
247- self .log_debug (fl_ctx , "Cyclic ended." )
267+ self .log_info (fl_ctx , "Cyclic ended." )
248268 except Exception as e :
249269 error_msg = f"Cyclic control_flow exception: { secure_format_exception (e )} "
250270 self .log_error (fl_ctx , error_msg )
@@ -279,12 +299,3 @@ def restore(self, state_data: dict, fl_ctx: FLContext):
279299 self ._start_round = self ._current_round
280300 finally :
281301 pass
282-
283- def handle_dead_job (self , client_name : str , fl_ctx : FLContext ):
284- super ().handle_dead_job (client_name , fl_ctx )
285-
286- new_client_list = []
287- for client in self ._participating_clients :
288- if client_name != client .name :
289- new_client_list .append (client )
290- self ._participating_clients = new_client_list
0 commit comments