Skip to content

Commit 2d731b9

Browse files
Pre-trained Model and training_mode changes (#2793)
* Updated FOBS readme to add DatumManager, added agrpcs as secure scheme * Added support for pre-trained model * Changed training_mode to split_mode + secure_training * split_mode => data_split_mode * Format error * Fixed a format error * Addressed PR comments * Fixed format * Changed all xgboost controller/executor to use new XGBoost --------- Co-authored-by: Yuan-Ting Hsieh (謝沅廷) <[email protected]>
1 parent a3fb1e5 commit 2d731b9

File tree

17 files changed

+87
-102
lines changed

17 files changed

+87
-102
lines changed

examples/advanced/vertical_xgboost/code/vertical_xgb/vertical_data_loader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def __init__(self, data_split_path, psi_path, id_col, label_owner, train_proport
6262
self.label_owner = label_owner
6363
self.train_proportion = train_proportion
6464

65-
def load_data(self, client_id: str, training_mode: str = ""):
65+
def load_data(self, client_id: str, split_mode: int = 1):
6666
client_data_split_path = self.data_split_path.replace("site-x", client_id)
6767
client_psi_path = self.psi_path.replace("site-x", client_id)
6868

@@ -84,7 +84,7 @@ def load_data(self, client_id: str, training_mode: str = ""):
8484
label = ""
8585

8686
# for Vertical XGBoost, read from csv with label_column and set data_split_mode to 1 for column mode
87-
dtrain = xgb.DMatrix(train_path + f"?format=csv{label}", data_split_mode=1)
88-
dvalid = xgb.DMatrix(valid_path + f"?format=csv{label}", data_split_mode=1)
87+
dtrain = xgb.DMatrix(train_path + f"?format=csv{label}", data_split_mode=split_mode)
88+
dvalid = xgb.DMatrix(valid_path + f"?format=csv{label}", data_split_mode=split_mode)
8989

9090
return dtrain, dvalid

examples/advanced/xgboost/histogram-based/jobs/base_v2/app/custom/higgs_data_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(self, data_split_filename):
4141
"""
4242
self.data_split_filename = data_split_filename
4343

44-
def load_data(self, client_id: str, training_mode: str = ""):
44+
def load_data(self, client_id: str, split_mode: int):
4545
with open(self.data_split_filename, "r") as file:
4646
data_split = json.load(file)
4747

examples/advanced/xgboost/utils/prepare_job_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@ def _update_server_config(config: dict, args):
152152
config["num_rounds"] = args.round_num
153153
config["workflows"][0]["args"]["xgb_params"]["nthread"] = args.nthread
154154
config["workflows"][0]["args"]["xgb_params"]["tree_method"] = args.tree_method
155-
config["workflows"][0]["args"]["training_mode"] = args.training_mode
155+
config["workflows"][0]["args"]["split_mode"] = args.split_mode
156+
config["workflows"][0]["args"]["secure_training"] = args.secure_training
156157

157158

158159
def _copy_custom_files(src_job_path, src_app_name, dst_job_path, dst_app_name):

nvflare/app_opt/xgboost/constant.py

Lines changed: 0 additions & 26 deletions
This file was deleted.

nvflare/app_opt/xgboost/data_loader.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,10 @@
1818

1919
import xgboost as xgb
2020

21-
from .constant import TrainingMode
22-
2321

2422
class XGBDataLoader(ABC):
2523
@abstractmethod
26-
def load_data(
27-
self, client_id: str, training_mode: str = TrainingMode.HORIZONTAL
28-
) -> Tuple[xgb.DMatrix, xgb.DMatrix]:
24+
def load_data(self, client_id: str, split_mode: int) -> Tuple[xgb.DMatrix, xgb.DMatrix]:
2925
"""Loads data for xgboost.
3026
3127
Returns:

nvflare/app_opt/xgboost/histogram_based/controller.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,10 @@ def start_controller(self, fl_ctx: FLContext):
107107
if not self._get_certificates(fl_ctx):
108108
self.log_error(fl_ctx, "Can't get required certificates for XGB FL server in secure mode.")
109109
return
110+
self.log_info(fl_ctx, "Running XGB FL server in secure mode.")
110111
self._xgb_fl_server = multiprocessing.Process(
111112
target=xgb_federated.run_federated_server,
112-
args=(self._port, len(clients), self._server_key_path, self._server_cert_path, self._ca_cert_path),
113+
args=(len(clients), self._port, self._server_key_path, self._server_cert_path, self._ca_cert_path),
113114
)
114115
else:
115116
self._xgb_fl_server = multiprocessing.Process(

nvflare/app_opt/xgboost/histogram_based/executor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -269,9 +269,9 @@ def train(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -
269269
if not self._get_certificates(fl_ctx):
270270
return make_reply(ReturnCode.ERROR)
271271

272-
communicator_env["federated_server_cert"] = self._ca_cert_path
273-
communicator_env["federated_client_key"] = self._client_key_path
274-
communicator_env["federated_client_cert"] = self._client_cert_path
272+
communicator_env["federated_server_cert_path"] = self._ca_cert_path
273+
communicator_env["federated_client_key_path"] = self._client_key_path
274+
communicator_env["federated_client_cert_path"] = self._client_cert_path
275275

276276
try:
277277
with xgb.collective.CommunicatorContext(**communicator_env):

nvflare/app_opt/xgboost/histogram_based_v2/adaptors/grpc_client_adaptor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ class since the self object contains a sender that contains a Core Cell which ca
9898
Constant.RUNNER_CTX_SERVER_ADDR: server_addr,
9999
Constant.RUNNER_CTX_RANK: self.rank,
100100
Constant.RUNNER_CTX_NUM_ROUNDS: self.num_rounds,
101-
Constant.RUNNER_CTX_TRAINING_MODE: self.training_mode,
101+
Constant.RUNNER_CTX_SPLIT_MODE: self.split_mode,
102+
Constant.RUNNER_CTX_SECURE_TRAINING: self.secure_training,
102103
Constant.RUNNER_CTX_XGB_PARAMS: self.xgb_params,
103104
Constant.RUNNER_CTX_XGB_OPTIONS: self.xgb_options,
104105
Constant.RUNNER_CTX_MODEL_DIR: self._run_dir,

nvflare/app_opt/xgboost/histogram_based_v2/adaptors/xgb_adaptor.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ def __init__(self, in_process: bool, per_msg_timeout: float, tx_timeout: float):
150150
self.stopped = False
151151
self.rank = None
152152
self.num_rounds = None
153-
self.training_mode = None
153+
self.split_mode = None
154+
self.secure_training = None
154155
self.xgb_params = None
155156
self.xgb_options = None
156157
self.world_size = None
@@ -196,10 +197,15 @@ def configure(self, config: dict, fl_ctx: FLContext):
196197
check_positive_int(Constant.CONF_KEY_NUM_ROUNDS, num_rounds)
197198
self.num_rounds = num_rounds
198199

199-
self.training_mode = config.get(Constant.CONF_KEY_TRAINING_MODE)
200-
if self.training_mode is None:
201-
raise RuntimeError("training_mode is not configured")
202-
fl_ctx.set_prop(key=Constant.PARAM_KEY_TRAINING_MODE, value=self.training_mode, private=True, sticky=True)
200+
self.split_mode = config.get(Constant.CONF_KEY_SPLIT_MODE)
201+
if self.split_mode is None:
202+
raise RuntimeError("split_mode is not configured")
203+
fl_ctx.set_prop(key=Constant.PARAM_KEY_SPLIT_MODE, value=self.split_mode, private=True, sticky=True)
204+
205+
self.secure_training = config.get(Constant.CONF_KEY_SECURE_TRAINING)
206+
if self.secure_training is None:
207+
raise RuntimeError("secure_training is not configured")
208+
fl_ctx.set_prop(key=Constant.PARAM_KEY_SECURE_TRAINING, value=self.secure_training, private=True, sticky=True)
203209

204210
self.xgb_params = config.get(Constant.CONF_KEY_XGB_PARAMS)
205211
if not self.xgb_params:

nvflare/app_opt/xgboost/histogram_based_v2/controller.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from nvflare.fuel.utils.validation_utils import check_number_range, check_object_type, check_positive_number, check_str
2828
from nvflare.security.logging import secure_format_exception
2929

30-
from .defs import TRAINING_MODE_MAPPING, Constant
30+
from .defs import Constant
3131

3232

3333
class ClientStatus:
@@ -59,7 +59,8 @@ def __init__(
5959
self,
6060
adaptor_component_id: str,
6161
num_rounds: int,
62-
training_mode: str,
62+
split_mode: int,
63+
secure_training: bool,
6364
xgb_params: dict,
6465
xgb_options: Optional[dict] = None,
6566
configure_task_name=Constant.CONFIG_TASK_NAME,
@@ -80,7 +81,8 @@ def __init__(
8081
Args:
8182
adaptor_component_id - the component ID of server target adaptor
8283
num_rounds - number of rounds
83-
training_mode - Split mode (horizontal, vertical, horizontal_secure, vertical_secure)
84+
split_mode - 0 for horizontal/row-split, 1 for vertical/column-split
85+
secure_training - If true, secure training is enabled
8486
xgb_params - The params argument for train method
8587
xgb_options - All other arguments for train method are passed through this dictionary
8688
configure_task_name - name of the config task
@@ -100,7 +102,8 @@ def __init__(
100102
Controller.__init__(self)
101103
self.adaptor_component_id = adaptor_component_id
102104
self.num_rounds = num_rounds
103-
self.training_mode = training_mode.lower()
105+
self.split_mode = split_mode
106+
self.secure_training = secure_training
104107
self.xgb_params = xgb_params
105108
self.xgb_options = xgb_options
106109
self.configure_task_name = configure_task_name
@@ -118,10 +121,8 @@ def __init__(
118121
self.client_statuses = {} # client name => ClientStatus
119122
self.abort_signal = None
120123

121-
check_str("training_mode", training_mode)
122-
valid_mode = TRAINING_MODE_MAPPING.keys()
123-
if training_mode not in valid_mode:
124-
raise ValueError(f"training_mode must be one of following values: {valid_mode}")
124+
if split_mode not in {0, 1}:
125+
raise ValueError("split_mode must be either 0 or 1")
125126

126127
if not self.xgb_params:
127128
raise ValueError("xgb_params can't be empty")
@@ -462,7 +463,8 @@ def _configure_clients(self, abort_signal: Signal, fl_ctx: FLContext):
462463

463464
shareable[Constant.CONF_KEY_CLIENT_RANKS] = self.client_ranks
464465
shareable[Constant.CONF_KEY_NUM_ROUNDS] = self.num_rounds
465-
shareable[Constant.CONF_KEY_TRAINING_MODE] = self.training_mode
466+
shareable[Constant.CONF_KEY_SPLIT_MODE] = self.split_mode
467+
shareable[Constant.CONF_KEY_SECURE_TRAINING] = self.secure_training
466468
shareable[Constant.CONF_KEY_XGB_PARAMS] = self.xgb_params
467469
shareable[Constant.CONF_KEY_XGB_OPTIONS] = self.xgb_options
468470

0 commit comments

Comments
 (0)