Skip to content

Commit e956fea

Browse files
Add support of just doing metrics streaming with client api (#2763)
* Add support of just doing metrics streaming with client api * Address review comments
1 parent 7b01b0f commit e956fea

File tree

5 files changed

+66
-35
lines changed

5 files changed

+66
-35
lines changed

nvflare/app_common/widgets/metric_relay.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(
3333
pipe_id: str,
3434
read_interval=0.1,
3535
heartbeat_interval=5.0,
36+
heartbeat_timeout=60.0,
3637
pipe_channel_name=PipeChannelName.METRIC,
3738
event_type: str = ANALYTIC_EVENT_TYPE,
3839
fed_event: bool = True,
@@ -41,6 +42,7 @@ def __init__(
4142
self.pipe_id = pipe_id
4243
self._read_interval = read_interval
4344
self._heartbeat_interval = heartbeat_interval
45+
self._heartbeat_timeout = heartbeat_timeout
4446
self.pipe_channel_name = pipe_channel_name
4547
self.pipe = None
4648
self.pipe_handler = None
@@ -62,7 +64,7 @@ def handle_event(self, event_type: str, fl_ctx: FLContext):
6264
pipe=self.pipe,
6365
read_interval=self._read_interval,
6466
heartbeat_interval=self._heartbeat_interval,
65-
heartbeat_timeout=0,
67+
heartbeat_timeout=self._heartbeat_timeout,
6668
)
6769
self.pipe_handler.set_status_cb(self._pipe_status_cb)
6870
self.pipe_handler.set_message_cb(self._pipe_msg_cb)

nvflare/client/api.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
import logging
1416
import os
1517
from enum import Enum
1618
from typing import Any, Dict, Optional
@@ -45,12 +47,14 @@ def init(rank: Optional[str] = None):
4547
api_type_name = os.environ.get(CLIENT_API_TYPE_KEY, ClientAPIType.IN_PROCESS_API.value)
4648
api_type = ClientAPIType(api_type_name)
4749
global client_api
48-
if api_type == ClientAPIType.IN_PROCESS_API:
49-
client_api = data_bus.get_data(CLIENT_API_KEY)
50+
if client_api is None:
51+
if api_type == ClientAPIType.IN_PROCESS_API:
52+
client_api = data_bus.get_data(CLIENT_API_KEY)
53+
else:
54+
client_api = ExProcessClientAPI()
55+
client_api.init(rank=rank)
5056
else:
51-
client_api = ExProcessClientAPI()
52-
53-
client_api.init(rank=rank)
57+
logging.warning("Warning: called init() more than once. The subsequence calls are ignored")
5458

5559

5660
def receive(timeout: Optional[float] = None) -> Optional[FLModel]:

nvflare/client/config.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class ConfigKey:
4444
TASK_NAME = "TASK_NAME"
4545
TASK_EXCHANGE = "TASK_EXCHANGE"
4646
METRICS_EXCHANGE = "METRICS_EXCHANGE"
47+
HEARTBEAT_TIMEOUT = "HEARTBEAT_TIMEOUT"
4748

4849

4950
class ClientConfig:
@@ -133,19 +134,19 @@ def get_pipe_class(self, section: str) -> str:
133134
return self.config[section][ConfigKey.PIPE][ConfigKey.CLASS_NAME]
134135

135136
def get_exchange_format(self) -> str:
136-
return self.config[ConfigKey.TASK_EXCHANGE][ConfigKey.EXCHANGE_FORMAT]
137+
return self.config.get(ConfigKey.TASK_EXCHANGE, {}).get(ConfigKey.EXCHANGE_FORMAT, "")
137138

138139
def get_transfer_type(self) -> str:
139140
return self.config.get(ConfigKey.TASK_EXCHANGE, {}).get(ConfigKey.TRANSFER_TYPE, "FULL")
140141

141142
def get_train_task(self):
142-
return self.config[ConfigKey.TASK_EXCHANGE][ConfigKey.TRAIN_TASK_NAME]
143+
return self.config.get(ConfigKey.TASK_EXCHANGE, {}).get(ConfigKey.TRAIN_TASK_NAME, "")
143144

144145
def get_eval_task(self):
145-
return self.config[ConfigKey.TASK_EXCHANGE][ConfigKey.EVAL_TASK_NAME]
146+
return self.config.get(ConfigKey.TASK_EXCHANGE, {}).get(ConfigKey.EVAL_TASK_NAME, "")
146147

147148
def get_submit_model_task(self):
148-
return self.config[ConfigKey.TASK_EXCHANGE][ConfigKey.SUBMIT_MODEL_TASK_NAME]
149+
return self.config.get(ConfigKey.TASK_EXCHANGE, {}).get(ConfigKey.SUBMIT_MODEL_TASK_NAME, "")
149150

150151
def to_json(self, config_file: str):
151152
with open(config_file, "w") as f:

nvflare/client/ex_process/api.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,15 @@
2828
from nvflare.client.model_registry import ModelRegistry
2929
from nvflare.fuel.utils import fobs
3030
from nvflare.fuel.utils.import_utils import optional_import
31+
from nvflare.fuel.utils.obj_utils import get_logger
3132
from nvflare.fuel.utils.pipe.pipe import Pipe
3233

3334

3435
def _create_client_config(config: str) -> ClientConfig:
3536
if isinstance(config, str):
3637
client_config = from_file(config_file=config)
3738
else:
38-
raise ValueError("config should be a string but got: {type(config)}")
39+
raise ValueError(f"config should be a string but got: {type(config)}")
3940
return client_config
4041

4142

@@ -62,6 +63,7 @@ def _register_tensor_decomposer():
6263
class ExProcessClientAPI(APISpec):
6364
def __init__(self):
6465
self.process_model_registry = None
66+
self.logger = get_logger(self)
6567

6668
def get_model_registry(self) -> ModelRegistry:
6769
"""Gets the ModelRegistry."""
@@ -81,20 +83,23 @@ def init(self, rank: Optional[str] = None):
8183
rank = os.environ.get("RANK", "0")
8284

8385
if self.process_model_registry:
84-
print("Warning: called init() more than once. The subsequence calls are ignored")
86+
self.logger.warning("Warning: called init() more than once. The subsequence calls are ignored")
8587
return
8688

87-
client_config = _create_client_config(config=f"config/{CLIENT_API_CONFIG}")
89+
config_file = f"config/{CLIENT_API_CONFIG}"
90+
client_config = _create_client_config(config=config_file)
8891

8992
flare_agent = None
9093
try:
9194
if rank == "0":
9295
if client_config.get_exchange_format() == ExchangeFormat.PYTORCH:
9396
_register_tensor_decomposer()
9497

95-
pipe, task_channel_name = _create_pipe_using_config(
96-
client_config=client_config, section=ConfigKey.TASK_EXCHANGE
97-
)
98+
pipe, task_channel_name = None, ""
99+
if ConfigKey.TASK_EXCHANGE in client_config.config:
100+
pipe, task_channel_name = _create_pipe_using_config(
101+
client_config=client_config, section=ConfigKey.TASK_EXCHANGE
102+
)
98103
metric_pipe, metric_channel_name = None, ""
99104
if ConfigKey.METRICS_EXCHANGE in client_config.config:
100105
metric_pipe, metric_channel_name = _create_pipe_using_config(
@@ -106,12 +111,13 @@ def init(self, rank: Optional[str] = None):
106111
task_channel_name=task_channel_name,
107112
metric_pipe=metric_pipe,
108113
metric_channel_name=metric_channel_name,
114+
heartbeat_timeout=client_config.config.get(ConfigKey.HEARTBEAT_TIMEOUT, 60),
109115
)
110116
flare_agent.start()
111117

112118
self.process_model_registry = ModelRegistry(client_config, rank, flare_agent)
113119
except Exception as e:
114-
print(f"flare.init failed: {e}")
120+
self.logger.error(f"flare.init failed: {e}")
115121
raise e
116122

117123
def receive(self, timeout: Optional[float] = None) -> Optional[FLModel]:

nvflare/client/flare_agent.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,14 @@ def __init__(self, task_id, task_name: str, msg_id):
6464
class FlareAgent:
6565
def __init__(
6666
self,
67-
pipe: Pipe,
67+
pipe: Optional[Pipe] = None,
6868
read_interval=0.1,
6969
heartbeat_interval=5.0,
70-
heartbeat_timeout=30.0,
70+
heartbeat_timeout=60.0,
7171
resend_interval=2.0,
7272
max_resends=None,
73-
submit_result_timeout=30.0,
74-
metric_pipe=None,
73+
submit_result_timeout=60.0,
74+
metric_pipe: Optional[Pipe] = None,
7575
task_channel_name: str = PipeChannelName.TASK,
7676
metric_channel_name: str = PipeChannelName.METRIC,
7777
close_pipe: bool = True,
@@ -103,21 +103,27 @@ def __init__(
103103
Usually for ``FilePipe`` we set to False, for ``CellPipe`` we set to True.
104104
decomposer_module (str): the module name which contains the external decomposers.
105105
"""
106+
if pipe is None and metric_pipe is None:
107+
raise RuntimeError(
108+
"Please configure at least one pipe. Both the task pipe and the metric pipe are set to None."
109+
)
106110
flare_decomposers.register()
107111
common_decomposers.register()
108112
if decomposer_module:
109113
register_ext_decomposers(decomposer_module)
110114

111115
self.logger = logging.getLogger(self.__class__.__name__)
112116
self.pipe = pipe
113-
self.pipe_handler = PipeHandler(
114-
pipe=self.pipe,
115-
read_interval=read_interval,
116-
heartbeat_interval=heartbeat_interval,
117-
heartbeat_timeout=heartbeat_timeout,
118-
resend_interval=resend_interval,
119-
max_resends=max_resends,
120-
)
117+
self.pipe_handler = None
118+
if self.pipe:
119+
self.pipe_handler = PipeHandler(
120+
pipe=self.pipe,
121+
read_interval=read_interval,
122+
heartbeat_interval=heartbeat_interval,
123+
heartbeat_timeout=heartbeat_timeout,
124+
resend_interval=resend_interval,
125+
max_resends=max_resends,
126+
)
121127
self.submit_result_timeout = submit_result_timeout
122128
self.task_channel_name = task_channel_name
123129
self.metric_channel_name = metric_channel_name
@@ -148,14 +154,17 @@ def start(self):
148154
Returns: None
149155
150156
"""
151-
self.pipe.open(self.task_channel_name)
152-
self.pipe_handler.set_status_cb(self._status_cb, pipe_handler=self.pipe_handler, channel=self.task_channel_name)
153-
self.pipe_handler.start()
157+
if self.pipe:
158+
self.pipe.open(self.task_channel_name)
159+
self.pipe_handler.set_status_cb(
160+
self._status_cb, pipe_handler=self.pipe_handler, channel=self.task_channel_name
161+
)
162+
self.pipe_handler.start()
154163

155164
if self.metric_pipe:
156165
self.metric_pipe.open(self.metric_channel_name)
157166
self.metric_pipe_handler.set_status_cb(
158-
self._status_cb, pipe_handler=self.metric_pipe_handler, channel=self.metric_channel_name
167+
self._metrics_status_cb, pipe_handler=self.metric_pipe_handler, channel=self.metric_channel_name
159168
)
160169
self.metric_pipe_handler.start()
161170

@@ -164,6 +173,11 @@ def _status_cb(self, msg: Message, pipe_handler: PipeHandler, channel):
164173
self.asked_to_stop = True
165174
pipe_handler.stop(self._close_pipe)
166175

176+
def _metrics_status_cb(self, msg: Message, pipe_handler: PipeHandler, channel):
177+
self.logger.info(f"{channel} pipe status changed to {msg.topic}: {msg.data}")
178+
self.asked_to_stop = True
179+
pipe_handler.stop(self._close_metric_pipe)
180+
167181
def stop(self):
168182
"""Stop the agent.
169183
@@ -172,9 +186,9 @@ def stop(self):
172186
Returns: None
173187
174188
"""
175-
self.logger.info("Calling flare agent stop")
176189
self.asked_to_stop = True
177-
self.pipe_handler.stop(self._close_pipe)
190+
if self.pipe_handler:
191+
self.pipe_handler.stop(self._close_pipe)
178192
if self.metric_pipe_handler:
179193
self.metric_pipe_handler.stop(self._close_metric_pipe)
180194

@@ -226,6 +240,8 @@ def get_task(self, timeout: Optional[float] = None) -> Optional[Task]:
226240
has been submitted.
227241
228242
"""
243+
if not self.pipe_handler:
244+
raise RuntimeError("task pipe is not available")
229245
start_time = time.time()
230246
while True:
231247
if self.asked_to_stop:
@@ -278,6 +294,8 @@ def submit_result(self, result, rc=RC.OK) -> bool:
278294
made a single time regardless whether the submission is successful.
279295
280296
"""
297+
if not self.pipe_handler:
298+
raise RuntimeError("task pipe is not available")
281299
with self.task_lock:
282300
current_task = self.current_task
283301
if not current_task:

0 commit comments

Comments
 (0)