|
2 | 2 | This example shows how to use vertical federated learning with [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html) on tabular data. |
3 | 3 | Here we use the optimized gradient boosting library [XGBoost](https://github.com/dmlc/xgboost) and leverage its federated learning support. |
4 | 4 |
|
5 | | -Before starting please make sure you set up a [virtual environment](../../../README.md#set-up-a-virtual-environment) and install the additional requirements: |
| 5 | +Before starting please make sure you set up a [virtual environment](../../README.md#set-up-a-virtual-environment) and install the additional requirements: |
6 | 6 | ``` |
7 | 7 | python3 -m pip install -r requirements.txt |
8 | 8 | ``` |
@@ -30,7 +30,7 @@ Run the following command to prepare the data splits: |
30 | 30 | ### Private Set Intersection (PSI) |
31 | 31 | Since not every site will have the same set of data samples (rows), we can use PSI to compare encrypted versions of the sites' datasets in order to jointly compute the intersection based on common IDs. In this example, the HIGGS dataset does not contain unique identifiers so we add a temporary `uid_{idx}` to each instance and give each site a portion of the HIGGS dataset that includes a common overlap. Afterwards the identifiers are dropped since they are only used for matching, and training is then done on the intersected data. To learn more about our PSI protocol implementation, see our [psi example](../psi/README.md). |
32 | 32 |
|
33 | | -> **_NOTE:_** The uid can be a composition of multiple variabes with a transformation, however in this example we use indices for simplicity. PSI can also be used for computing the intersection of overlapping features, but here we give each site unique features. |
| 33 | +> **_NOTE:_** The uid can be a composition of multiple variables with a transformation, however in this example we use indices for simplicity. PSI can also be used for computing the intersection of overlapping features, but here we give each site unique features. |
34 | 34 |
|
35 | 35 | Create the psi job using the predefined psi_csv template: |
36 | 36 | ``` |
@@ -58,7 +58,9 @@ Lastly, we must subclass `XGBDataLoader` and implement the `load_data()` method. |
58 | 58 | By default, CPU based training is used. |
59 | 59 |
|
60 | 60 | In order to enable GPU accelerated training, first ensure that your machine has CUDA installed and has at least one GPU. |
61 | | -In `config_fed_client.json` set `"use_gpus": true` and `"tree_method": "hist"` in `xgb_params`. Then, in `FedXGBHistogramExecutor` we use the `device` parameter to map each rank to a GPU device ordinal in `xgb_params`. If using multiple GPUs, we can map each rank to a different GPU device, however you can also map each rank to the same GPU device if using a single GPU. |
| 61 | +In `config_fed_client.json` set `"use_gpus": true` and `"tree_method": "hist"` in `xgb_params`. |
| 62 | +Then, in `FedXGBHistogramExecutor` we use the `device` parameter to map each rank to a GPU device ordinal in `xgb_params`. |
| 63 | +If using multiple GPUs, we can map each rank to a different GPU device, however you can also map each rank to the same GPU device if using a single GPU. |
62 | 64 |
|
63 | 65 | We can create a GPU enabled job using the job CLI: |
64 | 66 | ``` |
@@ -87,10 +89,11 @@ The model will be saved to `test.model.json`. |
87 | 89 | ## Results |
88 | 90 | Model accuracy can be visualized in tensorboard: |
89 | 91 | ``` |
90 | | -tensorboard --logdir /tmp/nvflare/vertical_xgb |
| 92 | +tensorboard --logdir /tmp/nvflare/vertical_xgb/simulate_job/tb_events |
91 | 93 | ``` |
92 | 94 |
|
93 | | -An example training (pink) and validation (orange) AUC graph from running vertical XGBoost on HIGGS. |
94 | | -Used an intersection of 50000 samples across 5 clients each with different features, and ran for ~50 rounds due to early stopping. |
| 95 | +An example training (pink) and validation (orange) AUC graph from running vertical XGBoost on HIGGS: |
| 96 | +(Used an intersection of 50000 samples across 5 clients each with different features, |
| 97 | +and ran for ~50 rounds due to early stopping.) |
95 | 98 |
|
96 | 99 |  |
0 commit comments