Skip to content

Commit f766f90

Browse files
authored
fix job api examples (#2823)
1 parent 16e0c27 commit f766f90

15 files changed

+144
-61
lines changed

examples/advanced/job_api/pt/fedavg_model_learner_xsite_val_cifar10.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,9 @@
2222
from pt.utils.cifar10_data_splitter import Cifar10DataSplitter
2323
from pt.utils.cifar10_data_utils import load_cifar10_data
2424

25-
from nvflare import FedJob
2625
from nvflare.app_common.executors.model_learner_executor import ModelLearnerExecutor
2726
from nvflare.app_common.workflows.cross_site_model_eval import CrossSiteModelEval
28-
from nvflare.app_common.workflows.fedavg import FedAvg
29-
from nvflare.app_opt.pt.job_config.model import PTModel
27+
from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob
3028

3129
if __name__ == "__main__":
3230
n_clients = 2
@@ -35,13 +33,9 @@
3533
alpha = 0.1
3634
train_split_root = f"/tmp/cifar10_splits/clients{n_clients}_alpha{alpha}" # avoid overwriting results
3735

38-
job = FedJob(name="cifar10_fedavg")
36+
job = FedAvgJob(name="cifar10_fedavg", n_clients=n_clients, num_rounds=num_rounds, initial_model=ModerateCNN())
3937

40-
ctrl1 = FedAvg(
41-
num_clients=n_clients,
42-
num_rounds=num_rounds,
43-
)
44-
ctrl2 = CrossSiteModelEval()
38+
ctrl = CrossSiteModelEval()
4539

4640
load_cifar10_data() # preload CIFAR10 data
4741
data_splitter = Cifar10DataSplitter(
@@ -50,17 +44,18 @@
5044
alpha=alpha,
5145
)
5246

53-
job.to(ctrl1, "server")
54-
job.to(ctrl2, "server")
47+
job.to(ctrl, "server")
5548
job.to(data_splitter, "server")
5649

57-
# Define the initial global model and send to server
58-
job.to(PTModel(ModerateCNN()), "server")
59-
6050
for i in range(n_clients):
61-
learner = CIFAR10ModelLearner(train_idx_root=train_split_root, aggregation_epochs=aggregation_epochs, lr=0.01)
62-
executor = ModelLearnerExecutor(learner_id=job.as_id(learner))
63-
job.to(executor, f"site-{i+1}") # data splitter assumes client names start from 1
51+
site_name = f"site-{i+1}"
52+
learner_id = job.to(
53+
CIFAR10ModelLearner(train_idx_root=train_split_root, aggregation_epochs=aggregation_epochs, lr=0.01),
54+
site_name,
55+
id="learner",
56+
)
57+
executor = ModelLearnerExecutor(learner_id=learner_id)
58+
job.to(executor, site_name) # data splitter assumes client names start from 1
6459

6560
# job.export_job("/tmp/nvflare/jobs/job_config")
6661
job.simulator_run("/tmp/nvflare/jobs/workdir", gpu="0")

examples/advanced/job_api/pt/fedavg_script_runner_cifar10.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from src.net import Net
1616

17+
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
1718
from nvflare.app_common.workflows.fedavg import FedAvg
1819
from nvflare.app_opt.pt.job_config.model import PTModel
1920

@@ -39,6 +40,8 @@
3940
# Define the initial global model and send to server
4041
job.to(PTModel(Net()), "server")
4142

43+
job.to(IntimeModelSelector(key_metric="accuracy"), "server")
44+
4245
# Note: We can optionally replace the above code with the FedAvgJob, which is a pattern to simplify FedAvg job creations
4346
# job = FedAvgJob(name="cifar10_fedavg", num_rounds=num_rounds, n_clients=n_clients, initial_model=Net())
4447

examples/advanced/job_api/pt/fedavg_script_runner_dp_filter_cifar10.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,13 @@
2727
job = FedAvgJob(name="cifar10_fedavg_privacy", num_rounds=num_rounds, n_clients=n_clients, initial_model=Net())
2828

2929
for i in range(n_clients):
30+
site_name = f"site-{i}"
3031
executor = ScriptRunner(script=train_script, script_args="")
31-
job.to(executor, f"site-{i}", tasks=["train"])
32+
job.to(executor, site_name, tasks=["train"])
3233

3334
# add privacy filter.
3435
pp_filter = PercentilePrivacy(percentile=10, gamma=0.01)
35-
job.to(pp_filter, f"site-{i}", tasks=["train"], filter_type=FilterType.TASK_RESULT)
36+
job.to(pp_filter, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT)
3637

3738
# job.export_job("/tmp/nvflare/jobs/job_config")
3839
job.simulator_run("/tmp/nvflare/jobs/workdir", gpu="0")

examples/advanced/job_api/sklearn/kmeans_script_runner_higgs.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from nvflare import FedJob
2222
from nvflare.app_common.aggregators.collect_and_assemble_aggregator import CollectAndAssembleAggregator
2323
from nvflare.app_common.shareablegenerators.full_model_shareable_generator import FullModelShareableGenerator
24+
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
2425
from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather
2526
from nvflare.app_opt.sklearn.joblib_model_param_persistor import JoblibModelParamPersistor
2627
from nvflare.client.config import ExchangeFormat
@@ -117,24 +118,29 @@ def split_higgs(input_data_path, input_header_path, output_dir, site_num, sample
117118
# ScatterAndGather also expects an "aggregator" which we define here.
118119
# The actual aggregation function is defined by an "assembler" to specify how to handle the collected updates.
119120
# We use KMeansAssembler which is the assembler designed for k-Means algorithm.
120-
aggregator = CollectAndAssembleAggregator(assembler_id=job.as_id(KMeansAssembler()))
121+
assembler_id = job.to_server(KMeansAssembler(), id="assembler")
122+
aggregator_id = job.to_server(CollectAndAssembleAggregator(assembler_id=assembler_id), id="aggregator")
121123

122124
# For kmeans with sklean, we need a custom persistor
123125
# JoblibModelParamPersistor is a persistor which save/read the model to/from file with JobLib format.
124-
persistor = JoblibModelParamPersistor(initial_params={"n_clusters": 2})
126+
persistor_id = job.to_server(JoblibModelParamPersistor(initial_params={"n_clusters": 2}), id="persistor")
127+
128+
shareable_generator_id = job.to_server(FullModelShareableGenerator(), id="shareable_generator")
125129

126130
controller = ScatterAndGather(
127131
min_clients=n_clients,
128132
num_rounds=num_rounds,
129133
wait_time_after_min_received=0,
130-
aggregator_id=job.as_id(aggregator),
131-
persistor_id=job.as_id(persistor),
132-
shareable_generator_id=job.as_id(FullModelShareableGenerator()),
134+
aggregator_id=aggregator_id,
135+
persistor_id=persistor_id,
136+
shareable_generator_id=shareable_generator_id,
133137
train_task_name="train", # Client will start training once received such task.
134138
train_timeout=0,
135139
)
136140
job.to(controller, "server")
137141

142+
job.to(IntimeModelSelector(key_metric="accuracy"), "server")
143+
138144
# Add clients
139145
for i in range(n_clients):
140146
executor = ScriptRunner(

examples/advanced/job_api/tf/tf_fl_script_runner_cifar10.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from src.tf_net import ModerateTFNet
2222

2323
from nvflare import FedJob
24+
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
2425
from nvflare.app_opt.tf.job_config.model import TFModel
2526
from nvflare.job_config.script_runner import ScriptRunner
2627

@@ -153,6 +154,8 @@
153154
# Define the initial global model and send to server
154155
job.to(TFModel(ModerateTFNet(input_shape=(None, 32, 32, 3))), "server")
155156

157+
job.to(IntimeModelSelector(key_metric="accuracy"), "server")
158+
156159
# Add clients
157160
for i, train_idx_path in enumerate(train_idx_paths):
158161
curr_task_script_args = task_script_args + f" --train_idx_path {train_idx_path}"

examples/getting_started/pt/nvflare_lightning_getting_started.ipynb

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,27 @@
383383
"job.to(PTModel(LitNet()), \"server\")"
384384
]
385385
},
386+
{
387+
"cell_type": "markdown",
388+
"id": "72eefb39",
389+
"metadata": {},
390+
"source": [
391+
"#### 5. Add ModelSelector\n",
392+
"Add IntimeModelSelector for global best model selection."
393+
]
394+
},
395+
{
396+
"cell_type": "code",
397+
"execution_count": null,
398+
"id": "091beb78",
399+
"metadata": {},
400+
"outputs": [],
401+
"source": [
402+
"from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector\n",
403+
"\n",
404+
"job.to(IntimeModelSelector(key_metric=\"accuracy\"), \"server\")"
405+
]
406+
},
386407
{
387408
"cell_type": "markdown",
388409
"id": "77f5bc7f-4fb4-46e9-8f02-5e7245d95070",
@@ -397,7 +418,7 @@
397418
"metadata": {},
398419
"source": [
399420
"#### 5. Add clients\n",
400-
"Next, we can use the `ScriptExecutor` and send it to each of the clients to run our training script.\n",
421+
"Next, we can use the `ScriptRunner` and send it to each of the clients to run our training script.\n",
401422
"\n",
402423
"Note that our script could have additional input arguments, such as batch size or data path, but we don't use them here for simplicity.\n",
403424
"We can also specify, which GPU should be used to run this client, which is helpful for simulated environments."

examples/getting_started/pt/nvflare_pt_getting_started.ipynb

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,27 @@
325325
"job.to(PTModel(Net()), \"server\")"
326326
]
327327
},
328+
{
329+
"cell_type": "markdown",
330+
"id": "eccae908",
331+
"metadata": {},
332+
"source": [
333+
"#### 5. Add ModelSelector\n",
334+
"Add IntimeModelSelector for global best model selection."
335+
]
336+
},
337+
{
338+
"cell_type": "code",
339+
"execution_count": null,
340+
"id": "d52dd194",
341+
"metadata": {},
342+
"outputs": [],
343+
"source": [
344+
"from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector\n",
345+
"\n",
346+
"job.to(IntimeModelSelector(key_metric=\"accuracy\"), \"server\")"
347+
]
348+
},
328349
{
329350
"cell_type": "markdown",
330351
"id": "77f5bc7f-4fb4-46e9-8f02-5e7245d95070",

examples/getting_started/sklearn/kmeans_script_runner_higgs.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from nvflare import FedJob
2222
from nvflare.app_common.aggregators.collect_and_assemble_aggregator import CollectAndAssembleAggregator
2323
from nvflare.app_common.shareablegenerators.full_model_shareable_generator import FullModelShareableGenerator
24+
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
2425
from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather
2526
from nvflare.app_opt.sklearn.joblib_model_param_persistor import JoblibModelParamPersistor
2627
from nvflare.client.config import ExchangeFormat
@@ -117,24 +118,29 @@ def split_higgs(input_data_path, input_header_path, output_dir, site_num, sample
117118
# ScatterAndGather also expects an "aggregator" which we define here.
118119
# The actual aggregation function is defined by an "assembler" to specify how to handle the collected updates.
119120
# We use KMeansAssembler which is the assembler designed for k-Means algorithm.
120-
aggregator = CollectAndAssembleAggregator(assembler_id=job.as_id(KMeansAssembler()))
121+
assembler_id = job.to_server(KMeansAssembler(), id="assembler")
122+
aggregator_id = job.to_server(CollectAndAssembleAggregator(assembler_id=assembler_id), id="aggregator")
121123

122124
# For kmeans with sklean, we need a custom persistor
123125
# JoblibModelParamPersistor is a persistor which save/read the model to/from file with JobLib format.
124-
persistor = JoblibModelParamPersistor(initial_params={"n_clusters": 2})
126+
persistor_id = job.to_server(JoblibModelParamPersistor(initial_params={"n_clusters": 2}), id="persistor")
127+
128+
shareable_generator_id = job.to_server(FullModelShareableGenerator(), id="shareable_generator")
125129

126130
controller = ScatterAndGather(
127131
min_clients=n_clients,
128132
num_rounds=num_rounds,
129133
wait_time_after_min_received=0,
130-
aggregator_id=job.as_id(aggregator),
131-
persistor_id=job.as_id(persistor),
132-
shareable_generator_id=job.as_id(FullModelShareableGenerator()),
134+
aggregator_id=aggregator_id,
135+
persistor_id=persistor_id,
136+
shareable_generator_id=shareable_generator_id,
133137
train_task_name="train", # Client will start training once received such task.
134138
train_timeout=0,
135139
)
136140
job.to(controller, "server")
137141

142+
job.to(IntimeModelSelector(key_metric="accuracy"), "server")
143+
138144
# Add clients
139145
for i in range(n_clients):
140146
executor = ScriptRunner(

examples/getting_started/tf/nvflare_tf_getting_started.ipynb

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,27 @@
315315
"job.to(TFModel(TFNet()), \"server\")"
316316
]
317317
},
318+
{
319+
"cell_type": "markdown",
320+
"id": "25c6eada",
321+
"metadata": {},
322+
"source": [
323+
"#### 5. Add ModelSelector\n",
324+
"Add IntimeModelSelector for global best model selection."
325+
]
326+
},
327+
{
328+
"cell_type": "code",
329+
"execution_count": null,
330+
"id": "0ae73e50",
331+
"metadata": {},
332+
"outputs": [],
333+
"source": [
334+
"from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector\n",
335+
"\n",
336+
"job.to(IntimeModelSelector(key_metric=\"accuracy\"), \"server\")"
337+
]
338+
},
318339
{
319340
"cell_type": "markdown",
320341
"id": "77f5bc7f-4fb4-46e9-8f02-5e7245d95070",
@@ -391,7 +412,7 @@
391412
},
392413
"outputs": [],
393414
"source": [
394-
"job.simulator_run(\"/tmp/nvflare/jobs/workdir\", , gpu=\"0\")"
415+
"job.simulator_run(\"/tmp/nvflare/jobs/workdir\", gpu=\"0\")"
395416
]
396417
},
397418
{

examples/getting_started/tf/tf_fl_script_runner_cifar10.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from src.tf_net import ModerateTFNet
2222

2323
from nvflare import FedJob
24+
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
2425
from nvflare.app_opt.tf.job_config.model import TFModel
2526
from nvflare.job_config.script_runner import ScriptRunner
2627

@@ -153,6 +154,8 @@
153154
# Define the initial global model and send to server
154155
job.to(TFModel(ModerateTFNet(input_shape=(None, 32, 32, 3))), "server")
155156

157+
job.to(IntimeModelSelector(key_metric="accuracy"), "server")
158+
156159
# Add clients
157160
for i, train_idx_path in enumerate(train_idx_paths):
158161
curr_task_script_args = task_script_args + f" --train_idx_path {train_idx_path}"

0 commit comments

Comments
 (0)