Skip to content

Commit b6ccae6

Browse files
Refactor Run and ExecEnv (#3699)
### Description - Rename POCEnv -> PocEnv - Refactor Run method to delegate to each ExecEnv - Add unit test and integration tests ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Quick tests passed locally by running `./runtest.sh`. - [ ] In-line docstrings updated. - [ ] Documentation updated. --------- Co-authored-by: Chester Chen <[email protected]>
1 parent ec66533 commit b6ccae6

File tree

12 files changed

+549
-321
lines changed

12 files changed

+549
-321
lines changed

examples/hello-world/hello-lr/job.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from nvflare.app_common.np.recipes.lr.fedavg import FedAvgLrRecipe
1818
from nvflare.recipe import SimEnv
1919

20-
# from nvflare.recipe import POCEnv
20+
# from nvflare.recipe import PocEnv
2121

2222

2323
def define_parser():
@@ -45,7 +45,7 @@ def main():
4545
train_args=f"--data_root {data_root}",
4646
)
4747
env = SimEnv(num_clients=n_clients, num_threads=n_clients)
48-
# env = POCEnv(num_clients=n_clients)
48+
# env = PocEnv(num_clients=n_clients)
4949
run = recipe.execute(env)
5050
w = run.get_result()
5151
print("result location =", w)

nvflare/recipe/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .poc_env import POCEnv
15+
from .poc_env import PocEnv
1616
from .prod_env import ProdEnv
1717
from .run import Run
1818
from .sim_env import SimEnv
1919
from .utils import add_experiment_tracking
2020

21-
__all__ = ["SimEnv", "POCEnv", "ProdEnv", "Run", "add_experiment_tracking"]
21+
__all__ = ["SimEnv", "PocEnv", "ProdEnv", "Run", "add_experiment_tracking"]

nvflare/recipe/poc_env.py

Lines changed: 23 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,13 @@
1414

1515
import os
1616
import shutil
17-
import tempfile
1817
import time
1918
from typing import Optional
2019

2120
from pydantic import BaseModel, conint, model_validator
2221

23-
from nvflare.fuel.flare_api.flare_api import new_secure_session
2422
from nvflare.job_config.api import FedJob
23+
from nvflare.recipe.spec import ExecEnv
2524
from nvflare.tool.poc.poc_commands import (
2625
_clean_poc,
2726
_start_poc,
@@ -34,7 +33,7 @@
3433
)
3534
from nvflare.tool.poc.service_constants import FlareServiceConstants as SC
3635

37-
from .spec import ExecEnv, ExecEnvType
36+
from .session_mgr import SessionManager
3837

3938
STOP_POC_TIMEOUT = 10
4039
SERVICE_START_TIMEOUT = 3
@@ -70,7 +69,7 @@ def check_client_configuration(self):
7069
return self
7170

7271

73-
class POCEnv(ExecEnv):
72+
class PocEnv(ExecEnv):
7473
"""Proof of Concept execution environment for local testing and development.
7574
7675
This environment sets up a POC deployment on a single machine with multiple
@@ -123,18 +122,7 @@ def __init__(
123122
self.project_conf_path = v.project_conf_path
124123
self.docker_image = v.docker_image
125124
self.username = v.username
126-
127-
def get_env_info(self) -> dict:
128-
return {
129-
"env_type": ExecEnvType.POC,
130-
"startup_kit_location": self._get_admin_startup_kit_path(),
131-
"num_clients": self.num_clients,
132-
"gpu_ids": self.gpu_ids,
133-
"use_he": self.use_he,
134-
"docker_image": self.docker_image,
135-
"project_conf_path": self.project_conf_path,
136-
"username": self.username,
137-
}
125+
self._session_manager = None # Lazy initialization
138126

139127
def deploy(self, job: FedJob):
140128
"""Deploy a FedJob to the POC environment.
@@ -170,12 +158,8 @@ def deploy(self, job: FedJob):
170158
# Give services time to start up
171159
time.sleep(SERVICE_START_TIMEOUT)
172160

173-
# Submit job using Flare API like ProdEnv
174-
with tempfile.TemporaryDirectory() as temp_dir:
175-
job.export_job(temp_dir)
176-
job_path = os.path.join(temp_dir, job.name)
177-
178-
return self._submit_and_monitor_job(job_path, job.name)
161+
# Submit job using SessionManager
162+
return self._get_session_manager().submit_job(job)
179163

180164
def _check_poc_running(self) -> bool:
181165
try:
@@ -225,37 +209,14 @@ def stop(self, clean_poc: bool = False):
225209
print(f"Removing POC workspace: {self.poc_workspace}")
226210
shutil.rmtree(self.poc_workspace, ignore_errors=True)
227211

228-
def _submit_and_monitor_job(self, job_path: str, job_name: str) -> str:
229-
"""Submit and monitor job via Flare API using a single session.
212+
def get_job_status(self, job_id: str) -> Optional[str]:
213+
return self._get_session_manager().get_job_status(job_id)
230214

231-
Args:
232-
job_path: Path to the exported job directory.
233-
job_name: Name of the job for logging.
215+
def abort_job(self, job_id: str) -> None:
216+
self._get_session_manager().abort_job(job_id)
234217

235-
Returns:
236-
str: Job ID returned by the system.
237-
"""
238-
sess = None
239-
try:
240-
# Get the admin startup kit path for POC
241-
admin_dir = self._get_admin_startup_kit_path()
242-
243-
# Create secure session with POC admin (reuse for both submit and monitor)
244-
sess = new_secure_session(
245-
username=self.username,
246-
startup_kit_location=admin_dir,
247-
)
248-
249-
# Submit the job
250-
job_id = sess.submit_job(job_path)
251-
252-
return job_id
253-
except Exception as e:
254-
raise RuntimeError(f"Failed to submit/monitor job via Flare API: {e}")
255-
256-
finally:
257-
if sess:
258-
sess.close()
218+
def get_job_result(self, job_id: str, timeout: float = 0.0) -> Optional[str]:
219+
return self._get_session_manager().get_job_result(job_id, timeout)
259220

260221
def _get_admin_startup_kit_path(self) -> str:
261222
"""Get the path to the admin startup kit for POC.
@@ -279,3 +240,14 @@ def _get_admin_startup_kit_path(self) -> str:
279240

280241
except Exception as e:
281242
raise RuntimeError(f"Failed to locate admin startup kit: {e}")
243+
244+
def _get_session_manager(self):
245+
"""Get or create SessionManager with lazy initialization."""
246+
if self._session_manager is None:
247+
session_params = {
248+
"username": self.username,
249+
"startup_kit_location": self._get_admin_startup_kit_path(),
250+
"timeout": self.get_extra_prop("login_timeout", 10),
251+
}
252+
self._session_manager = SessionManager(session_params)
253+
return self._session_manager

nvflare/recipe/prod_env.py

Lines changed: 27 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -13,32 +13,18 @@
1313
# limitations under the License.
1414

1515
import os.path
16-
import tempfile
16+
from typing import Optional
1717

1818
from pydantic import BaseModel, PositiveFloat, model_validator
1919

20-
from nvflare.fuel.flare_api.flare_api import Session, new_secure_session
2120
from nvflare.job_config.api import FedJob
21+
from nvflare.recipe.spec import ExecEnv
2222

23-
from .spec import ExecEnv, ExecEnvType
23+
from .session_mgr import SessionManager
2424

2525
DEFAULT_ADMIN_USER = "[email protected]"
2626

2727

28-
def status_monitor_cb(session: Session, job_id: str, job_meta, *cb_args, **cb_kwargs) -> bool:
29-
if job_meta["status"] == "RUNNING":
30-
if cb_kwargs["cb_run_counter"]["count"] < 3 or cb_kwargs["cb_run_counter"]["count"] % 15 == 0:
31-
print(job_meta)
32-
else:
33-
# avoid printing job_meta repeatedly to save space on the screen and not overwhelm the user
34-
print(".", end="")
35-
else:
36-
print("\n" + str(job_meta))
37-
38-
cb_kwargs["cb_run_counter"]["count"] += 1
39-
return True
40-
41-
4228
# Internal — not part of the public API
4329
class _ProdEnvValidator(BaseModel):
4430
startup_kit_location: str
@@ -81,30 +67,31 @@ def __init__(
8167
self.startup_kit_location = v.startup_kit_location
8268
self.login_timeout = v.login_timeout
8369
self.username = v.username
70+
self._session_manager = None # Lazy initialization
71+
72+
def get_job_status(self, job_id: str) -> Optional[str]:
73+
return self._get_session_manager().get_job_status(job_id)
74+
75+
def abort_job(self, job_id: str) -> None:
76+
self._get_session_manager().abort_job(job_id)
77+
78+
def get_job_result(self, job_id: str, timeout: float = 0.0) -> Optional[str]:
79+
return self._get_session_manager().get_job_result(job_id, timeout)
8480

8581
def deploy(self, job: FedJob):
86-
sess = None
82+
"""Deploy a job using SessionManager."""
8783
try:
88-
sess = new_secure_session(
89-
username=self.username, startup_kit_location=self.startup_kit_location, timeout=self.login_timeout
90-
)
91-
with tempfile.TemporaryDirectory() as temp_dir:
92-
job.export_job(temp_dir)
93-
job_path = os.path.join(temp_dir, job.name)
94-
job_id = sess.submit_job(job_path)
95-
print(f"Submitted job '{job.name}' with ID: {job_id}")
96-
97-
return job_id
84+
return self._get_session_manager().submit_job(job)
9885
except Exception as e:
99-
raise RuntimeError(f"Failed to submit/monitor job via Flare API: {e}")
100-
finally:
101-
if sess:
102-
sess.close()
103-
104-
def get_env_info(self) -> dict:
105-
return {
106-
"env_type": ExecEnvType.PROD,
107-
"startup_kit_location": self.startup_kit_location,
108-
"login_timeout": self.login_timeout,
109-
"username": self.username,
110-
}
86+
raise RuntimeError(f"Failed to submit job via Flare API: {e}")
87+
88+
def _get_session_manager(self):
89+
"""Get or create SessionManager with lazy initialization."""
90+
if self._session_manager is None:
91+
session_params = {
92+
"username": self.username,
93+
"startup_kit_location": self.startup_kit_location,
94+
"timeout": self.login_timeout,
95+
}
96+
self._session_manager = SessionManager(session_params)
97+
return self._session_manager

nvflare/recipe/run.py

Lines changed: 7 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -12,110 +12,26 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import os
16-
from contextlib import contextmanager
17-
from typing import Generator, Optional
15+
from typing import Optional
1816

19-
from nvflare.fuel.flare_api.api_spec import MonitorReturnCode
20-
from nvflare.fuel.flare_api.flare_api import Session, new_secure_session
21-
22-
23-
def _cb_with_print(session: Session, job_id: str, job_meta, *cb_args, **cb_kwargs) -> bool:
24-
"""Callback to print job meta."""
25-
# cb_run_counter is a dictionary that is passed to the callback and is used to keep track of the number of times the callback has been called
26-
if cb_kwargs["cb_run_counter"]["count"] == 0:
27-
print("Job ID: ", job_id)
28-
print("Job Meta: ", job_meta)
29-
30-
if job_meta["status"] == "RUNNING":
31-
print(".", end="")
32-
else:
33-
print("\n" + str(job_meta))
34-
35-
cb_kwargs["cb_run_counter"]["count"] += 1
36-
return True
17+
from nvflare.recipe.spec import ExecEnv
3718

3819

3920
class Run:
40-
def __init__(self, env_info: dict, job_id: str):
41-
self.env_info = env_info
21+
def __init__(self, exec_env: ExecEnv, job_id: str):
22+
self.exec_env = exec_env
4223
self.job_id = job_id
43-
self.handlers = {
44-
"sim": self._get_sim_result,
45-
"poc": self._get_prod_result,
46-
"prod": self._get_prod_result,
47-
}
4824

4925
def get_job_id(self) -> str:
5026
return self.job_id
5127

52-
def _is_sim_env(self) -> bool:
53-
"""Check if this is a simulation environment."""
54-
return self.env_info.get("env_type") == "sim"
55-
56-
def _get_session_params(self) -> dict:
57-
"""Get session parameters from env_info."""
58-
return {
59-
"startup_kit_location": self.env_info.get("startup_kit_location"),
60-
"username": self.env_info.get("username"),
61-
"timeout": self.env_info.get("login_timeout", 10),
62-
}
63-
64-
@contextmanager
65-
def _secure_session(self) -> Generator:
66-
"""Context manager for secure session handling."""
67-
sess = None
68-
try:
69-
sess = new_secure_session(**self._get_session_params())
70-
yield sess
71-
except Exception as e:
72-
raise RuntimeError(f"Failed to create/use session: {e}")
73-
finally:
74-
if sess:
75-
sess.close()
76-
7728
def get_status(self) -> Optional[str]:
7829
"""Get the status of the run.
7930
8031
Returns:
8132
Optional[str]: The status of the run, or None if called in a simulation environment.
8233
"""
83-
if self._is_sim_env():
84-
print(
85-
"get_status is not supported in a simulation environment, please check the log inside the workspace returned by get_result()"
86-
)
87-
return None
88-
89-
with self._secure_session() as sess:
90-
return sess.get_job_status(self.job_id)
91-
92-
def _get_sim_result(self, **kwargs) -> str:
93-
workspace_root = self.env_info.get("workspace_root")
94-
if workspace_root is None:
95-
raise RuntimeError("Simulation workspace_root is None - SimEnv may not be properly initialized")
96-
return os.path.join(workspace_root, self.job_id)
97-
98-
def _get_prod_result(self, timeout: float = 0.0) -> Optional[str]:
99-
with self._secure_session() as sess:
100-
cb_run_counter = {"count": 0}
101-
rc = sess.monitor_job(self.job_id, timeout=timeout, cb=_cb_with_print, cb_run_counter=cb_run_counter)
102-
print(f"job monitor done: {rc=}")
103-
if rc == MonitorReturnCode.JOB_FINISHED:
104-
return sess.download_job_result(self.job_id)
105-
elif rc == MonitorReturnCode.TIMEOUT:
106-
print(
107-
f"Job {self.job_id} did not complete within {timeout} seconds. "
108-
"Job is still running. Try calling get_result() again with a longer timeout."
109-
)
110-
return None
111-
elif rc == MonitorReturnCode.ENDED_BY_CB:
112-
print(
113-
"Job monitoring was stopped early by callback. "
114-
"Result may not be available yet. Check job status and try again."
115-
)
116-
return None
117-
else:
118-
raise RuntimeError(f"Unexpected monitor return code: {rc}")
34+
return self.exec_env.get_job_status(self.job_id)
11935

12036
def get_result(self, timeout: float = 0.0) -> Optional[str]:
12137
"""Get the result workspace of the run.
@@ -127,18 +43,8 @@ def get_result(self, timeout: float = 0.0) -> Optional[str]:
12743
Returns:
12844
Optional[str]: The result workspace path if job completed, None if still running or stopped early.
12945
"""
130-
env_type = self.env_info.get("env_type")
131-
return self.handlers[env_type](timeout=timeout)
46+
return self.exec_env.get_job_result(self.job_id, timeout=timeout)
13247

13348
def abort(self):
13449
"""Abort the running job."""
135-
if self._is_sim_env():
136-
print("abort is not supported in a simulation environment, it will always run to completion.")
137-
return
138-
139-
try:
140-
with self._secure_session() as sess:
141-
msg = sess.abort_job(self.job_id)
142-
print(f"Job {self.job_id} aborted successfully with message: {msg}")
143-
except Exception as e:
144-
print(f"Failed to abort job {self.job_id}: {e}")
50+
self.exec_env.abort_job(self.job_id)

0 commit comments

Comments
 (0)