Skip to content

Commit 2b10b9f

Browse files
Client api use exchange task (#2070)
1 parent b5ec419 commit 2b10b9f

File tree

18 files changed

+439
-271
lines changed

18 files changed

+439
-271
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, Optional
16+
17+
18+
class ExchangeTask:
19+
def __init__(self, task_name: str, task_id: str, data: Any, meta: Optional[dict] = None, return_code: str = "ok"):
20+
self.task_name = task_name
21+
self.task_id = task_id
22+
self.meta = meta
23+
self.data = data
24+
self.return_code = return_code
25+
26+
def __str__(self):
27+
return f"Task(name:{self.task_name},id:{self.task_id})"

nvflare/app_common/model_exchange/constants.py renamed to nvflare/app_common/data_exchange/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from enum import Enum
1616

1717

18-
class ModelExchangeFormat(str, Enum):
18+
class ExchangeFormat(str, Enum):
1919
RAW = "raw"
2020
PYTORCH = "pytorch"
2121
NUMPY = "numpy"

nvflare/app_common/model_exchange/model_exchanger.py renamed to nvflare/app_common/data_exchange/data_exchanger.py

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import logging
1616
import time
17-
from typing import Any, Optional, Tuple
17+
from typing import Any, List, Optional, Tuple
1818

1919
from nvflare.fuel.utils.pipe.pipe import Message, Pipe
2020
from nvflare.fuel.utils.pipe.pipe_handler import PipeHandler, Topic
@@ -40,31 +40,33 @@ class ExchangePeerGoneException(DataExchangeException):
4040
pass
4141

4242

43-
class ModelExchanger:
43+
class DataExchanger:
4444
def __init__(
4545
self,
46+
supported_topics: List[str],
4647
pipe: Pipe,
4748
pipe_name: str = "pipe",
48-
topic: str = "data",
4949
get_poll_interval: float = 0.5,
5050
read_interval: float = 0.1,
5151
heartbeat_interval: float = 5.0,
5252
heartbeat_timeout: float = 30.0,
5353
):
54-
"""Initializes the ModelExchanger.
54+
"""Initializes the DataExchanger.
5555
5656
Args:
57+
supported_topics (list[str]): Supported topics for data exchange. This allows the sender and receiver to identify
58+
the purpose or content of the data being exchanged.
5759
pipe (Pipe): The pipe used for data exchange.
5860
pipe_name (str): Name of the pipe. Defaults to "pipe".
59-
topic (str): Topic for data exchange. Defaults to "data".
6061
get_poll_interval (float): Interval for checking if the other side has sent data. Defaults to 0.5.
6162
read_interval (float): Interval for reading from the pipe. Defaults to 0.1.
6263
heartbeat_interval (float): Interval for sending heartbeat to the peer. Defaults to 5.0.
6364
heartbeat_timeout (float): Timeout for waiting for a heartbeat from the peer. Defaults to 30.0.
6465
"""
6566
self.logger = logging.getLogger(self.__class__.__name__)
6667
self._req_id: Optional[str] = None
67-
self._topic = topic
68+
self.current_topic: Optional[str] = None
69+
self._supported_topics = supported_topics
6870

6971
pipe.open(pipe_name)
7072
self.pipe_handler = PipeHandler(
@@ -76,51 +78,56 @@ def __init__(
7678
self.pipe_handler.start()
7779
self._get_poll_interval = get_poll_interval
7880

79-
def submit_model(self, model: Any) -> None:
80-
"""Submits a model for exchange.
81+
def submit_data(self, data: Any) -> None:
82+
"""Submits a data for exchange.
8183
8284
Args:
83-
model (Any): The model to be submitted.
85+
data (Any): The data to be submitted.
8486
8587
Raises:
86-
DataExchangeException: If there is no request ID available (needs to pull model from server first).
88+
DataExchangeException: If there is no request ID available (needs to pull data from server first).
8789
"""
8890
if self._req_id is None:
89-
raise DataExchangeException("need to pull a model first.")
90-
self._send_reply(data=model, req_id=self._req_id)
91+
raise DataExchangeException("Missing req_id, need to pull a data first.")
9192

92-
def receive_model(self, timeout: Optional[float] = None) -> Any:
93-
"""Receives a model.
93+
if self.current_topic is None:
94+
raise DataExchangeException("Missing current_topic, need to pull a data first.")
95+
96+
self._send_reply(data=data, topic=self.current_topic, req_id=self._req_id)
97+
98+
def receive_data(self, timeout: Optional[float] = None) -> Tuple[str, Any]:
99+
"""Receives a data.
94100
95101
Args:
96-
timeout (Optional[float]): Timeout for waiting to receive a model. Defaults to None.
102+
timeout (Optional[float]): Timeout for waiting to receive a data. Defaults to None.
97103
98104
Returns:
99-
Any: The received model.
105+
A tuple of (topic, data): The received data.
100106
101107
Raises:
102108
ExchangeTimeoutException: If the data cannot be received within the specified timeout.
103109
ExchangeAbortException: If the other endpoint of the pipe requests to abort.
104110
ExchangeEndException: If the other endpoint has ended.
105111
ExchangePeerGoneException: If the other endpoint is gone.
106112
"""
107-
model, req_id = self._receive_request(timeout)
108-
self._req_id = req_id
109-
return model
113+
msg = self._receive_request(timeout)
114+
self._req_id = msg.msg_id
115+
self.current_topic = msg.topic
116+
return msg.topic, msg.data
110117

111118
def finalize(self, close_pipe: bool = True) -> None:
112119
if self.pipe_handler is None:
113120
raise RuntimeError("PipeMonitor is not initialized.")
114121
self.pipe_handler.stop(close_pipe=close_pipe)
115122

116-
def _receive_request(self, timeout: Optional[float] = None) -> Tuple[Any, str]:
123+
def _receive_request(self, timeout: Optional[float] = None) -> Message:
117124
"""Receives a request.
118125
119126
Args:
120127
timeout: how long to wait for the request to come.
121128
122129
Returns:
123-
A tuple of (data, request id).
130+
A Message.
124131
125132
Raises:
126133
ExchangeTimeoutException: If can't receive data within timeout seconds.
@@ -138,24 +145,21 @@ def _receive_request(self, timeout: Optional[float] = None) -> Tuple[Any, str]:
138145
self.pipe_handler.notify_abort(msg)
139146
raise ExchangeTimeoutException(f"get data timeout after {timeout} secs")
140147
elif msg.topic == Topic.ABORT:
141-
raise ExchangeAbortException("the other end is aborted")
148+
raise ExchangeAbortException("the other end ask to abort")
142149
elif msg.topic == Topic.END:
143-
raise ExchangeEndException(
144-
f"received {msg.topic}: {msg.data} while waiting for result for {self._topic}"
145-
)
150+
raise ExchangeEndException(f"received msg: '{msg}' while waiting for requests")
146151
elif msg.topic == Topic.PEER_GONE:
147-
raise ExchangePeerGoneException(
148-
f"received {msg.topic}: {msg.data} while waiting for result for {self._topic}"
149-
)
150-
elif msg.topic == self._topic:
151-
return msg.data, msg.msg_id
152+
raise ExchangePeerGoneException(f"received msg: '{msg}' while waiting for requests")
153+
elif msg.topic in self._supported_topics:
154+
return msg
152155
time.sleep(self._get_poll_interval)
153156

154-
def _send_reply(self, data: Any, req_id: str, timeout: Optional[float] = None) -> bool:
157+
def _send_reply(self, data: Any, topic: str, req_id: str, timeout: Optional[float] = None) -> bool:
155158
"""Sends a reply.
156159
157160
Args:
158161
data: The data exchange object to be sent.
162+
topic: message topic.
159163
req_id: request ID.
160164
timeout: how long to wait for the peer to read the data.
161165
If not specified, return False immediately.
@@ -165,6 +169,6 @@ def _send_reply(self, data: Any, req_id: str, timeout: Optional[float] = None) -
165169
"""
166170
if self.pipe_handler is None:
167171
raise RuntimeError("PipeMonitor is not initialized.")
168-
msg = Message.new_reply(topic=self._topic, data=data, req_msg_id=req_id)
172+
msg = Message.new_reply(topic=topic, data=data, req_msg_id=req_id)
169173
has_been_read = self.pipe_handler.send_to_peer(msg, timeout)
170174
return has_been_read

nvflare/app_common/model_exchange/file_pipe_model_exchanger.py renamed to nvflare/app_common/data_exchange/file_pipe_data_exchanger.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,23 @@
1313
# limitations under the License.
1414

1515
import os
16-
from typing import Optional
16+
from typing import List, Optional
1717

1818
from nvflare.apis.utils.decomposers import flare_decomposers
19+
from nvflare.app_common.data_exchange.data_exchanger import DataExchanger
1920
from nvflare.app_common.decomposers import common_decomposers as app_common_decomposers
20-
from nvflare.app_common.model_exchange.model_exchanger import ModelExchanger
2121
from nvflare.fuel.utils.constants import Mode
2222
from nvflare.fuel.utils.pipe.file_accessor import FileAccessor
2323
from nvflare.fuel.utils.pipe.file_pipe import FilePipe
2424

2525

26-
class FilePipeModelExchanger(ModelExchanger):
26+
class FilePipeDataExchanger(DataExchanger):
2727
def __init__(
2828
self,
2929
data_exchange_path: str,
30+
supported_topics: List[str],
3031
file_accessor: Optional[FileAccessor] = None,
3132
pipe_name: str = "pipe",
32-
topic: str = "data",
3333
get_poll_interval: float = 0.5,
3434
read_interval: float = 0.1,
3535
heartbeat_interval: float = 5.0,
@@ -40,14 +40,14 @@ def __init__(
4040
Args:
4141
data_exchange_path (str): The path for data exchange. This is the location where the data
4242
will be read from or written to.
43+
supported_topics (list[str]): Supported topics for data exchange. This allows the sender and receiver to identify
44+
the purpose or content of the data being exchanged.
4345
file_accessor (Optional[FileAccessor]): The file accessor for reading and writing files.
4446
If not provided, the default file accessor (FobsFileAccessor) will be used.
4547
Please refer to the docstring of the FileAccessor class for more information
4648
on implementing a custom file accessor. Defaults to None.
4749
pipe_name (str): The name of the pipe to be used for communication. This pipe will be used
4850
for transmitting data between the sender and receiver. Defaults to "pipe".
49-
topic (str): The topic for data exchange. This allows the sender and receiver to identify
50-
the purpose or content of the data being exchanged. Defaults to "data".
5151
get_poll_interval (float): The interval (in seconds) for checking if the other side has sent data.
5252
This determines how often the receiver checks for incoming data. Defaults to 0.5.
5353
read_interval (float): The interval (in seconds) for reading from the pipe. This determines
@@ -66,9 +66,9 @@ def __init__(
6666
file_pipe.set_file_accessor(file_accessor)
6767

6868
super().__init__(
69+
supported_topics=supported_topics,
6970
pipe=file_pipe,
7071
pipe_name=pipe_name,
71-
topic=topic,
7272
get_poll_interval=get_poll_interval,
7373
read_interval=read_interval,
7474
heartbeat_interval=heartbeat_interval,

nvflare/app_common/decomposers/common_decomposers.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import numpy as np
2222

23+
from nvflare.app_common.abstract.exchange_task import ExchangeTask
2324
from nvflare.app_common.abstract.fl_model import FLModel
2425
from nvflare.app_common.abstract.learnable import Learnable
2526
from nvflare.app_common.abstract.model import ModelLearnable
@@ -29,6 +30,33 @@
2930
from nvflare.fuel.utils.fobs.decomposer import Decomposer, DictDecomposer, Externalizer, Internalizer
3031

3132

33+
class ExchangeTaskDecomposer(fobs.Decomposer):
34+
def supported_type(self):
35+
return ExchangeTask
36+
37+
def decompose(self, b: ExchangeTask, manager: DatumManager = None) -> Any:
38+
externalizer = Externalizer(manager)
39+
return (
40+
b.task_id,
41+
b.task_name,
42+
externalizer.externalize(b.data),
43+
externalizer.externalize(b.meta),
44+
b.return_code,
45+
)
46+
47+
def recompose(self, data: tuple, manager: DatumManager = None) -> ExchangeTask:
48+
assert isinstance(data, tuple)
49+
task_id, task_name, task_data, meta, return_code = data
50+
internalizer = Internalizer(manager)
51+
return ExchangeTask(
52+
task_name=task_name,
53+
task_id=task_id,
54+
data=internalizer.internalize(task_data),
55+
meta=internalizer.internalize(meta),
56+
return_code=return_code,
57+
)
58+
59+
3260
class FLModelDecomposer(fobs.Decomposer):
3361
def supported_type(self):
3462
return FLModel

nvflare/app_opt/lightning/api.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torch import Tensor
2020

2121
from nvflare.app_common.abstract.fl_model import FLModel, MetaKey
22-
from nvflare.client.api import clear, get_config, init, receive, send
22+
from nvflare.client.api import clear, get_config, init, is_evaluate, is_train, receive, send
2323
from nvflare.client.config import ConfigKey
2424

2525
FL_META_KEY = "__fl_meta__"
@@ -41,8 +41,7 @@ class FLCallback(Callback):
4141
def __init__(self, rank: int = 0):
4242
super(FLCallback, self).__init__()
4343
init(rank=str(rank))
44-
self.has_global_eval = get_config().get(ConfigKey.GLOBAL_EVAL, False)
45-
self.has_training = get_config().get(ConfigKey.TRAINING, False)
44+
self.train_with_evaluation = get_config().get(ConfigKey.TRAIN_WITH_EVAL, False)
4645
self.current_round = None
4746
self.metrics = None
4847
self.total_local_epochs = 0
@@ -59,15 +58,15 @@ def reset_state(self, trainer):
5958
"""
6059
# set states for next round
6160
if self.current_round is not None:
62-
if self.current_round == 0:
61+
if self.max_epochs_per_round is None:
6362
if trainer.max_epochs and trainer.max_epochs > 0:
6463
self.max_epochs_per_round = trainer.max_epochs
6564
if trainer.max_steps and trainer.max_steps > 0:
6665
self.max_steps_per_round = trainer.max_steps
6766

6867
# record total local epochs/steps
6968
self.total_local_epochs = trainer.current_epoch
70-
self.total_local_steps += trainer.estimated_stepping_batches
69+
self.total_local_steps = trainer.estimated_stepping_batches
7170

7271
# for next round
7372
trainer.num_sanity_val_steps = 0 # Turn off sanity validation steps in following rounds of FL
@@ -82,11 +81,11 @@ def reset_state(self, trainer):
8281

8382
def on_train_start(self, trainer, pl_module):
8483
# receive the global model and update the local model with global model
85-
if self.has_training:
84+
if is_train():
8685
self._receive_and_update_model(trainer, pl_module)
8786

8887
def on_train_end(self, trainer, pl_module):
89-
if self.has_training:
88+
if is_train():
9089
if hasattr(pl_module, FL_META_KEY):
9190
fl_meta = getattr(pl_module, FL_META_KEY)
9291
if not isinstance(fl_meta, dict):
@@ -105,13 +104,15 @@ def on_validation_start(self, trainer, pl_module):
105104
# the metrics will be set.
106105
# The subsequence validate() calls will not trigger the receive update model.
107106
# Hence the validate() will be validating the local model.
108-
if pl_module and self.has_global_eval and self.metrics is None:
109-
self._receive_and_update_model(trainer, pl_module)
107+
if (is_train() and self.train_with_evaluation) or is_evaluate():
108+
if pl_module and self.metrics is None:
109+
self._receive_and_update_model(trainer, pl_module)
110110

111111
def on_validation_end(self, trainer, pl_module):
112-
if pl_module and self.has_global_eval and self.metrics is None:
113-
self.metrics = _extract_metrics(trainer.callback_metrics)
114-
self._send_model(FLModel(metrics=self.metrics))
112+
if (is_train() and self.train_with_evaluation) or is_evaluate():
113+
if pl_module and self.metrics is None:
114+
self.metrics = _extract_metrics(trainer.callback_metrics)
115+
self._send_model(FLModel(metrics=self.metrics))
115116

116117
def _receive_and_update_model(self, trainer, pl_module):
117118
model = self._receive_model(trainer)

nvflare/client/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,17 @@
1818
from nvflare.app_common.abstract.fl_model import FLModel as FLModel
1919
from nvflare.app_common.abstract.fl_model import ParamsType as ParamsType
2020

21+
from .api import DataExchangeException as DataExchangeException
2122
from .api import clear as clear
2223
from .api import get_config as get_config
2324
from .api import get_job_id as get_job_id
2425
from .api import get_site_name as get_site_name
2526
from .api import get_total_rounds as get_total_rounds
2627
from .api import init as init
28+
from .api import is_evaluate as is_evaluate
29+
from .api import is_running as is_running
30+
from .api import is_submit_model as is_submit_model
31+
from .api import is_train as is_train
2732
from .api import params_diff as params_diff
2833
from .api import receive as receive
2934
from .api import receive_global_model as receive_global_model

0 commit comments

Comments
 (0)