Skip to content

Commit 426f252

Browse files
[2.5] Add TF based TBAnalyticsReceiver (#3035)
* Remove TBReceiver from tf job api * Add tf based receiver
1 parent 68c7eb8 commit 426f252

File tree

2 files changed

+147
-1
lines changed

2 files changed

+147
-1
lines changed

nvflare/app_opt/tf/job_config/base_fed_job.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from nvflare.app_common.widgets.streaming import AnalyticsReceiver
2424
from nvflare.app_common.widgets.validation_json_generator import ValidationJsonGenerator
2525
from nvflare.app_opt.tf.job_config.model import TFModel
26-
from nvflare.app_opt.tracking.tb.tb_receiver import TBAnalyticsReceiver
26+
from nvflare.app_opt.tf.tb_receiver import TBAnalyticsReceiver
2727
from nvflare.job_config.api import FedJob, validate_object_for_job
2828

2929

nvflare/app_opt/tf/tb_receiver.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
from typing import List, Optional
17+
18+
import tensorflow as tf
19+
20+
from nvflare.apis.analytix import AnalyticsData, AnalyticsDataType
21+
from nvflare.apis.dxo import from_shareable
22+
from nvflare.apis.fl_context import FLContext
23+
from nvflare.apis.shareable import Shareable
24+
from nvflare.app_common.widgets.streaming import AnalyticsReceiver
25+
26+
27+
def _create_new_data(key, value, sender):
28+
if isinstance(value, (int, float)):
29+
data_type = AnalyticsDataType.SCALAR
30+
elif isinstance(value, str):
31+
data_type = AnalyticsDataType.TEXT
32+
else:
33+
return None
34+
35+
return AnalyticsData(key=key, value=value, data_type=data_type, sender=sender)
36+
37+
38+
class TBAnalyticsReceiver(AnalyticsReceiver):
39+
def __init__(self, tb_folder="tb_events", events: Optional[List[str]] = None):
40+
"""Receives analytics data to save to TensorBoard.
41+
42+
Args:
43+
tb_folder (str): the folder to store tensorboard files.
44+
events (optional, List[str]): A list of events to be handled by this receiver.
45+
46+
.. code-block:: text
47+
:caption: Folder structure
48+
49+
Inside run_XX folder:
50+
- workspace
51+
- run_01 (already created):
52+
- output_dir (default: tb_events):
53+
- peer_name_1:
54+
- peer_name_2:
55+
56+
- run_02 (already created):
57+
- output_dir (default: tb_events):
58+
- peer_name_1:
59+
- peer_name_2:
60+
61+
"""
62+
super().__init__(events=events)
63+
self.writers_table = {}
64+
self.tb_folder = tb_folder
65+
self.root_log_dir = None
66+
67+
def initialize(self, fl_ctx: FLContext):
68+
workspace = fl_ctx.get_engine().get_workspace()
69+
run_dir = workspace.get_run_dir(fl_ctx.get_job_id())
70+
root_log_dir = os.path.join(run_dir, self.tb_folder)
71+
os.makedirs(root_log_dir, exist_ok=True)
72+
self.root_log_dir = root_log_dir
73+
self.log_info(
74+
fl_ctx,
75+
f"Tensorboard records can be found in {self.root_log_dir} you can view it using `tensorboard --logdir={self.root_log_dir}`",
76+
)
77+
78+
def _convert_to_records(self, analytic_data: AnalyticsData, fl_ctx: FLContext) -> List[AnalyticsData]:
79+
# break dict of stuff to smaller items to support
80+
# AnalyticsDataType.PARAMETER and AnalyticsDataType.PARAMETERS
81+
records = []
82+
83+
if analytic_data.data_type in (AnalyticsDataType.PARAMETER, AnalyticsDataType.PARAMETERS):
84+
items = (
85+
analytic_data.value.items()
86+
if analytic_data.data_type == AnalyticsDataType.PARAMETERS
87+
else [(analytic_data.tag, analytic_data.value)]
88+
)
89+
for k, v in items:
90+
new_data = _create_new_data(k, v, analytic_data.sender)
91+
if new_data is None:
92+
self.log_warning(fl_ctx, f"Entry {k} of type {type(v)} is not supported.", fire_event=False)
93+
else:
94+
records.append(new_data)
95+
elif analytic_data.data_type in (AnalyticsDataType.SCALARS, AnalyticsDataType.METRICS):
96+
data_type = (
97+
AnalyticsDataType.SCALAR
98+
if analytic_data.data_type == AnalyticsDataType.SCALARS
99+
else AnalyticsDataType.METRIC
100+
)
101+
records.extend(
102+
AnalyticsData(key=k, value=v, data_type=data_type, sender=analytic_data.sender)
103+
for k, v in analytic_data.value.items()
104+
)
105+
else:
106+
records.append(analytic_data)
107+
108+
return records
109+
110+
def save(self, fl_ctx: FLContext, shareable: Shareable, record_origin):
111+
dxo = from_shareable(shareable)
112+
analytic_data = AnalyticsData.from_dxo(dxo)
113+
if not analytic_data:
114+
return
115+
116+
writer = self.writers_table.get(record_origin)
117+
if writer is None:
118+
peer_log_dir = os.path.join(self.root_log_dir, record_origin)
119+
writer = tf.summary.create_file_writer(peer_log_dir)
120+
self.writers_table[record_origin] = writer
121+
122+
# do different things depending on the type in dxo
123+
self.log_info(
124+
fl_ctx,
125+
f"try to save data {analytic_data} from {record_origin}",
126+
fire_event=False,
127+
)
128+
129+
data_records = self._convert_to_records(analytic_data, fl_ctx)
130+
131+
with writer.as_default():
132+
for data_record in data_records:
133+
if data_record.data_type in (AnalyticsDataType.METRIC, AnalyticsDataType.SCALAR):
134+
tf.summary.scalar(data_record.tag, data_record.value, data_record.step)
135+
elif data_record.data_type == AnalyticsDataType.TEXT:
136+
tf.summary.text(data_record.tag, data_record.value, data_record.step)
137+
elif data_record.data_type == AnalyticsDataType.IMAGE:
138+
tf.summary.image(data_record.tag, data_record.value, data_record.step)
139+
else:
140+
self.log_warning(
141+
fl_ctx, f"The data_type {data_record.data_type} is not supported.", fire_event=False
142+
)
143+
144+
def finalize(self, fl_ctx: FLContext):
145+
for writer in self.writers_table.values():
146+
tf.summary.flush(writer)

0 commit comments

Comments
 (0)