Skip to content

Commit f4bd1b4

Browse files
Add he support to pt params converter (#2238)
* Add he support to pt params converter * change to path, add condition --------- Co-authored-by: Chester Chen <[email protected]>
1 parent 2ad8da2 commit f4bd1b4

File tree

6 files changed

+289
-1
lines changed

6 files changed

+289
-1
lines changed
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
{
2+
# version of the configuration
3+
format_version = 2
4+
5+
# This is the application script which will be invoked. Client can replace this script with user's own training script.
6+
app_script = "train.py"
7+
8+
# Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx.
9+
app_config = ""
10+
11+
# Client Computing Executors.
12+
executors = [
13+
{
14+
# tasks the executors are defined to handle
15+
tasks = ["train"]
16+
17+
# This particular executor
18+
executor {
19+
20+
# This is an executor for Client API. The underline data exchange is using Pipe.
21+
path = "nvflare.app_opt.pt.client_api_launcher_executor.PTClientAPILauncherExecutor"
22+
23+
args {
24+
# launcher_id is used to locate the Launcher object in "components"
25+
launcher_id = "launcher"
26+
27+
# pipe_id is used to locate the Pipe object in "components"
28+
pipe_id = "pipe"
29+
30+
# Timeout in seconds for waiting for a heartbeat from the training script. Defaults to 30 seconds.
31+
# Please refer to the class docstring for all available arguments
32+
heartbeat_timeout = 60
33+
34+
# format of the exchange parameters
35+
params_exchange_format = "pytorch"
36+
37+
# if the transfer_type is FULL, then it will be sent directly
38+
# if the transfer_type is DIFF, then we will calculate the
39+
# difference VS received parameters and send the difference
40+
params_transfer_type = "DIFF"
41+
42+
# if train_with_evaluation is true, the executor will expect
43+
# the custom code need to send back both the trained parameters and the evaluation metric
44+
# otherwise only trained parameters are expected
45+
train_with_evaluation = true
46+
}
47+
}
48+
}
49+
],
50+
51+
task_data_filters = [
52+
{
53+
tasks = ["train"]
54+
filters = [
55+
{
56+
path = "nvflare.app_opt.he.model_decryptor.HEModelDecryptor"
57+
args {
58+
}
59+
}
60+
]
61+
}
62+
]
63+
task_result_filters = [
64+
{
65+
tasks = ["train"]
66+
filters = [
67+
{
68+
path = "nvflare.app_opt.he.model_encryptor.HEModelEncryptor"
69+
args {
70+
weigh_by_local_iter = true
71+
}
72+
}
73+
]
74+
},
75+
]
76+
77+
components = [
78+
{
79+
# component id is "launcher"
80+
id = "launcher"
81+
82+
# the class path of this component
83+
path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher"
84+
85+
args {
86+
# the launcher will invoke the script
87+
script = "python3 custom/{app_script} {app_config} "
88+
# if launch_once is true, the SubprocessLauncher will launch once for the whole job
89+
# if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server
90+
launch_once = true
91+
}
92+
}
93+
{
94+
id = "pipe"
95+
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe"
96+
args {
97+
mode = "PASSIVE"
98+
site_name = "{SITE_NAME}"
99+
token = "{JOB_ID}"
100+
root_url = "{ROOT_URL}"
101+
secure_mode = "{SECURE_MODE}"
102+
workspace_dir = "{WORKSPACE}"
103+
}
104+
}
105+
{
106+
id = "metrics_pipe"
107+
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe"
108+
args {
109+
mode = "PASSIVE"
110+
site_name = "{SITE_NAME}"
111+
token = "{JOB_ID}"
112+
root_url = "{ROOT_URL}"
113+
secure_mode = "{SECURE_MODE}"
114+
workspace_dir = "{WORKSPACE}"
115+
}
116+
},
117+
{
118+
id = "metric_relay"
119+
path = "nvflare.app_common.widgets.metric_relay.MetricRelay"
120+
args {
121+
pipe_id = "metrics_pipe"
122+
event_type = "fed.analytix_log_stats"
123+
# how fast should it read from the peer
124+
read_interval = 0.1
125+
}
126+
},
127+
{
128+
# we use this component so the client api `flare.init()` can get required information
129+
id = "config_preparer"
130+
path = "nvflare.app_common.widgets.external_configurator.ExternalConfigurator"
131+
args {
132+
component_ids = ["metric_relay"]
133+
}
134+
}
135+
]
136+
}
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
{
2+
# version of the configuration
3+
format_version = 2
4+
5+
# task data filter: if filters are provided, the filter will filter the data flow out of server to client.
6+
task_data_filters =[]
7+
8+
# task result filter: if filters are provided, the filter will filter the result flow out of client to server.
9+
task_result_filters = []
10+
11+
# This assumes that there will be a "net.py" file with class name "Net".
12+
# If your model code is not in "net.py" and class name is not "Net", please modify here
13+
model_class_path = "net.Net"
14+
15+
# workflows: Array of workflows the control the Federated Learning workflow lifecycle.
16+
# One can specify multiple workflows. The NVFLARE will run them in the order specified.
17+
workflows = [
18+
{
19+
# 1st workflow"
20+
id = "scatter_and_gather"
21+
22+
# name = ScatterAndGather, path is the class path of the ScatterAndGather controller.
23+
path = "nvflare.app_common.workflows.scatter_and_gather.ScatterAndGather"
24+
args {
25+
# argument of the ScatterAndGather class.
26+
# min number of clients required for ScatterAndGather controller to move to the next round
27+
# during the workflow cycle. The controller will wait until the min_clients returned from clients
28+
# before move to the next step.
29+
min_clients = 2
30+
31+
# number of global round of the training.
32+
num_rounds = 2
33+
34+
# starting round is 0-based
35+
start_round = 0
36+
37+
# after received min number of clients' result,
38+
# how much time should we wait further before move to the next step
39+
wait_time_after_min_received = 0
40+
41+
# For ScatterAndGather, the server will aggregate the weights based on the client's result.
42+
# the aggregator component id is named here. One can use the this ID to find the corresponding
43+
# aggregator component listed below
44+
aggregator_id = "aggregator"
45+
46+
# The Scatter and Gather controller use an persistor to load the model and save the model.
47+
# The persistent component can be identified by component ID specified here.
48+
persistor_id = "persistor"
49+
50+
# Shareable to a communication message, i.e. shared between clients and server.
51+
# Shareable generator is a component that responsible to take the model convert to/from this communication message: Shareable.
52+
# The component can be identified via "shareable_generator_id"
53+
shareable_generator_id = "shareable_generator"
54+
55+
# train task name: client side needs to have an executor that handles this task
56+
train_task_name = "train"
57+
58+
# train timeout in second. If zero, meaning no timeout.
59+
train_timeout = 0
60+
}
61+
}
62+
]
63+
64+
# List of components used in the server side workflow.
65+
components = [
66+
{
67+
# This is the persistence component used in above workflow.
68+
# PTFileModelPersistor is a Pytorch persistor which save/read the model to/from file.
69+
70+
id = "persistor"
71+
path = "nvflare.app_opt.pt.file_model_persistor.PTFileModelPersistor"
72+
73+
# the persitor class take model class as argument
74+
# This imply that the model is initialized from the server-side.
75+
# The initialized model will be broadcast to all the clients to start the training.
76+
args {
77+
model {
78+
path = "{model_class_path}"
79+
}
80+
filter_id = "serialize_filter"
81+
}
82+
},
83+
{
84+
id = "shareable_generator"
85+
path = "nvflare.app_opt.he.model_shareable_generator.HEModelShareableGenerator"
86+
args {}
87+
}
88+
{
89+
id = "aggregator"
90+
path = "nvflare.app_opt.he.intime_accumulate_model_aggregator.HEInTimeAccumulateWeightedAggregator"
91+
args {
92+
weigh_by_local_iter = false
93+
expected_data_kind = "WEIGHT_DIFF"
94+
}
95+
}
96+
{
97+
id = "serialize_filter"
98+
path = "nvflare.app_opt.he.model_serialize_filter.HEModelSerializeFilter"
99+
args {
100+
}
101+
}
102+
{
103+
# This component is not directly used in Workflow.
104+
# it select the best model based on the incoming global validation metrics.
105+
id = "model_selector"
106+
path = "nvflare.app_common.widgets.intime_model_selector.IntimeModelSelector"
107+
# need to make sure this "key_metric" match what server side received
108+
args.key_metric = "accuracy"
109+
},
110+
{
111+
id = "receiver"
112+
path = "nvflare.app_opt.tracking.tb.tb_receiver.TBAnalyticsReceiver"
113+
args.events = ["fed.analytix_log_stats"]
114+
}
115+
]
116+
117+
}

job_templates/sag_pt_he/info.conf

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
description = "scatter & gather workflow using pytorch and homomorphic encryption"
3+
client_category = "client_api"
4+
controller_type = "server"
5+
}

job_templates/sag_pt_he/info.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Job Template Information Card
2+
3+
## sag_pt_he
4+
name = "sag_pt_he"
5+
description = "Scatter and Gather Workflow using pytorch and homomorphic encryption"
6+
class_name = "ScatterAndGather"
7+
controller_type = "server"
8+
executor_type = "launcher_executor"
9+
contributor = "NVIDIA"
10+
init_publish_date = "2023-12-20"
11+
last_updated_date = "2023-12-20" # yyyy-mm-dd

job_templates/sag_pt_he/meta.conf

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{
2+
name = "sag_pt_he"
3+
resource_spec = {}
4+
deploy_map {
5+
# change deploy map as needed.
6+
app = ["@ALL"]
7+
}
8+
min_clients = 2
9+
mandatory_clients = []
10+
}

nvflare/app_opt/pt/params_converter.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,25 @@
1414

1515
from typing import Dict
1616

17+
import numpy as np
1718
import torch
1819

1920
from nvflare.app_common.abstract.params_converter import ParamsConverter
2021

2122

2223
class NumpyToPTParamsConverter(ParamsConverter):
2324
def convert(self, params: Dict, fl_ctx) -> Dict:
24-
return {k: torch.as_tensor(v) for k, v in params.items()}
25+
tensor_shapes = fl_ctx.get_prop("tensor_shapes")
26+
if tensor_shapes:
27+
return {
28+
k: torch.as_tensor(np.reshape(v, tensor_shapes[k])) if k in tensor_shapes else torch.as_tensor(v)
29+
for k, v in params.items()
30+
}
31+
else:
32+
return {k: torch.as_tensor(v) for k, v in params.items()}
2533

2634

2735
class PTToNumpyParamsConverter(ParamsConverter):
2836
def convert(self, params: Dict, fl_ctx) -> Dict:
37+
fl_ctx.set_prop("tensor_shapes", {k: v.shape for k, v in params.items()})
2938
return {k: v.cpu().numpy() for k, v in params.items()}

0 commit comments

Comments
 (0)