Skip to content

Commit fb8f7c5

Browse files
[2.4] Add custom order and early termination to cyclic controller (#2422)
* Add custom order and early termination to CyclicController and add tests * Add more error handling
1 parent cd9237f commit fb8f7c5

File tree

3 files changed

+214
-35
lines changed

3 files changed

+214
-35
lines changed

nvflare/apis/fl_constant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class ReturnCode(object):
3939
VALIDATE_TYPE_UNKNOWN = "VALIDATE_TYPE_UNKNOWN"
4040
EMPTY_RESULT = "EMPTY_RESULT"
4141
UNSAFE_JOB = "UNSAFE_JOB"
42+
EARLY_TERMINATION = "EARLY_TERMINATION"
4243
SERVER_NOT_READY = "SERVER_NOT_READY"
4344
SERVICE_UNAVAILABLE = "SERVICE_UNAVAILABLE"
4445

nvflare/app_common/workflows/cyclic_ctl.py

Lines changed: 79 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import gc
1616
import random
17+
from typing import List, Union
1718

1819
from nvflare.apis.client import Client
1920
from nvflare.apis.fl_constant import ReturnCode
@@ -48,7 +49,8 @@ def __init__(
4849
task_check_period: float = 0.5,
4950
persist_every_n_rounds: int = 1,
5051
snapshot_every_n_rounds: int = 1,
51-
order: str = RelayOrder.FIXED,
52+
order: Union[str, List[str]] = RelayOrder.FIXED,
53+
allow_early_termination=False,
5254
):
5355
"""A sample implementation to demonstrate how to use relay method for Cyclic Federated Learning.
5456
@@ -65,11 +67,13 @@ def __init__(
6567
If n is 0 then no persist.
6668
snapshot_every_n_rounds (int, optional): persist the server state every n rounds. Defaults to 1.
6769
If n is 0 then no persist.
68-
order (str, optional): the order of relay.
69-
If FIXED means the same order for every round.
70-
If RANDOM means random order for every round.
71-
If RANDOM_WITHOUT_SAME_IN_A_ROW means every round the order gets shuffled but a client will never be
72-
run twice in a row (in different round).
70+
order (Union[str, List[str]], optional): The order of relay.
71+
- If a string is provided:
72+
- "FIXED": Same order for every round.
73+
- "RANDOM": Random order for every round.
74+
- "RANDOM_WITHOUT_SAME_IN_A_ROW": Shuffled order, no repetition in consecutive rounds.
75+
- If a list of strings is provided, it represents a custom order for relay.
76+
allow_early_termination: whether to allow early workflow termination from clients
7377
7478
Raises:
7579
TypeError: when any of input arguments does not have correct type
@@ -88,13 +92,14 @@ def __init__(
8892
if not isinstance(task_name, str):
8993
raise TypeError("task_name must be a string but got {}".format(type(task_name)))
9094

91-
if order not in SUPPORTED_ORDERS:
92-
raise ValueError(f"order must be in {SUPPORTED_ORDERS}")
95+
if order not in SUPPORTED_ORDERS and not isinstance(order, list):
96+
raise ValueError(f"order must be in {SUPPORTED_ORDERS} or a list")
9397

9498
self._num_rounds = num_rounds
9599
self._start_round = 0
96100
self._end_round = self._start_round + self._num_rounds
97101
self._current_round = 0
102+
self._is_done = False
98103
self._last_learnable = None
99104
self.persistor_id = persistor_id
100105
self.shareable_generator_id = shareable_generator_id
@@ -107,6 +112,7 @@ def __init__(
107112
self._participating_clients = None
108113
self._last_client = None
109114
self._order = order
115+
self._allow_early_termination = allow_early_termination
110116

111117
def start_controller(self, fl_ctx: FLContext):
112118
self.log_debug(fl_ctx, "starting controller")
@@ -127,46 +133,79 @@ def start_controller(self, fl_ctx: FLContext):
127133
fl_ctx.set_prop(AppConstants.NUM_ROUNDS, self._num_rounds, private=True, sticky=True)
128134
self.fire_event(AppEventType.INITIAL_MODEL_LOADED, fl_ctx)
129135

130-
self._participating_clients = self._engine.get_clients()
136+
self._participating_clients: List[Client] = self._engine.get_clients()
131137
if len(self._participating_clients) <= 1:
132138
self.system_panic("Not enough client sites.", fl_ctx)
133139
self._last_client = None
134140

135-
def _get_relay_orders(self, fl_ctx: FLContext):
136-
targets = list(self._participating_clients)
137-
if len(targets) <= 1:
138-
self.system_panic("Not enough client sites.", fl_ctx)
139-
if self._order == RelayOrder.RANDOM:
140-
random.shuffle(targets)
141-
elif self._order == RelayOrder.RANDOM_WITHOUT_SAME_IN_A_ROW:
142-
random.shuffle(targets)
143-
if self._last_client == targets[0]:
144-
targets = targets.append(targets.pop(0))
141+
def _get_relay_orders(self, fl_ctx: FLContext) -> Union[List[Client], None]:
142+
if len(self._participating_clients) <= 1:
143+
self.system_panic(f"Not enough client sites ({len(self._participating_clients)}).", fl_ctx)
144+
return None
145+
146+
if isinstance(self._order, list):
147+
targets = []
148+
active_clients_map = {t.name: t for t in self._participating_clients}
149+
for c_name in self._order:
150+
if c_name not in active_clients_map:
151+
self.system_panic(f"Required client site ({c_name}) is not in active clients.", fl_ctx)
152+
return None
153+
targets.append(active_clients_map[c_name])
154+
else:
155+
targets = list(self._participating_clients)
156+
if self._order == RelayOrder.RANDOM or self._order == RelayOrder.RANDOM_WITHOUT_SAME_IN_A_ROW:
157+
random.shuffle(targets)
158+
if self._order == RelayOrder.RANDOM_WITHOUT_SAME_IN_A_ROW and self._last_client == targets[0]:
159+
targets.append(targets.pop(0))
145160
self._last_client = targets[-1]
146161
return targets
147162

148-
def _process_result(self, client_task: ClientTask, fl_ctx: FLContext):
149-
result = client_task.result
150-
rc = result.get_return_code()
151-
client_name = client_task.client.name
152-
153-
# Raise errors if ReturnCode is not OK.
154-
if rc and rc != ReturnCode.OK:
155-
self.system_panic(
156-
f"Result from {client_name} is bad, error code: {rc}. "
157-
f"{self.__class__.__name__} exiting at round {self._current_round}.",
158-
fl_ctx=fl_ctx,
159-
)
160-
return False
163+
def _stop_workflow(self, task: Task):
164+
self.cancel_task(task)
165+
self._is_done = True
161166

167+
def _process_result(self, client_task: ClientTask, fl_ctx: FLContext):
162168
# submitted shareable is stored in client_task.result
163169
# we need to update task.data with that shareable so the next target
164170
# will get the updated shareable
165171
task = client_task.task
166172

167-
# update the global learnable with the received result (shareable)
168-
# e.g. the received result could be weight_diffs, the learnable could be full weights.
169-
self._last_learnable = self.shareable_generator.shareable_to_learnable(client_task.result, fl_ctx)
173+
result = client_task.result
174+
if isinstance(result, Shareable):
175+
# update the global learnable with the received result (shareable)
176+
# e.g. the received result could be weight_diffs, the learnable could be full weights.
177+
rc = result.get_return_code()
178+
try:
179+
self._last_learnable = self.shareable_generator.shareable_to_learnable(result, fl_ctx)
180+
except Exception as ex:
181+
if rc != ReturnCode.EARLY_TERMINATION:
182+
self._stop_workflow(task)
183+
self.log_error(fl_ctx, f"exception {secure_format_exception(ex)} from shareable_to_learnable")
184+
return
185+
else:
186+
self.log_warning(
187+
fl_ctx,
188+
f"ignored {secure_format_exception(ex)} from shareable_to_learnable in early termination",
189+
)
190+
191+
if rc == ReturnCode.EARLY_TERMINATION:
192+
if self._allow_early_termination:
193+
# the workflow is done
194+
self._stop_workflow(task)
195+
self.log_info(fl_ctx, f"Stopping workflow due to {rc} from client {client_task.client.name}")
196+
return
197+
else:
198+
self.log_warning(
199+
fl_ctx,
200+
f"Ignored {rc} from client {client_task.client.name} because early termination is not allowed",
201+
)
202+
else:
203+
self._stop_workflow(task)
204+
self.log_error(
205+
fl_ctx,
206+
f"Stopping workflow due to result from client {client_task.client.name} is not a Shareable",
207+
)
208+
return
170209

171210
# prepare task shareable data for next client
172211
task.data = self.shareable_generator.learnable_to_shareable(self._last_learnable, fl_ctx)
@@ -179,6 +218,9 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
179218
self.log_debug(fl_ctx, "Cyclic starting.")
180219

181220
for self._current_round in range(self._start_round, self._end_round):
221+
if self._is_done:
222+
return
223+
182224
if abort_signal.triggered:
183225
return
184226

@@ -187,6 +229,8 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
187229

188230
# Task for one cyclic
189231
targets = self._get_relay_orders(fl_ctx)
232+
if targets is None:
233+
return
190234
targets_names = [t.name for t in targets]
191235
self.log_debug(fl_ctx, f"Relay on {targets_names}")
192236

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import uuid
17+
from unittest.mock import Mock, patch
18+
19+
import pytest
20+
21+
from nvflare.apis.client import Client
22+
from nvflare.apis.controller_spec import ClientTask, Task
23+
from nvflare.apis.fl_constant import ReturnCode
24+
from nvflare.apis.fl_context import FLContext
25+
from nvflare.apis.shareable import Shareable
26+
from nvflare.apis.signal import Signal
27+
from nvflare.app_common.abstract.learnable import Learnable
28+
from nvflare.app_common.workflows.cyclic_ctl import CyclicController, RelayOrder
29+
30+
SITE_1_ID = uuid.uuid4()
31+
SITE_2_ID = uuid.uuid4()
32+
SITE_3_ID = uuid.uuid4()
33+
34+
ORDER_TEST_CASES = [
35+
(
36+
RelayOrder.FIXED,
37+
[Client("site-1", SITE_1_ID), Client("site-2", SITE_2_ID)],
38+
[Client("site-1", SITE_1_ID), Client("site-2", SITE_2_ID)],
39+
),
40+
(
41+
["site-1", "site-2"],
42+
[Client("site-1", SITE_1_ID), Client("site-2", SITE_2_ID)],
43+
[Client("site-1", SITE_1_ID), Client("site-2", SITE_2_ID)],
44+
),
45+
(
46+
["site-2", "site-1"],
47+
[Client("site-1", SITE_1_ID), Client("site-2", SITE_2_ID)],
48+
[Client("site-2", SITE_2_ID), Client("site-1", SITE_1_ID)],
49+
),
50+
(
51+
["site-2", "site-1", "site-3"],
52+
[Client("site-3", SITE_3_ID), Client("site-1", SITE_1_ID), Client("site-2", SITE_2_ID)],
53+
[Client("site-2", SITE_2_ID), Client("site-1", SITE_1_ID), Client("site-3", SITE_3_ID)],
54+
),
55+
]
56+
57+
58+
def gen_shareable(is_early_termination: bool = False, is_not_shareable: bool = False):
59+
if is_not_shareable:
60+
return [1, 2, 3]
61+
return_result = Shareable()
62+
if is_early_termination:
63+
return_result.set_return_code(ReturnCode.EARLY_TERMINATION)
64+
return return_result
65+
66+
67+
PROCESS_RESULT_TEST_CASES = [gen_shareable(is_early_termination=True), gen_shareable(is_not_shareable=True)]
68+
69+
70+
class TestCyclicController:
71+
@pytest.mark.parametrize("order,active_clients,expected_result", ORDER_TEST_CASES)
72+
def test_get_relay_orders(self, order, active_clients, expected_result):
73+
ctl = CyclicController(order=order)
74+
ctx = FLContext()
75+
ctl._participating_clients = active_clients
76+
targets = ctl._get_relay_orders(ctx)
77+
for c, e_c in zip(targets, expected_result):
78+
assert c.name == e_c.name
79+
assert c.token == e_c.token
80+
81+
def test_control_flow_call_relay_and_wait(self):
82+
83+
with patch("nvflare.app_common.workflows.cyclic_ctl.CyclicController.relay_and_wait") as mock_method:
84+
ctl = CyclicController(persist_every_n_rounds=0, snapshot_every_n_rounds=0, num_rounds=1)
85+
ctl.shareable_generator = Mock()
86+
ctl._participating_clients = [
87+
Client("site-3", SITE_3_ID),
88+
Client("site-1", SITE_1_ID),
89+
Client("site-2", SITE_2_ID),
90+
]
91+
92+
abort_signal = Signal()
93+
fl_ctx = FLContext()
94+
95+
with patch.object(ctl.shareable_generator, "learnable_to_shareable") as mock_method1, patch.object(
96+
ctl.shareable_generator, "shareable_to_learnable"
97+
) as mock_method2:
98+
mock_method1.return_value = Shareable()
99+
mock_method2.return_value = Learnable()
100+
101+
ctl.control_flow(abort_signal, fl_ctx)
102+
103+
mock_method.assert_called_once()
104+
105+
@pytest.mark.parametrize("return_result", PROCESS_RESULT_TEST_CASES)
106+
def test_process_result(self, return_result):
107+
ctl = CyclicController(
108+
persist_every_n_rounds=0, snapshot_every_n_rounds=0, num_rounds=1, allow_early_termination=True
109+
)
110+
ctl.shareable_generator = Mock()
111+
ctl._participating_clients = [
112+
Client("site-3", SITE_3_ID),
113+
Client("site-1", SITE_1_ID),
114+
Client("site-2", SITE_2_ID),
115+
]
116+
117+
fl_ctx = FLContext()
118+
with patch.object(ctl, "cancel_task") as mock_method, patch.object(
119+
ctl.shareable_generator, "learnable_to_shareable"
120+
) as mock_method1, patch.object(ctl.shareable_generator, "shareable_to_learnable") as mock_method2:
121+
mock_method1.return_value = Shareable()
122+
mock_method2.return_value = Learnable()
123+
124+
client_task = ClientTask(
125+
client=Mock(),
126+
task=Task(
127+
name="__test_task",
128+
data=Shareable(),
129+
),
130+
)
131+
client_task.result = return_result
132+
ctl._process_result(client_task, fl_ctx)
133+
mock_method.assert_called_once()
134+
assert ctl._is_done is True

0 commit comments

Comments
 (0)