Skip to content

Commit d66d14e

Browse files
Enhance recipe experiment tracking (#3655)
For MLFlowReceiver, it will be good to include job name. recipe.job.to_server(receiver, "receiver") this one is hardcoded for now to be consistent with BaseFedJob, in next release we need to fix this as well. ### Description MLFlowReceiver: - If no run name is provided, include the job name for the default run name. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Quick tests passed locally by running `./runtest.sh`. - [ ] In-line docstrings updated. - [ ] Documentation updated. --------- Co-authored-by: Copilot <[email protected]>
1 parent cf90640 commit d66d14e

File tree

3 files changed

+21
-6
lines changed

3 files changed

+21
-6
lines changed

examples/hello-world/hello-tf/client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616
from model import Net
1717

1818
import nvflare.client as flare
19+
from nvflare.client.tracking import SummaryWriter
1920

2021
WEIGHTS_PATH = "./tf_model.weights.h5"
2122

2223

2324
def main():
2425
flare.init()
26+
writer = SummaryWriter()
2527

2628
sys_info = flare.system_info()
2729
print(f"system info is: {sys_info}", flush=True)
@@ -69,6 +71,7 @@ def main():
6971
print(
7072
f"Accuracy of the received model on round {input_model.current_round} on the test images: {test_global_acc * 100} %"
7173
)
74+
writer.add_scalar(tag="local_acc", scalar=test_global_acc)
7275

7376
# training
7477
model.fit(train_images, train_labels, epochs=1, validation_data=(test_images, test_labels))

nvflare/app_opt/tracking/mlflow/mlflow_receiver.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@
2323

2424
from nvflare.apis.analytix import ANALYTIC_EVENT_TYPE, AnalyticsData, AnalyticsDataType, LogWriterName, TrackConst
2525
from nvflare.apis.dxo import from_shareable
26-
from nvflare.apis.fl_constant import ProcessType
26+
from nvflare.apis.fl_constant import ProcessType, ReservedKey
2727
from nvflare.apis.fl_context import FLContext
28+
from nvflare.apis.job_def import JobMetaKey
2829
from nvflare.apis.shareable import Shareable
2930
from nvflare.app_common.widgets.streaming import AnalyticsReceiver
3031

@@ -41,6 +42,14 @@ def get_current_time_millis():
4142
return int(round(time.time() * 1000))
4243

4344

45+
def _get_job_name_from_fl_ctx(fl_ctx: FLContext, default=None):
46+
# TODO: it might be good to have a function in fl_context to get the job name
47+
job_meta = fl_ctx.get_prop(ReservedKey.JOB_META)
48+
if job_meta and isinstance(job_meta, dict):
49+
return job_meta.get(JobMetaKey.JOB_NAME, default)
50+
return default
51+
52+
4453
class MLflowReceiver(AnalyticsReceiver):
4554
def __init__(
4655
self,
@@ -73,14 +82,15 @@ def __init__(
7382
less delay. Keep in mind that reducing the buffer_flush_time will potentially cause high
7483
traffic to the MLflow tracking server, which in some cases can actually cause more latency.
7584
"""
85+
if not isinstance(tracking_uri, (str, type(None))):
86+
raise ValueError("tracking_uri needs to be either None or str")
7687
if events is None:
7788
events = ["fed." + ANALYTIC_EVENT_TYPE]
7889
super().__init__(events=events)
7990
self.artifact_location = artifact_location if artifact_location is not None else "artifacts"
8091

8192
self.kw_args = kw_args if kw_args else {}
8293
self.tracking_uri = tracking_uri
83-
self.mlflow = mlflow
8494
self.mlflow_clients: Dict[str, MlflowClient] = {}
8595
self.experiment_id = None
8696
self.run_ids = {}
@@ -164,8 +174,9 @@ def _mlflow_setup(self, art_full_path, experiment_name, experiment_tags, site_na
164174
)
165175

166176
job_id_tag = self._get_job_id_tag(fl_ctx)
177+
job_name = _get_job_name_from_fl_ctx(fl_ctx)
167178

168-
run_name = self._get_run_name(self.kw_args, site_name, job_id_tag)
179+
run_name = self._get_run_name(self.kw_args, site_name, job_id_tag, job_name)
169180
tags = self._get_run_tags(self.kw_args, job_id_tag, run_name)
170181
run = mlflow_client.create_run(experiment_id=self.experiment_id, run_name=run_name, tags=tags)
171182
self.run_ids[site_name] = run.info.run_id
@@ -179,9 +190,10 @@ def _init_buffer(self, site_names: List[str]):
179190
AnalyticsDataType.TAGS: [],
180191
}
181192

182-
def _get_run_name(self, kwargs: dict, site_name: str, job_id_tag: str):
193+
def _get_run_name(self, kwargs: dict, site_name: str, job_id_tag: str, job_name: str):
183194
run_name = kwargs.get(TrackConst.RUN_NAME, DEFAULT_RUN_NAME)
184-
return f"{site_name}-{job_id_tag[:6]}-{run_name}"
195+
job_name_str = job_name if job_name is not None else "unknown_job"
196+
return f"{site_name}-{job_id_tag[:6]}-{job_name_str}-{run_name}"
185197

186198
def _get_run_tags(self, kwargs, job_id_tag: str, run_name: str):
187199
run_tags = self._get_tags(TrackConst.RUN_TAGS, kwargs=kwargs)

nvflare/recipe/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,4 @@ def add_experiment_tracking(recipe: Recipe, tracking_type: str, tracking_config:
5656
module = importlib.import_module(TRACKING_REGISTRY[tracking_type]["receiver_module"])
5757
receiver_class = getattr(module, TRACKING_REGISTRY[tracking_type]["receiver_class"])
5858
receiver = receiver_class(**tracking_config)
59-
recipe.job.to_server(receiver)
59+
recipe.job.to_server(receiver, "receiver")

0 commit comments

Comments
 (0)