Skip to content

Commit ba391aa

Browse files
Update flower examples (#2871)
1 parent 7a843fb commit ba391aa

File tree

8 files changed

+38
-31
lines changed

8 files changed

+38
-31
lines changed

examples/hello-world/hello-flower/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,19 @@ pip install ./flwr-pt/
3333

3434
Next, we run 2 Flower clients and Flower Server in parallel using NVFlare's simulator.
3535
```bash
36-
python job.py
36+
python job.py --job_name "flwr-pt" --content_dir "./flwr-pt"
3737
```
3838

3939
## 2.2 Run a simulation with TensorBoard streaming
4040

4141
To run flwr-pt_tb_streaming job with NVFlare, we first need to install its dependencies.
4242
```bash
43-
pip install ./flwr-pt-metrics/
43+
pip install ./flwr-pt-tb/
4444
```
4545

4646
Next, we run 2 Flower clients and Flower Server in parallel using NVFlare while streaming
4747
the TensorBoard metrics to the server at each iteration using NVFlare's metric streaming.
4848

4949
```bash
50-
python job_with_metric.py
50+
python job.py --job_name "flwr-pt-tb" --content_dir "./flwr-pt-tb" --stream_metrics --use_client_api
5151
```

examples/hello-world/hello-flower/flwr-pt-metrics/pyproject.toml renamed to examples/hello-world/hello-flower/flwr-pt-tb/pyproject.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ requires = ["hatchling"]
33
build-backend = "hatchling.build"
44

55
[project]
6-
name = "flwr_pt_tb_streaming"
6+
name = "flwr_pt_tb"
77
version = "1.0.0"
88
description = ""
99
license = "Apache-2.0"
@@ -12,6 +12,7 @@ dependencies = [
1212
"nvflare~=2.5.0rc",
1313
"torch==2.2.1",
1414
"torchvision==0.17.1",
15+
"tensorboard"
1516
]
1617

1718
[tool.hatch.build.targets.wheel]
@@ -21,8 +22,8 @@ packages = ["."]
2122
publisher = "nvidia"
2223

2324
[tool.flwr.app.components]
24-
serverapp = "flwr_pt_tb_streaming.server:app"
25-
clientapp = "flwr_pt_tb_streaming.client:app"
25+
serverapp = "flwr_pt_tb.server:app"
26+
clientapp = "flwr_pt_tb.client:app"
2627

2728
[tool.flwr.app.config]
2829
num-server-rounds = 3

examples/hello-world/hello-flower/job.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,37 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from argparse import ArgumentParser
16+
1517
from nvflare.app_opt.flower.flower_job import FlowerJob
18+
from nvflare.client.api import ClientAPIType
19+
from nvflare.client.api_spec import CLIENT_API_TYPE_KEY
1620

17-
if __name__ == "__main__":
18-
job = FlowerJob(name="flwr-pt", flower_content="./flwr-pt")
1921

20-
job.export_job("jobs")
21-
job.simulator_run("/tmp/nvflare/flwr-pt", gpu="0", n_clients=2)
22+
def main():
23+
parser = ArgumentParser()
24+
parser.add_argument("--job_name", type=str, required=True)
25+
parser.add_argument("--content_dir", type=str, required=True)
26+
parser.add_argument("--stream_metrics", action="store_true")
27+
parser.add_argument("--use_client_api", action="store_true")
28+
parser.add_argument("--export_dir", type=str, default="jobs")
29+
parser.add_argument("--workdir", type=str, default="/tmp/nvflare/hello-flower")
30+
args = parser.parse_args()
31+
32+
env = {}
33+
if args.use_client_api:
34+
env = {CLIENT_API_TYPE_KEY: ClientAPIType.EX_PROCESS_API.value}
35+
36+
job = FlowerJob(
37+
name=args.job_name,
38+
flower_content=args.content_dir,
39+
stream_metrics=args.stream_metrics,
40+
extra_env=env,
41+
)
42+
43+
job.export_job(args.export_dir)
44+
job.simulator_run(args.workdir, gpu="0", n_clients=2)
45+
46+
47+
if __name__ == "__main__":
48+
main()

examples/hello-world/hello-flower/job_with_metric.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

0 commit comments

Comments
 (0)