Skip to content

Commit 4946ac6

Browse files
[2.4] Add xgboost example, unit tests, integration tests (#2392)
1 parent b66da2a commit 4946ac6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1388
-96
lines changed

.readthedocs.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,6 @@ sphinx:
2626
python:
2727
install:
2828
- method: pip
29-
path: .[doc]
29+
path: .[dev]
3030
# system_packages: true
3131

build_doc.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ function clean_docs() {
4949
}
5050

5151
function build_html_docs() {
52-
pip install -e .[doc]
52+
pip install -e .[dev]
5353
sphinx-apidoc --module-first -f -o docs/apidocs/ nvflare "*poc" "*private"
5454
sphinx-build -b html docs docs/_build
5555
}

examples/advanced/xgboost/histogram-based/jobs/base/app/config/config_fed_client.json

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,24 @@
11
{
22
"format_version": 2,
3+
"num_rounds": 100,
34
"executors": [
45
{
56
"tasks": [
67
"train"
78
],
89
"executor": {
910
"id": "Executor",
10-
"name": "FedXGBHistogramExecutor",
11+
"path": "nvflare.app_opt.xgboost.histogram_based.executor.FedXGBHistogramExecutor",
1112
"args": {
1213
"data_loader_id": "dataloader",
13-
"num_rounds": 100,
14+
"num_rounds": "{num_rounds}",
1415
"early_stopping_rounds": 2,
1516
"xgb_params": {
1617
"max_depth": 8,
1718
"eta": 0.1,
1819
"objective": "binary:logistic",
1920
"eval_metric": "auc",
20-
"tree_method": "gpu_hist",
21+
"tree_method": "hist",
2122
"nthread": 16
2223
}
2324
}

examples/advanced/xgboost/histogram-based/jobs/base/app/config/config_fed_server.json

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
{
22
"format_version": 2,
3-
"server": {
4-
"heart_beat_timeout": 600
5-
},
63
"task_data_filters": [],
74
"task_result_filters": [],
85
"components": [],
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
{
2+
"format_version": 2,
3+
"num_rounds": 100,
4+
"executors": [
5+
{
6+
"tasks": [
7+
"config", "start"
8+
],
9+
"executor": {
10+
"id": "Executor",
11+
"path": "nvflare.app_opt.xgboost.histogram_based_v2.executor.FedXGBHistogramExecutor",
12+
"args": {
13+
"data_loader_id": "dataloader",
14+
"early_stopping_rounds": 2,
15+
"xgb_params": {
16+
"max_depth": 8,
17+
"eta": 0.1,
18+
"objective": "binary:logistic",
19+
"eval_metric": "auc",
20+
"tree_method": "hist",
21+
"nthread": 16
22+
}
23+
}
24+
}
25+
}
26+
],
27+
"task_result_filters": [],
28+
"task_data_filters": [],
29+
"components": [
30+
{
31+
"id": "dataloader",
32+
"path": "higgs_data_loader.HIGGSDataLoader",
33+
"args": {
34+
"data_split_filename": "data_split.json"
35+
}
36+
}
37+
]
38+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"format_version": 2,
3+
"num_rounds": 100,
4+
"task_data_filters": [],
5+
"task_result_filters": [],
6+
"components": [],
7+
"workflows": [
8+
{
9+
"id": "xgb_controller",
10+
"path": "nvflare.app_opt.xgboost.histogram_based_v2.controller.XGBFedController",
11+
"args": {
12+
"num_rounds": "{num_rounds}"
13+
}
14+
}
15+
]
16+
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
17+
import pandas as pd
18+
import xgboost as xgb
19+
20+
from nvflare.app_opt.xgboost.data_loader import XGBDataLoader
21+
22+
23+
def _read_higgs_with_pandas(data_path, start: int, end: int):
24+
data_size = end - start
25+
data = pd.read_csv(data_path, header=None, skiprows=start, nrows=data_size)
26+
data_num = data.shape[0]
27+
28+
# split to feature and label
29+
x = data.iloc[:, 1:].copy()
30+
y = data.iloc[:, 0].copy()
31+
32+
return x, y, data_num
33+
34+
35+
class HIGGSDataLoader(XGBDataLoader):
36+
def __init__(self, data_split_filename):
37+
"""Reads HIGGS dataset and return XGB data matrix.
38+
39+
Args:
40+
data_split_filename: file name to data splits
41+
"""
42+
self.data_split_filename = data_split_filename
43+
44+
def load_data(self, client_id: str):
45+
with open(self.data_split_filename, "r") as file:
46+
data_split = json.load(file)
47+
48+
data_path = data_split["data_path"]
49+
data_index = data_split["data_index"]
50+
51+
# check if site_id and "valid" in the mapping dict
52+
if client_id not in data_index.keys():
53+
raise ValueError(
54+
f"Data does not contain Client {client_id} split",
55+
)
56+
57+
if "valid" not in data_index.keys():
58+
raise ValueError(
59+
"Data does not contain Validation split",
60+
)
61+
62+
site_index = data_index[client_id]
63+
valid_index = data_index["valid"]
64+
65+
# training
66+
x_train, y_train, total_train_data_num = _read_higgs_with_pandas(
67+
data_path=data_path, start=site_index["start"], end=site_index["end"]
68+
)
69+
dmat_train = xgb.DMatrix(x_train, label=y_train)
70+
71+
# validation
72+
x_valid, y_valid, total_valid_data_num = _read_higgs_with_pandas(
73+
data_path=data_path, start=valid_index["start"], end=valid_index["end"]
74+
)
75+
dmat_valid = xgb.DMatrix(x_valid, label=y_valid)
76+
77+
return dmat_train, dmat_valid
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{
2+
"name": "xgboost_histogram_based_v2",
3+
"resource_spec": {},
4+
"deploy_map": {
5+
"app": [
6+
"@ALL"
7+
]
8+
},
9+
"min_clients": 2
10+
}

examples/advanced/xgboost/prepare_job_config.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,6 @@ prepare_job_config 20 cyclic uniform uniform $TREE_METHOD
2222

2323
prepare_job_config 2 histogram uniform uniform $TREE_METHOD
2424
prepare_job_config 5 histogram uniform uniform $TREE_METHOD
25+
prepare_job_config 2 histogram_v2 uniform uniform $TREE_METHOD
26+
prepare_job_config 5 histogram_v2 uniform uniform $TREE_METHOD
2527
echo "Job configs generated"

examples/advanced/xgboost/tree-based/jobs/bagging_base/app/config/config_fed_server.json

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
11
{
22
"format_version": 2,
3-
4-
"server": {
5-
"heart_beat_timeout": 600,
6-
"task_request_interval": 0.05
7-
},
8-
3+
"num_rounds": 101,
94
"task_data_filters": [],
105
"task_result_filters": [],
116

@@ -34,7 +29,7 @@
3429
"name": "ScatterAndGather",
3530
"args": {
3631
"min_clients": 5,
37-
"num_rounds": 101,
32+
"num_rounds": "{num_rounds}",
3833
"start_round": 0,
3934
"wait_time_after_min_received": 0,
4035
"aggregator_id": "aggregator",

0 commit comments

Comments
 (0)