Skip to content

Commit f208127

Browse files
authored
Support 3rd-party training system integration with FLARE (#2074)
* support av ipc based model exchange * polish * add license text * support child-based comm * reformat * support client side listening * reorg * formatting * added license text * fix f-str * address PR comments
1 parent bd4468f commit f208127

File tree

10 files changed

+1190
-4
lines changed

10 files changed

+1190
-4
lines changed

integration/av/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

integration/av/trainer.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import logging
17+
18+
from nvflare.client.defs import RC, AgentClosed, MetaKey, TaskResult
19+
from nvflare.client.ipc_agent import IPCAgent
20+
21+
NUMPY_KEY = "numpy_key"
22+
23+
24+
def main():
25+
26+
logging.basicConfig()
27+
logging.getLogger().setLevel(logging.INFO)
28+
29+
parser = argparse.ArgumentParser()
30+
parser.add_argument("--workspace", "-w", type=str, help="workspace folder", required=False, default=".")
31+
parser.add_argument("--site_name", "-s", type=str, help="flare site name", required=True)
32+
parser.add_argument("--agent_id", "-a", type=str, help="agent id", required=True)
33+
parser.add_argument("--job_id", "-j", type=str, help="flare job id", required=False, default="")
34+
parser.add_argument("--site_url", "-u", type=str, help="flare site url", required=False, default="")
35+
36+
args = parser.parse_args()
37+
38+
agent = IPCAgent(
39+
root_url="grpc://server:8002",
40+
flare_site_name=args.site_name,
41+
agent_id=args.agent_id,
42+
workspace_dir=args.workspace,
43+
secure_mode=True,
44+
submit_result_timeout=2.0,
45+
flare_site_heartbeat_timeout=120.0,
46+
job_id=args.job_id,
47+
flare_site_url=args.site_url,
48+
)
49+
50+
agent.start()
51+
52+
while True:
53+
print("getting task ...")
54+
try:
55+
task = agent.get_task()
56+
except AgentClosed:
57+
print("agent closed - exit")
58+
break
59+
60+
print(f"got task: {task}")
61+
rc, meta, result = train(task.meta, task.data)
62+
submitted = agent.submit_result(TaskResult(data=result, meta=meta, return_code=rc))
63+
print(f"result submitted: {submitted}")
64+
65+
agent.stop()
66+
67+
68+
def train(meta, model):
69+
current_round = meta.get(MetaKey.CURRENT_ROUND)
70+
total_rounds = meta.get(MetaKey.TOTAL_ROUND)
71+
72+
# Ensure that data is of type weights. Extract model data
73+
np_data = model
74+
75+
# Display properties.
76+
print(f"Model: \n{np_data}")
77+
print(f"Current Round: {current_round}")
78+
print(f"Total Rounds: {total_rounds}")
79+
80+
# Doing some dummy training.
81+
if np_data:
82+
if NUMPY_KEY in np_data:
83+
np_data[NUMPY_KEY] += 1.0
84+
else:
85+
print("error: numpy_key not found in model.")
86+
return RC.BAD_TASK_DATA, None, None
87+
else:
88+
print("No model weights found in shareable.")
89+
return RC.BAD_TASK_DATA, None, None
90+
91+
# Save local numpy model
92+
print(f"Model after training: {np_data}")
93+
94+
# Prepare a DXO for our updated model. Create shareable and return
95+
return RC.OK, {MetaKey.NUM_STEPS_CURRENT_ROUND: 1}, np_data
96+
97+
98+
if __name__ == "__main__":
99+
main()

0 commit comments

Comments
 (0)