Skip to content

Commit 4b32f27

Browse files
yhwenchesterxgchenSYangster
authored
Added id to the jobAPI swarm_script_executor_cifar10 component deploy (#2678)
* Added id to the swarm_script_executor_cifar10 component deploy. * codestyle fix. * Changed to use job.as_id(). * codestyle fix. * changed to use job.as_id(shareable_generator) for shareable_generator_id. * removed the un-necessary job.to() calls. --------- Co-authored-by: Chester Chen <[email protected]> Co-authored-by: Sean Yang <[email protected]>
1 parent 5c63229 commit 4b32f27

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

examples/getting_started/pt/swarm_script_executor_cifar10.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,19 +47,22 @@
4747
executor = ScriptExecutor(task_script_path=train_script)
4848
job.to(executor, f"site-{i}", gpu=0, tasks=["train", "validate", "submit_model"])
4949

50-
client_controller = SwarmClientController()
51-
job.to(client_controller, f"site-{i}", tasks=["swarm_*"])
52-
53-
client_controller = CrossSiteEvalClientController()
54-
job.to(client_controller, f"site-{i}", tasks=["cse_*"])
55-
5650
# In swarm learning, each client acts also as an aggregator
5751
aggregator = InTimeAccumulateWeightedAggregator(expected_data_kind=DataKind.WEIGHTS)
58-
job.to(aggregator, f"site-{i}")
5952

6053
# In swarm learning, each client uses a model persistor and shareable_generator
61-
job.to(PTFileModelPersistor(model=Net()), f"site-{i}")
62-
job.to(SimpleModelShareableGenerator(), f"site-{i}")
54+
persistor = PTFileModelPersistor(model=Net())
55+
shareable_generator = SimpleModelShareableGenerator()
56+
57+
client_controller = SwarmClientController(
58+
aggregator_id=job.as_id(aggregator),
59+
persistor_id=job.as_id(persistor),
60+
shareable_generator_id=job.as_id(shareable_generator),
61+
)
62+
job.to(client_controller, f"site-{i}", tasks=["swarm_*"])
63+
64+
client_controller = CrossSiteEvalClientController()
65+
job.to(client_controller, f"site-{i}", tasks=["cse_*"])
6366

6467
# job.export_job("/tmp/nvflare/jobs/job_config")
6568
job.simulator_run("/tmp/nvflare/jobs/workdir")

0 commit comments

Comments
 (0)