Skip to content

Commit 73b75c9

Browse files
Add back metric callback and fix examples based on new xgboost version (#2787)
1 parent 138387e commit 73b75c9

File tree

40 files changed

+472
-188
lines changed

40 files changed

+472
-188
lines changed

.readthedocs.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ version: 2
99
build:
1010
os: ubuntu-22.04
1111
tools:
12-
python: "3.8"
12+
python: "3.10"
1313

1414
# Build documentation in the docs/ directory with Sphinx
1515
sphinx:
@@ -26,6 +26,6 @@ sphinx:
2626
python:
2727
install:
2828
- method: pip
29-
path: .[doc]
29+
path: .[dev]
3030
# system_packages: true
3131

build_doc.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ function clean_docs() {
4949
}
5050

5151
function build_html_docs() {
52-
pip install -e .[doc]
52+
pip install -e .[dev]
5353
sphinx-apidoc --module-first -f -o docs/apidocs/ nvflare "*poc" "*private"
5454
sphinx-build -b html docs docs/_build
5555
}

examples/advanced/vertical_xgboost/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ The model will be saved to `test.model.json`.
8989
## Results
9090
Model accuracy can be visualized in tensorboard:
9191
```
92-
tensorboard --logdir /tmp/nvflare/vertical_xgb/simulate_job/tb_events
92+
tensorboard --logdir /tmp/nvflare/vertical_xgb/server/simulate_job/tb_events
9393
```
9494

9595
An example training (pink) and validation (orange) AUC graph from running vertical XGBoost on HIGGS:

examples/advanced/vertical_xgboost/code/vertical_xgb/vertical_data_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def __init__(self, data_split_path, psi_path, id_col, label_owner, train_proport
6262
self.label_owner = label_owner
6363
self.train_proportion = train_proportion
6464

65-
def load_data(self, client_id: str):
65+
def load_data(self, client_id: str, training_mode: str = ""):
6666
client_data_split_path = self.data_split_path.replace("site-x", client_id)
6767
client_psi_path = self.psi_path.replace("site-x", client_id)
6868

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1-
nvflare~=2.4.0rc
1+
nvflare~=2.5.0rc
22
openmined.psi==1.1.1
33
pandas
4-
tensorboard
54
torch
6-
xgboost>=2.0.0
5+
tensorboard
6+
# require xgboost 2.2 version, for now need to install a binary build
7+
# "xgboost>=2.2"
8+
9+
--extra-index-url https://s3-us-west-2.amazonaws.com/xgboost-nightly-builds/list.html?prefix=federated-secure/
10+
xgboost

examples/advanced/xgboost/histogram-based/jobs/base/app/config/config_fed_server.json

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
{
22
"format_version": 2,
3-
"server": {
4-
"heart_beat_timeout": 600
5-
},
63
"task_data_filters": [],
74
"task_result_filters": [],
85
"components": [

examples/advanced/xgboost/histogram-based/jobs/base_v2/app/config/config_fed_client.json

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,16 @@
11
{
22
"format_version": 2,
3-
"num_rounds": 100,
43
"executors": [
54
{
65
"tasks": [
76
"config", "start"
87
],
98
"executor": {
109
"id": "Executor",
11-
"path": "nvflare.app_opt.xgboost.histogram_based_v2.executor.FedXGBHistogramExecutor",
10+
"path": "nvflare.app_opt.xgboost.histogram_based_v2.fed_executor.FedXGBHistogramExecutor",
1211
"args": {
1312
"data_loader_id": "dataloader",
14-
"metrics_writer_id": "metrics_writer",
15-
"early_stopping_rounds": 2,
16-
"xgb_params": {
17-
"max_depth": 8,
18-
"eta": 0.1,
19-
"objective": "binary:logistic",
20-
"eval_metric": "auc",
21-
"tree_method": "hist",
22-
"nthread": 16
23-
}
13+
"metrics_writer_id": "metrics_writer"
2414
}
2515
}
2616
}

examples/advanced/xgboost/histogram-based/jobs/base_v2/app/config/config_fed_server.json

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,21 @@
1515
"workflows": [
1616
{
1717
"id": "xgb_controller",
18-
"path": "nvflare.app_opt.xgboost.histogram_based_v2.controller.XGBFedController",
18+
"path": "nvflare.app_opt.xgboost.histogram_based_v2.fed_controller.XGBFedController",
1919
"args": {
20-
"num_rounds": "{num_rounds}"
20+
"num_rounds": "{num_rounds}",
21+
"training_mode": "horizontal",
22+
"xgb_params": {
23+
"max_depth": 8,
24+
"eta": 0.1,
25+
"objective": "binary:logistic",
26+
"eval_metric": "auc",
27+
"tree_method": "hist",
28+
"nthread": 16
29+
},
30+
"xgb_options": {
31+
"early_stopping_rounds": 2
32+
}
2133
}
2234
}
2335
]
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
17+
import pandas as pd
18+
import xgboost as xgb
19+
20+
from nvflare.app_opt.xgboost.data_loader import XGBDataLoader
21+
22+
23+
def _read_higgs_with_pandas(data_path, start: int, end: int):
24+
data_size = end - start
25+
data = pd.read_csv(data_path, header=None, skiprows=start, nrows=data_size)
26+
data_num = data.shape[0]
27+
28+
# split to feature and label
29+
x = data.iloc[:, 1:].copy()
30+
y = data.iloc[:, 0].copy()
31+
32+
return x, y, data_num
33+
34+
35+
class HIGGSDataLoader(XGBDataLoader):
36+
def __init__(self, data_split_filename):
37+
"""Reads HIGGS dataset and return XGB data matrix.
38+
39+
Args:
40+
data_split_filename: file name to data splits
41+
"""
42+
self.data_split_filename = data_split_filename
43+
44+
def load_data(self, client_id: str, training_mode: str = ""):
45+
with open(self.data_split_filename, "r") as file:
46+
data_split = json.load(file)
47+
48+
data_path = data_split["data_path"]
49+
data_index = data_split["data_index"]
50+
51+
# check if site_id and "valid" in the mapping dict
52+
if client_id not in data_index.keys():
53+
raise ValueError(
54+
f"Data does not contain Client {client_id} split",
55+
)
56+
57+
if "valid" not in data_index.keys():
58+
raise ValueError(
59+
"Data does not contain Validation split",
60+
)
61+
62+
site_index = data_index[client_id]
63+
valid_index = data_index["valid"]
64+
65+
# training
66+
x_train, y_train, total_train_data_num = _read_higgs_with_pandas(
67+
data_path=data_path, start=site_index["start"], end=site_index["end"]
68+
)
69+
dmat_train = xgb.DMatrix(x_train, label=y_train)
70+
71+
# validation
72+
x_valid, y_valid, total_valid_data_num = _read_higgs_with_pandas(
73+
data_path=data_path, start=valid_index["start"], end=valid_index["end"]
74+
)
75+
dmat_valid = xgb.DMatrix(x_valid, label=y_valid)
76+
77+
return dmat_train, dmat_valid
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{
2+
"name": "xgboost_histogram_based_v2",
3+
"resource_spec": {},
4+
"deploy_map": {
5+
"app": [
6+
"@ALL"
7+
]
8+
},
9+
"min_clients": 2
10+
}

0 commit comments

Comments
 (0)