1414
1515import gc
1616import random
17+ from typing import List , Union
1718
1819from nvflare .apis .client import Client
1920from nvflare .apis .fl_constant import ReturnCode
@@ -48,7 +49,8 @@ def __init__(
4849 task_check_period : float = 0.5 ,
4950 persist_every_n_rounds : int = 1 ,
5051 snapshot_every_n_rounds : int = 1 ,
51- order : str = RelayOrder .FIXED ,
52+ order : Union [str , List [str ]] = RelayOrder .FIXED ,
53+ allow_early_termination = False ,
5254 ):
5355 """A sample implementation to demonstrate how to use relay method for Cyclic Federated Learning.
5456
@@ -65,11 +67,13 @@ def __init__(
6567 If n is 0 then no persist.
6668 snapshot_every_n_rounds (int, optional): persist the server state every n rounds. Defaults to 1.
6769 If n is 0 then no persist.
68- order (str, optional): the order of relay.
69- If FIXED means the same order for every round.
70- If RANDOM means random order for every round.
71- If RANDOM_WITHOUT_SAME_IN_A_ROW means every round the order gets shuffled but a client will never be
72- run twice in a row (in different round).
70+ order (Union[str, List[str]], optional): The order of relay.
71+ - If a string is provided:
72+ - "FIXED": Same order for every round.
73+ - "RANDOM": Random order for every round.
74+ - "RANDOM_WITHOUT_SAME_IN_A_ROW": Shuffled order, no repetition in consecutive rounds.
75+ - If a list of strings is provided, it represents a custom order for relay.
76+ allow_early_termination: whether to allow early workflow termination from clients
7377
7478 Raises:
7579 TypeError: when any of input arguments does not have correct type
@@ -88,13 +92,14 @@ def __init__(
8892 if not isinstance (task_name , str ):
8993 raise TypeError ("task_name must be a string but got {}" .format (type (task_name )))
9094
91- if order not in SUPPORTED_ORDERS :
92- raise ValueError (f"order must be in { SUPPORTED_ORDERS } " )
95+ if order not in SUPPORTED_ORDERS and not isinstance ( order , list ) :
96+ raise ValueError (f"order must be in { SUPPORTED_ORDERS } or a list " )
9397
9498 self ._num_rounds = num_rounds
9599 self ._start_round = 0
96100 self ._end_round = self ._start_round + self ._num_rounds
97101 self ._current_round = 0
102+ self ._is_done = False
98103 self ._last_learnable = None
99104 self .persistor_id = persistor_id
100105 self .shareable_generator_id = shareable_generator_id
@@ -107,6 +112,7 @@ def __init__(
107112 self ._participating_clients = None
108113 self ._last_client = None
109114 self ._order = order
115+ self ._allow_early_termination = allow_early_termination
110116
111117 def start_controller (self , fl_ctx : FLContext ):
112118 self .log_debug (fl_ctx , "starting controller" )
@@ -127,46 +133,79 @@ def start_controller(self, fl_ctx: FLContext):
127133 fl_ctx .set_prop (AppConstants .NUM_ROUNDS , self ._num_rounds , private = True , sticky = True )
128134 self .fire_event (AppEventType .INITIAL_MODEL_LOADED , fl_ctx )
129135
130- self ._participating_clients = self ._engine .get_clients ()
136+ self ._participating_clients : List [ Client ] = self ._engine .get_clients ()
131137 if len (self ._participating_clients ) <= 1 :
132138 self .system_panic ("Not enough client sites." , fl_ctx )
133139 self ._last_client = None
134140
135- def _get_relay_orders (self , fl_ctx : FLContext ):
136- targets = list (self ._participating_clients )
137- if len (targets ) <= 1 :
138- self .system_panic ("Not enough client sites." , fl_ctx )
139- if self ._order == RelayOrder .RANDOM :
140- random .shuffle (targets )
141- elif self ._order == RelayOrder .RANDOM_WITHOUT_SAME_IN_A_ROW :
142- random .shuffle (targets )
143- if self ._last_client == targets [0 ]:
144- targets = targets .append (targets .pop (0 ))
141+ def _get_relay_orders (self , fl_ctx : FLContext ) -> Union [List [Client ], None ]:
142+ if len (self ._participating_clients ) <= 1 :
143+ self .system_panic (f"Not enough client sites ({ len (self ._participating_clients )} )." , fl_ctx )
144+ return None
145+
146+ if isinstance (self ._order , list ):
147+ targets = []
148+ active_clients_map = {t .name : t for t in self ._participating_clients }
149+ for c_name in self ._order :
150+ if c_name not in active_clients_map :
151+ self .system_panic (f"Required client site ({ c_name } ) is not in active clients." , fl_ctx )
152+ return None
153+ targets .append (active_clients_map [c_name ])
154+ else :
155+ targets = list (self ._participating_clients )
156+ if self ._order == RelayOrder .RANDOM or self ._order == RelayOrder .RANDOM_WITHOUT_SAME_IN_A_ROW :
157+ random .shuffle (targets )
158+ if self ._order == RelayOrder .RANDOM_WITHOUT_SAME_IN_A_ROW and self ._last_client == targets [0 ]:
159+ targets .append (targets .pop (0 ))
145160 self ._last_client = targets [- 1 ]
146161 return targets
147162
148- def _process_result (self , client_task : ClientTask , fl_ctx : FLContext ):
149- result = client_task .result
150- rc = result .get_return_code ()
151- client_name = client_task .client .name
152-
153- # Raise errors if ReturnCode is not OK.
154- if rc and rc != ReturnCode .OK :
155- self .system_panic (
156- f"Result from { client_name } is bad, error code: { rc } . "
157- f"{ self .__class__ .__name__ } exiting at round { self ._current_round } ." ,
158- fl_ctx = fl_ctx ,
159- )
160- return False
163+ def _stop_workflow (self , task : Task ):
164+ self .cancel_task (task )
165+ self ._is_done = True
161166
167+ def _process_result (self , client_task : ClientTask , fl_ctx : FLContext ):
162168 # submitted shareable is stored in client_task.result
163169 # we need to update task.data with that shareable so the next target
164170 # will get the updated shareable
165171 task = client_task .task
166172
167- # update the global learnable with the received result (shareable)
168- # e.g. the received result could be weight_diffs, the learnable could be full weights.
169- self ._last_learnable = self .shareable_generator .shareable_to_learnable (client_task .result , fl_ctx )
173+ result = client_task .result
174+ if isinstance (result , Shareable ):
175+ # update the global learnable with the received result (shareable)
176+ # e.g. the received result could be weight_diffs, the learnable could be full weights.
177+ rc = result .get_return_code ()
178+ try :
179+ self ._last_learnable = self .shareable_generator .shareable_to_learnable (result , fl_ctx )
180+ except Exception as ex :
181+ if rc != ReturnCode .EARLY_TERMINATION :
182+ self ._stop_workflow (task )
183+ self .log_error (fl_ctx , f"exception { secure_format_exception (ex )} from shareable_to_learnable" )
184+ return
185+ else :
186+ self .log_warning (
187+ fl_ctx ,
188+ f"ignored { secure_format_exception (ex )} from shareable_to_learnable in early termination" ,
189+ )
190+
191+ if rc == ReturnCode .EARLY_TERMINATION :
192+ if self ._allow_early_termination :
193+ # the workflow is done
194+ self ._stop_workflow (task )
195+ self .log_info (fl_ctx , f"Stopping workflow due to { rc } from client { client_task .client .name } " )
196+ return
197+ else :
198+ self .log_warning (
199+ fl_ctx ,
200+ f"Ignored { rc } from client { client_task .client .name } because early termination is not allowed" ,
201+ )
202+ else :
203+ self ._stop_workflow (task )
204+ self .log_error (
205+ fl_ctx ,
206+ f"Stopping workflow due to result from client { client_task .client .name } is not a Shareable" ,
207+ )
208+ return
170209
171210 # prepare task shareable data for next client
172211 task .data = self .shareable_generator .learnable_to_shareable (self ._last_learnable , fl_ctx )
@@ -179,6 +218,9 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
179218 self .log_debug (fl_ctx , "Cyclic starting." )
180219
181220 for self ._current_round in range (self ._start_round , self ._end_round ):
221+ if self ._is_done :
222+ return
223+
182224 if abort_signal .triggered :
183225 return
184226
@@ -187,6 +229,8 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
187229
188230 # Task for one cyclic
189231 targets = self ._get_relay_orders (fl_ctx )
232+ if targets is None :
233+ return
190234 targets_names = [t .name for t in targets ]
191235 self .log_debug (fl_ctx , f"Relay on { targets_names } " )
192236
0 commit comments