Skip to content

Commit 9ca4e4f

Browse files
authored
Add financial example with xgboost (#2054)
* Add financial examples, update xgboost to account for xgboost's API change * add readme * format fix * format fix * print message fix * change vertical config settings to match horizontal and update readme
1 parent ad9245e commit 9ca4e4f

File tree

38 files changed

+1736
-15
lines changed

38 files changed

+1736
-15
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Financial Application with Federated XGBoost Methods
2+
This example illustrates the use of [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html) on a financial application.
3+
These examples show how to use [XGBoost](https://github.com/dmlc/xgboost) in various ways to train a model in a federated manner to perform fraud detection with a
4+
[finance dataset](https://www.kaggle.com/datasets/mlg-ulb/creditcardfraud).
5+
6+
## Federated Training of XGBoost
7+
Several mechanisms have been proposed for training an XGBoost model in a federated learning setting.
8+
In these examples, we illustrate the use of NVFlare to carry out the following four approaches:
9+
- *vertical* federated learning using histogram-based collaboration
10+
- *horizontal* federated learning using three approaches:
11+
- histogram-based collaboration
12+
- tree-based collaboration with cyclic federation
13+
- tree-based collaboration with bagging federation
14+
15+
For more details, please refer to the READMEs for
16+
[vertical](https://github.com/NVIDIA/NVFlare/blob/main/examples/advanced/vertical_xgboost/README.md),
17+
[histogram-based](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/xgboost/histogram-based/README.md),
18+
and [tree-based](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/xgboost/tree-based/README.md)
19+
methods.
20+
21+
## Data Preparation
22+
### Download and Store Data
23+
To run the examples, we first download the dataset from the link above, which is a single `.csv` file.
24+
By default, we assume the dataset is downloaded, uncompressed, and stored in `${PWD}/dataset/creditcard.csv`.
25+
26+
> **_NOTE:_** If the dataset is downloaded in another place,
27+
> make sure to modify the corresponding `DATASET_PATH` inside `prepare_data.sh`.
28+
29+
### Data Split
30+
We first split the dataset into two parts: training and testing. Then perform data split for each client under both horizontal and vertical settings.
31+
32+
Data splits used in this example can be generated with
33+
```
34+
bash prepare_data.sh
35+
```
36+
37+
This will generate data splits for 2 clients under all experimental settings. Note that the overlapping ratio between clients for vertical setting is 1.0 by default, so that the training data amount is the same as horizontal experiments.
38+
If you want to customize for your experiments to simulate more realistic scenarios, please check their corresponding scripts under `utils/`.
39+
40+
> **_NOTE:_** The generated data files will be stored in the folder `/tmp/dataset/`,
41+
> and will be used by jobs by specifying the path within `config_fed_client.json`
42+
43+
## Run experiments for all settings
44+
To run all experiments, we provide a script for all settings.
45+
```
46+
bash run_training.sh
47+
```
48+
This will cover baseline centralized training, horizontal FL with histogram-based, tree-based cyclic, and tree-based bagging
49+
collaborations, as well as vertical FL.
50+
51+
Then we test the resulting models on the test dataset with
52+
```
53+
bash run_testing.sh
54+
```
55+
The results are as follows:
56+
```
57+
Testing baseline_xgboost
58+
AUC score: 0.965017768854869
59+
Testing xgboost_vertical
60+
AUC score: 0.9650650531737737
61+
Testing xgboost_horizontal_histogram
62+
AUC score: 0.9579533839422094
63+
Testing xgboost_horizontal_cyclic
64+
AUC score: 0.9688269828190139
65+
Testing xgboost_horizontal_bagging
66+
AUC score: 0.9713936151275366
67+
```
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import os
17+
import time
18+
19+
import pandas as pd
20+
import xgboost as xgb
21+
from sklearn.model_selection import train_test_split
22+
23+
24+
def xgboost_args_parser():
25+
parser = argparse.ArgumentParser(description="Centralized XGBoost training with random forest option")
26+
parser.add_argument(
27+
"--train_data_path",
28+
type=str,
29+
default="./dataset/train.csv",
30+
help="folder to training dataset file",
31+
)
32+
parser.add_argument(
33+
"--test_data_path",
34+
type=str,
35+
default="./dataset/test.csv",
36+
help="folder to testing dataset file",
37+
)
38+
parser.add_argument("--valid_ratio", type=float, default=0.1, help="ratio of validation split")
39+
parser.add_argument("--num_rounds", type=int, default=100, help="number of boosting rounds")
40+
parser.add_argument("--num_parallel_tree", type=int, default=1, help="number of parallel tree")
41+
parser.add_argument(
42+
"--output_folder",
43+
type=str,
44+
default="./workspaces/xgboost_workspace_centralized",
45+
help="model output folder",
46+
)
47+
return parser
48+
49+
50+
def prepare_data(data_path: str):
51+
df = pd.read_csv(data_path)
52+
print(df.info())
53+
print(df.head())
54+
total_data_num = df.shape[0]
55+
print(f"Total data count: {total_data_num}")
56+
# Split to feature and label
57+
X = df.iloc[:, 1:]
58+
y = df.iloc[:, 0]
59+
print(y.value_counts())
60+
return X, y
61+
62+
63+
def get_training_parameters(args):
64+
# use logistic regression loss for binary classification
65+
# use auc as metric
66+
param = {
67+
"objective": "binary:logistic",
68+
"eta": 0.1,
69+
"max_depth": 8,
70+
"eval_metric": "auc",
71+
"nthread": 16,
72+
"num_parallel_tree": args.num_parallel_tree,
73+
}
74+
return param
75+
76+
77+
def main():
78+
parser = xgboost_args_parser()
79+
args = parser.parse_args()
80+
81+
train_data_path = args.train_data_path
82+
valid_ratio = args.valid_ratio
83+
num_rounds = args.num_rounds
84+
output_folder = args.output_folder
85+
86+
# Set mode file paths
87+
model_path = os.path.join(output_folder, "model_centralized.json")
88+
89+
# Load data
90+
start = time.time()
91+
X, y = prepare_data(train_data_path)
92+
93+
# Split to training and validation
94+
X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=valid_ratio, random_state=77)
95+
print(
96+
f"TRAINING: X_train: {X_train.shape}, y_train: {y_train.shape}, Fraudulant transaction: {y_train.value_counts()[1]}"
97+
)
98+
print(
99+
f"VALIDATION: X_validate: {X_valid.shape}, y_validate: {y_valid.shape}, Fraudulant transaction: {y_valid.value_counts()[1]}"
100+
)
101+
102+
# construct xgboost DMatrix
103+
dmat_train = xgb.DMatrix(X_train, label=y_train)
104+
dmat_valid = xgb.DMatrix(X_valid, label=y_valid)
105+
106+
end = time.time()
107+
lapse_time = end - start
108+
print(f"Data loading time: {lapse_time}")
109+
110+
# xgboost training
111+
start = time.time()
112+
xgb_params = get_training_parameters(args)
113+
bst = xgb.train(
114+
xgb_params,
115+
dmat_train,
116+
num_boost_round=num_rounds,
117+
evals=[(dmat_valid, "validate"), (dmat_train, "train")],
118+
)
119+
end = time.time()
120+
lapse_time = end - start
121+
print(f"Training time: {lapse_time}")
122+
123+
# save model
124+
if not os.path.exists(output_folder):
125+
os.makedirs(output_folder, exist_ok=True)
126+
bst.save_model(model_path)
127+
128+
129+
if __name__ == "__main__":
130+
main()
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
{
2+
"format_version": 2,
3+
"server": {
4+
"heart_beat_timeout": 600,
5+
"task_request_interval": 0.05
6+
},
7+
"task_data_filters": [],
8+
"task_result_filters": [],
9+
"components": [
10+
{
11+
"id": "persistor",
12+
"path": "nvflare.app_opt.xgboost.tree_based.model_persistor.XGBModelPersistor",
13+
"args": {
14+
"save_name": "xgboost_model.json"
15+
}
16+
},
17+
{
18+
"id": "shareable_generator",
19+
"path": "nvflare.app_opt.xgboost.tree_based.shareable_generator.XGBModelShareableGenerator",
20+
"args": {}
21+
},
22+
{
23+
"id": "aggregator",
24+
"path": "nvflare.app_opt.xgboost.tree_based.bagging_aggregator.XGBBaggingAggregator",
25+
"args": {}
26+
}
27+
],
28+
"workflows": [
29+
{
30+
"id": "scatter_and_gather",
31+
"name": "ScatterAndGather",
32+
"args": {
33+
"min_clients": 2,
34+
"num_rounds": 100,
35+
"start_round": 0,
36+
"wait_time_after_min_received": 0,
37+
"aggregator_id": "aggregator",
38+
"persistor_id": "persistor",
39+
"shareable_generator_id": "shareable_generator",
40+
"train_task_name": "train",
41+
"train_timeout": 0,
42+
"allow_empty_global_weights": true,
43+
"task_check_period": 0.01,
44+
"persist_every_n_rounds": 0,
45+
"snapshot_every_n_rounds": 0
46+
}
47+
}
48+
]
49+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
{
2+
"format_version": 2,
3+
"executors": [
4+
{
5+
"tasks": [
6+
"train"
7+
],
8+
"executor": {
9+
"id": "Executor",
10+
"name": "FedXGBTreeExecutor",
11+
"args": {
12+
"data_loader_id": "dataloader",
13+
"training_mode": "bagging",
14+
"num_client_bagging": 2,
15+
"num_local_parallel_tree": 1,
16+
"local_subsample": 1,
17+
"lr_mode": "uniform",
18+
"local_model_path": "model.json",
19+
"global_model_path": "model_global.json",
20+
"learning_rate": 0.1,
21+
"objective": "binary:logistic",
22+
"max_depth": 8,
23+
"eval_metric": "auc",
24+
"tree_method": "hist",
25+
"nthread": 16,
26+
"lr_scale": 0.49999756170115234
27+
}
28+
}
29+
}
30+
],
31+
"task_result_filters": [],
32+
"task_data_filters": [],
33+
"components": [
34+
{
35+
"id": "dataloader",
36+
"path": "data_loader.DataLoader",
37+
"args": {
38+
"data_split_filename": "/tmp/dataset/horizontal_xgb_data/data_site-1.json"
39+
}
40+
}
41+
]
42+
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
17+
import pandas as pd
18+
import xgboost as xgb
19+
20+
from nvflare.app_opt.xgboost.data_loader import XGBDataLoader
21+
22+
23+
def _read_with_pandas(data_path, start: int, end: int):
24+
data_size = end - start
25+
# skip rows for different sites but keep the header
26+
data = pd.read_csv(data_path, skiprows=range(1, start), nrows=data_size)
27+
data_num = data.shape[0]
28+
# split to feature and label
29+
x = data.iloc[:, 1:].copy()
30+
y = data.iloc[:, 0].copy()
31+
32+
return x, y, data_num
33+
34+
35+
class DataLoader(XGBDataLoader):
36+
def __init__(self, data_split_filename):
37+
"""Reads dataset and return XGB data matrix.
38+
39+
Args:
40+
data_split_filename: file name to data splits
41+
"""
42+
self.data_split_filename = data_split_filename
43+
44+
def load_data(self, client_id: str):
45+
with open(self.data_split_filename, "r") as file:
46+
data_split = json.load(file)
47+
48+
data_path = data_split["data_path"]
49+
data_index = data_split["data_index"]
50+
51+
# check if site_id and "valid" in the mapping dict
52+
if client_id not in data_index.keys():
53+
raise ValueError(
54+
f"Data does not contain Client {client_id} split",
55+
)
56+
57+
if "valid" not in data_index.keys():
58+
raise ValueError(
59+
"Data does not contain Validation split",
60+
)
61+
62+
site_index = data_index[client_id]
63+
valid_index = data_index["valid"]
64+
65+
# training
66+
x_train, y_train, total_train_data_num = _read_with_pandas(
67+
data_path=data_path, start=site_index["start"], end=site_index["end"]
68+
)
69+
dmat_train = xgb.DMatrix(x_train, label=y_train)
70+
71+
# validation
72+
x_valid, y_valid, total_valid_data_num = _read_with_pandas(
73+
data_path=data_path, start=valid_index["start"], end=valid_index["end"]
74+
)
75+
dmat_valid = xgb.DMatrix(x_valid, label=y_valid)
76+
77+
return dmat_train, dmat_valid

0 commit comments

Comments
 (0)