Skip to content

Commit 9e02ba2

Browse files
YuanTingHsiehIsaacYangSLA
authored andcommitted
Fix handle dead job logic
1 parent eaa7570 commit 9e02ba2

File tree

1 file changed

+45
-34
lines changed

1 file changed

+45
-34
lines changed

nvflare/app_common/workflows/cyclic_ctl.py

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import gc
1616
import random
17+
from typing import List, Union
1718

1819
from nvflare.apis.client import Client
1920
from nvflare.apis.fl_constant import ReturnCode
@@ -48,7 +49,7 @@ def __init__(
4849
task_check_period: float = 0.5,
4950
persist_every_n_rounds: int = 1,
5051
snapshot_every_n_rounds: int = 1,
51-
order: str = RelayOrder.FIXED,
52+
order: Union[str, List[str]] = RelayOrder.FIXED,
5253
allow_early_termination=False,
5354
):
5455
"""A sample implementation to demonstrate how to use relay method for Cyclic Federated Learning.
@@ -66,11 +67,12 @@ def __init__(
6667
If n is 0 then no persist.
6768
snapshot_every_n_rounds (int, optional): persist the server state every n rounds. Defaults to 1.
6869
If n is 0 then no persist.
69-
order (str, optional): the order of relay.
70-
If FIXED means the same order for every round.
71-
If RANDOM means random order for every round.
72-
If RANDOM_WITHOUT_SAME_IN_A_ROW means every round the order gets shuffled but a client will never be
73-
run twice in a row (in different round).
70+
order (Union[str, List[str]], optional): The order of relay.
71+
- If a string is provided:
72+
- "FIXED": Same order for every round.
73+
- "RANDOM": Random order for every round.
74+
- "RANDOM_WITHOUT_SAME_IN_A_ROW": Shuffled order, no repetition in consecutive rounds.
75+
- If a list of strings is provided, it represents a custom order for relay.
7476
allow_early_termination: whether to allow early workflow termination from clients
7577
7678
Raises:
@@ -90,8 +92,8 @@ def __init__(
9092
if not isinstance(task_name, str):
9193
raise TypeError("task_name must be a string but got {}".format(type(task_name)))
9294

93-
if order not in SUPPORTED_ORDERS:
94-
raise ValueError(f"order must be in {SUPPORTED_ORDERS}")
95+
if order not in SUPPORTED_ORDERS and not isinstance(order, list):
96+
raise ValueError(f"order must be in {SUPPORTED_ORDERS} or a list")
9597

9698
self._num_rounds = num_rounds
9799
self._start_round = 0
@@ -131,21 +133,34 @@ def start_controller(self, fl_ctx: FLContext):
131133
fl_ctx.set_prop(AppConstants.NUM_ROUNDS, self._num_rounds, private=True, sticky=True)
132134
self.fire_event(AppEventType.INITIAL_MODEL_LOADED, fl_ctx)
133135

134-
self._participating_clients = self._engine.get_clients()
136+
self._participating_clients: List[Client] = self._engine.get_clients()
135137
if len(self._participating_clients) <= 1:
136-
self.system_panic("Not enough client sites.", fl_ctx)
138+
self.system_panic(f"Not enough client sites ({len(self._participating_clients)}).", fl_ctx)
137139
self._last_client = None
138140

139-
def _get_relay_orders(self, fl_ctx: FLContext):
140-
targets = list(self._participating_clients)
141-
if len(targets) <= 1:
142-
self.system_panic("Not enough client sites.", fl_ctx)
143-
if self._order == RelayOrder.RANDOM:
144-
random.shuffle(targets)
145-
elif self._order == RelayOrder.RANDOM_WITHOUT_SAME_IN_A_ROW:
146-
random.shuffle(targets)
147-
if self._last_client == targets[0]:
148-
targets = targets.append(targets.pop(0))
141+
def _get_relay_orders(self, fl_ctx: FLContext) -> Union[List[Client], None]:
142+
if len(self._participating_clients) <= 1:
143+
self.system_panic(f"Not enough client sites ({len(self._participating_clients)}).", fl_ctx)
144+
return None
145+
146+
active_clients_map = {}
147+
for t in self._participating_clients:
148+
if not self.get_client_death_time(t.name):
149+
active_clients_map[t.name] = t
150+
151+
if isinstance(self._order, list):
152+
targets = []
153+
for c_name in self._order:
154+
if c_name not in active_clients_map:
155+
self.system_panic(f"Required client site ({c_name}) is not in active clients.", fl_ctx)
156+
return None
157+
targets.append(active_clients_map[c_name])
158+
else:
159+
targets = list(active_clients_map.values())
160+
if self._order == RelayOrder.RANDOM or self._order == RelayOrder.RANDOM_WITHOUT_SAME_IN_A_ROW:
161+
random.shuffle(targets)
162+
if self._order == RelayOrder.RANDOM_WITHOUT_SAME_IN_A_ROW and self._last_client == targets[0]:
163+
targets.append(targets.pop(0))
149164
self._last_client = targets[-1]
150165
return targets
151166

@@ -191,22 +206,27 @@ def _process_result(self, client_task: ClientTask, fl_ctx: FLContext):
191206

192207
def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
193208
try:
194-
self.log_debug(fl_ctx, "Cyclic starting.")
209+
self.log_info(fl_ctx, "Cyclic starting.")
195210

196211
for self._current_round in range(self._start_round, self._end_round):
197212
if self._is_done:
213+
self.log_info(fl_ctx, "Cyclic is done.")
198214
return
199215

200216
if abort_signal.triggered:
217+
self.log_info(fl_ctx, "abort signal triggered.")
201218
return
202219

203-
self.log_debug(fl_ctx, "Starting current round={}.".format(self._current_round))
220+
self.log_info(fl_ctx, "Starting current round={}.".format(self._current_round))
204221
fl_ctx.set_prop(AppConstants.CURRENT_ROUND, self._current_round, private=True, sticky=True)
205222

206223
# Task for one cyclic
207224
targets = self._get_relay_orders(fl_ctx)
225+
if not targets:
226+
self.log_info(fl_ctx, "No cyclic targets.")
227+
return
208228
targets_names = [t.name for t in targets]
209-
self.log_debug(fl_ctx, f"Relay on {targets_names}")
229+
self.log_info(fl_ctx, f"Relay on {targets_names}")
210230

211231
shareable = self.shareable_generator.learnable_to_shareable(self._last_learnable, fl_ctx)
212232
shareable.set_header(AppConstants.CURRENT_ROUND, self._current_round)
@@ -241,10 +261,10 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
241261
# Call the self._engine to persist the snapshot of all the FLComponents
242262
self._engine.persist_components(fl_ctx, completed=False)
243263

244-
self.log_debug(fl_ctx, "Ending current round={}.".format(self._current_round))
264+
self.log_info(fl_ctx, "Ending current round={}.".format(self._current_round))
245265
gc.collect()
246266

247-
self.log_debug(fl_ctx, "Cyclic ended.")
267+
self.log_info(fl_ctx, "Cyclic ended.")
248268
except Exception as e:
249269
error_msg = f"Cyclic control_flow exception: {secure_format_exception(e)}"
250270
self.log_error(fl_ctx, error_msg)
@@ -279,12 +299,3 @@ def restore(self, state_data: dict, fl_ctx: FLContext):
279299
self._start_round = self._current_round
280300
finally:
281301
pass
282-
283-
def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
284-
super().handle_dead_job(client_name, fl_ctx)
285-
286-
new_client_list = []
287-
for client in self._participating_clients:
288-
if client_name != client.name:
289-
new_client_list.append(client)
290-
self._participating_clients = new_client_list

0 commit comments

Comments
 (0)