|
14 | 14 | import builtins |
15 | 15 | import inspect |
16 | 16 | import json |
| 17 | +import logging |
17 | 18 | import os |
| 19 | +import shlex |
18 | 20 | import shutil |
| 21 | +import subprocess |
19 | 22 | import sys |
20 | 23 | from enum import Enum |
21 | 24 | from tempfile import TemporaryDirectory |
|
25 | 28 | from nvflare.job_config.base_app_config import BaseAppConfig |
26 | 29 | from nvflare.job_config.fed_app_config import FedAppConfig |
27 | 30 | from nvflare.private.fed.app.fl_conf import FL_PACKAGES |
28 | | -from nvflare.private.fed.app.simulator.simulator_runner import SimulatorRunner |
| 31 | +from nvflare.private.fed.app.utils import kill_child_processes |
29 | 32 |
|
30 | 33 | CONFIG = "config" |
31 | 34 | CUSTOM = "custom" |
@@ -58,6 +61,7 @@ def __init__(self, job_name, min_clients, mandatory_clients=None) -> None: |
58 | 61 | self.resource_specs: Dict[str, Dict] = {} |
59 | 62 |
|
60 | 63 | self.custom_modules = [] |
| 64 | + self.logger = logging.getLogger(self.__class__.__name__) |
61 | 65 |
|
62 | 66 | def add_fed_app(self, app_name: str, fed_app: FedAppConfig): |
63 | 67 | if not isinstance(fed_app, FedAppConfig): |
@@ -136,15 +140,31 @@ def simulator_run(self, workspace, clients=None, n_clients=None, threads=None, g |
136 | 140 | with TemporaryDirectory() as job_root: |
137 | 141 | self.generate_job_config(job_root) |
138 | 142 |
|
139 | | - simulator = SimulatorRunner( |
140 | | - job_folder=os.path.join(job_root, self.job_name), |
141 | | - workspace=workspace, |
142 | | - clients=clients, |
143 | | - n_clients=n_clients, |
144 | | - threads=threads, |
145 | | - gpu=gpu, |
146 | | - ) |
147 | | - simulator.run() |
| 143 | + try: |
| 144 | + command = ( |
| 145 | + f"{sys.executable} -m nvflare.private.fed.app.simulator.simulator " |
| 146 | + + os.path.join(job_root, self.job_name) |
| 147 | + + " -w " |
| 148 | + + workspace |
| 149 | + ) |
| 150 | + if clients: |
| 151 | + command += " -c " + str(clients) |
| 152 | + if n_clients: |
| 153 | + command += " -n " + str(n_clients) |
| 154 | + if threads: |
| 155 | + command += " -t " + str(threads) |
| 156 | + if gpu: |
| 157 | + command += " -gpu " + str(gpu) |
| 158 | + |
| 159 | + new_env = os.environ.copy() |
| 160 | + process = subprocess.Popen(shlex.split(command, True), preexec_fn=os.setsid, env=new_env) |
| 161 | + |
| 162 | + process.wait() |
| 163 | + |
| 164 | + except KeyboardInterrupt: |
| 165 | + self.logger.info("KeyboardInterrupt, terminate all the child processes.") |
| 166 | + kill_child_processes(os.getpid()) |
| 167 | + return -9 |
148 | 168 |
|
149 | 169 | def _get_server_app(self, config_dir, custom_dir, fed_app): |
150 | 170 | server_app = {"format_version": 2, "workflows": []} |
|
0 commit comments